Giter Club home page Giter Club logo

tats's Introduction

Long Video Generation with Time-Agnostic VQGAN and Time-Sensitive Transformer (ECCV 2022)

Project Website | Video | Paper

tl;dr We propose TATS, a long video generation framework that is trained on videos with tens of frames while it is able to generate videos with thousands of frames using sliding window.

[New!] We analyze cases where FVD disagrees with human judgment. Check out our project webpage and paper for more information!

Setup

  conda create -n tats python=3.8
  conda activate tats
  conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch
  pip install pytorch-lightning==1.5.4
  pip install einops ftfy h5py imageio imageio-ffmpeg regex scikit-video tqdm

Datasets and trained models

UCF-101: official data, VQGAN, TATS-base, TATS-base-uncond
Sky-Timelapse: official data, VQGAN, TATS-base
Taichi-HD: official data, VQGAN, TATS-base
MUGEN: official data, VQGAN, TATS-base
AudioSet-Drums: official data, Video-VQGAN, STFT-VQGAN, TATS-base

Synthesis

  1. Short videos: To sample the videos of the same length with the training data, use the code under scripts/ with the following flags:
  • gpt_ckpt: path to the trained transformer checkpoint.
  • vqgan_ckpt: path to the trained VQGAN checkpoint.
  • save: path to the save the generation results.
  • save_videos: indicates that videos will be saved.
  • class_cond: indicates that class labels are used as conditional information.

To compute the FVD, these flags are required:

  • compute_fvd: indicates that FVD will be calculated.
  • data_path: path to the dataset folder.
  • dataset: dataset name.
  • image_folder: should be used when the dataset contains frames instead of videos, e.g. Sky Time-lapse.
  • sample_every_n_frames: number of frames to skip in the real video data, e.g. please set it to 4 when using Taichi-HD dataset.
  • resolution: the resolution of real videos to compute FVD, e.g. 128 for UCF, Sky, and Taichi, and 256 for MUGEN.
python sample_vqgan_transformer_short_videos.py \
    --gpt_ckpt {GPT-CKPT} --vqgan_ckpt {VQGAN-CKPT} --class_cond \
    --save {SAVEPATH} --data_path {DATAPATH} --batch_size 16 --resolution 128 \
    --top_k 2048 --top_p 0.8 --dataset {DATANAME} --compute_fvd --save_videos
  1. Long videos: To sample the videos with a length longer than the training length with a sliding window, use the following script.
  • sample_length: number of latent frames to be generated.
  • temporal_sample_pos: position of the frame that the sliding window approach generates.
python sample_vqgan_transformer_long_videos.py \
    --gpt_ckpt {GPT-CKPT} --vqgan_ckpt {VQGAN-CKPT} \
    --dataset ucf101 --class_cond --sample_length 16 --temporal_sample_pos 1 --batch_size 5 --n_sample 5 --save_videos
  1. Text to video: To sample MUGEN videos conditioned on the text, check this colab notebook for an example!

  2. Audio to video: To sample drum videos conditioned on the audio, use the following script.

  • stft_vqgan_ckpt: path to the trained VQGAN checkpoint for STFT features.
python sample_vqgan_transformer_audio_cond.py \
    --gpt_ckpt {GPT-CKPT} --vqgan_ckpt {VQGAN-CKPT} --stft_vqgan_ckpt {STFT-CKPT} \
    --dataset drum --n_sample 10
  1. Hierarchical sampling: To sample the videos with a length longer than the training length with the hierarchical models, first an AR transformer and then an interpolation transformer, use the following checkpoints and scripts.
python sample_vqgan_transformer_hierarchical.py \
    --ckpt1 {AR-CKPT} --CKPT2 {Interpolation-CKPT} --vqgan {VQGAN-CKPT} \
    --dataset sky --top_k_init 2048 --top_p_init 0.8 --top_k 2048 --top_p 0.8 --temporal_sample_pos 1

Training

