- The agent we are most interested in.
- In this scenario, it is a left-turn vehicle.
- In this scenario, there is one agent.
- Agents that exist around the ego agent.
- In this scenario, it is a horizontal moving vehicle.
- In this scenario, there are multiple number of agents.
Before training ego agent
- Note that you need 'encoder' network before training the ego agent.
- For example, before train ego with social RL, you need encoder network trained with social RL vehicle's offline data.
pretext_collect_data.py
: collecting offline data for encoder network training.- It doesn't take a long time. (< 1 hour)
pretext_train.py
: training encoder network.- If you use GPU, it takes (< 3 hours)
Model training code
train_social_with_RLEgo.py
: training guide model for social vehicle with RL ego agent.- You need (rational) RL ego agent
- But, You don't need to train RL ego agent. Just use the pretrained model which is provided by Liu's github repo.
- That is, this code is the first step.
- You can train guide model for social vehicle this code.
- Not a meta model for social vehicle
train_social_with_RLEgo.py
: training meta model for social vehicle with RL ego agent.- If you want to train guided meta-RL, you need guide policy made by previous step.
- If you want to train meta-RL without guide, just run this code without guide model.
train_ego_with_trained_social.py
: training ego agent with trained social agent.- You can train ego agent with trained social RL model.
- You have to prepare the trained social agent.
Model testing code
ego_social_test.py
: testing ego agent with social agent.- IDM can be selected.
- meta RL can be selected.
- guided meta RL can be selected.
- guiding RL can be selected.
How to use test script?
- You have to check argument in
test_*.py
script. --model_dir
indicates the folder which containsconfigs
andcheckpoints
folder.--visualize
: if this value is set toTrue
, simulation is rendered.--test_model
indicates checkpoint of the model. (You have to manually change this value.)
parser.add_argument('--model_dir', type=str, default='data/new_rl')
parser.add_argument('--visualize', default=True, action='store_true')
parser.add_argument('--test_model', type=str, default='00020.pt')
This repository contains the codes for our paper titled "Learning to Navigate Intersections with Unsupervised Driver Trait Inference" in ICRA 2022. For more details, please refer to the project website and arXiv preprint. For experiment demonstrations, please refer to the youtube video.
Navigation through uncontrolled intersections is one of the key challenges for autonomous vehicles. Identifying the subtle differences in hidden traits of other drivers can bring significant benefits when navigating in such environments. We propose an unsupervised method for inferring driver traits such as driving styles from observed vehicle trajectories. We use a variational autoencoder with recurrent neural networks to learn a latent representation of traits without any ground truth trait labels. Then, we use this trait representation to learn a policy for an autonomous vehicle to navigate through a T-intersection with deep reinforcement learning. Our pipeline enables the autonomous vehicle to adjust its actions when dealing with drivers of different traits to ensure safety and efficiency. Our method demonstrates promising performance and outperforms state-of-the-art baselines in the T-intersection scenario.
-
Install Python3.6.
-
Install the required python package using pip or conda. For pip, use the following command:
pip install -r requirements.txt
For conda, please install each package in
requirements.txt
into your conda environment manually and follow the instructions on the anaconda website. -
Install OpenAI Baselines.
git clone https://github.com/openai/baselines.git cd baselines pip install -e .
This repository is organized in five parts:
configs/
folder contains configurations for training and neural networks.driving_sim/
folder contains the simulation environment and the wrapper for inferring the traits during RL training (indriving_sim/vec_env/
).pretext/
folder contains the code for VAE trait inference task, including the networks, collecting and loading trajectory data, as well as loss functions for VAE training.rl/
contains the code for the RL policy networks and ppo algorithm.trained_models/
contains some pretrained models provided by us.
Below are the instructions for training and testing.
- Data collection
- In
configs/config.py
, modify number of data to collect, saving directory, and trajectory length in line 76-79 - Then run
python collect_data.py
Alternatively, we provide a downloadable dataset here.
2. Training
- Modify pretext configs in
configs/config.py
. Especially,- Set
pretext.data_load_dir
to the directory of the dataset obtained from Step 1. - If our method is used, set
pretext.cvae_decoder = 'lstm'
; if the baseline by Morton and Kochenderfer is used, setpretext.cvae_decoder = 'mlp'
. - Set
pretext.model_save_dir
to a new folder that you want to save the model in.
- Set
- Then run
python train_pretext.py
- Testing
Modify the test arguments in the beginning oftest_pretext.py
, and runpython test_pretext.py
This script will generate a visualization of learned representation and a testing log in the folder of the tested model.
For example,
We provide two trained example weights for each method:
- Ours: trained_models/pretext/public_ours/checkpoints/995.pt
- Baseline: trained_models/pretext/public_morton/checkpoints/995.pt
- Training.
- Modify training and ppo configs in
configs/config.py
. Especially,- Set
training.output_dir
to a new folder that you want to save the model in. - Set
training.pretext_model_path
to the path of the trait inference model that you wish to use in RL training. - If our method is used, set
pretext.cvae_decoder = 'lstm'
; if the baseline by Morton and Kochenderfer is used, setpretext.cvae_decoder = 'mlp'
.
- Set
- Modify environment configs in
configs/driving_config.py
. Especially,- If our method is used, set
env.env_name = 'TIntersectionPredictFront-v0'
. Else if the baseline by Morton and Kochenderfer is used, setenv.env_name = 'TIntersectionPredictFrontAct-v0'
. - Set
env.con_prob
as the portion of conservative cars in the environment (Note:env.con_prob
is NOT equal to P(conservative) in the paper, please check the comments inconfigs/driving_config.py
for reference).
- If our method is used, set
- Then, run
python train_rl.py
- Testing.
Please modify the test arguments in the begining oftest_rl.py
, and runThe testing results are logged in the same folder as the checkpoint model.python test_rl.py
If the "visualize" argument is True intest_rl.py
, you can visualize the ego car's policy in different episodes.
We provide trained example weights for each method when P(conservative) = 0.4:
- Ours:
trained_models/rl/con40/public_ours_rl/checkpoints/25200.pt
- Baseline:
trained_models/rl/con40/public_morton_rl/checkpoints/26800.pt
-
We only tested our code in Ubuntu 16.04 with Python 3.6. The code may work with other versions of Python, but we do not have any guarantee.
-
The performance of our code can vary depending on the choice of hyperparameters and random seeds (see this reddit post). Unfortunately, we do not have time or resources for a thorough hyperparameter search. To achieve the best performance, we recommend some manual hyperparameter tuning.
Optionally, you can plot the training curves by running the following:
- for the VAE pretext task
python plot_pretext.py
- for the RL policy learning
python plot_rl.py
If you find the code or the paper useful for your research, please cite our paper:
@inproceedings{liu2021learning,
title={Learning to Navigate Intersections with Unsupervised Driver Trait Inference},
author={Liu, Shuijing and Chang, Peixin and Chen, Haonan and Chakraborty, Neeloy and Driggs-Campbell, Katherine},
booktitle={IEEE International Conference on Robotics and Automation (ICRA)},
year={2022}
}
Other contributors:
Xiaobai Ma (developed the T-intersection gym environment)
Neeloy Chakraborty
Part of the code is based on the following repositories:
[1] S. Liu, P. Chang, W. Liang, N. Chakraborty, and K. Driggs-Campbell, "Decentralized Structural-RNN for Robot Crowd Navigation with Deep Reinforcement Learning," in IEEE International Conference on Robotics and Automation (ICRA), 2019, pp. 3517-3524. (Github: https://github.com/Shuijing725/CrowdNav_DSRNN)
[2] I. Kostrikov, “Pytorch implementations of reinforcement learning algorithms,” https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail, 2018.
If you have any questions or find any bugs, please feel free to open an issue or pull request.