Giter Club home page Giter Club logo

dalle-pytorch's Introduction

DALL-E in Pytorch

Train DALL-E w/ DeepSpeed Join us on Discord
Released DALLE Models
Web-Hostable DALLE Checkpoints

Yannic Kilcher's video

Implementation / replication of DALL-E (paper), OpenAI's Text to Image Transformer, in Pytorch. It will also contain CLIP for ranking the generations.


Quick Start

Deep Daze or Big Sleep are great alternatives!

For generating video and audio, please see NรœWA

Appreciation

This library could not have been possible without the contributions of janEbert, Clay, robvanvolt, Romain Beaumont, and Alexander! ๐Ÿ™

Status

  • Hannu has managed to train a small 6 layer DALL-E on a dataset of just 2000 landscape images! (2048 visual tokens)

  • Kobiso, a research engineer from Naver, has trained on the CUB200 dataset here, using full and deepspeed sparse attention

  • (3/15/21) afiaka87 has managed one epoch using a reversible DALL-E and the dVaE here

  • TheodoreGalanos has trained on 150k layouts with the following results

- Rom1504 has trained on 50k fashion images with captions with a really small DALL-E (2 layers) for just 24 hours with the following results

  • afiaka87 trained for 6 epochs on the same dataset as before thanks to the efficient 16k VQGAN with the following results

Thanks to the amazing "mega b#6696" you can generate from this checkpoint in colab - Run inference on the Afiaka checkpoint in Colab

  • (5/2/21) First 1.3B DALL-E from ๐Ÿ‡ท๐Ÿ‡บ has been trained and released to the public! ๐ŸŽ‰

  • (4/8/22) Moving onwards to DALLE-2!

Install

$ pip install dalle-pytorch

Usage

Train VAE

import torch
from dalle_pytorch import DiscreteVAE

vae = DiscreteVAE(
    image_size = 256,
    num_layers = 3,           # number of downsamples - ex. 256 / (2 ** 3) = (32 x 32 feature map)
    num_tokens = 8192,        # number of visual tokens. in the paper, they used 8192, but could be smaller for downsized projects
    codebook_dim = 512,       # codebook dimension
    hidden_dim = 64,          # hidden dimension
    num_resnet_blocks = 1,    # number of resnet blocks
    temperature = 0.9,        # gumbel softmax temperature, the lower this is, the harder the discretization
    straight_through = False, # straight-through for gumbel softmax. unclear if it is better one way or the other
)

images = torch.randn(4, 3, 256, 256)

loss = vae(images, return_loss = True)
loss.backward()

# train with a lot of data to learn a good codebook

Train DALL-E with pretrained VAE from above

import torch
from dalle_pytorch import DiscreteVAE, DALLE

vae = DiscreteVAE(
    image_size = 256,
    num_layers = 3,
    num_tokens = 8192,
    codebook_dim = 1024,
    hidden_dim = 64,
    num_resnet_blocks = 1,
    temperature = 0.9
)

dalle = DALLE(
    dim = 1024,
    vae = vae,                  # automatically infer (1) image sequence length and (2) number of image tokens
    num_text_tokens = 10000,    # vocab size for text
    text_seq_len = 256,         # text sequence length
    depth = 12,                 # should aim to be 64
    heads = 16,                 # attention heads
    dim_head = 64,              # attention head dimension
    attn_dropout = 0.1,         # attention dropout
    ff_dropout = 0.1            # feedforward dropout
)

text = torch.randint(0, 10000, (4, 256))
images = torch.randn(4, 3, 256, 256)

loss = dalle(text, images, return_loss = True)
loss.backward()

# do the above for a long time with a lot of data ... then

images = dalle.generate_images(text)
images.shape # (4, 3, 256, 256)

To prime with a starting crop of an image, simply pass two more arguments

img_prime = torch.randn(4, 3, 256, 256)

images = dalle.generate_images(
    text,
    img = img_prime,
    num_init_img_tokens = (14 * 32)  # you can set the size of the initial crop, defaults to a little less than ~1/2 of the tokens, as done in the paper
)

images.shape # (4, 3, 256, 256)

You may also want to generate text using DALL-E. For that call this function:

text_tokens, texts = dalle.generate_texts(tokenizer, text)

OpenAI's Pretrained VAE

You can also skip the training of the VAE altogether, using the pretrained model released by OpenAI! The wrapper class should take care of downloading and caching the model for you auto-magically.

import torch
from dalle_pytorch import OpenAIDiscreteVAE, DALLE

vae = OpenAIDiscreteVAE()       # loads pretrained OpenAI VAE

dalle = DALLE(
    dim = 1024,
    vae = vae,                  # automatically infer (1) image sequence length and (2) number of image tokens
    num_text_tokens = 10000,    # vocab size for text
    text_seq_len = 256,         # text sequence length
    depth = 1,                  # should aim to be 64
    heads = 16,                 # attention heads
    dim_head = 64,              # attention head dimension
    attn_dropout = 0.1,         # attention dropout
    ff_dropout = 0.1            # feedforward dropout
)

text = torch.randint(0, 10000, (4, 256))
images = torch.randn(4, 3, 256, 256)

loss = dalle(text, images, return_loss = True)
loss.backward()

Taming Transformer's Pretrained VQGAN VAE

You can also use the pretrained VAE offered by the authors of Taming Transformers! Currently only the VAE with a codebook size of 1024 is offered, with the hope that it may train a little faster than OpenAI's, which has a size of 8192.

In contrast to OpenAI's VAE, it also has an extra layer of downsampling, so the image sequence length is 256 instead of 1024 (this will lead to a 16 reduction in training costs, when you do the math). Whether it will generalize as well as the original DALL-E is up to the citizen scientists out there to discover.

Update - it works!

from dalle_pytorch import VQGanVAE

vae = VQGanVAE()

# the rest is the same as the above example

The default VQGan is the codebook size 1024 one trained on imagenet. If you wish to use a different one, you can use the vqgan_model_path and vqgan_config_path to pass the .ckpt file and the .yaml file. These options can be used both in train-dalle script or as argument of VQGanVAE class. Other pretrained VQGAN can be found in taming transformers readme. If you want to train a custom one you can follow this guide

Adjust text conditioning strength

Recently there has surfaced a new technique for guiding diffusion models without a classifier. The gist of the technique involves randomly dropping out the text condition during training, and at inference time, deriving the rough direction from unconditional to conditional distributions.

Katherine Crowson outlined in a tweet how this could work for autoregressive attention models. I have decided to include her idea in this repository for further exploration. One only has to account for two extra keyword arguments on training (null_cond_prob) and generation (cond_scale).

import torch
from dalle_pytorch import DiscreteVAE, DALLE

vae = DiscreteVAE(
    image_size = 256,
    num_layers = 3,
    num_tokens = 8192,
    codebook_dim = 1024,
    hidden_dim = 64,
    num_resnet_blocks = 1,
    temperature = 0.9
)

dalle = DALLE(
    dim = 1024,
    vae = vae,
    num_text_tokens = 10000,
    text_seq_len = 256,
    depth = 12,
    heads = 16,
    dim_head = 64,
    attn_dropout = 0.1,
    ff_dropout = 0.1
)

text = torch.randint(0, 10000, (4, 256))
images = torch.randn(4, 3, 256, 256)

loss = dalle(
    text,
    images,
    return_loss = True,
    null_cond_prob = 0.2  # firstly, set this to the probability of dropping out the condition, 20% is recommended as a default
)

loss.backward()

# do the above for a long time with a lot of data ... then

