Giter Club home page Giter Club logo

dit's Introduction

Scalable Diffusion Models with Transformers (DiT)
Official PyTorch Implementation

Paper | Project Page | Run DiT-XL/2 Hugging Face Spaces Open In Colab

DiT samples

This repo contains PyTorch model definitions, pre-trained weights and training/sampling code for our paper exploring diffusion models with transformers (DiTs). You can find more visualizations on our project page.

Scalable Diffusion Models with Transformers
William Peebles, Saining Xie
UC Berkeley, New York University

We train latent diffusion models, replacing the commonly-used U-Net backbone with a transformer that operates on latent patches. We analyze the scalability of our Diffusion Transformers (DiTs) through the lens of forward pass complexity as measured by Gflops. We find that DiTs with higher Gflops---through increased transformer depth/width or increased number of input tokens---consistently have lower FID. In addition to good scalability properties, our DiT-XL/2 models outperform all prior diffusion models on the class-conditional ImageNet 512ร—512 and 256ร—256 benchmarks, achieving a state-of-the-art FID of 2.27 on the latter.

This repository contains:

  • ๐Ÿช A simple PyTorch implementation of DiT
  • โšก๏ธ Pre-trained class-conditional DiT models trained on ImageNet (512x512 and 256x256)
  • ๐Ÿ’ฅ A self-contained Hugging Face Space and Colab notebook for running pre-trained DiT-XL/2 models
  • ๐Ÿ›ธ A DiT training script using PyTorch DDP

An implementation of DiT directly in Hugging Face diffusers can also be found here.

Setup

First, download and set up the repo:

git clone https://github.com/facebookresearch/DiT.git
cd DiT

We provide an environment.yml file that can be used to create a Conda environment. If you only want to run pre-trained models locally on CPU, you can remove the cudatoolkit and pytorch-cuda requirements from the file.

conda env create -f environment.yml
conda activate DiT

Sampling Hugging Face Spaces Open In Colab

More DiT samples

Pre-trained DiT checkpoints. You can sample from our pre-trained DiT models with sample.py. Weights for our pre-trained DiT model will be automatically downloaded depending on the model you use. The script has various arguments to switch between the 256x256 and 512x512 models, adjust sampling steps, change the classifier-free guidance scale, etc. For example, to sample from our 512x512 DiT-XL/2 model, you can use:

python sample.py --image-size 512 --seed 1

For convenience, our pre-trained DiT models can be downloaded directly here as well:

DiT Model Image Resolution FID-50K Inception Score Gflops
XL/2 256x256 2.27 278.24 119
XL/2 512x512 3.04 240.82 525

Custom DiT checkpoints. If you've trained a new DiT model with train.py (see below), you can add the --ckpt argument to use your own checkpoint instead. For example, to sample from the EMA weights of a custom 256x256 DiT-L/4 model, run:

python sample.py --model DiT-L/4 --image-size 256 --ckpt /path/to/model.pt

Training DiT

We provide a training script for DiT in train.py. This script can be used to train class-conditional DiT models, but it can be easily modified to support other types of conditioning. To launch DiT-XL/2 (256x256) training with N GPUs on one node:

torchrun --nnodes=1 --nproc_per_node=N train.py --model DiT-XL/2 --data-path /path/to/imagenet/train

Note
This script is a PyTorch reimplementation of DiT training. It has been only partially tested. We have trained DiT-XL/2 (256x256) from scratch for 90K iterations; the loss curve closely matches the JAX implementation's and FID is very similar at 50K iterations. If you encounter any bugs, please open an issue!

Training could likely be sped-up significantly by:

  • using Flash Attention in the DiT model
  • using torch.compile in PyTorch 2.0

Basic features that would be nice to add:

  • Monitor FID and other metrics
  • Generate and save samples from the EMA model periodically
  • Resume training from a checkpoint
  • AMP/bfloat16 support

Differences from JAX

Our models were originally trained in JAX on TPUs. The weights in this repo are ported directly from the JAX models. There may be minor differences in results stemming from sampling with different floating point precisions. We re-evaluated our ported PyTorch weights at FP32, and they actually perform marginally better than sampling in JAX (2.21 FID versus 2.27 in the paper).

BibTeX

@article{Peebles2022DiT,
  title={Scalable Diffusion Models with Transformers},
  author={William Peebles and Saining Xie},
  year={2022},
  journal={arXiv preprint arXiv:2212.09748},
}

Acknowledgments

We thank Kaiming He, Ronghang Hu, Alexander Berg, Shoubhik Debnath, Tim Brooks, Ilija Radosavovic and Tete Xiao for helpful discussions. William Peebles is supported by the NSF Graduate Research Fellowship.

This codebase borrows from OpenAI's diffusion repos, most notably ADM.

License

The code and model weights are licensed under CC-BY-NC. See LICENSE.txt for details.

dit's People

Contributors

eltociear avatar wpeebles avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.