Open Source ยท MIT Licensed

Mastering Atari with Discrete World Models

DreamerV2 is the first world model agent that achieves human-level performance on the Atari benchmark. It learns a model of the environment directly from high-dimensional input images and trains entirely from imagined trajectories.

Hafner, Lillicrap, Norouzi, Ba (2020) "Mastering Atari with Discrete World Models" arXiv:2010.02193
FrameworkTensorFlow 2
LicenseMIT
Benchmark55 Atari Games
PerformanceHuman-level

Quick Start

Install and train in minutes

DreamerV2 is available as a pip package. Install it, point it at an environment, and start training. The code automatically detects discrete vs. continuous actions.

pip install
One command to get started
Recommended
$pip3 install dreamerv2
Atari (55 games)DM ControlMiniGridCustom Envs

How It Works

The DreamerV2 Method

DreamerV2 learns a world model from images, then trains actor and critic networks entirely from imagined trajectories โ€” no real environment interaction needed during policy learning.

World Model. DreamerV2 learns compact latent representations of environment observations. The latent states consist of a deterministic component and several categorical variables, trained via a KL divergence loss with straight-through gradients.

Imagination. Starting from encoded observations of real sequences, the model predicts ahead using the learned prior and selected actions. This produces imagined trajectories entirely within the latent space.

Actor-Critic. The critic is trained via temporal difference learning on imagined trajectories. The actor maximizes the learned value function using both REINFORCE and straight-through gradient estimators.

Training Loop

ObserveImage input
โ†’
EncodeLatent state
โ†’
ImaginePredict ahead
โ†’
LearnActor-Critic

Getting Started

Manual Setup

For those who want to modify DreamerV2, clone the repository and follow these steps. Docker instructions are also available.

Install dependencies

Get TensorFlow 2, tensorflow_probability, ruamel.yaml, and your environment of choice (Atari, DM Control, or a custom Gym).

pip3 install tensorflow==2.6.0 tensorflow_probability ruamel.yaml 'gym[atari]' dm_control

Train on Atari

Point the training script at any of the 55 Atari games. Logs and checkpoints are saved to your logdir.

python3 dreamerv2/train.py --logdir ~/logdir/atari_pong/dreamerv2/1 \ --configs atari --task atari_pong

Train on DM Control

Switch to continuous control with a single flag change. Same codebase, same algorithm.

python3 dreamerv2/train.py --logdir ~/logdir/dmc_walker_walk/dreamerv2/1 \ --configs dmc_vision --task dmc_walker_walk

Monitor results

Metrics are logged in both TensorBoard and JSON lines format. Visualize training curves in real-time.

tensorboard --logdir ~/logdir

Results

Benchmark Performance

DreamerV2 sets a new standard for world model agents. Training curves for all 55 Atari games are included in the repository.

Human-Level Atari

DreamerV2 is the first world model agent to achieve human-level performance across the Atari benchmark, outperforming Rainbow and IQN with the same compute.

Evaluated on 55 Atari games

Sample Efficient

By learning in imagination, DreamerV2 requires far fewer real environment interactions than model-free baselines to reach equivalent performance.

200M environment frames

Single GPU Training

The implementation alternates between training the world model, training the policy, and collecting experience โ€” all on a single GPU.

No distributed setup required

Discrete Latent Space

Uses categorical latent variables instead of Gaussian, enabling more expressive representations and sharper predictions via straight-through gradients.

32 categoricals ร— 32 classes = 1024-dim

Tips

Practical Advice

Common questions and debugging tips from the DreamerV2 documentation.

Efficient debugging

Use the debug config to reduce batch size, increase evaluation frequency, and disable tf.function graph compilation for line-by-line debugging.

--configs atari debug

Infinite gradient norms

This is expected behavior described under loss scaling in the mixed precision guide. Disable mixed precision if needed.

--precision 32

Accessing logged metrics

Metrics are stored in both TensorBoard and JSON lines format. Load them directly with pandas for custom analysis and plotting.

pandas.read_json()