images = dalle.generate_images(
    text,
    cond_scale = 3. # secondly, set this to a value greater than 1 to increase the conditioning beyond average
)

images.shape # (4, 3, 256, 256)

That's it!

Ranking the generations

Train CLIP

import torch
from dalle_pytorch import CLIP

clip = CLIP(
    dim_text = 512,
    dim_image = 512,
    dim_latent = 512,
    num_text_tokens = 10000,
    text_enc_depth = 6,
    text_seq_len = 256,
    text_heads = 8,
    num_visual_tokens = 512,
    visual_enc_depth = 6,
    visual_image_size = 256,
    visual_patch_size = 32,
    visual_heads = 8
)

text = torch.randint(0, 10000, (4, 256))
images = torch.randn(4, 3, 256, 256)
mask = torch.ones_like(text).bool()

loss = clip(text, images, text_mask = mask, return_loss = True)
loss.backward()

To get the similarity scores from your trained Clipper, just do

images, scores = dalle.generate_images(text, mask = mask, clip = clip)

scores.shape # (2,)
images.shape # (2, 3, 256, 256)

# do your topk here, in paper they sampled 512 and chose top 32

Or you can just use the official CLIP model to rank the images from DALL-E

Scaling depth

In the blog post, they used 64 layers to achieve their results. I added reversible networks, from the Reformer paper, in order for users to attempt to scale depth at the cost of compute. Reversible networks allow you to scale to any depth at no memory cost, but a little over 2x compute cost (each layer is rerun on the backward pass).

Simply set the reversible keyword to True for the DALLE class

dalle = DALLE(
    dim = 1024,
    vae = vae,
    num_text_tokens = 10000,
    text_seq_len = 256,
    depth = 64,
    heads = 16,
    reversible = True  # <-- reversible networks https://arxiv.org/abs/2001.04451
)

Sparse Attention

The blogpost alluded to a mixture of different types of sparse attention, used mainly on the image (while the text presumably had full causal attention). I have done my best to replicate these types of sparse attention, on the scant details released. Primarily, it seems as though they are doing causal axial row / column attention, combined with a causal convolution-like attention.

By default DALLE will use full attention for all layers, but you can specify the attention type per layer as follows.

  • full full attention

  • axial_row axial attention, along the rows of the image feature map

  • axial_col axial attention, along the columns of the image feature map

  • conv_like convolution-like attention, for the image feature map

The sparse attention only applies to the image. Text will always receive full attention, as said in the blogpost.

dalle = DALLE(
    dim = 1024,
    vae = vae,
    num_text_tokens = 10000,
    text_seq_len = 256,
    depth = 64,
    heads = 16,
    reversible = True,
    attn_types = ('full', 'axial_row', 'axial_col', 'conv_like')  # cycles between these four types of attention
)

Deepspeed Sparse Attention

You can also train with Microsoft Deepspeed's Sparse Attention, with any combination of dense and sparse attention that you'd like. However, you will have to endure the installation process.

First, you need to install Deepspeed with Sparse Attention

$ sh install_deepspeed.sh

Next, you need to install the pip package triton. It will need to be a version < 1.0 because that's what Microsoft used.

$ pip install triton==0.4.2

If both of the above succeeded, now you can train with Sparse Attention!

dalle = DALLE(
    dim = 512,
    vae = vae,
    num_text_tokens = 10000,
    text_seq_len = 256,
    depth = 64,
    heads = 8,
    attn_types = ('full', 'sparse')  # interleave sparse and dense attention for 64 layers
)

Training

This section will outline how to train the discrete variational autoencoder as well as the final multi-modal transformer (DALL-E). We are going to use Weights & Biases for all the experiment tracking.

(You can also do everything in this section in a Google Colab, link below)

Open In Colab Train in Colab

$ pip install wandb

Followed by

$ wandb login

VAE

To train the VAE, you just need to run

$ python train_vae.py --image_folder /path/to/your/images

If you installed everything correctly, a link to the experiments page should show up in your terminal. You can follow your link there and customize your experiment, like the example layout below.

You can of course open up the training script at ./train_vae.py, where you can modify the constants, what is passed to Weights & Biases, or any other tricks you know to make the VAE learn better.

Model will be saved periodically to ./vae.pt

In the experiment tracker, you will have to monitor the hard reconstruction, as we are essentially teaching the network to compress images into discrete visual tokens for use in the transformer as a visual vocabulary.

Weights and Biases will allow you to monitor the temperature annealing, image reconstructions (encoder and decoder working properly), as well as to watch out for codebook collapse (where the network decides to only use a few tokens out of what you provide it).

Once you have trained a decent VAE to your satisfaction, you can move on to the next step with your model weights at ./vae.pt.

DALL-E Training

Training using an Image-Text-Folder

Now you just have to invoke the ./train_dalle.py script, indicating which VAE model you would like to use, as well as the path to your folder if images and text.

The dataset I am currently working with contains a folder of images and text files, arbitraily nested in subfolders, where text file name corresponds with the image name, and where each text file contains multiple descriptions, delimited by newlines. The script will find and pair all the image and text files with the same names, and randomly select one of the textual descriptions during batch creation.

ex.

๐Ÿ“‚image-and-text-data
 โ”ฃ ๐Ÿ“œcat.png
 โ”ฃ ๐Ÿ“œcat.txt
 โ”ฃ ๐Ÿ“œdog.jpg
 โ”ฃ ๐Ÿ“œdog.txt
 โ”ฃ ๐Ÿ“œturtle.jpeg
 โ”— ๐Ÿ“œturtle.txt

ex. cat.txt

A black and white cat curled up next to the fireplace
A fireplace, with a cat sleeping next to it
A black cat with a red collar napping

If you have a dataset with its own directory structure for tying together image and text descriptions, do let me know in the issues, and I'll see if I can accommodate it in the script.

$ python train_dalle.py --vae_path ./vae.pt --image_text_folder /path/to/data

You likely will not finish DALL-E training as quickly as you did your Discrete VAE. To resume from where you left off, just run the same script, but with the path to your DALL-E checkpoints.

$ python train_dalle.py --dalle_path ./dalle.pt --image_text_folder /path/to/data

Training using WebDataset

WebDataset files are regular .tar(.gz) files which can be streamed and used for DALLE-pytorch training. You Just need to provide the image (first comma separated argument) and caption (second comma separated argument) column key after the --wds argument. The ---image_text_folder points to your .tar(.gz) file instead of the datafolder.

$ python train_dalle.py --wds img,cap --image_text_folder /path/to/data.tar(.gz)

Distributed training with deepspeed works the same way, e.g.:

$ deepspeed train_dalle.py --wds img,cap --image_text_folder /path/to/data.tar(.gz) --fp16 --deepspeed

If you have containing shards (dataset split into several .tar(.gz) files), this is also supported:

$ deepspeed train_dalle.py --wds img,cap --image_text_folder /path/to/shardfolder --fp16 --deepspeed

You can stream the data from a http server or gloogle cloud storage like this:

$ deepspeed train_dalle.py --image_text_folder "http://storage.googleapis.com/nvdata-openimages/openimages-train-{000000..000554}.tar" --wds jpg,json --taming --truncate_captions --random_resize_crop_lower_ratio=0.8 --attn_types=full --epochs=2 --fp16 --deepspeed

