Pre-training Contextualized World Models with In-the-wild Videos for Reinforcement Learning (NeurIPS 2023)
Official implementation of the Contextualized World Models (ContextWM) with In-the-wild Pre-training from Videos (IPV) in PyTorch. Unified implementations of DreamerV2 and APV in PyTorch are also included.
If you find our codebase useful for your research, please cite our paper as:
@inproceedings{wu2023pre,
title={Pre-training Contextualized World Models with In-the-wild Videos for Reinforcement Learning},
author={Jialong Wu and Haoyu Ma and Chaoyi Deng and Mingsheng Long},
booktitle={Advances in Neural Information Processing Systems},
year={2023}
}
TL;DR: We introduce Contextualized World Models (ContextWM), which utilize pre-training on in-the-wild videos to enable sample-efficient model-based RL of visual control tasks in various domains.
The conda environment can be created by the following command:
conda env create -f environment.yaml
conda activate wmlib
Meta-world
Meta-world depends on MuJoCo200. You may need to install it manually. Meta-world itself can be installed using the following command:
pip install git+https://github.com/rlworkgroup/metaworld.git@a0009ed9a208ff9864a5c1368c04c273bb20dd06#egg=metaworld
DMCR
We adopt the original DMCR implementation provided by QData and integrated it into our codebase.
You need to additionally download the assets of DMCR from here and move them to the wmlib/envs/dmcr/assets
folder.
CARLA
We use CARLA 0.9.11 for our experiments. Please follow the official instructions to install and run CARLA.
Note that we use the CARLA 0.9.8 version map of Town04, which is included in the wmlib/envs/carla_api
folder. You should move this map to the CARLA 0.9.11 map folder (e.g., CARLA_0.9.11/CarlaUE4/Content/Carla/Maps/OpenDrive
) to run the CARLA experiments.
An example of running a CARLA server:
./CarlaUE4.sh -fps=20 -carla-rpc-port={port to use} -carla-streaming-port=0
Before running the training script, you may need to add the CARLA Python API path and the CARLA egg file path to your PYTHONPATH
environment variable:
export PYTHONPATH=$PYTHONPATH:{path to CARLA}/PythonAPI/
export PYTHONPATH=$PYTHONPATH:{path to CARLA}/PythonAPI/carla
export PYTHONPATH=$PYTHONPATH:{path to CARLA}/PythonAPI/carla/dist/carla-0.9.11-py3.7-linux-x86_64.egg
Something-Something-V2 (Recommended)
Download the Something-Something-V2 dataset and extract frames of videos (Note that you should properly specify data paths in this script):
cd data/somethingv2
python extract_frames.py
Then you can generate data lists by the following command (also, properly specify data paths):
python process_somethingv2.py
We have already included the generated lists in this repo (see data/somethingv2/*.txt
).
Human3.6M
Download the processed Human3.6M dataset by Pavlakos et al. using this script and clip the frames to 64x64 using the following command (also, properly specify data paths):
cd data/human36m
python build_clip_dataset.py
Then you can generate data lists by the following command (also, properly specify data paths):
python make_list.py
We have already included the generated lists in this repo (see data/human36m/*.txt
).
YouTubeDriving
Download the YouTubeDriving dataset and preprocess the dataset (Note that you should properly specify data paths in this script):
cd data/ytb_driving
python make_list.py
We have already included the generated data lists in this repo (see data/ytb_driving/*.txt
).
We provide our pre-trained world models:
- ContextWM pre-trained on Something-Something-V2 (Recommended), Human3.6M, YouTubeDriving, assembled three datasets, and RLBench dataset from APV, respectively
- Plain WM pre-trained on Something-Something-V2
You can obtain them from [Google Drive] or [Tsinghua Cloud].
Run the following command to pre-train world models.
Something-Something-V2
python examples/train_apv_pretraining.py --logdir {save path} --configs something_pretrain contextualized --video_list train_video_folder --steps 1200000 --save_all_models True --video_dir {path to extracted video frames}
Human3.6M
python examples/train_apv_pretraining.py --logdir {save path} --configs human_pretrain contextualized --steps 1200000 --save_all_models True --video_dir {path to extracted video frames}
YoutubeDriving
python examples/train_apv_pretraining.py --logdir {save path} --configs ytb_pretrain contextualized --steps 1200000 --save_all_models True --video_dir {path to extracted video frames}
Run the following commands to start model-based RL with pre-trained world models.
Meta-world
python examples/train_apv_finetuning.py --logdir {save path} --configs metaworld contextualized --task metaworld_{task, e.g. drawer_open} --seed 0 --loss_scales.reward 1.0 --loss_scales.aux_reward 1.0 --encoder_ctx.ctx_aug erasing --load_logdir {path to the pre-trained models}
Note that for the drawer open task, we find removing --encoder_ctx.ctx_aug erasing
slightly improves the performance, thus we disable this option for our reported results of this task. All other Meta-world tasks enable this option.
DMC Remastered
python examples/train_apv_finetuning.py --logdir {save path} --configs dmc_remastered contextualized --task dmcr_{task, e.g. walker_run} --seed 0 --loss_scales.reward 1.0 --loss_scales.aux_reward 1.0 --load_logdir {path to the pre-trained models}
Note that you need to add dmcr_hopper
to --configs
when running the DMCR Hopper Stand task. The dmcr_hopper
option fixes the camera's position, as we find it is too difficult for the agent to learn when the camera is randomly positioned and rotated in this task.
CARLA
python examples/train_apv_finetuning.py --logdir {save path} --configs carla contextualized --task carla_{task, e.g. ClearNoon} --seed 0 --loss_scales.reward 1.0 --loss_scales.aux_reward 1.0 --load_logdir {path to the pre-trained models} --carla_port {port number}
For each individual run, you need to start two CARLA servers with an interval of 10 for port numbers (e.g. 2030 and 2040) and pass the first port number to --carla_port
.
The two servers are used for collecting data and evaluating the agent, respectively.
We include unified implementations for our baseline methods DreamerV2 and APV, which can also be easily run by the following commands.
DreamerV2
python examples/train_dreamerv2.py --logdir {save path} --configs metaworld --task metaworld_{task, e.g. drawer_open} --seed 0
APV
python examples/train_apv_pretraining.py --logdir {save path} --configs something_pretrain plainresnet --video_list train_video_folder --steps 1200000 --save_all_models True --video_dir {path to extracted video frames}
python examples/train_apv_finetuning.py --logdir {save path} --configs metaworld plainresnet --task metaworld_{task, e.g. drawer_open} --seed 0 --load_logdir {path to the pre-trained models}
Mixed precision is enabled by default, which is faster but can probably cause numerical instabilities. It is normal to encounter infinite gradient norms, and the training may be interrupted by nan
values. You can pass --precision 32
to disable mixed precision.
See also the tips available in DreamerV2 repository.
If you have any questions, please contact [email protected] .
We sincerely appreciate the following github repos for their valuable code base we build upon: