Giter Club home page Giter Club logo

is_eval_interactive_driving's Introduction

Robustness MARL

Terminology

Ego agent

  • The agent we are most interested in.
  • In this scenario, it is a left-turn vehicle.
  • In this scenario, there is one agent.

Social agent

  • Agents that exist around the ego agent.
  • In this scenario, it is a horizontal moving vehicle.
  • In this scenario, there are multiple number of agents.

Script description

Before training ego agent

  • Note that you need 'encoder' network before training the ego agent.
  • For example, before train ego with social RL, you need encoder network trained with social RL vehicle's offline data.
  • pretext_collect_data.py: collecting offline data for encoder network training.
    • It doesn't take a long time. (< 1 hour)
  • pretext_train.py: training encoder network.
    • If you use GPU, it takes (< 3 hours)

Model training code

  • train_social_with_RLEgo.py: training guide model for social vehicle with RL ego agent.
    • You need (rational) RL ego agent
    • But, You don't need to train RL ego agent. Just use the pretrained model which is provided by Liu's github repo.
    • That is, this code is the first step.
    • You can train guide model for social vehicle this code.
    • Not a meta model for social vehicle
  • train_social_with_RLEgo.py: training meta model for social vehicle with RL ego agent.
    • If you want to train guided meta-RL, you need guide policy made by previous step.
    • If you want to train meta-RL without guide, just run this code without guide model.
  • train_ego_with_trained_social.py: training ego agent with trained social agent.
    • You can train ego agent with trained social RL model.
    • You have to prepare the trained social agent.

Model testing code

  • ego_social_test.py: testing ego agent with social agent.
    • IDM can be selected.
    • meta RL can be selected.
    • guided meta RL can be selected.
    • guiding RL can be selected.

How to use test script?

  • You have to check argument in test_*.py script.
  • --model_dir indicates the folder which contains configs and checkpoints folder.
  • --visualize : if this value is set to True, simulation is rendered.
  • --test_model indicates checkpoint of the model. (You have to manually change this value.)
parser.add_argument('--model_dir', type=str, default='data/new_rl')
parser.add_argument('--visualize', default=True, action='store_true')
parser.add_argument('--test_model', type=str, default='00020.pt')



VAE Trait Inference

This repository contains the codes for our paper titled "Learning to Navigate Intersections with Unsupervised Driver Trait Inference" in ICRA 2022. For more details, please refer to the project website and arXiv preprint. For experiment demonstrations, please refer to the youtube video.

Abstract

Navigation through uncontrolled intersections is one of the key challenges for autonomous vehicles. Identifying the subtle differences in hidden traits of other drivers can bring significant benefits when navigating in such environments. We propose an unsupervised method for inferring driver traits such as driving styles from observed vehicle trajectories. We use a variational autoencoder with recurrent neural networks to learn a latent representation of traits without any ground truth trait labels. Then, we use this trait representation to learn a policy for an autonomous vehicle to navigate through a T-intersection with deep reinforcement learning. Our pipeline enables the autonomous vehicle to adjust its actions when dealing with drivers of different traits to ensure safety and efficiency. Our method demonstrates promising performance and outperforms state-of-the-art baselines in the T-intersection scenario.

Setup

  1. Install Python3.6.

  2. Install the required python package using pip or conda. For pip, use the following command:

    pip install -r requirements.txt
    

    For conda, please install each package in requirements.txt into your conda environment manually and follow the instructions on the anaconda website.

  3. Install OpenAI Baselines.

    git clone https://github.com/openai/baselines.git
    cd baselines
    pip install -e .
    

Getting started

This repository is organized in five parts:

  • configs/ folder contains configurations for training and neural networks.
  • driving_sim/ folder contains the simulation environment and the wrapper for inferring the traits during RL training (in driving_sim/vec_env/).
  • pretext/ folder contains the code for VAE trait inference task, including the networks, collecting and loading trajectory data, as well as loss functions for VAE training.
  • rl/ contains the code for the RL policy networks and ppo algorithm.
  • trained_models/ contains some pretrained models provided by us.

Below are the instructions for training and testing.

Run the code

Trait inference (pretext task)

  1. Data collection
  • In configs/config.py, modify number of data to collect, saving directory, and trajectory length in line 76-79
  • Then run
    python collect_data.py 
    