In order to convert your image-text-folder to WebDataset format, you can make use of one of several methods. (https://www.youtube.com/watch?v=v_PacO-3OGQ here are given 4 examples, or a little helper script which also supports splitting your dataset into shards of .tar.gz files https://github.com/robvanvolt/DALLE-datasets/blob/main/wds_create_shards.py)

DALL-E with OpenAI's VAE

You can now also train DALL-E without having to train the Discrete VAE at all, courtesy to their open-sourcing their model. You simply have to invoke the train_dalle.py script without specifying the --vae_path

$ python train_dalle.py --image_text_folder /path/to/coco/dataset

DALL-E with Taming Transformer's VQVAE

Just use the --taming flag. Highly recommended you use this VAE over the OpenAI one!

$ python train_dalle.py --image_text_folder /path/to/coco/dataset --taming

Generation

Once you have successfully trained DALL-E, you can then use the saved model for generation!

$ python generate.py --dalle_path ./dalle.pt --text 'fireflies in a field under a full moon'

You should see your images saved as ./outputs/{your prompt}/{image number}.jpg

To generate multiple images, just pass in your text with '|' character as a separator.

ex.

$ python generate.py --dalle_path ./dalle.pt --text 'a dog chewing a bone|a cat chasing mice|a frog eating a fly'

Note that DALL-E is a full image+text language model. As a consequence you can also generate text using a dalle model.

$ python generate.py --dalle_path ./dalle.pt --text 'a dog chewing a bone' --gentext

This will complete the provided text, save it in a caption.txt and generate the corresponding images.

Docker

You can use a docker container to make sure the version of Pytorch and Cuda are correct for training DALL-E. Docker and Docker Container Runtime should be installed.

To build:

docker build -t dalle docker

To run in an interactive shell:

docker run --gpus all -it --mount src="$(pwd)",target=/workspace/dalle,type=bind dalle:latest bash

Distributed Training

DeepSpeed

Thanks to janEbert, the repository is now equipped so you can train DALL-E with Microsoft's Deepspeed!

You can simply replace any $ python <file>.py [args...] command with

$ deepspeed <file>.py [args...] --deepspeed

to use the aforementioned DeepSpeed library for distributed training, speeding up your experiments.

Modify the deepspeed_config dictionary in train_dalle.py or train_vae.py according to the DeepSpeed settings you'd like to use for each one. See the DeepSpeed configuration docs for more information.

DeepSpeed - 32 and 16 bit Precision

As of DeepSpeed version 0.3.16, ZeRO optimizations can be used with single-precision floating point numbers. If you are using an older version, you'll have to pass the --fp16 flag to be able to enable ZeRO optimizations.

DeepSpeed - Apex Automatic Mixed Precision.

Automatic mixed precision is a stable alternative to fp16 which still provides a decent speedup. In order to run with Apex AMP (through DeepSpeed), you will need to install DeepSpeed using either the Dockerfile or the bash script.

Then you will need to install apex from source. This may take awhile and you may see some compilation warnings which can be ignored.

sh install_apex.sh

Now, run train_dalle.py with deepspeed instead of python as done here:

deepspeed train_dalle.py \
    --taming \
    --image_text_folder 'DatasetsDir' \
    --distr_backend 'deepspeed' \
    --amp

Horovod

Horovod offers a stable way for data parallel training.

After installing Horovod, replace any $ python <file>.py [args...] command with

$ horovodrun -np <num-gpus> <file>.py [args...] --distributed_backend horovod

to use the Horovod library for distributed training, speeding up your experiments. This will multiply your effective batch size per training step by <num-gpus>, so you may need to rescale the learning rate accordingly.

Custom Tokenizer

This repository supports custom tokenization with YouTokenToMe, if you wish to use it instead of the default simple tokenizer. Simply pass in an extra --bpe_path when invoking train_dalle.py and generate.py, with the path to your BPE model file.

The only requirement is that you use 0 as the padding during tokenization

ex.

$ python train_dalle.py --image_text_folder ./path/to/data --bpe_path ./path/to/bpe.model

To create a BPE model file from scratch, firstly

$ pip install youtokentome

Then you need to prepare a big text file that is a representative sample of the type of text you want to encode. You can then invoke the youtokentome command-line tools. You'll also need to specify the vocab size you wish to use, in addition to the corpus of text.

$ yttm bpe --vocab_size 8000 --data ./path/to/big/text/file.txt --model ./path/to/bpe.model

That's it! The BPE model file is now saved to ./path/to/bpe.model and you can begin training!

Chinese

You can train with a pretrained chinese tokenizer offered by Huggingface ๐Ÿค— by simply passing in an extra flag --chinese

ex.

$ python train_dalle.py --chinese --image_text_folder ./path/to/data
$ python generate.py --chinese --text '่ฟฝ่€้ผ ็š„็Œซ'

Citations

@misc{ramesh2021zeroshot,
    title   = {Zero-Shot Text-to-Image Generation}, 
    author  = {Aditya Ramesh and Mikhail Pavlov and Gabriel Goh and Scott Gray and Chelsea Voss and Alec Radford and Mark Chen and Ilya Sutskever},
    year    = {2021},
    eprint  = {2102.12092},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{unpublished2021clip,
    title  = {CLIP: Connecting Text and Images},
    author = {Alec Radford, Ilya Sutskever, Jong Wook Kim, Gretchen Krueger, Sandhini Agarwal},
    year   = {2021}
}
@misc{kitaev2020reformer,
    title   = {Reformer: The Efficient Transformer},
    author  = {Nikita Kitaev and ลukasz Kaiser and Anselm Levskaya},
    year    = {2020},
    eprint  = {2001.04451},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{esser2021taming,
    title   = {Taming Transformers for High-Resolution Image Synthesis},
    author  = {Patrick Esser and Robin Rombach and Bjรถrn Ommer},
    year    = {2021},
    eprint  = {2012.09841},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{ding2021cogview,
    title   = {CogView: Mastering Text-to-Image Generation via Transformers},
    author  = {Ming Ding and Zhuoyi Yang and Wenyi Hong and Wendi Zheng and Chang Zhou and Da Yin and Junyang Lin and Xu Zou and Zhou Shao and Hongxia Yang and Jie Tang},
    year    = {2021},
    eprint  = {2105.13290},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@software{peng_bo_2021_5196578,
    author       = {PENG Bo},
    title        = {BlinkDL/RWKV-LM: 0.01},
    month        = {aug},
    year         = {2021},
    publisher    = {Zenodo},
    version      = {0.01},
    doi          = {10.5281/zenodo.5196578},
    url          = {https://doi.org/10.5281/zenodo.5196578}
}
@misc{su2021roformer,
    title   = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
    author  = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
    year    = {2021},
    eprint  = {2104.09864},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@inproceedings{ho2021classifierfree,
    title   = {Classifier-Free Diffusion Guidance},
    author  = {Jonathan Ho and Tim Salimans},
    booktitle = {NeurIPS 2021 Workshop on Deep Generative Models and Downstream Applications},
    year    = {2021},
    url     = {https://openreview.net/forum?id=qw8AKxfYbI}
}
@misc{crowson2022,
    author  = {Katherine Crowson},
    url     = {https://twitter.com/RiversHaveWings/status/1478093658716966912}
}
@article{Liu2023BridgingDA,
    title   = {Bridging Discrete and Backpropagation: Straight-Through and Beyond},
    author  = {Liyuan Liu and Chengyu Dong and Xiaodong Liu and Bin Yu and Jianfeng Gao},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2304.08612}
}

Those who do not want to imitate anything, produce nothing. - Dali

dalle-pytorch's People

Contributors

afiaka87 avatar asvskartheek avatar borzunov avatar grantcelley avatar haskie-lambda avatar janebert avatar jules-samaran avatar kobiso avatar lucidrains avatar mehdidc avatar muennighoff avatar naxalpha avatar robvanvolt avatar rom1504 avatar sdtblck avatar sic98 avatar snoop2head avatar sorrge avatar theocoombes avatar tillfalko avatar wcshin-git avatar ylsung avatar yuliang-liu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

dalle-pytorch's Issues

Cross-entropy does not get the right input dimensions

I am trying to train the DALLE model on the COCO dataset. The dataset setups are the same as #45. The setup for DALLE is like the follows:

import ...
from dalle.dalle_pytorch import DALLE
from transformers import BertTokenizer
import transformers

def train_dalle(epoch, model, loader, optimizer, tokenizer, device):
    model.train()
    
    for index, (x, y) in enumerate(loader):
        x = x.to(device)
        optimizer.zero_grad()
        
        tokenized = tokenizer(y, return_tensors='pt', padding=True)
        loss = model(
            tokenized['input_ids'].to(device), 
            x, 
            mask=tokenized['attention_mask'].bool().to(device), 
            return_loss=True
        )
        loss.backward()
        optimizer.step()
        
        print(f'Epoch {epoch:2} [{index:4}/{len(loader):4}]: Loss: {loss:.4}')


DEVICE = 'cuda'
EPOCHS = 30

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
dalle = DALLE(
    dim = 1024,
    vae = vae,                                 # automatically infer (1) image sequence length and (2) number of image tokens
    num_text_tokens = tokenizer.vocab_size,    # vocab size for text
    text_seq_len = 256,                        # text sequence length
    depth = 12,                                # should aim to be 64
    heads = 16,                                # attention heads
    dim_head = 64,                             # attention head dimension
    attn_dropout = 0.1,                        # attention dropout
    ff_dropout = 0.1                           # feedforward dropout
)

dalle.to(DEVICE)

optimizer = optim.Adam(dalle.parameters(), lr=0.001, weight_decay=0.0)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.98)

for epoch in range(1, EPOCHS+1):
    print(f'Epoch {epoch}/{EPOCHS}'.center(50, '='))
    print(f'Learning Rate: {optimizer.param_groups[0]["lr"]: .3}')
    train_dalle(epoch, dalle, train_loader, optimizer, tokenizer, DEVICE)
    scheduler.step()

But I got the error:

ValueError                                Traceback (most recent call last)
<ipython-input-8-f8cec408bdd6> in <module>
     27         print(f'Epoch {epoch}/{EPOCHS}'.center(50, '='))
     28         print(f'Learning Rate: {optimizer.param_groups[0]["lr"]: .3}')
---> 29         train_dalle(epoch, dalle, train_loader, optimizer, tokenizer, DEVICE)
     30         scheduler.step()
     31 

<ipython-input-7-7b03345d1110> in train_dalle(epoch, model, loader, optimizer, tokenizer, device)
     11             x,
     12             mask=tokenized['attention_mask'].bool().to(device),
---> 13             return_loss=True
     14         )
     15         loss.backward()

~/miniconda3/envs/dalle/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

/usr/itetnas04/data-scratch-01/zaishi/data/AllModalities/notebooks/dalle/dalle_pytorch.py in forward(self, text, image, mask, return_loss)
    450         labels = torch.cat((text[:, 1:], offsetted_image), dim = 1)
    451 
--> 452         loss = F.cross_entropy(rearrange(logits, 'b n c -> b c n'), labels)
    453         return loss

~/miniconda3/envs/dalle/lib/python3.6/site-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
   2466     if size_average is not None or reduce is not None:
   2467         reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2468     return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
   2469 
   2470 

~/miniconda3/envs/dalle/lib/python3.6/site-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
   2272         if target.size()[1:] != input.size()[2:]:
   2273             raise ValueError('Expected target size {}, got {}'.format(
-> 2274                 out_size, target.size()))
   2275         input = input.contiguous()
   2276         target = target.contiguous()

ValueError: Expected target size (8, 1042), got torch.Size([8, 1041])

I am not sure what is causing this error.

confusion

Forgive my confusion. Is there a guide on how to use this? I am used to systems like darknet where I have images and annotations, then I train on them. Finally I have a cfg and weights file which I can use to find examples in an image.

I ran both scripts (train DALL-E and train VAE) but nothing has changed in my folder. How do I get from here to the point where I can give it a sentence and have it produce an image for me?

VAE does not generate reasonable images

I am trying to train the VAE for the COCO dataset. But the generated output looks strange (input above, output below):

image

My code:

import os
import json
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torch
from torch import optim
from torch.utils.data import Dataset, DataLoader
from dalle.dalle_pytorch import DiscreteVAE

IMAGE_SIZE = 256
TRAIN_ANNOT_PATH = '...'
VAL_ANNOT_PATH = '...'
TRAIN_IMAGE_PATH = '...'
VAL_IMAGE_PATH = '...'

class COCODataset(Dataset):
    def __init__(self, annot_path, image_path, image_size=256):
        self._image_path = image_path
        
        with open(annot_path) as file:
            json_file = json.load(file)
        
        self._image_size = image_size
        self._metadata = json_file['images']
        self._captions = {entry['image_id']: entry for entry in json_file['annotations']}
    
    def __getitem__(self, index):
        metadata = self._metadata[index]
        caption = self._captions[metadata['id']]
        image_path = os.path.join(self._image_path, metadata['file_name'])
        image = Image.open(image_path).convert('RGB')
        image = self._crop_image(image)
        x = np.asarray(image) / 255
        return torch.Tensor(x).permute(2, 0, 1), caption['caption']
    
    def _crop_image(self, image):
        width, height = image.size
        min_length = min(width, height)
        
        # center crop
        left = (width - min_length)/2
        top = (height - min_length)/2
        right = (width + min_length)/2
        bottom = (height + min_length)/2
        image = image.crop((left, top, right, bottom))
        
        # resize
        image = image.resize((self._image_size, self._image_size))
        
        return image
    
    def __len__(self):
        return len(self._metadata)
        

def tensor_to_image(tensor):
    tensor = tensor.permute(1, 2, 0)
    arr = ((tensor.numpy())*255).astype(np.uint8)
    return Image.fromarray(arr)

BATCH_SIZE = 16
train_dataset = COCODataset(TRAIN_ANNOT_PATH, TRAIN_IMAGE_PATH, image_size=IMAGE_SIZE)
val_dataset = COCODataset(VAL_ANNOT_PATH, VAL_IMAGE_PATH, image_size=IMAGE_SIZE)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=4)

def train_vae(epoch, model, loader, optimizer, device):
    model.train()
    
    for index, (x, _) in enumerate(loader):
        x = x.to(device)
        optimizer.zero_grad()
        
        loss = model(x, return_loss=True)
        loss.backward()
        optimizer.step()
        
        print(f'Epoch {epoch:2} [{index:4}/{len(loader):4}]: Loss: {loss:.4}')

DEVICE = 'cuda'
vae = DiscreteVAE(
    image_size = IMAGE_SIZE,
    num_layers = 3,          # number of downsamples - ex. 256 / (2 ** 3) = (32 x 32 feature map)
    num_tokens = 8192,       # number of visual tokens. in the paper, they used 8192, but could be smaller for downsized projects
    codebook_dim = 512,      # codebook dimension
    hidden_dim = 64,         # hidden dimension
    num_resnet_blocks = 1,   # number of resnet blocks
    temperature = 0.9,       # gumbel softmax temperature, the lower this is, the harder the discretization
    straight_through = False # straight-through for gumbel softmax. unclear if it is better one way or the other
).to(DEVICE)

EPOCHS = 2
optimizer = optim.Adam(vae.parameters(), lr=0.001, weight_decay=0.0)

for epoch in range(1, EPOCHS+1):
    print(f'Epoch {epoch}/{EPOCHS}'.center(50, '='))
    train_vae(epoch, vae, train_loader, optimizer, DEVICE)

vae.eval()
x, y = val_dataset[0]
codes = vae.get_codebook_indices(x.unsqueeze(0).cuda())
generated = vae.decode(codes)[0].detach().cpu()
plt.imshow(tensor_to_image(x))
plt.show()
plt.imshow(tensor_to_image(generated))
plt.show()

I have trained the VAE for 2 epochs (10k+ steps). The training loss decreased from 3.0+ to about 0.07. The output looks like the monotone color even on training set images. Am I doing anything wrong?

Attention Layer

How are you doing the sparse attention mentioned in the blog?
"The attention mask at each of its 64 self-attention layers allows each image token to attend to all text tokens. DALLยทE uses the standard causal mask for the text tokens, and sparse attention for the image tokens with either a row, column, or convolutional attention pattern, depending on the layer. "

You are alternating between dense and sparse attention instead of alternating between all sparse.

Can you explain that part of the code?

Also, how are you padding the text if input text < 256 tokens.? Are you changing sparse attention for each input? or how are you masking text tokens with different padding for left to right generation with the input of different text in a multi-batch scenario

Image and description fetcher

Hi!

I put together a quick script that fetches images from Bing using descriptions from FFHQ.

It works in the following fashion:

  1. Fetch descriptions from FFHQ
  2. Download one image for each description from Bing
  3. Resize images to specified dimensions.
  4. Save in easy to use format: filename:description

Examples of images with corresponding information:
image
1069164182.png:A group of young children walking

image
6458676875.png:A man holding flowers runs from

It takes 3.46 hours to scrape 10 000 image+desc pairs. What do you think? Could the fetched data be suitable to use for training DALL-E?

Missing KL diverge term in the loss for the VAE?

Thanks for the repo! Was skimming around and found

loss = F.mse_loss(img, out)

where I'm a bit nervous that you may be missing the KL Divergence term to the prior in the -ELBO, which in this case would be a diffuse uniform.

  • Eric's tutorial here
  • Eric's implementation here implements this under KL_qp
  • Another random implementation here
  • My own implementation when I was deriving this today. Though it skips a few steps of derivation, could have been clearer

Without this term you're potentially going to see "index collapse", with many of the dimensions going unused.

In any case, in my preliminary experiments I'm finding this version a little bit finicky to tune so far, the VQVAE version worked more more out of the box and trained faster for me, though only after a bit gross data-dependent init scheme.

Tensor device error when computing kl_div in VAE

Hello, @lucidrains :)

I saw 706f06d this commit, and I got below error when I trained with it.

Traceback (most recent call last):
  File "/root/.pycharm_helpers/pydev/pydevd.py", line 1434, in _exec
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "/root/.pycharm_helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/home/shared/workspace/torch_research/text-to-image/dalle-pytorch/train_VAE.py", line 237, in <module>
    main()
  File "/home/shared/workspace/torch_research/text-to-image/dalle-pytorch/train_VAE.py", line 164, in main
    loss = get_loss(images)
  File "/home/shared/workspace/torch_research/text-to-image/dalle-pytorch/train_VAE.py", line 145, in get_loss
    loss = F.smooth_l1_loss(images, recons) + vae(images, return_loss=True)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/shared/workspace/torch_research/text-to-image/dalle-pytorch/models/model_arch.py", line 173, in forward
    kl_div = (qy * (log_qy - g)).sum(dim = (1, 2)).mean()
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

Reason is the device of g construction is not adaptable to the current device.

g = torch.log(torch.Tensor([1. / num_tokens]))

This can be solved by

device = img.device
....
g = torch.log(torch.tensor([1. / num_tokens], device=device))

Please check if it is a correct way to fix :)

Influence of batch size on training convergence in CLIP

I asked the same question in the official CLIP repository.

According to the paper they use a large batch size of ~32k samples which means that the raw untrained network initially has a chance of ~1/32k of predicting the correct pair.

I am wondering, how the convergence/learning process would differ, if instead a binary classification problem was formulated and the network would be presented with matching text/image pairs and non-matching pairs alternatingly and be tasked with predicting whether those samples are actually in agreement or not.

In other words, what does the softmax over ~32k entries prior to cross entropy calculation bring to the table which cannot be achieved more conveniently by using sigmoid and binary cross entropy to predict matching/non-matching pairs. As a side effect this would also abolish the dependence on the batch size which seems to be rather crucial?

Easy to use dataset for DALL-E

The COCO dataset is a high-quality dataset both in terms of images and text. Each image has multiple captions and it consists of around 100 000 images. The script takes around 40 minutes to complete on Google Colab Pro. I have created a Google Colab which does the following:

  1. Fetches the COCO images and captions
  2. Allows the user to specify image dimensions which the images will be resized to
  3. Stores information in two files for easy access

od-captions.txt (<image_path> : <image_caption>)

train2017/000000203564.jpg : A bicycle replica with a clock as the front wheel.
train2017/000000322141.jpg : A room with blue walls and a white sink and door.
train2017/00000016977.jpg : A car that seems to be parked illegally behind a legally parked car
train2017/000000106140.jpg : A large passenger airplane flying through the air.
train2017/000000571635.jpg : A bathroom with a toilet, sink, and shower.

od-captionsonly.txt (<image_caption>)

A bicycle replica with a clock as the front wheel.
A room with blue walls and a white sink and door.
A car that seems to be parked illegally behind a legally parked car
A large passenger airplane flying through the air.
A bathroom with a toilet, sink, and shower.

Here is an example image and caption:
image
The man at bat readies to swing at the pitch while the umpire looks on.

I have written this hoping it will be somewhat compatible with @htoyryla's code. I think it should work out-of-the-box. @htoyryla's training script is also included in the Colab. Feel free to use this to generate a dataset for DALL-E and/or to continue on @htoyryla's work.

about version of vq-vae

Hello, I'm wondering which vq-vae model you are using. Is it vq-vae-1? or 2
Thanks in advance

Continue training on Dalle model

I haven't found a function for continuing training on the Dalle model - is it analogue to loading VAE in the train_dalle.py?

vae_path = Path(args.vae_path)
assert vae_path.exists(), 'VAE model file does not exist'

loaded_obj = torch.load(str(vae_path))

vae_params, weights = loaded_obj['hparams'], loaded_obj['weights']

vae = DiscreteVAE(**vae_params)
vae.load_state_dict(weights)

hparams missing

  File "train_dalle.py", line 43, in <module>
    vae_params, weights = loaded_obj['hparams'], loaded_obj['weights']
KeyError: 'hparams'

After training the vae i wanted to train dall_e next, but after successfully loading the VAE the hparams is not defined, see terminal output above. Is this a problem of the trained VAE or is the bug in train_dalle.py?

Thanks in advance! Great project! :-)

Can't install successfully

Because of the network and my country's policy, I can't run the command "git clone https://github.com/microsoft/DeepSpeed.git /tmp/Deepspeed" successfully in my Unbuntu,

picture-1

so I use VPN to download the file "DeepSpeed-master" (picture-2) in my Windows 10 and then copy into the linux,

picture-2

but there is a question described as picture-3

picture-3

picture-4

, I don't know why the "python command not found" and I am certain that the setup.py is in the same directory with the file "install.sh".
I only download the file "DeepSpeed-master", is that any other file needed to be download?

ignore_index defined but not used

ignore_indexis defined but not used.
I assume the purpose is to ignore the padding from the text. If that's the case, then instead of using:
F.cross_entropy(rearrange(logits, 'b n c -> b c n'), labels)
We have to define a nn.Module Cross_entropy
self.loss = nn.CrossEntropyLoss(reduction='none', ignore_index=self.ignore_index)
and then
self.loss(rearrange(logits, 'b n c -> b c n'), labels)

Finally the default value of ignore_index should be 0. Since thats the pad_id of the default tokenizer.

CLIP reranks images not visual tokens

I think according to the blog of DALL_E, CLIP is used to rerank the decoded images instead of the visual tokens. Therefore, a CLIP model pre-trained on raw images should be used.

CLIP different text and image embedding sizes

As per the CLIP paper's pseudo-code implementation, it has the feature to use different embedding sizes for image encoder and text encoder. To do this -

  1. They have two additional linear layers - One for Text Encoded to Text embedded and another for images.
  2. These multimodal embeddings are L2 Normalized
  3. The scaling factor used in the current implementation (dim**-0.5) is different from pseudo-code, which uses a learned temperature parameter.

Partial Pseudo-Code -

I_f = image_encoder(I) #[n, d_i]
T_f = text_encoder(T) #[n, d_t]
# joint multimodal embedding [n, d_e]
I_e = l2_normalize(np.dot(I_f, W_i), axis=1)
T_e = l2_normalize(np.dot(T_f, W_t), axis=1)
# scaled pairwise cosine similarities [n, n]
logits = np.dot(I_e, T_e.T) * np.exp(t)

Typical training results

Hi, great repo and thanks for sharing your work!

I'm trying your bird dataset example from the Colab with OpenAI's pretrained VAE. I wasn't able to get meaningful results so far on the Colab or on my own vm (Tesla T4 GPU).

13 epochs in of train_dalle.py and only seeing these kinds of results:
image

On my vm I ran $ python train_dalle.py --image_text_folder /parent/to/birds/dataset/directory without changing any of the code (only replaced wandb with another experiment tracking framework, but I doubt that should make a difference)

Should the bird dataset work better with the pretrained VAE? Can you share some results or common training parameters/times/number of epochs?

Deepspeed sparse attention error

Hello!
I've been using DeepSpeed sparse attention pretty well, but I got this error since tag/0.0.59 release.

Traceback (most recent call last):
  File "train_DALLE.py", line 300, in <module>
    main()
  File "train_DALLE.py", line 219, in main
    loss = dalle(caption_tokens, images, mask=mask, return_loss=True)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/shared/workspace/torch_research/text-to-image/dalle-pytorch/models/model_arch.py", line 421, in forward
    out = self.transformer(tokens, mask = mask)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/dalle_pytorch/transformer.py", line 106, in forward
    return self.layers(x, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/dalle_pytorch/reversible.py", line 139, in forward
    x = x + f(x, **f_args)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/dalle_pytorch/transformer.py", line 34, in forward
    return self.fn(self.norm(x), **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/dalle_pytorch/attention.py", line 314, in forward
    if self.noncausal_attn_len:
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 779, in __getattr__
    type(self).__name__, name))
torch.nn.modules.module.ModuleAttributeError: 'SparseAttention' object has no attribute 'noncausal_attn_len'

I guess the error occurs because self.noncausal_attn_len is removed from Attention class in 95ce537 this update.

I hope DeepSpeed sparse attention would continue to be updated, because it reduces model size by half, speeds up 2x, and final performance seems better :)

DALL-E Image Embedding

A token is any symbol from a discrete vocabulary; for humans, each English letter is a token from a 26-letter alphabet. DALLยทEโ€™s vocabulary has tokens for both text and image concepts. Specifically, each image caption is represented using a maximum of 256 BPE-encoded tokens with a vocabulary size of 16384, and the image is represented using 1024 tokens with a vocabulary size of 8192.
The images are preprocessed to 256x256 resolution during training. Similar to VQVAE, each image is compressed to a 32x32 grid of discrete latent codes using a discrete VAE that we pretrained using a continuous relaxation. We found that training using the relaxation obviates the need for an explicit codebook, EMA loss, or tricks like dead code revival, and can scale up to large vocabulary sizes.

We can use openAI CLIP implementation to filter the good samples, but I would assume they didn*t used it to create the embedding.
So therefore we could assume they used some kind of VQ-VAE? For example https://github.com/openai/vdvae or https://github.com/NVlabs/NVAE ?

So this GIT should have 2-step Training
Step 1 - Pretrained a autoencoder to tokenize the images. We could go small first and do it with a 16x16 Embedding and a relatively low vocab size. (2k-4k?)
Step 2 - Train the Decoder-Transformer. Here we should have a preprocessing step to convert the image-text pairs to tokens. Some Huggingface tokenizer for Text and the encoder of VQ-VAE for the image.

We hope that someone will offer a pretrained model weights for CLIP to remove bad samples during Inference. If it was trained on something like the Microsoft Dataset, then it should be general enough for most usecases.

Some Open Questions:

  • They use Sparse Attention for the Image Part. We could just use full-attention for the whole network for now or go full sparse?
  • If its not a VQ-VAE, which GANs work well with discrete latent values?
  • If its VQ-VAE, its some kind of Hierarchical one. Does DALL-E model the first latent value and the rest is just randomly sampled during reconstructions?

Impact of the convolutional Encoder/Decoder

I am wondering about the actual importance of the encoder/decoder architecture for training the dVAE. How much of a factor is size here? Phil you are using a very simplistic and pure but shallow encoder without any normalization layer interleaved. Others use quite intricate structures with all sorts of nifty tricks in between layers. The only common factor seems to be that they all use resnet-like skip layers in one form or the other.

  • What would be a reason to leave out norm layers in the encoder/decoder?
  • What role does channel depth play? (Your encoder uses the same channel dim along the entire encoder)
  • In the official Dall-E implementation they even go so far as to use maxpooling ops which are known to neglect vital information

Trained on COCO

I trained a 6 layer 8 head model on the COCO dataset (15 epochs, 82k images, 1x Tesla V100, batch size 12, no CLIP filtering). The results look as follows (conditioned on text "A car on the street"):
image
image
image
image
image
image
image
image
There are more or less some structures but the description "A car on the street" does not meet well (I can only recognize a "car" in a few samples). Maybe the model and the dataset are still too small to learn the text condition.

General questions to the algorithmic understanding

Been trying to get a grasp of the DALLE code recently. However, there are a couple of things, I cant quite wrap my head around and since the paper is not published yet, I was wondering, if we can maybe clarify them here.

So there is the VAE training which basically features the codebook in the bottleneck and is trained a priori.

Next, Dalle receives text and image pairs, embeds them and adds positional encodings individually to both modalities.
However, the image data is not embedded like e.g. in ViT but by downsampling it via the Encoder of the VAE (without accumulating gradients), argmax search within the feature dimension across the downsampled image patches and finally indexing into the previously trained codebook.

The resulting representations of both modalities are then concatenated along the token dimension. And while every word of the text input is one token, the height and width of the VAE-encoded image yields the number of image tokens.

The combined embedding is then passed into a single transformer which calculates self-attention not only intra-modal but also across both modalities if I am not mistaken.

A masking of the form

mask = torch.ones_like(text).bool()

results in unmasked attention calculation, right?

A final Mlp maps the transformer output to all potential token possibilities (both text and image).

Then I dont understand the masking

  logits_mask = (
         ((seq_range >= (text_seq_len - 1)) & (logits_range < num_text_tokens)) |
         ((seq_range < (text_seq_len - 1)) & (logits_range >= num_text_tokens)) |
         ((seq_range != (seq_len - 1)) & (logits_range >= (total_tokens - 1)))
     )

shouldnt there be one more row concerned with the text input and one less row for the image input?

For the following config with 3 text input tokens

vae = DiscreteVAE(
    image_size = 64,
    num_layers = 5,
    num_tokens = 10,
    codebook_dim = 256,
    hidden_dim = 64,
    num_resnet_blocks = 1,
    temperature = 0.9
).cuda()

dalle = DALLE(
    dim = 256,
    vae = vae,                  # automatically infer (1) image sequence length and (2) number of image tokens
    num_text_tokens = 4,    # vocab size for text
    text_seq_len = 3,         # text sequence length
    depth = 6,                 # should aim to be 64
    heads = 8,                 # attention heads
    dim_head = 64,              # attention head dimension
    attn_dropout = 0.1,         # attention dropout
    ff_dropout = 0.1            # feedforward dropout
).cuda()

text = torch.randint(0, 4, (1, 3)).cuda()
images = torch.randn(1, 3, 64, 64).cuda()
mask = torch.ones_like(text).bool().cuda()

the mask looks like this

tensor([[[False, False, False, False,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [False, False, False, False,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False,  True],
         [ True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False,  True],
         [ True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False,  True],
         [ True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False,  True],
         [ True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False, False]]])

shouldt it be?

tensor([[[False, False, False, False,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [False, False, False, False,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [False, False, False, False,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False,  True],
         [ True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False,  True],
         [ True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False,  True],
         [ True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False, False]]])

The purpose of the masking is so that image tokens dont contribute to the predictions of text and vice versa.
The code proceeds by constructing labels from the text integer tokens and the VAE image embedding pixels by using the codebook indices.

But what is it we are actually trying to predict with this classification task here?
It is a 2d CrossEntropyLoss where for each token (either text or image) we are trying to predict ... exactly what?
Some I am missing the intuition here I guess...

And then, why is the label vector neglecting the very first label entry but using the EOS enty?

    **loss = F.cross_entropy(rearrange(logits, 'b n c -> b c n'), labels[:, 1:])**

Maybe someone can help me (and others) in understanding better whats going on here. Thank you in advance

Simple training fails (?)

Hi,

I made a toy synthetic dataset with various shapes and their text descriptions and trained DALL-E on it: https://gist.github.com/sorrge/ef0b4cbf53c3496a09596d53663655f9

The dataset is deterministic (no noise, no variation), with 9216 samples in total. Pictures are 32x32, and texts are up to 7 words. VAE is trained well w.r.t. reconstruction. I then train DALL-E on the whole dataset, so the task amounts to memorizing 9216 sequences of up to 23 tokens. I suppose that this should be easy enough for a network of this size.

However, it can't reach 100% accuracy, even after visibly converging after 500 epochs. The frequency of getting a perfect token reconstruction is around 55%.

I've investigated where the mistakes are made. It turns out that almost all errors happen in the first image token (see plots at the bottom of the notebook). Moreover, the errors are strongly correlated with the presence of masked text tokens at the end (where the text is shorter than the 7 token maximum). It never makes mistakes when the text has no masked tokens. This leads me to believe that there is something wrong with how I use masking, or there is a bug related to masking.

Can you help me to track down the problem? Also, this notebook can serve as a tutorial for DALL-E training, after the problem is figured out. It is self-contained, and the whole run takes about 4 hours.

Thanks!

Axial Positional Embedding defined wrong?

I think Axial Positional Embedding is initialized wrong.

self.image_pos_emb = AxialPositionalEmbedding(dim, axial_shape = (image_size, image_size))
where image_size is the image_size we pass to the VAE. Default: 256, therefore the max_length is 256*256.
It wouldn't make a big difference, but this should be set to the discrete latent space size of the vqvae. Ex. 16 (16x16= 256)

DataParallel Support for DallE (RuntimeError: Unconvertible NCCL type)

Hello! I was trying to train DallE with DataParallel on a multi GPU machine but encountered the following error:

Traceback (most recent call last):
  File "minimal_reproduction.py", line 67, in <module>
    main()
  File "minimal_reproduction.py", line 62, in main
    loss = dalle(caption_ids, image_tokens, mask = mask, return_loss = True)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 160, in forward
    replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 165, in replicate
    return replicate(module, device_ids, not torch.is_grad_enabled())
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/nn/parallel/replicate.py", line 103, in replicate
    buffer_copies_not_rg = _broadcast_coalesced_reshape(buffers_not_rg, devices, detach=True)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/nn/parallel/replicate.py", line 67, in _broadcast_coalesced_reshape
    return comm.broadcast_coalesced(tensors, devices)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/nn/parallel/comm.py", line 56, in broadcast_coalesced
    return torch._C._broadcast_coalesced(tensors, devices, buffer_size)
RuntimeError: Unconvertible NCCL type

Does this implementation of DallE support multi GPU training with DataParallel? Or is the error between the keyboard and the chair ๐Ÿ˜„? I tested this with dalle-pytorch==0.0.59.
I didn't encounter this on my local machine, but that only has a single GPU so not exactly apples-to-apples.


Here's a gist with a training script to reproduce the error: https://gist.github.com/johnmccain/ae2d9f376abff90dfd5cd4c9b36fa7e5

Environment (from https://github.com/pytorch/pytorch/blob/master/torch/utils/collect_env.py):

PyTorch version: 1.7.1+cu110
Is debug build: False
CUDA used to build PyTorch: 11.0
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.1 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: Could not collect
CMake version: version 3.16.3

Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 10.2.89
GPU models and configuration:
GPU 0: Tesla V100-SXM2-16GB
GPU 1: Tesla V100-SXM2-16GB
GPU 2: Tesla V100-SXM2-16GB
GPU 3: Tesla V100-SXM2-16GB
GPU 4: Tesla V100-SXM2-16GB
GPU 5: Tesla V100-SXM2-16GB
GPU 6: Tesla V100-SXM2-16GB
GPU 7: Tesla V100-SXM2-16GB

Nvidia driver version: 450.80.02
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] dalle-pytorch==0.0.59
[pip3] numpy==1.17.4
[pip3] torch==1.7.1+cu110
[pip3] torchaudio==0.7.2
[pip3] torchvision==0.8.2+cu110
[conda] Could not collect

Pretrained models

The image generation takes a good amount of time because of the training, as far as I understand.

When the pretrained models are released, how big is the size of the pretrained model, and how long will image generation take then? And how much computing power?

To my understanding, the training of the model usually takes a long time, but with pretrained models the results would be there "in an instant" more or less - or is there more to it?

Best regards

missing clip?

I was able to train the vae, but I am now stuck on the dalle training step. It is telling me:
loading state dict from vae.pt
preloading txt files...
processing 72496/72496
done preloading txt onehots
Train Epoch: 1 [0/72496 (0%)] Loss: 0.569296
Traceback (most recent call last):
File "dalle.py", line 615, in
main()
File "dalle.py", line 608, in main
train_dalle(vae, args)
File "dalle.py", line 331, in train_dalle
if clip is not None:
UnboundLocalError: local variable 'clip' referenced before assignment

Results

I have trained DiscreteVEE on 128x128 FFHQ dataset. using this configration:

vae = DiscreteVAE(
    num_layers = 2,
    num_tokens = 4096,
    dim = 1024,
    hidden_dim = 256
)

Here are the results after 3 epochs (top original, bottom reconstructed):

image
image
image

Pretrained models

I myself have trained a solid vae-model and am currently working on the dalle-model. How come noone has posted a pretrained model yet? It seems like a waste of ressources if everyone has to start vom scratch?

I'm on day 3 of training and still at a loss of 2,5 - i will have to increase the total number of images and include sparse attention next, as the 2,5 looks like a plateu with the current settings.

Ready to use?

Am I right in assuming this repo is not ready for generating images?

Colab?

I have a dataset but not the coding skills to train on my own. Would you mind sharing a Colab notebook? Thanks in advance.

performance optimization in train_vae.py and train_dalle.py scripts

Wondering if anyone has any configuration tips that improve performance of vae and dalle training. I'm using the coco dataset that has 118287 images. Currently training processes 500 images per minute, which would take about 4 hours to run a single epoch. I have an RTX 3090, and am only utilizing 25% of it's VRAM. I have the train_vae.py script configured to run batch size of 10 with 32 workers

dl = DataLoader(ds, BATCH_SIZE, num_workers=32, shuffle = True)

I've found increasing increasing my batch size (all the way to 56) increased the GPU utilization but ran 5x slower. Wondering if there are other settings I can play around with to improve performance or just trim my dataset down considerably.

Trained DALL-E result

Hi, you showed good result of training.

I tried to use flickr-8k, 12 epochs, but result of inference is a pink background.. I can't understand why.

How many epochs and what the data do you use?

Thanks

RuntimeError: dlsym unable to load function

Hi when I was running the example code about CLIP, I got the following error. It seems to me something of triton is not working. I went through a lot of websites but there were not many discussions about this. Could anyone please show me how to fix this? Many thanks!

RuntimeError Traceback (most recent call last)
in
20 mask = torch.ones_like(text).bool().to(device)
21
---> 22 loss = clip(text, images, text_mask = mask, return_loss = True)
23 loss.backward()

~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
--> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),

~/anaconda3/lib/python3.7/site-packages/dalle_pytorch/dalle_pytorch.py in forward(self, text, image, text_mask, return_loss)
213 image_emb += self.visual_pos_emb(torch.arange(image_emb.shape[1], device = device))
214
--> 215 enc_text = self.text_transformer(text_emb, mask = text_mask)
216 enc_image = self.visual_transformer(image_emb)
217

~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
--> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),

~/anaconda3/lib/python3.7/site-packages/dalle_pytorch/transformer.py in forward(self, x, **kwargs)
200
201 def forward(self, x, **kwargs):
--> 202 return self.layers(x, **kwargs)

~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
--> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),

~/anaconda3/lib/python3.7/site-packages/dalle_pytorch/reversible.py in forward(self, x, **kwargs)
137
138 for (f, g), (f_args, g_args) in layers_and_args:
--> 139 x = x + f(x, **f_args)
140 x = x + g(x, **g_args)
141 return x

~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
--> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),

