Giter Club home page Giter Club logo

predrnn-pytorch's Introduction

PredRNN: A Recurrent Neural Network for Spatiotemporal Predictive Learning (TPAMI 2022)

The predictive learning of spatiotemporal sequences aims to generate future images by learning from the historical context, where the visual dynamics are believed to have modular structures that can be learned with compositional subsystems.

Initial version at NeurIPS 2017

This repo first contains a PyTorch implementation of PredRNN (2017) [paper], a recurrent network with a pair of memory cells that operate in nearly independent transition manners, and finally form unified representations of the complex environment.

Concretely, besides the original memory cell of LSTM, this network is featured by a zigzag memory flow that propagates in both bottom-up and top-down directions across all layers, enabling the learned visual dynamics at different levels of RNNs to communicate.

New in PredRNN-V2 at TPAMI 2022

This repo also includes the implementation of PredRNN-V2 [paper], which improves PredRNN in the following three aspects.

1. Memory-Decoupled ST-LSTM

We find that the pair of memory cells in PredRNN contain undesirable, redundant features, and thus present a memory decoupling loss to encourage them to learn modular structures of visual dynamics.

decouple

2. Reverse Scheduled Sampling

Reverse scheduled sampling is a new curriculum learning strategy for seq-to-seq RNNs. As opposed to scheduled sampling, it gradually changes the training process of the PredRNN encoder from using the previously generated frame to using the previous ground truth. Benefit: It forces the model to learn long-term dynamics from context frames.

3. Action-Conditioned Video Prediction

We further extend PredRNN to action-conditioned video prediction. By fusing the actions with hidden states, PredRNN and PredRNN-V2 show highly competitive performance in long-term forecasting. They are potential to serve as the base dynamic model in model-based visual control.

We show quantitative results on the BAIR robot pushing dataset for predicting 28 future frames from 2 observations.

action

Showcases

Moving MNIST

mnist

KTH

kth

BAIR (We zoom in on the area in the red box)

bair

Traffic4Cast

Traffic4Cast

Radar echoes

radar

Quantitative results on Moving MNIST and KTH in LPIPS

LPIPS is more sensitive to perceptual human judgments, the lower the better.

Moving MNIST KTH action
PredRNN 0.109 0.204
PredRNN-V2 0.071 0.139

Quantitative results on Traffic4Cast (Berlin)

MSE (10^{-3})
U-Net 6.992
CrevNet 6.789
U-Net+PredRNN-V2 5.135

Get Started

  1. Install Python 3.6, PyTorch 1.9.0 for the main code. Also, install Tensorflow 2.1.0 for BAIR dataloader.

  2. Download data. This repo contains code for three datasets: the Moving Mnist dataset, the KTH action dataset, and the BAIR dataset (30.1GB), which can be obtained by:

    wget http://rail.eecs.berkeley.edu/datasets/bair_robot_pushing_dataset_v0.tar
    
  3. Train the model. You can use the following bash script to train the model. The learned model will be saved in the --save_dir folder. The generated future frames will be saved in the --gen_frm_dir folder.

  4. You can get pretrained models from Tsinghua Cloud or Google Drive.

cd mnist_script/
sh predrnn_mnist_train.sh
sh predrnn_v2_mnist_train.sh

cd kth_script/
sh predrnn_kth_train.sh
sh predrnn_v2_kth_train.sh

cd bair_script/
sh predrnn_bair_train.sh
sh predrnn_v2_bair_train.sh

Citation

If you find this repo useful, please cite the following papers.

@inproceedings{wang2017predrnn,
  title={{PredRNN}: Recurrent Neural Networks for Predictive Learning Using Spatiotemporal {LSTM}s},
  author={Wang, Yunbo and Long, Mingsheng and Wang, Jianmin and Gao, Zhifeng and Yu, Philip S},
  booktitle={Advances in Neural Information Processing Systems},
  pages={879--888},
  year={2017}
}

@misc{wang2021predrnn,
      title={{PredRNN}: A Recurrent Neural Network for Spatiotemporal Predictive Learning}, 
      author={Wang, Yunbo and Wu, Haixu and Zhang, Jianjin and Gao, Zhifeng and Wang, Jianmin and Yu, Philip S and Long, Mingsheng},
      year={2021},
      eprint={2103.09504},
      archivePrefix={arXiv},
}

predrnn-pytorch's People

Contributors

gtziolas avatar junyaohu avatar lkoelman avatar wangyb15 avatar wuhaixu2016 avatar wyb15 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

predrnn-pytorch's Issues

Question about reshape_patch function

what the role of reshape_patch in core/utils/preprocess?
I found that adding the patch_size parameter can significantly speed up the training process and reduce the cuda memory.

And I test function reshape_patch using the code following:

    img = cv2.imread('cat.jpeg', 0)
    img = img[np.newaxis, np.newaxis, :, :, np.newaxis]
    img_patched = reshape_patch(img, 3)