Example usages of training the VQGAN and transformers are shown below. Explanation of the flags that are opted to change according to different settings:

  • data_path: path to the dataset folder.
  • default_root_dir: path to save the checkpoints and the tensorboard logs.
  • vqvae: path to the trained VQGAN checkpoint.
  • resolution: the resolution of the training video clips.
  • sequence_length: frame number of the training video clips.
  • discriminator_iter_start: the step id to start the GAN losses.
  • image_folder: should be used when the dataset contains frames instead of videos, e.g. Sky Time-lapse.
  • unconditional: when no conditional information is available, e.g. Sky Time-lapse, use this flag.
  • sample_every_n_frames: number of frames to skip in the real video data, e.g. please set it to 4 when training on the Taichi-HD dataset.
  • downsample: sample rate in the dimensions of time, height, and width.
  • no_random_restart: whether to re-initialize the codebook tokens.

VQGAN

python train_vqgan.py --embedding_dim 256 --n_codes 16384 --n_hiddens 32 --downsample 4 8 8 --no_random_restart \
                      --gpus 8 --sync_batchnorm --batch_size 2 --num_workers 32 --accumulate_grad_batches 6 \
                      --progress_bar_refresh_rate 500 --max_steps 2000000 --gradient_clip_val 1.0 --lr 3e-5 \
                      --data_path {DATAPATH} --default_root_dir {CKPTPATH} \
                      --resolution 128 --sequence_length 16 --discriminator_iter_start 10000 --norm_type batch \
                      --perceptual_weight 4 --image_gan_weight 1 --video_gan_weight 1  --gan_feat_weight 4

Transformer

TATS-base Transformer

python train_transformer.py --num_workers 32 --val_check_interval 0.5 --progress_bar_refresh_rate 500 \
                        --gpus 8 --sync_batchnorm --batch_size 3 --unconditional \
                        --vqvae {VQGAN-CKPT} --data_path {DATAPATH} --default_root_dir {CKPTPATH} \
                        --vocab_size 16384 --block_size 1024 --n_layer 24 --n_head 16 --n_embd 1024  \
                        --resolution 128 --sequence_length 16 --max_steps 2000000

To train a conditional transformer, remove the --unconditional flag and use the following flags

  • cond_stage_key: what kind of conditional information to be used. It can be label, text, or stft.
  • stft_vqvae: path to the trained VQGAN checkpoint for STFT features.
  • text_cond: use this flag to indicate BPE encoded text.

TATS-hierarchical Transformer

python train_transformer.py --num_workers 32 --val_check_interval 0.5 --progress_bar_refresh_rate 500 \
                        --gpus 8 --sync_batchnorm --batch_size 3 --unconditional \
                        --vqvae {VQGAN-CKPT} --data_path {DATAPATH} --default_root_dir {CKPTPATH} \
                        --vocab_size 16384 --block_size 1280 --n_layer 24 --n_head 16 --n_embd 1024  \
                        --resolution 128 --sequence_length 20 --spatial_length 128 --n_unmasked 256 --max_steps 2000000

python train_transformer.py --num_workers 32 --val_check_interval 0.5 --progress_bar_refresh_rate 500 \
                        --gpus 8 --sync_batchnorm --batch_size 4 --unconditional \
                        --vqvae {VQGAN-CKPT} --data_path {DATAPATH} --default_root_dir {CKPTPATH} \
                        --vocab_size 16384 --block_size 1024 --n_layer 24 --n_head 16 --n_embd 1024  \
                        --resolution 128 --sequence_length 64 --sample_every_n_latent_frames 4 --spatial_length 128 --max_steps 2000000

Acknowledgments

Our code is partially built upon VQGAN and VideoGPT.

Citation

@article{ge2022long,
  title={Long video generation with time-agnostic vqgan and time-sensitive transformer},
  author={Ge, Songwei and Hayes, Thomas and Yang, Harry and Yin, Xi and Pang, Guan and Jacobs, David and Huang, Jia-Bin and Parikh, Devi},
  journal={arXiv preprint arXiv:2204.03638},
  year={2022}
}