~/anaconda3/lib/python3.7/site-packages/dalle_pytorch/transformer.py in forward(self, x, **kwargs)
34
35 def forward(self, x, **kwargs):
---> 36 return self.fn(self.norm(x), **kwargs)
37
38 class GEGLU(nn.Module):

~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
--> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),

~/anaconda3/lib/python3.7/site-packages/dalle_pytorch/transformer.py in forward(self, x, mask)
158 attn_mask[ind, ind] = 0.
159
--> 160 out = self.attn_fn(q, k, v, attn_mask = attn_mask, key_padding_mask = key_pad_mask)
161 out = rearrange(out, 'b h n d -> b n (h d)')
162 out = self.to_out(out)

~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
--> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),

~/anaconda3/lib/python3.7/site-packages/deepspeed/ops/sparse_attention/sparse_self_attention.py in forward(self, query, key, value, rpe, key_padding_mask, attn_mask)
150
151 # attention scores
--> 152 attn_output_weights = sparse_dot_sdd_nt(query, key)
153 attn_output_weights = sparse_softmax(
154 attn_output_weights,

~/anaconda3/lib/python3.7/site-packages/deepspeed/ops/sparse_attention/matmul.py in call(self, a, b)
744 db_packs,
745 self.bench,
--> 746 time_db)
747 self.time_c = time_c[0]
748 self.time_da = time_da[0]