and I show the img(a gray image) and img_patched:
1624435126(1)

So is this function used to reduce the spatial resolution of the image?

And thanks for your good work!

when i use CIKM_predrnn.py ,it run wrong

: python experiment/CIKM_predrnn.py

Traceback (most recent call last):
File "experiment/CIKM_predrnn.py", line 294, in
wrapper_train(model)
File "experiment/CIKM_predrnn.py", line 248, in wrapper_train
cost = trainer.train(model, ims, real_input_flag, args, itr)
File "/home/t42/zhangwei/w外推预测/RAP-Net/core/trainer.py", line 4, in train
cost = model.train(ims, real_input_flag)
File "/home/t42/zhangwei/w外推预测/RAP-Net/core/models/model_factory.py", line 57, in train
next_frames = self.network(frames_tensor, mask_tensor)
File "/home/t42/anaconda3/envs/predrnn/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/t42/zhangwei/w外推预测/RAP-Net/core/models/predict.py", line 56, in forward
h_t[0], c_t[0], memory = self.cell_list[0](net, h_t[0], c_t[0], memory)
File "/home/t42/anaconda3/envs/predrnn/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/t42/zhangwei/w外推预测/RAP-Net/core/layers/STLSTMCell.py", line 35, in forward
x_concat = self.conv_x(x_t)
File "/home/t42/anaconda3/envs/predrnn/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/t42/anaconda3/envs/predrnn/lib/python3.6/site-packages/torch/nn/modules/container.py", line 139, in forward
input = module(input)
File "/home/t42/anaconda3/envs/predrnn/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/t42/anaconda3/envs/predrnn/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 443, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/home/t42/anaconda3/envs/predrnn/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 440, in _conv_forward
self.padding, self.dilation, self.groups)
RuntimeError: Given groups=1, weight of size [448, 16, 5, 5], expected input[4, 32, 32, 16] to have 16 channels, but got 32 channels instead

Inconsistent dataset of Moving_MNIST in the paper and code?

Dear authors, thanks for sharing your code and datasets. The dataset volume described in the paper is fixed, with 10000 sequences for the training set, 3000 sequences for the validation set and 5000 sequences for the test set.
However, the dataset provided in the GitHub is different, with 10000 sequences for the training set, 2000 sequences for the validation set and 3000 sequences for the test set.
So, could you pl tell me the difference? Thanks!

How to run visualization using existing model

After I load the two pre trained models, how do I run different MovingMINIST images and see the predictions? (I downloaded the kept files, but I am not sure how to run predictions using the checkpoints)

question about decouple loss

when I print decouple loss at core/train.py, I found the requires_grad of decouple loss is False!Does this mean that decouple loss does not play any role in backpropagation?

Problem with hidden layers

Dear developers, I'm trying to use predrnn-pytorch but I face the following error each time I declare hidden layers mimicking an Unet architecture ("64,32,64", for example). No error happen if the hidden layers are regular ("64,64,64").

Do you have a hint on how to solve this ?

best regards.

Traceback (most recent call last):
File "run.py", line 223, in
train_wrapper(model)
File "run.py", line 191, in train_wrapper
trainer.train(model, ims, real_input_flag, args, itr)
File "/notebooks/predrnn-pytorch/core/trainer.py", line 15, in train
cost = model.train(ims, real_input_flag)
File "/notebooks/predrnn-pytorch/core/models/model_factory.py", line 42, in train
next_frames, loss = self.network(frames_tensor, mask_tensor)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/notebooks/predrnn-pytorch/core/models/predrnn_v2.py", line 92, in forward
h_t[i], c_t[i], memory, delta_c, delta_m = self.cell_list[i](h_t[i - 1], h_t[i], c_t[i], memory)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/notebooks/predrnn-pytorch/core/layers/SpatioTemporalLSTMCell_v2.py", line 49, in forward
m_concat = self.conv_m(m_t)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/container.py", line 204, in forward
input = module(input)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/conv.py", line 463, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/conv.py", line 459, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Given groups=1, weight of size [96, 32, 3, 3], expected input[8, 64, 32, 32] to have 32 channels, but got 64 channels instead

A list of dependencies

Could you list the dependencies like a requirements.txt file? Sorry if it's mentioned somewhere but I didn't find them.

Prediction images don't change.

Hi,

I ultilized predRNN and your traininig strategy (i.e., combine reverse schedule sampling and schedule sampling) to give a soil moisture forecasting. We ultilized 7 days soil moisture to predict it on future 7 days. However, I found the prediction images can't capture the evolution of soil moisture during forecasting steps, and give the same pattern of soil moisture on step 8 (see attached figure).

Can you give me some suggestions?
Thanks a lot !

Lu

fee08917879e8bf9e882655329e5c93

