Comments (14)
Hi @nnsriram97,
The main idea from the paper is to overcome the mode collapse which is usually the case when you train end-to-end. In our work, we first train the sampling network to generate diverse hypotheses and then the fitting on top of the sampling. Note that we fix the sampling network when training the fitting network. In other words, the fitting network has to find the parameters of GMM with different modes to fit the diverse set of hypotheses. Only at the end, we do joint training of the two networks.
So to summarize, we have three stages:
- train the sampling network with EWTA.
- train the fitting network with NLL while keeping the weights of the sampling network fixed.
- train jointly both of them with NLL.
from multimodal-future-prediction.
You can check the Figure 3 in our supplementary material of the arXiv paper:
https://arxiv.org/pdf/1906.03631.pdf
All stages of the EWTA are trained equally, this means that if we train for 150k iterations using EWTA and we have 5 stages, then every stage is trained for 30k.
I cannot give exact dates for the torch implementation but we hope to do it within 5-7 weeks.
Best,
from multimodal-future-prediction.
Hi,
Thanks for your interest in our work.
We are planning to release an updated version (using pytorch) to train the framework which will make it easier for people to train and test the framework.
Meanwhile, feel free to raise more questions if you want to implement it yourself.
I will keep this issue open until we have the updated version.
Best,
from multimodal-future-prediction.
Thanks, looking forward to your PyTorch implementation.
May I know how did you avoid mode collapse on the fitting network? All the soft assignments assigning predicted hypotheses to one mode of the GMM. Found this while implementing, any insights on training the fitting network to avoid this issue?
from multimodal-future-prediction.
Thanks for the great work! How long did you take to train each stage(time or steps?)
Also, can you let us know when Torch implementation(Training, Testing) would be available? Really looking forward to Torch implementation :)
from multimodal-future-prediction.
I'm currently trying to train the model. If I have more questions I will leave it here. Thanks for the great work!
from multimodal-future-prediction.
@os1a How is the progress going with the pytorch implementation?
from multimodal-future-prediction.
Hi @rafalk342
Thanks for your interest in our code,
Unfortunately we are still waiting for an approval from our business partner, but we are working on it and hopefully we can publish it before the end of the year.
I will post here when we have it ready.
from multimodal-future-prediction.
Is there any progress on the training script release? Or has somebody replicated the training process successfully?
from multimodal-future-prediction.
Hi @Shaluols,
Thanks for your interest in our codebase.
Unfortunately publishing the source code for training will take longer than expected. The main problem is the transfer between frameworks. The original code we had was trained on Tensorflow with two-privately libraries used for data reading and augmentation.
I will update you as soon as we have it ready.
Sorry for the delay,
Best,
from multimodal-future-prediction.
@os1a Thanks for the update~
from multimodal-future-prediction.
Hi all!
I currently work on a research project related to this work. We have adapted the parts of code from this repository to generate the CPI dataset files in the identical format to the SDD data so that one can plug-and-play with it.
https://github.com/maciejzieba/regressionFlow/tree/master/cpi_generation
from multimodal-future-prediction.
Thanks @mprzewie for sharing the code,
Your work is pretty good and I liked it.
Best,
from multimodal-future-prediction.
@os1a Any update on releasing the training code? Thanks!
from multimodal-future-prediction.
Related Issues (12)
- OSError: dlopen(wemd/lib/libwemd.so, 6): image not found HOT 5
- Question related to make_sampling_loss function HOT 3
- Question about training the fitting stage HOT 3
- Training scripts for SDD dataset
- What is the `tb` package? HOT 4
- EMD on CPI dataset HOT 2
- Figure3 on paper HOT 4
- Questions about sampling network training process HOT 7
- Questions about dataset creation HOT 3
- Feasibility for the training script HOT 1
- Questions about the optimizers used for training the sampling and fitting neural networks HOT 3
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from multimodal-future-prediction.