~/anaconda3/lib/python3.7/site-packages/deepspeed/ops/sparse_attention/matmul.py in forward(ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, c_lut, c_num_locks, c_width, c_packs, c_bench, c_time, da_lut, da_num_locks, da_width, da_packs, da_bench, da_time, db_lut, db_num_locks, db_width, db_packs, db_bench, db_time)
548 c_packs,
549 c_bench,
--> 550 c_time)
551 # save for backward
552 ctx.save_for_backward(a, b)

~/anaconda3/lib/python3.7/site-packages/deepspeed/ops/sparse_attention/matmul.py in _sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, luts, num_locks, widths, packs, bench, time)
226 width - off_width),
227 AS0],
--> 228 bench=bench)
229 total = total + current if bench else None
230 time[0] = total

~/anaconda3/lib/python3.7/site-packages/triton/kernel.py in call(self, *args, **kwargs)
84 params = pack(self.tys, *[x.data_ptr() if isinstance(x, torch.Tensor) else x for x in args])
85 torch.cuda.synchronize()
---> 86 torch.ops.triton.launch_kernel(self.op_id, device, params)

RuntimeError: dlsym unable to load function

codebook keeps getting trained during DALLE training

self.image_emb = vae.codebook

right now, neighter an apropritate no_grad call nor manually disabling codebook.requires_grad_(False) prevents the pretrained VAE codebook from getting further adjusted during the subsequent DALLE training procedure.

