Giter Club home page Giter Club logo

relaydiffusion_fromthudm's Introduction

Relay Diffusion: Unifying diffusion process across resolutions for image synthesis
Official Pytorch Implementation

We propose Relay Diffusion Model (RDM) as a better framework for diffusion generation. RDM transfers a low-resolution image or noise into an equivalent high-resolution one via blurring diffusion and block noise. Therefore, the diffusion process can continue seamlessly in any new resolution or model without restarting from pure noise or low-resolution conditioning.

RDM achieved state-of-the-art FID on CelebA-HQ and sFID ImageNet-256 (FID=1.87)!

For a formal introduction, Read our paper: Relay Diffusion: Unifying diffusion process across resolutions for image synthesis.

Setup

Environment

Download the repo and setup the environment with:

git clone https://github.com/THUDM/RelayDiffusion.git
cd RelayDiffusion
conda env create -f environment.yml
conda activate rdm

We enable xformers.ops.memory_efficient_attention to reduce about 15% training cost. If there is no need you can also remove xformers from environment.yml.

Linux servers with Nvidia A100s are recommended. However, by setting smaller --batch-gpu (batch size on a single gpu), you can still run the inference and training scripts on less powerful GPUs.

Dataset

We preprocess and implement datasets with the same format as EDM. For CelebA-HQ, follow Progressive Growing of GANs for Improved Quality, Stability, and Variation to construct the high-quality subset of CelebA. For ImageNet, download data from the official site.

To convert the original data to organized data ready for training at $64\times 64$ or $256\times 256$ resolution, run command:

python dataset_tool.py \
	--source=/path/to/original/data \
	--dest=/path/to/output/data.zip \
    --transform=center-crop \
	--resolution=64x64 # or --resolution=256x256

Inference & Evaluation

Sample Generation

To generate samples from RDM models, run command:

torchrun --standalone --nproc_per_node=1 generate.py --sampler_stages=both --outdir=/path/to/output/dir/ \
    --network_first=/path/to/1st/ckpt --network_second=/path/to/2nd/ckpt

To generate $N$ images, set --seed=[K]-[K+N-1] with a randomly-picked $K$. You can assign --nproc_per_node=N to enable parallel generation of multiple GPUs.

If you want to generate final samples from first-stage results (only use the second stage model), set --sampler_stages=second and assign input directory of first-stage results by --indir.

Besides, arguments for configurations of the first stage are:

  • num_steps_first: number of sampling steps.
  • sigma_min_first & sigma_max_first: lowest & highest noise level.
  • rho_first: time step exponent.
  • cfg_scale_first: scale of classifier-free guidance.
  • S_churn: stochasticity strength.
  • S_min & S_max: min & max noise level.
  • S_noise: noise inflation.

Arguments for configurations of the second stage are:

  • num_steps_second: number of sampling steps.
  • sigma_min_second & sigma_max_second: lowest & highest noise level.
  • blur_sigma_max_second: maximum sigma of blurring schedule.
  • rho_second: time step exponent.
  • cfg_scale_second: scale of classifier-free guidance.
  • up_scale_second: scale of upsampling.
  • truncation_sigma_second & truncation_t_second: truncation point of noise & time schedule.
  • s_block_second: strength of block noise addition.
  • s_noise_second: strength of stochasticity.

Evaluation Metrics

We quantitatively measure the sample quality by metrics including Fréchet inception distance (FID), spatial FID (sFID), Inception Score (IS), Precision and Recall. For sFID, IS, Precision and Recall, we reformat the calculation pipeline based on the formulation in tensorflow from ADM.

First, run the following command to generate activation data file from samples and dataset:

torchrun --standalone --nproc_per_node=1 evaluate.py activations --data=/sample/dir/ --dest=eval-refs/activations_sample.npz --batch=64 # build sample activations
torchrun --standalone --nproc_per_node=1 evaluate.py activations --data=/path/to/dataset.zip --dest=eval-refs/activations_ref.npz --batch=64 # build reference activations

Then calculate metrics based on pre-built activations, run command:

torchrun --standalone --nproc_per_node=1 evaluate.py calc --batch=64 \
    --activations_sample=eval-refs/activations_sample.npz \
    --activations_ref=eval-refs/activations_ref.npz \
    [-m fid] [-m sfid] [-m is] [-m pr] \ # assign metrics to be calculated

Performance Reproduction

RDM achieves competitive results in comparison with previous SoTA models:

Dataset Resolution Training Samples FID sFID IS Precision Recall
CelebA-HQ 256x256 47M 3.15 - - 0.77 0.55
ImageNet 256x256 1250M 1.87 3.97 278.75 0.81 0.59