emergency

When I type 'sh predrnn_mnist_train.sh' at the terminal and run the code, I get an error ' No such file or directory: '/workspace/wuhaixu/predrnn/data/moving-mnist-example/moving-mnist-valid.npz''. Why does 'args' of 'run.py' not work? If I just use 'mnist_model.ckpt', and type 'python run.py', then I get an error'RuntimeError: Error(s) in loading state_dict for RNN:
Missing key(s) in state_dict: "cell_list.0.conv_x.1.weight", "cell_list.0.conv_x.1.bias", "cell_list.0.conv_h.1.weight", "cell_list.0.conv_h.1.bias", "cell_list.0.conv_m.1.weight", "cell_list.0.conv_m.1.bias", "cell_list.0.conv_o.1.weight", "cell_list.0.conv_o.1.bias", "cell_list.1.conv_x.1.weight", "cell_list.1.conv_x.1.bias", "cell_list.1.conv_h.1.weight", "cell_list.1.conv_h.1.bias", "cell_list.1.conv_m.1.weight", "cell_list.1.conv_m.1.bias", "cell_list.1.conv_o.1.weight", "cell_list.1.conv_o.1.bias", "cell_list.2.conv_x.1.weight", "cell_list.2.conv_x.1.bias", "cell_list.2.conv_h.1.weight", "cell_list.2.conv_h.1.bias", "cell_list.2.conv_m.1.weight", "cell_list.2.conv_m.1.bias", "cell_list.2.conv_o.1.weight", "cell_list.2.conv_o.1.bias", "cell_list.3.conv_x.1.weight", "cell_list.3.conv_x.1.bias", "cell_list.3.conv_h.1.weight", "cell_list.3.conv_h.1.bias", "cell_list.3.conv_m.1.weight", "cell_list.3.conv_m.1.bias", "cell_list.3.conv_o.1.weight", "cell_list.3.conv_o.1.bias".
Unexpected key(s) in state_dict: "adapter.weight".
size mismatch for cell_list.0.conv_x.0.weight: copying a param with shape torch.Size([896, 16, 5, 5]) from checkpoint, the shape in current model is torch.Size([448, 16, 5, 5]).
size mismatch for cell_list.0.conv_h.0.weight: copying a param with shape torch.Size([512, 128, 5, 5]) from checkpoint, the shape in current model is torch.Size([256, 64, 5, 5]).
size mismatch for cell_list.0.conv_m.0.weight: copying a param with shape torch.Size([384, 128, 5, 5]) from checkpoint, the shape in current model is torch.Size([192, 64, 5, 5])....'. What should I do? @wangyb15

Generate new Mnist dataset

Hi, As mentioned in previous issue here, may I know how did you include items like "clips", "dims" in the MNIST dataset?

Thank you!

Calculation of MSE per frame

First of all, congrats on your code. I don't understand why don't you use the gx value of line 56 instead of the value used on line 59. Why are lines 57 and 58 necessary? (link).

            line 56: gx = img_out[:, i, :, :, :]
            line 57: gx = np.maximum(gx, 0)
            line 58: gx = np.minimum(gx, 1)
            line 59: mse = np.square(x - gx).sum()

Thanks in advance!

Dataset sequence length

Hi
I ran both the MNIST and KTH datasets. But it seems that in the Moving MNIST, there are 2000 test sequences instead of 5000, and in the KTH dataset, it prints that 'there are 8488/5041 sequences'. The number of sequences in the datasets provided seems to be less than that reported in the paper. Does this matter?

Train/test split of Traffic4Cast dataset

Hi I am hoping to run some tests using the Traffic4Cast dataset tested in the paper but I need the train/test split that is used. Can this be provided? If downloading the most recent version of the competition, the Berlin data unzipped contains several files:

"BERLIN_map_high_res.h5 BERLIN_static.h5 BERLIN_test_additional_temporal.h5 BERLIN_test_temporal.h5 training"

Within the training folder there are 180 h5 files containing the frames for each day. Can you please let me know how you assign the train and test set so that I can accurately repeat your workflow?

EDIT: Also, are you using the 9th channel (car accident report)? The 2021 data does not provide this channel.

predicted results

@wuhaixu2016
Hello, thank you for sharing your work.
My test results on the KTH dataset are shown below.
微信图片_20210730163626
My understanding is to input 20 frames of images and predict 19 frames.
However, why are these two sequences almost identical instead of predicting future actions.

Looking forward to your reply, thank you so much!!!

Moving MNIST 3 digits

Hi,
Thanks for the repo. I could get good results with your pretrained model.
May I know how we can use the pretrained model for 3 digits Moving MNIST?

Radar echo dataset

Hi,
May I kindly know how to get the radar echo dataset and its preprocessing code?
Can you provide a reference?

Thank you!

Whether the Moving_MNIST dataset volume in PredRNN++ is the same as the code repo?