I am in doube if this is meant to be the case.

Training of the VAE encoder part is rightfully disabled by the associated decorator

but this does not pertain to the codebook. Maybe I am missing something here? Just wanted to draw the attention to this point

CLIP: Loss in implementation vs. in paper

Hey!

In your implementation of CLIP, F.cross_entropy is called once with an implicit F.log_softmax
being computed along dimension 1 (because F.cross_entropy internally calls F.log_softmax(input, dim=1)):

sim = einsum('i d, j d -> i j', text_latents, image_latents) * temp
labels = torch.arange(b, device = device)
loss = F.cross_entropy(sim, labels)
return loss

Whereas in the CLIP-Paper, they compute the cross-entropy-loss on both dimension 0 and 1 and then take an average:

paper-loss

So unless I am missing something here, one possible solution would just be something like

sim = einsum('i d, j d -> i j', text_latents, image_latents) * temp 
labels = torch.arange(b, device = device) 
loss_t = F.cross_entropy(sim, labels) 
loss_i = F.cross_entropy(sim.T, labels) 
return (loss_t + loss_i) / 2. 

which, in this case, with sim being of shape (n, n), should work?

Arbitrary vocab size in text embedding might lead to undesired incomes.

The BPE Tokenizer which you provide ( the one from CLIP) has a vocab size around 40k.
When training DALL-E on your own dataset, the default parameter of the text embedding is 10k.
For most cases, this won't lead to any problems, however depending on the words you might have tokens which have a id over 10k. One example is "mental health", which has a token Id around 20k.
I went through the code & haven't seen anything deal with this problem. If thats the case, here are my suggestions:

  1. Just set the default value of the text embedding size to the vocab size of the tokenizer. Or at least warn the user in the README.md
  2. Before training the model, go through the data loader and get the largest Id from the dataset and then set the model to that size.
  3. A more elegant solution would be to build a set with all the ids in the dataset and define the size of the embedding accordingly. Disadvantage of this solution is that we need a extra map to translate between the ids, since the ids won't necessary match with the embedding

I could implement solution 2 or 3, if needed. But I think just making the user aware of the problem, should be enough.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.