Alternatively, we provide a downloadable dataset here.
2. Training

  • Modify pretext configs in configs/config.py. Especially,
    • Set pretext.data_load_dir to the directory of the dataset obtained from Step 1.
    • If our method is used, set pretext.cvae_decoder = 'lstm'; if the baseline by Morton and Kochenderfer is used, set pretext.cvae_decoder = 'mlp'.
    • Set pretext.model_save_dir to a new folder that you want to save the model in.
  • Then run
    python train_pretext.py 
    
  1. Testing
    Modify the test arguments in the beginning of test_pretext.py, and run
    python test_pretext.py 
    

This script will generate a visualization of learned representation and a testing log in the folder of the tested model. For example,

We provide two trained example weights for each method:
- Ours: trained_models/pretext/public_ours/checkpoints/995.pt
- Baseline: trained_models/pretext/public_morton/checkpoints/995.pt

Navigation policy learning using RL

  1. Training.
  • Modify training and ppo configs in configs/config.py. Especially,
    • Set training.output_dir to a new folder that you want to save the model in.
    • Set training.pretext_model_path to the path of the trait inference model that you wish to use in RL training.
    • If our method is used, set pretext.cvae_decoder = 'lstm'; if the baseline by Morton and Kochenderfer is used, set pretext.cvae_decoder = 'mlp'.
  • Modify environment configs in configs/driving_config.py. Especially,
    • If our method is used, set env.env_name = 'TIntersectionPredictFront-v0'. Else if the baseline by Morton and Kochenderfer is used, set env.env_name = 'TIntersectionPredictFrontAct-v0'.
    • Set env.con_prob as the portion of conservative cars in the environment (Note: env.con_prob is NOT equal to P(conservative) in the paper, please check the comments in configs/driving_config.py for reference).
  • Then, run
    python train_rl.py 
    
  1. Testing.
    Please modify the test arguments in the begining of test_rl.py, and run
    python test_rl.py 
    
    The testing results are logged in the same folder as the checkpoint model.
    If the "visualize" argument is True in test_rl.py, you can visualize the ego car's policy in different episodes.

    We provide trained example weights for each method when P(conservative) = 0.4:
  • Ours: trained_models/rl/con40/public_ours_rl/checkpoints/25200.pt
  • Baseline: trained_models/rl/con40/public_morton_rl/checkpoints/26800.pt

Author notes

  1. We only tested our code in Ubuntu 16.04 with Python 3.6. The code may work with other versions of Python, but we do not have any guarantee.

  2. The performance of our code can vary depending on the choice of hyperparameters and random seeds (see this reddit post). Unfortunately, we do not have time or resources for a thorough hyperparameter search. To achieve the best performance, we recommend some manual hyperparameter tuning.

Learning curves

Optionally, you can plot the training curves by running the following:

  • for the VAE pretext task
    python plot_pretext.py
    
  • for the RL policy learning
    python plot_rl.py
    

Citation

If you find the code or the paper useful for your research, please cite our paper:

@inproceedings{liu2021learning,
  title={Learning to Navigate Intersections with Unsupervised Driver Trait Inference},
  author={Liu, Shuijing and Chang, Peixin and Chen, Haonan and Chakraborty, Neeloy and Driggs-Campbell, Katherine},
  booktitle={IEEE International Conference on Robotics and Automation (ICRA)},
  year={2022}
}

Credits

Other contributors:
Xiaobai Ma (developed the T-intersection gym environment)
Neeloy Chakraborty

Part of the code is based on the following repositories:

[1] S. Liu, P. Chang, W. Liang, N. Chakraborty, and K. Driggs-Campbell, "Decentralized Structural-RNN for Robot Crowd Navigation with Deep Reinforcement Learning," in IEEE International Conference on Robotics and Automation (ICRA), 2019, pp. 3517-3524. (Github: https://github.com/Shuijing725/CrowdNav_DSRNN)

[2] I. Kostrikov, “Pytorch implementations of reinforcement learning algorithms,” https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail, 2018.

Contact

If you have any questions or find any bugs, please feel free to open an issue or pull request.

is_eval_interactive_driving's People

Contributors

miketimmerman-ae avatar

Stargazers

 avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.