License

TATS is licensed under the MIT license, as found in the LICENSE file.

tats's People

Contributors

songweige avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

tats's Issues

Environment Setting

In my setting, using conda to install pytorch and pip to install pytorch-lighting caused unmattchable problems and failed runing .

Here is my setting

  conda create -n tats python=3.8
  conda activate tats
  pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
  pip install pytorch-lightning
  pip install einops ftfy h5py imageio imageio-ffmpeg regex scikit-video tqdm av

but I still meet some warning as follows:

IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (259, 259) to (272, 272) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to 1 (risking incompatibility).
[swscaler @ 0x6c83700] Warning: data is not aligned! This can lead to a speed loss

and

~/miniconda3/envs/tats/lib/python3.8/site-packages/torchvision/io/video.py:162: UserWarning: The pts_unit 'pts' gives wrong results. Please use pts_unit 'sec'.
  warnings.warn("The pts_unit 'pts' gives wrong results. Please use pts_unit 'sec'.")

Does this affect the final result?
Looking forward to your suggesstion and help。

Training on single GPU

Hi!

Thanks for the great work. I have been trying to train on a single GPU but it keeps throwing this error:

"Default process group has not been initialized, "
RuntimeError: Default process group has not been initialized, please make sure to call init_process_group.

Is it possible to configure the model to train on a single GPU?

full error message:

Traceback (most recent call last):
File "/content/TATS/scripts/train_vqgan.py", line 70, in
main()
File "/content/TATS/scripts/train_vqgan.py", line 66, in main
trainer.fit(model, data)
File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 738, in fit
self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 772, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1195, in _run
self._dispatch()
File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1274, in _dispatch
self.training_type_plugin.start_training(self)
File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
self._results = trainer.run_stage()
File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1284, in run_stage
return self._run_train()
File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1314, in _run_train
self.fit_loop.run()
File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(*args, **kwargs)
File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/loops/fit_loop.py", line 234, in advance
self.epoch_loop.run(data_fetcher)
File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(*args, **kwargs)
File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 193, in advance
batch_output = self.batch_loop.run(batch, batch_idx)
File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(*args, **kwargs)
File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 88, in advance
outputs = self.optimizer_loop.run(split_batch, optimizers, batch_idx)
File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(*args, **kwargs)
File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 219, in advance
self.optimizer_idx,
File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 259, in _run_optimization
closure()
File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 160, in call
self._result = self.closure(*args, **kwargs)
File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 142, in closure
step_output = self._step_fn()
File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 435, in _training_step
training_step_output = self.trainer.accelerator.training_step(step_kwargs)
File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py", line 216, in training_step
return self.training_type_plugin.training_step(*step_kwargs.values())
File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 213, in training_step
return self.model.training_step(*args, **kwargs)
File "/content/TATS/scripts/tats/tats_vqgan.py", line 182, in training_step
recon_loss, _, vq_output, aeloss, perceptual_loss, gan_feat_loss = self.forward(x, optimizer_idx)
File "/content/TATS/scripts/tats/tats_vqgan.py", line 118, in forward
logits_image_fake, pred_image_fake = self.image_discriminator(frames_recon)
File "/usr/local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/content/TATS/scripts/tats/tats_vqgan.py", line 463, in forward
res.append(model(res[-1]))
File "/usr/local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.7/site-packages/torch/nn/modules/container.py", line 139, in forward
input = module(input)
File "/usr/local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.7/site-packages/torch/nn/modules/batchnorm.py", line 731, in forward
world_size = torch.distributed.get_world_size(process_group)
File "/usr/local/lib/python3.7/site-packages/torch/distributed/distributed_c10d.py", line 867, in get_world_size
return _get_group_size(group)
File "/usr/local/lib/python3.7/site-packages/torch/distributed/distributed_c10d.py", line 325, in _get_group_size
default_pg = _get_default_group()
File "/usr/local/lib/python3.7/site-packages/torch/distributed/distributed_c10d.py", line 430, in _get_default_group
"Default process group has not been initialized, "
RuntimeError: Default process group has not been initialized, please make sure to call init_process_group.