We provide best pre-trained checkpoints of RDM and their sampler settings for reproducing performance:

  • CelebA-HQ $256\times 256$:

    Download checkpoints of first stage and second stage, place them in ckpts/, generate samples and their activations by commands:

    torchrun --standalone --nproc_per_node=8 generate_celebahq.py --outdir=generations/celebahq_samples/ \
        --network_first=ckpts/celebahq_first_stage.pt \
        --network_second=ckpts/celebahq_second_stage.pt
    torchrun --standalone --nproc_per_node=1 evaluate.py activations \
        --data=generations/celebahq_samples/ --dest=eval-refs/celebahq_act_sample.npz 

    Generate activation data from CelebA-HQ zip or download our version from here:

    torchrun --standalone --nproc_per_node=1 evaluate.py activations \
        --data=datasets/celebahq-256x256.zip --dest=eval-refs/celebahq_act_ref.npz 

    Calculate metrics by command:

    python evaluate.py calc -m fid -m pr \
        --activations_sample=eval-refs/celebahq_act_sample.npz \
        --activations_ref=eval-refs/celebahq_act_ref.npz
  • ImageNet $256\times 256$:

    Download checkpoints of first stage and second stage, place them in ckpts/, generate samples and their activations by commands:

    torchrun --standalone --nproc_per_node=8 generate_imagenet.py --outdir=generations/imagenet_samples/ \
        --network_first=ckpts/imagenet_first_stage.pkl \
        --network_second=ckpts/imagenet_second_stage.pt
    torchrun --standalone --nproc_per_node=1 evaluate.py activations \
        --data=generations/imagenet_samples/ --dest=eval-refs/imagenet_act_sample.npz 

    Generate activation data from ImageNet zip:

    torchrun --standalone --nproc_per_node=1 evaluate.py activations \
        --data=datasets/imagenet-256x256.zip --dest=eval-refs/imagenet_act_ref.npz 

    Calculate FID, sFID and IS by command:

    python evaluate.py calc -m fid -m sfid -m is \
        --activations_sample=eval-refs/imagenet_act_sample.npz \
        --activations_ref=eval-refs/imagenet_act_ref.npz

    For the calculation of Precision and Recall on ImageNet, we follow ADM to use 1w reference samples. You can download the activation data we produced from here. Then run the following command:

    python evaluate.py calc -m pr \
        --activations_sample=eval-refs/imagenet_act_sample.npz \
        --activations_ref=eval-refs/imagenet_act_1w_ref.npz

Training

you can follow the instruction of EDM to train a new model of the first stage (standard diffusion). Using ImageNet for example, run command:

torchrun --standalone --nproc_per_node=8 train.py --outdir=training-runs --data=datasets/imagenet-64x64.zip --eff-attn=True \
	--cond=1 --batch=4096  --batch-gpu=32 --lr=1e-4 --ema=50 --dropout=0.1 --fp16=1 --ls=25 \
	--arch=adm --precond=edm

If you want to train a second stage model (blurring diffusion), set argument --precond=blur and other arguments for the configuration of blurring diffusion. The command will be:

torchrun --standalone --nproc_per_node=8 train.py --outdir=training-runs --data=datasets/imagenet-256x256.zip --eff-attn=True \
	--cond=1 --batch=4096  --batch-gpu=8 --lr=1e-4 --dropout=0.1 --fp16=1 --ls=1 \
	--arch=adm --precond=blur --up-scale=4 --block-scale=0.15 --prob-length=0.93 --blur-sigma-max=3.0

As for CelebA-HQ, train a first stage model with:

torchrun --standalone --nproc_per_node=8 train.py --outdir=training-runs --data=datasets/CelebA-HQ-64x64.zip --eff-attn=True \
	--cond=0 --batch=1024  --batch-gpu=32 --lr=1e-4 --dropout=0.15 --augment=0.2 --ls=1 \
	--arch=adm --precond=edm

And for training a second stage model:

torchrun --standalone --nproc_per_node=8 train.py --outdir=training-runs --data=datasets/CelebA-HQ-256x256.zip --eff-attn=True \
	--cond=0 --batch=1024  --batch-gpu=8 --lr=1e-4 --dropout=0.2 --augment=0.2 --fp16=1 --ls=1 \
	--arch=adm --precond=blur --up-scale=4 --block-scale=0.15 --prob-length=0.89 --blur-sigma-max=2.0

Citation

@article{teng2023relay,
  title={Relay Diffusion: Unifying diffusion process across resolutions for image synthesis},
  author={Teng, Jiayan and Zheng, Wendi and Ding, Ming and Hong, Wenyi and Wangni, Jianqiao and Yang, Zhuoyi and Tang, Jie},
  journal={arXiv preprint arXiv:2309.03350},
  year={2023}
}

Acknowledgements

This implementation is based on https://github.com/NVlabs/edm (codebase of EDM). Thanks a lot!

relaydiffusion_fromthudm's People

Contributors

minkowski0125 avatar tengjiayan20 avatar sleepychord avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.