Giter Club home page Giter Club logo

fast-dit's Introduction

Scalable Diffusion Models with Transformers (DiT)
Improved PyTorch Implementation

Paper | Project Page | Run DiT-XL/2 Hugging Face Spaces Open In Colab

DiT samples

This repo features an improved PyTorch implementation for the paper Scalable Diffusion Models with Transformers.

It contains:

Setup

First, download and set up the repo:

git clone https://github.com/chuanyangjin/fast-DiT.git
cd DiT

We provide an environment.yml file that can be used to create a Conda environment. If you only want to run pre-trained models locally on CPU, you can remove the cudatoolkit and pytorch-cuda requirements from the file.

conda env create -f environment.yml
conda activate DiT

Sampling Hugging Face Spaces Open In Colab

More DiT samples

Pre-trained DiT checkpoints. You can sample from our pre-trained DiT models with sample.py. Weights for our pre-trained DiT model will be automatically downloaded depending on the model you use. The script has various arguments to switch between the 256x256 and 512x512 models, adjust sampling steps, change the classifier-free guidance scale, etc. For example, to sample from our 512x512 DiT-XL/2 model, you can use:

python sample.py --image-size 512 --seed 1

For convenience, our pre-trained DiT models can be downloaded directly here as well:

DiT Model Image Resolution FID-50K Inception Score Gflops
XL/2 256x256 2.27 278.24 119
XL/2 512x512 3.04 240.82 525

Custom DiT checkpoints. If you've trained a new DiT model with train.py (see below), you can add the --ckpt argument to use your own checkpoint instead. For example, to sample from the EMA weights of a custom 256x256 DiT-L/4 model, run:

python sample.py --model DiT-L/4 --image-size 256 --ckpt /path/to/model.pt

Training

Preparation Before Training

To extract ImageNet features with 1 GPUs on one node:

torchrun --nnodes=1 --nproc_per_node=1 extract_features.py --model DiT-XL/2 --data-path /path/to/imagenet/train --features-path /path/to/store/features

Training DiT

We provide a training script for DiT in train.py. This script can be used to train class-conditional DiT models, but it can be easily modified to support other types of conditioning.

To launch DiT-XL/2 (256x256) training with 1 GPUs on one node:

accelerate launch --mixed_precision fp16 train.py --model DiT-XL/2 --features-path /path/to/store/features

To launch DiT-XL/2 (256x256) training with N GPUs on one node:

accelerate launch --multi_gpu --num_processes N --mixed_precision fp16 train.py --model DiT-XL/2 --features-path /path/to/store/features

Alternatively, you have the option to extract and train the scripts located in the folder training options.

PyTorch Training Results

We've trained DiT-XL/2 and DiT-B/4 models from scratch with the PyTorch training script to verify that it reproduces the original JAX results up to several hundred thousand training iterations. Across our experiments, the PyTorch-trained models give similar (and sometimes slightly better) results compared to the JAX-trained models up to reasonable random variation. Some data points:

DiT Model Train Steps FID-50K
(JAX Training)
FID-50K
(PyTorch Training)
PyTorch Global Training Seed
XL/2 400K 19.5 18.1 42
B/4 400K 68.4 68.9 42
B/4 400K 68.4 68.3 100

These models were trained at 256x256 resolution; we used 8x A100s to train XL/2 and 4x A100s to train B/4. Note that FID here is computed with 250 DDPM sampling steps, with the mse VAE decoder and without guidance (cfg-scale=1).

Improved Training Performance

In comparison to the original implementation, we implement a selection of training speed acceleration and memory saving features including gradient checkpointing, mixed precision training, and pre-extracted VAE features, resulting in a 95% speed increase and 60% memory reduction on DiT-XL/2. Some data points using a global batch size of 128 with a A100:

gradient checkpointing mixed precision training feature pre-extraction training speed memory
❌ ❌ ❌ - out of memory
βœ” ❌ ❌ 0.43 steps/sec 44045 MB
βœ” βœ” ❌ 0.56 steps/sec 40461 MB
βœ” βœ” βœ” 0.84 steps/sec 27485 MB

Evaluation (FID, Inception Score, etc.)

We include a sample_ddp.py script which samples a large number of images from a DiT model in parallel. This script generates a folder of samples as well as a .npz file which can be directly used with ADM's TensorFlow evaluation suite to compute FID, Inception Score and other metrics. For example, to sample 50K images from our pre-trained DiT-XL/2 model over N GPUs, run:

torchrun --nnodes=1 --nproc_per_node=N sample_ddp.py --model DiT-XL/2 --num-fid-samples 50000

There are several additional options; see sample_ddp.py for details.

fast-dit's People

Contributors

chuanyangjin avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

fast-dit's Issues

How to train model using mixed precision fp16?

Hello author, I tried using the original DiT model for training but facing out of memory issue. I saw your repository which implements DiT using memory constraints. In the README file, I saw you used a mixed_precision argument but I couldn't find it anywhere in the code. I just want to copy the model architecture file and adjust it according to my implementation of the work. Can you please tell which model arch uses less memory constraints as it is a bit confusing to me to understand so just clarifying.

CustomDataset file order

Thanks for the great work! In L102 of train.py, the feature files and label files are obtained using os.listdir. But os.listdir does not guarantee the returned list the be sorted. Wouldn't this cause mismatch between the feature files and label files? Thank you.

Jiahao

inefficient data loader

I just wanted to point out that the data loader in this implementation seems to be a lot less efficient than it could have been. Right now, the code writes each encoded image into a separate .npy file and during training loads each file in a batch separately. That's a lot of inefficient file I/O. You could have just saved all pre-extracted features in a single array/tensor and loaded a single file into RAM (or even into GPU RAM) once before starting training. The entire ImageNet takes up only 5 GB of memory if you store it in uint8 in this way, e.g.: https://huggingface.co/datasets/cloneofsimo/imagenet.int8.

how to condition on an image?

I was wondering if it is possible to condition the input on an image? since the scale size is (Bdim) so how we can scale the input (BLdim) using an image. Should I patch the image to size of (B, pp , dim)?

Finetune possible?

There is one global question: is it possible to train model on custom data, not ImageNet, with only one class, which is not in ImageNet?
So there are following questions:

  1. Why do we need extract_features.py? This file doesn’t occur in original code from FB. Is it right that extract_features.py is extracting part from old train.py file for create features from images using vae encoder, and we dont need so much parameters in last strings in the end, like epochs, models?
  2. Is it possible to train model on custom data using ImageNet’s weights?

Training Cost

Thanks for your efforts. I just want to know how long it will take to train a XL model on ImageNet256 by one A100 GPU in terms of hours.

Training on 3D data

Hi,

Any piece of advice on how to efficiently transform the code to work with 3D images?

Thank you

Gradient checkpointing

Hi, can you point out where does the gradient checkpointing is happening in the code? I cannot figure it out.

Thanks!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.