Question on training the transformer

Hello, I have a question regarding the training of the transformer. Would there be a leakage of ground truth information during the training of the transformer due to the receptive field of the 3D convolutions across time? The information contained in the last token's embedding might be present in other tokens preceding it. Would that cause inefficient learning of the token indices in the transformer?

The Training of Interpolation Transformer

Dear author:

In the training of Interpolation Transformer, given the latent space is 5 * 16 * 16, I found the first 16 * 16 and the last 16 * 16 tokens join the gradient propagation. But in the inference of Interpolation Transformer, the first and last 16 * 16 tokens are given. So, in my opinion, the first 16 * 16 and the last 16 * 16 tokens should not take part in gradient back-propagation during the training process? Please correct me if I'm wrong.

Kang

FileNotFoundError: [Errno 2] No such file or directory: 'datasets/UCF-101/train/metadata_16.pkl'

Hi.
Thank you for your great work!
Unfortunately, I got an error when I ran the VQGAN train script.

Traceback (most recent call last):
  File "scripts/train_vqgan.py", line 70, in <module>
    main()
  File "scripts/train_vqgan.py", line 21, in main
    data.train_dataloader()
  File "mydir/TATS/scripts/tats/data.py", line 296, in train_dataloader
    return self._dataloader(True)
  File "mydir/TATS/scripts/tats/data.py", line 275, in _dataloader
    dataset = self._dataset(train)
  File "mydir/TATS/scripts/tats/data.py", line 270, in _dataset
    dataset = Dataset(self.args.data_path, self.args.sequence_length,
  File "mydir/TATS/scripts/tats/data.py", line 60, in __init__
    pickle.dump(clips.metadata, open(cache_file, 'wb'))
FileNotFoundError: [Errno 2] No such file or directory: 'datasets/UCF-101/train/metadata_16.pkl'

Although I downloaded the UCF-101 dataset from the official website, the UCF-101 folder does not include the train folder and metada_16.pkl.
How can I preprocess the dataset?
The thrown script is as follows,

python scripts/train_vqgan.py \
--embedding_dim 256 \
--n_codes 16384 \
--n_hiddens 32 \
--downsample 4 8 8 \
--no_random_restart \
--gpus 8 \
--sync_batchnorm \
--batch_size 2 \
--num_workers 32 \
--accumulate_grad_batches 6 \
--progress_bar_refresh_rate 500 \
--max_steps 2000000 \
--gradient_clip_val 1.0 \
--lr 3e-5 \
--data_path datasets/UCF-101 \
--default_root_dir ckpt \
--resolution 128 \
--sequence_length 16 \
--discriminator_iter_start 10000 \
--norm_type batch \
--perceptual_weight 4 \
--image_gan_weight 1 \
--video_gan_weight 1 \
--gan_feat_weight 4 

Thanks.

The Results about Audio to Video Generation Using AudioSet-Drum.

Hi,

Congratulations for your excellent work!

I am interested in using audio to generate video. I found that the scripts you provided for audio conditioning generation is just using audio only as the condition, not both audio and 15 frames as you say in paper. I tried to modify the scripts, like adding the audio latent indices and 15 latent frame indices, to generate the 16th frame and evaluate the 16th frame, but the SSIM and PSNR scores are not ideal actually. I was wondering whether my idea is reasonable and accurate, and whether you can kindly provide the scripts. I am also struggling to generate 45th frame using audio condition. Can it be gotten from sample_long_video.py scripts?

By the way, the way to calculating the SSIM and PSNR is suggested by the scripts from CCVS model. Could you please give me an idea on getting the final results close to the ones in paper?

Thank you for your help!

Minglu

On the Interpolation Transformer

Hi, thanks for the awesome work.

I have been going through the codes to better understand the paper.
I wish to understand how the interpolation transformer is implemented, but could not find it.

I assume the transformer(GPT) of Net2NetTransformer of tats/tats_transformer.py is the autoregressive transformer, but there does not seem to be another transformer for interpolation?

I found the Casual Attention to be existent, but not the Interpolation Casual Attention.

Is it not provided or am I missing something? Could you kindly specify where to look?

Question about training procedure when dataset is devided with train/test set.

Hi.
Thanks for your great work!

I recently studied about video domain.

And I wonder the training procedure.

I figured out that UCF-101 dataset has train and test set.

But Video generation not prediction, Train and Test set are meaning less.
Because It is just generation. So there is not ground truth but can only measure by using FVD.

In that case with train/test set dataset (like UCF-101), Do you use train set only when you train your model? or use both of them?

Thanks.

kind regards.

Unavailable MUGEN checkpoints

Hi, the dropout link for the MUGEN checkpoint seems to be unavailable now. Could you please share another link? Thanks!

The dropout of Transformer

Dear authors:

I found the dropout (embd_pdrop, resid_pdrop, attn_pdrop) is set to 0 during the GPT training.
To verify my observations, I downloaded the TATS-base of UCF101 and Sky-Timelapse from the homepage, the embd_pdrop, resid_pdrop, attn_pdrop were all set to 0.

0 means the dropout does not work. So I want to check is this correct? Or do I miss something?

Kang

Code issues in CausalSelfAttention module

I have a few queries regarding the CausalSelfAttention module in the codebase.

  1. In lines https://github.com/songweige/TATS/blob/main/tats/modules/gpt.py#L100C13-L102C86,
mask[:, :config.n_unmasked+1] = 1
mask[:, -config.n_unmasked+1:] = 1
mask[-config.n_unmasked+1:, config.n_unmasked+1:-config.n_unmasked+1] = 0

The masking seems to be incorrect. I believe the corrected code should be -

mask[:, :config.n_unmasked+1] = 1
mask[:, -(config.n_unmasked+1):] = 1
mask[-(config.n_unmasked+1):, config.n_unmasked+1:-(config.n_unmasked+1)] = 0
  1. In lines https://github.com/songweige/TATS/blob/main/tats/modules/gpt.py#L122C9-L123C76,
if layer_past is None:
    att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))