'Hi, this code repo is correct. We will rephrase the paper soon.

Originally posted by @wuhaixu2016 in #26 (comment)'

Thanks for your reply. I couldn't reproduce the performance on Moving_MNIST in PredRNN++ anyway. So I would like to know whether the Moving_MNIST dataset volume in PredRNN++ is the same as the code repo? If not, which volume setting is the experiment in PredRNN++ based on?
Looking forward to your reply. Thanks very much!

no prediction

It turns out that there's nothing prediction in your result, and the final output is exactly the same as the input sequence image. maybe there is some problem.

Guidance for custom dataset

Hi, I'm wondering whether I can use my own data. Is there any guidelines for building custom dataset to train and test the model?

Thanks in advance.

Generate a new Moving-MNIST dataset

Hi,
May I know where can I find the code to generate a new Moving-MNIST dataset? I need to create subsets of different sizes.
Thank you,
Mareeta

Do you have the kth_action dataset?

Do you have the kth_action dataset ?
I use the kth_action dataset from the official website, and preprocess the video to frames.
But I cannot get the training set of 108,717 and test set of 4,086 sequences as the paper mention.
I get like below:

training set:
there are 183861 pictures
there are 30379 sequences
test set:
there are 105855 pictures
there are 17145 sequences

Can u give me the preprocess method or the preprocessed dataset ?

FileNotFoundError: [Errno 2] No such file or directory

When I try to execute the Moving MNIST script using a pretrained model (with arguments "--save_dir" and "--pretrained_model"), the error in the title appears. I believe the lines of code posted below cause the error in the title.

predrnn-pytorch/run.py

Lines 210 to 212 in 36ba2b6

if os.path.exists(args.save_dir):
shutil.rmtree(args.save_dir)
os.makedirs(args.save_dir)

The traffic4cast model

In your article you wrote in the traffic4cast section:

To cope with high-dimensional input frames, we apply the autoencoder architecture of U-Net [88] to the network backbone of PredRNN. Specifically, the decoder of U-Net contains four ST- LSTM layers, and the CNN encoder takes both traffic flow maps and spatiotemporal memory states as inputs.

The code for this model is not available on this GitHub repo, can you make it available please.

Thanks a lot!

Kind regards,
Sébastien de Blois

Calculate model size

Hi,
May I kindly know how to calculate the model sizes/ memory/flops reported in Table 2 and 3 ?
Can you please share the code snippet ?

Thanks!

misconfigured parameter `num_action_ch` for `action_cond_predrnn_v2`

Hi there,

I am having issues using the action-conditional PredRNNV2 for inference.

The way it seems to work (action_injection=concat): Load the actions, grid-repeat them and concat the actual video data and the resulting action tensor channel-wise. Then, use reshape_patch() and pass the input to the model, resulting in a tensor of shape [batch, seq_length, height // patch_size, width // patch_size, (img_ch + action_ch) * patch_size ** 2].

For the action-conditional PredRNNV2 model however, the parameter num_action_ch is used directly for the input channels for the conv layers instead of num_action_ch * patch_size ** 2. For me, this leads to runtime shape mismatches in forward(). Is this an error or did I get it wrong somehow?

Pretrained model-Unexpected key

Hi,
I tried testing the pretrained model of predrnn-v2 for moving-mnist dataset. But I get the following error. Can you please check the reason?

RuntimeError: Error(s) in loading state_dict for RNN:
Unexpected key(s) in state_dict: "adapter.weight".

Thanks!!

Random output images

Hi, I'm using predrnn to predict future frames of experimental acquired data. To augment the data, I had cut each frame into overlapping tiles, and trained predrnn on those tiles.

Currently I'm testing how good predrnn predicts. To do so, I'm using the trained model on data, test.npz, that had not previously seen. This data is also composed of overlapping tiles, the collection of which represents different frames.

The problem I have is this: once predictions are made, the predicted tiles are generated in the corresponding folders, however these tiles are somewhat in a random order, or at least I'm sure they're not in the same order as in the test.npz file that was fed to the trained predrnn model. My question is, is it possible to preserve the same order? The reason for this is that I want to reconstruct frames from the predicted tiles, and if the tiles' order changes, it's very difficult to do a proper reconstruction.

Thanks
Miguel

Moving Mnist dataset

可以提供Moving Mnist dataset 数据集的百度云盘或google云盘的链接吗?

Traffic4cast dataset

Hi,
Could you please share the Traffic4cast dataset used? I couldn't get the 2019 data from the website.
Thanks in advance!

Why is MSE calculated with "sum" instead of "mean"?

Why is MSE calculated with "sum" instead of "mean"?
image

Is the amount of parameters of CrevNet really so small? It seems to be just the amount of parameters contained in a convolution of size 256x256x3x3 stored in float32.
image

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.