Giter Club home page Giter Club logo

mantra-cvpr20's Introduction

MANTRA: Memory Augmented Networks for Multiple Trajectory Prediction

Official pytorch code for Mantra: Memory augmented networks for multiple trajectory prediction - CVPR2020

MANTRA: Memory Augmented Networks for Multiple Trajectory Prediction" by Francesco Marchetti, Federico Becattini, Lorenzo Seidenari, and Alberto Del Bimbo.

Multiple trajectory prediction. Blue: past, red: futures.

Installation

To install the required packages, in a Python 3.6 environment just execute the following:

pip install -r requirements.txt

Dataset

We provide a dataloader for the KITTI dataset in dataset_invariance.py. The dataloader yields samples of (past, future) trajectories paired with a semantic map of the surrounding scene.

Training

To train MANTRA, first it is necessary to train the autoencoder, then to train the writing controller and finally to train the Iterative Refinment Module (IRM). Trainings can be monitored using tensorboard, logs are stored in the folder runs/(runs-pretrain/runs-createMem/runs-IRM). In the pretrained_model folder there are pretrained models of the different components (autoencoder, writing controller, MANTRA).

Training encoder-decoder model (autoencoder)

python train_ae.py

The autoencoder can be trained with the train_ae.py script. train_ae.py calls trainer_ae.py The model will be saved into the folder test/[current_date]. A pretrained model can be found in pretrained_models/model_AE/

Training writing controller

python train_controllerMem.py --model pretrained_autoencoder_model_path

The writing controller for the memory with autoencoder can be trained with train_controllerMem.py. train_controllerMem.py calls trainer_controllerMem.py. The path of a pretrained autoencoder model has to be passed to the script (it defaults to the pretrained model we provided). A pretrained model (autoencoder + writing controller) can be found in pretrained_models/model_controller/

Training Iterative Refinement Module (IRM)

python train_IRM.py --model pretrained_autoencoder+controller_model_path

train_IRM.py calls trainer_IRM.py The script trains the IRM module that generates the final prediction based on the decoded trajectory and the context map. The paths of a pretrained autoencoder with writing controller model and populated memories have to be passed to the script (it defaults to the pretrained models we provided). A pretrained MANTRA model can be found in pretrained_models/model_complete/

Test

python test.py --model pretrained_complete_model_path --withIRM True/False --saved_memory True/False

test.py calls evaluate_MemNet.py This script generates metrics on the KITTI dataset using a trained models. We compute Average Displacement Error (ADE) and Final Displacement Error (FDE, also referred to as Error@K or Horizon Error).

Command line arguments

    --cuda                         Enable/Disable GPU device (default=True).
    --batch_size                   Number of samples that will be fed to MANTRA in one iteration (default=32).
    --past_len                     Past length (default=20).
    --future_len                   Future length (default=40).
    --preds                        Number of predictions generated by MANTRA model (default=5)
    --model                        Path of pretrained model for the evaluation (default='pretrained_models/MANTRA/model_MANTRA')
    --visualize_dataset            The system saves (in *folder_test/dataset_train* and *folder_test/dataset_test*) all examples
                                   of dataset.
    --saved_memory                 The system chooses which memories will be used in evaluation.
                                   If True, it will be loaded memories from 'memories_path' folder.
                                   If False, new memories will be generated. pairs of past-future will be decided by writing controller of model.
    --memories_path                This path will be used only if saved_memory flag is True.
    --withIRM                      The model generates predictions with/without Iterative Refinement Module.
    --saveImages                   The system saves in test folder examples of dataset with prediction generated by MANTRA.
                                   If None, it doesn't save any qualitative examples but only quantitative results.
                                   If 'All', it saves all examples.
                                   If 'Subset', it saves examples defined in index_qualitative.py (hand picked most significant samples)
                                   (default=None)
    --dataset_file                 Name of json file cointaining the dataset (default='kitti_dataset.json')
    --info                         Name of evaluation. It will use for name of the test folder (default='')

Citation

If you use our code or find it useful in your research, please cite the following paper:

@inproceedings{cvpr_2020,
 author = {Marchetti, Francesco and  Becattini, Federico and Seidenari, Lorenzo and Del Bimbo, Alberto},
 booktitle = {International Conference on Computer Vision and Pattern Recognition (CVPR)},
 publisher = {IEEE},
 title = {MANTRA: Memory Augmented Networks for Multiple Trajectory Prediction},
 year = {2020}
}
@ARTICLE{Geiger2013IJRR,
  author = {Andreas Geiger and Philip Lenz and Christoph Stiller and Raquel Urtasun},
  title = {Vision meets Robotics: The KITTI Dataset},
  journal = {International Journal of Robotics Research (IJRR)},
  year = {2013}
}

License

logo

This source code is shared under the license CC-BY-NC-SA, please refer to the LICENSE file for more information.

This source code is only shared for R&D or evaluation of this model on user database.

Any commercial utilization is strictly forbidden.

For any utilization with a commercial goal, please contact contact_cs or bendahan

mantra-cvpr20's People

Contributors

fedebecat avatar lseidenari avatar marchetz avatar seide avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

mantra-cvpr20's Issues

Error in running test module

Hi,
I get the following error when running the test.py module:

RuntimeError: cuDNN error: CUDNN_STATUS_MAPPING_ERROR

I use Python 3.6 with all requirements in the requirement.txt.

cant run

This bug occurs when training a controller AttributeError: 'GRU' object has no attribute '_flat_weights_names',How to solve it?

python version

Hi
Thank you very much for sharing your work.
Which python version is used?

How to extract object trajectories from KITTI dataset?

I am trying to project a point cloud into a global coordinate system and follow the (x,y) coordinates of the global coordinate system as the trajectory of the object, however, I am getting a different trajectory than the data provided.
May I know how to extract the object trajectory from KITTI dataset?

Get different result

Hi
Thanks for sharing your implementation. I have tried regenerating the result reported in table1 (KITTI dataset). I have used your pretain's models, but I have gotten different results. Also, I have trained three models and gotten worse results.
What must I do to get the same results?
The results that were published, the result of pertrain's models and my results are in the following table ( KITTI dataset -Table 1 on the paper).

Capture1

controller loss

Hi,

Thank you very much for sharing your work. I came across the implementation of controller loss below:

loss = prob * sim + (1 - prob) * (1 - sim)

I see that it contradicts with its equation (2) in your paper : L = e(1-P(w)) + (1-e)P(w).

Could you please help me to understand this ? I guess I am missing something here.

Thank you very much,

GRU error

Hi, I'm trying to run your code with your guide. So at the first I run train_ae.py and the script run successfully. After it I run train_controllerMem.py script with CUDA and without CUDA. With CUDA I got an error at this line

self.mem_n2n = self.mem_n2n.cuda()
with this message AttributeError: 'GRU' object has no attribute '_flat_weights_names' and when I run the script without CUDA, I got an error at this line
output_past, state_past = self.encoder_past(story_embed)
but it shows different message AttributeError: 'GRU' object has no attribute '_flat_weights'.

I installed nvidia-driver-470, CUDA 11.4, torch+cud111 and libcudnn8_8.2.4.15 which is compatible with CUDA.

Can you guide me to solve this problem please?

Question about evaluate input

It seems that even in the evaluation phase, the future information is passed to the network.

file: trainer/trainer_ae.py
line: 213
phase: evaluate
pred = self.mem_n2n(past, future).data

file: models/model_encdec.py
line: 94
state_conc = torch.cat((state_past, state_fut), 2)

train IRM

The training IRM module keeps getting an error asking for help, thank you!训练IRM模块一直报错请求帮忙,谢谢!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.