We are only masking when layer_past is None. But when it is not, no masking is applied which would imply we are not performing causal attention anymore. Why is that the case?

Great Work!!!!!

Few Queries.....

(a)Can you Please provide the evaluation code for reproducing Table 1(a), 1(b), 1(c) and 1(d).
(b)Can you Please let me know the total computation hours needed to train the full model.

audio-video generation on custom dataset

Thank you so much for your excellent work!
I have a question about generating audio-guided videos on custom dataset. I noticed that a pre-trained VQGAN for audio encoding is required, but it seem that the script of training such a model is missed in the repo. Could you please share the script if it's convient to you?
Thank you again for your contribution to the community!

Downsample and Upsample ratio for T dimension

Hi, Thanks for your great work. When I train the mmnist dataset with a downsampling parameter of 4, 8,8, mmnist dataset's input and output in T-dimension are both 10, so the encoder changes the T-dimension to 2, so the decoder can't return the dimension to 10, I want to know if T-dimension of dataset is not a power of 2 what should I do with it?

Checkpoints for TATS-Hierarchical

Hi, thanks for the great work.
Do you have any plan to share the checkpoints and the inference code for TATS-Hierarchical?

If it's hard to share, may I get some metrics or values to monitor about TATS-Hierarchical so that I can reproduce them?

Mugen checkpoint

Great work and thanks for kindly releasing both codes and pretrained models! I could access all the checkpoints except the one on Mugen dataset. Could you please upload it to google drive instead of dropbox?

Training with custom dataset

