Giter Club home page Giter Club logo

contextwm's Introduction

Pre-training Contextualized World Models with In-the-wild Videos for Reinforcement Learning (NeurIPS 2023)

Official implementation of the Contextualized World Models (ContextWM) with In-the-wild Pre-training from Videos (IPV) in PyTorch. Unified implementations of DreamerV2 and APV in PyTorch are also included.

If you find our codebase useful for your research, please cite our paper as:

@inproceedings{wu2023pre,
  title={Pre-training Contextualized World Models with In-the-wild Videos for Reinforcement Learning},
  author={Jialong Wu and Haoyu Ma and Chaoyi Deng and Mingsheng Long},
  booktitle={Advances in Neural Information Processing Systems},
  year={2023}
}

Method

TL;DR: We introduce Contextualized World Models (ContextWM), which utilize pre-training on in-the-wild videos to enable sample-efficient model-based RL of visual control tasks in various domains.

Dependencies

The conda environment can be created by the following command:

conda env create -f environment.yaml
conda activate wmlib

Meta-world

Meta-world depends on MuJoCo200. You may need to install it manually. Meta-world itself can be installed using the following command:

pip install git+https://github.com/rlworkgroup/metaworld.git@a0009ed9a208ff9864a5c1368c04c273bb20dd06#egg=metaworld

DMCR

We adopt the original DMCR implementation provided by QData and integrated it into our codebase. You need to additionally download the assets of DMCR from here and move them to the wmlib/envs/dmcr/assets folder.

CARLA

We use CARLA 0.9.11 for our experiments. Please follow the official instructions to install and run CARLA. Note that we use the CARLA 0.9.8 version map of Town04, which is included in the wmlib/envs/carla_api folder. You should move this map to the CARLA 0.9.11 map folder (e.g., CARLA_0.9.11/CarlaUE4/Content/Carla/Maps/OpenDrive) to run the CARLA experiments.

An example of running a CARLA server:

./CarlaUE4.sh -fps=20 -carla-rpc-port={port to use} -carla-streaming-port=0

Before running the training script, you may need to add the CARLA Python API path and the CARLA egg file path to your PYTHONPATH environment variable:

export PYTHONPATH=$PYTHONPATH:{path to CARLA}/PythonAPI/
export PYTHONPATH=$PYTHONPATH:{path to CARLA}/PythonAPI/carla
export PYTHONPATH=$PYTHONPATH:{path to CARLA}/PythonAPI/carla/dist/carla-0.9.11-py3.7-linux-x86_64.egg

Datasets

Something-Something-V2 (Recommended)

Download the Something-Something-V2 dataset and extract frames of videos (Note that you should properly specify data paths in this script):

cd data/somethingv2
python extract_frames.py

Then you can generate data lists by the following command (also, properly specify data paths):

python process_somethingv2.py

We have already included the generated lists in this repo (see data/somethingv2/*.txt).

Human3.6M

Download the processed Human3.6M dataset by Pavlakos et al. using this script and clip the frames to 64x64 using the following command (also, properly specify data paths):

cd data/human36m
python build_clip_dataset.py

Then you can generate data lists by the following command (also, properly specify data paths):

python make_list.py

We have already included the generated lists in this repo (see data/human36m/*.txt).

YouTubeDriving

Download the YouTubeDriving dataset and preprocess the dataset (Note that you should properly specify data paths in this script):

cd data/ytb_driving
python make_list.py

We have already included the generated data lists in this repo (see data/ytb_driving/*.txt).

Pre-trained Models

We provide our pre-trained world models:

  • ContextWM pre-trained on Something-Something-V2 (Recommended), Human3.6M, YouTubeDriving, assembled three datasets, and RLBench dataset from APV, respectively
  • Vanilla WM pre-trained on Something-Something-V2

You can obtain them from [Google Drive] or [Tsinghua Cloud].

Experiments

Pre-training from In-the-wild Videos

Run the following command to pre-train world models.

Something-Something-V2

python examples/train_apv_pretraining.py --logdir {save path} --configs something_pretrain contextualized --video_list train_video_folder --steps 1200000 --save_all_models True --video_dir {path to extracted video frames}

Human3.6M

python examples/train_apv_pretraining.py --logdir {save path} --configs human_pretrain contextualized --steps 1200000 --save_all_models True --video_dir {path to extracted video frames}

YoutubeDriving

python examples/train_apv_pretraining.py --logdir {save path} --configs ytb_pretrain contextualized --steps 1200000 --save_all_models True --video_dir {path to extracted video frames}

Fine-tuning with Model-based RL

Run the following commands to start model-based RL with pre-trained world models.

Meta-world

python examples/train_apv_finetuning.py --logdir {save path} --configs metaworld contextualized --task metaworld_{task, e.g. drawer_open} --seed 0 --loss_scales.reward 1.0 --loss_scales.aux_reward 1.0 --encoder_ctx.ctx_aug erasing --load_logdir {path to the pre-trained models}

Note that for the drawer open task, we find removing --encoder_ctx.ctx_aug erasing slightly improves the performance, thus we disable this option for our reported results of this task. All other Meta-world tasks enable this option.

DMC Remastered

python examples/train_apv_finetuning.py --logdir {save path} --configs dmc_remastered contextualized --task dmcr_{task, e.g. walker_run} --seed 0 --loss_scales.reward 1.0 --loss_scales.aux_reward 1.0 --load_logdir {path to the pre-trained models}

Note that you need to add dmcr_hopper to --configs when running the DMCR Hopper Stand task. The dmcr_hopper option fixes the camera's position, as we find it is too difficult for the agent to learn when the camera is randomly positioned and rotated in this task.

CARLA

python examples/train_apv_finetuning.py --logdir {save path} --configs carla contextualized --task carla_{task, e.g. ClearNoon} --seed 0 --loss_scales.reward 1.0 --loss_scales.aux_reward 1.0 --load_logdir {path to the pre-trained models} --carla_port {port number}

For each individual run, you need to start two CARLA servers with an interval of 10 for port numbers (e.g. 2030 and 2040) and pass the first port number to --carla_port. The two servers are used for collecting data and evaluating the agent, respectively.

DreamerV2 and APV Baselines

We include unified implementations for our baseline methods DreamerV2 and APV, which can also be easily run by the following commands.

DreamerV2

python examples/train_dreamerv2.py --logdir {save path} --configs metaworld --task metaworld_{task, e.g. drawer_open} --seed 0

APV

python examples/train_apv_pretraining.py --logdir {save path} --configs something_pretrain plainresnet --video_list train_video_folder --steps 1200000 --save_all_models True --video_dir {path to extracted video frames}
python examples/train_apv_finetuning.py --logdir {save path} --configs metaworld plainresnet --task metaworld_{task, e.g. drawer_open} --seed 0 --load_logdir {path to the pre-trained models}

Tips

Mixed precision are enabled by default, which is faster but can probably cause numerical instabilities. It is normal to encounter infinite gradient norms, and the training may be interrupted by nan values. You can pass --precision 32 to disable mixed precision.

See also the tips available in DreamerV2 repository.

Contact

If you have any question, please contact [email protected] .

Acknowledgement

We sincerely appreciate the following github repos for their valuable code base we build upon:

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.