I really apreciated the work you done in this repo. I have a custom dataset which has mp4 videos and txt text descriptions by timestamp (300 GB zip). How can I use my custom data to train a text2video generator. Main theme of my data is nearly same as MUGEN dataset which has subjects, actions, objects...

1-) How I should pre process my data (such as a csv file points mp4 and coressponding txt or a json file and other options)
2-) How exactly I can train the tats and vqgan models, is there any official scripts? If not, can you help me with it?
3-) This data is huge and I have just a laptop with nvidia 3080-q max laptop gpu, do you think it is possible to train a good model in a meaningful time or should I search for a service like aws. (3080 is 4 times better than a tesla k80 gpu but the real deal is number of gpus in cases like that)

Thank you for answers!

Loss starts to diverge When the discriminator opens

I try to train VQGAN on UCF-101 dataset with 4 A100s, and 24 samples for each device, recontruction and perceptual loss could converge normally, until the discriminator is opened in 10k steps. In addtion, the commitment loss diverge from the start time, How can I fix it?
image
image

Training on FaceForensics

Hi!
I am currently working on using the code for the FaceForensics dataset. I have been able to train VQGAN, but I encountered an issue while training the TATS-base Transformer. Here is the command I am using for training:

python scripts/train_transformer.py --num_workers 32 --val_check_interval 0.5 --progress_bar_refresh_rate 500 \
                        --gpus 4 --sync_batchnorm --batch_size 2 --unconditional \
                        --vqvae exp_1/lightning_logs/version_78873/checkpoints/latest_checkpoint.ckpt \
                        --data_path data/ffs_processed --default_root_dir exp_1_tats --image_folder \
                        --vocab_size 16384 --block_size 1024 --n_layer 24 --n_head 16 --n_embd 1024  \
                        --resolution 256 --sequence_length 16 --max_steps 2000000

However, I got an error:
Traceback (most recent call last):
File "/lu/tetyda/home/pkomor/Projects/VidGen/TATS/scripts/train_transformer.py", line 77, in
main()
File "/lu/tetyda/home/pkomor/Projects/VidGen/TATS/scripts/train_transformer.py", line 73, in main
trainer.fit(model, data)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 737, in fit
self._call_and_handle_interrupt(
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 772, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1195, in _run
self._dispatch()
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1274, in _dispatch
self.training_type_plugin.start_training(self)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
self._results = trainer.run_stage()
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1284, in run_stage
return self._run_train()
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1306, in _run_train
self._run_sanity_check(self.lightning_module)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1370, in _run_sanity_check
self._evaluation_loop.run()
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(*args, **kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 109, in advance
dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(*args, **kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 122, in advance
output = self._evaluation_step(batch, batch_idx, dataloader_idx)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 217, in _evaluation_step
output = self.trainer.accelerator.validation_step(step_kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 236, in validation_step
return self.training_type_plugin.validation_step(*step_kwargs.values())
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/ddp.py", line 444, in validation_step
return self.model(*args, **kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1008, in forward
output = self._run_ddp_forward(*inputs, **kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 969, in _run_ddp_forward
return module_to_run(*inputs[0], **kwargs[0])
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 92, in forward
output = self.module.validation_step(*inputs, **kwargs)
File "/lu/tetyda/home/pkomor/Projects/VidGen/TATS/scripts/../tats/tats_transformer.py", line 250, in validation_step
loss, acc1, acc5 = self.shared_step(batch, batch_idx)
File "/lu/tetyda/home/pkomor/Projects/VidGen/TATS/scripts/../tats/tats_transformer.py", line 237, in shared_step
logits, target = self(x, c, cbox)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/lu/tetyda/home/pkomor/Projects/VidGen/TATS/scripts/../tats/tats_transformer.py", line 126, in forward
logits, _ = self.transformer(cz_indices[:, :-1], cbox=cbox)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/lu/tetyda/home/pkomor/Projects/VidGen/TATS/scripts/../tats/modules/gpt.py", line 208, in forward
assert t <= self.block_size, "Cannot forward, model block size is exhausted."
AssertionError: Cannot forward, model block size is exhausted.

where t=4096 and self.block_size=1024. When I tried to increase the block_size argument to 4096, I received another error message (even with a batch_size of 1):
Traceback (most recent call last):
File "/lu/tetyda/home/pkomor/Projects/VidGen/TATS/scripts/train_transformer.py", line 77, in
main()
File "/lu/tetyda/home/pkomor/Projects/VidGen/TATS/scripts/train_transformer.py", line 73, in main
trainer.fit(model, data)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 737, in fit
self._call_and_handle_interrupt(
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 772, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1195, in _run
self._dispatch()
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1274, in _dispatch
self.training_type_plugin.start_training(self)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
self._results = trainer.run_stage()
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1284, in run_stage
return self._run_train()
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1314, in _run_train
self.fit_loop.run()
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(*args, **kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 234, in advance
self.epoch_loop.run(data_fetcher)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(*args, **kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 193, in advance
batch_output = self.batch_loop.run(batch, batch_idx)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(*args, **kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 88, in advance
outputs = self.optimizer_loop.run(split_batch, optimizers, batch_idx)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(*args, **kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 215, in advance
result = self._run_optimization(
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 266, in _run_optimization
self._optimizer_step(optimizer, opt_idx, batch_idx, closure)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 378, in _optimizer_step
lightning_module.optimizer_step(
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/core/lightning.py", line 1651, in optimizer_step
optimizer.step(closure=optimizer_closure)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py", line 164, in step
trainer.accelerator.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 336, in optimizer_step
self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 163, in optimizer_step
optimizer.step(closure=closure, **kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/optim/optimizer.py", line 113, in wrapper
return func(*args, **kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/optim/adamw.py", line 119, in step
loss = closure()
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 148, in _wrap_closure
closure_result = closure()
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 160, in call
self._result = self.closure(*args, **kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 142, in closure
step_output = self._step_fn()
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 435, in _training_step
training_step_output = self.trainer.accelerator.training_step(step_kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 216, in training_step
return self.training_type_plugin.training_step(*step_kwargs.values())
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/ddp.py", line 439, in training_step
return self.model(*args, **kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1008, in forward
output = self._run_ddp_forward(*inputs, **kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 969, in _run_ddp_forward
return module_to_run(*inputs[0], **kwargs[0])
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 81, in forward
output = self.module.training_step(*inputs, **kwargs)
File "/lu/tetyda/home/pkomor/Projects/VidGen/TATS/scripts/../tats/tats_transformer.py", line 243, in training_step
loss, acc1, acc5 = self.shared_step(batch, batch_idx)
File "/lu/tetyda/home/pkomor/Projects/VidGen/TATS/scripts/../tats/tats_transformer.py", line 237, in shared_step
logits, target = self(x, c, cbox)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/lu/tetyda/home/pkomor/Projects/VidGen/TATS/scripts/../tats/tats_transformer.py", line 126, in forward
logits, _ = self.transformer(cz_indices[:, :-1], cbox=cbox)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/lu/tetyda/home/pkomor/Projects/VidGen/TATS/scripts/../tats/modules/gpt.py", line 218, in forward
x = self.blocks(x)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/modules/container.py", line 139, in forward
input = module(input)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/lu/tetyda/home/pkomor/Projects/VidGen/TATS/scripts/../tats/modules/gpt.py", line 153, in forward
attn, present = self.attn(self.ln1(x), layer_past=layer_past)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/lu/tetyda/home/pkomor/Projects/VidGen/TATS/scripts/../tats/modules/gpt.py", line 121, in forward
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
RuntimeError: CUDA out of memory. Tried to allocate 1024.00 MiB (GPU 1; 31.74 GiB total capacity; 29.71 GiB already allocated; 45.12 MiB free; 30.61 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Can you suggest any potential solutions that might work?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.