Giter Club home page Giter Club logo

yang-song / score_sde Goto Github PK

View Code? Open in Web Editor NEW
1.3K 17.0 187.0 4.45 MB

Official code for Score-Based Generative Modeling through Stochastic Differential Equations (ICLR 2021, Oral)

Home Page: https://arxiv.org/abs/2011.13456

License: Apache License 2.0

Python 5.65% Jupyter Notebook 94.35%
score-matching stochastic-differential-equations generative-models score-based-generative-modeling controllable-generation inverse-problems jax flax iclr-2021 diffusion-models

score_sde's Introduction

Score-Based Generative Modeling through Stochastic Differential Equations

PWC

This repo contains the official implementation for the paper Score-Based Generative Modeling through Stochastic Differential Equations

by Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole


We propose a unified framework that generalizes and improves previous work on score-based generative models through the lens of stochastic differential equations (SDEs). In particular, we can transform data to a simple noise distribution with a continuous-time stochastic process described by an SDE. This SDE can be reversed for sample generation if we know the score of the marginal distributions at each intermediate time step, which can be estimated with score matching. The basic idea is captured in the figure below:

schematic

Our work enables a better understanding of existing approaches, new sampling algorithms, exact likelihood computation, uniquely identifiable encoding, latent code manipulation, and brings new conditional generation abilities (including but not limited to class-conditional generation, inpainting and colorization) to the family of score-based generative models.

All combined, we achieved an FID of 2.20 and an Inception score of 9.89 for unconditional generation on CIFAR-10, as well as high-fidelity generation of 1024px Celeba-HQ images (samples below). In addition, we obtained a likelihood value of 2.99 bits/dim on uniformly dequantized CIFAR-10 images.

FFHQ samples

What does this code do?

Aside from the NCSN++ and DDPM++ models in our paper, this codebase also re-implements many previous score-based models in one place, including NCSN from Generative Modeling by Estimating Gradients of the Data Distribution, NCSNv2 from Improved Techniques for Training Score-Based Generative Models, and DDPM from Denoising Diffusion Probabilistic Models.

It supports training new models, evaluating the sample quality and likelihoods of existing models. We carefully designed the code to be modular and easily extensible to new SDEs, predictors, or correctors.

PyTorch version

Please find a PyTorch implementation here, which supports everything except class-conditional generation with a pre-trained classifier.

JAX vs. PyTorch

In general, the PyTorch version consumes less memory but also runs slower than JAX. Here is a benchmark on training an NCSN++ cont. model with VE SDE. Hardware is 4x Nvidia Tesla V100 GPUs (32GB)

Framework Time (second per step) Memory usage in total (GB)
PyTorch 0.56 20.6
JAX (n_jitted_steps=1) 0.30 29.7
JAX (n_jitted_steps=5) 0.20 74.8

How to run the code

Dependencies

Run the following to install a subset of necessary python packages for our code

pip install -r requirements.txt

Stats files for quantitative evaluation

We provide the stats file for CIFAR-10. You can download cifar10_stats.npz and save it to assets/stats/. Check out #5 on how to compute this stats file for new datasets.

Usage

Train and evaluate our models through main.py.

main.py:
  --config: Training configuration.
    (default: 'None')
  --eval_folder: The folder name for storing evaluation results
    (default: 'eval')
  --mode: <train|eval>: Running mode: train or eval
  --workdir: Working directory
  • config is the path to the config file. Our prescribed config files are provided in configs/. They are formatted according to ml_collections and should be quite self-explanatory.

    Naming conventions of config files: the path of a config file is a combination of the following dimensions:

    • dataset: One of cifar10, celeba, celebahq, celebahq_256, ffhq_256, celebahq, ffhq.
    • model: One of ncsn, ncsnv2, ncsnpp, ddpm, ddpmpp.
    • continuous: train the model with continuously sampled time steps.
  • workdir is the path that stores all artifacts of one experiment, like checkpoints, samples, and evaluation results.

  • eval_folder is the name of a subfolder in workdir that stores all artifacts of the evaluation process, like meta checkpoints for pre-emption prevention, image samples, and numpy dumps of quantitative results.

  • mode is either "train" or "eval". When set to "train", it starts the training of a new model, or resumes the training of an old model if its meta-checkpoints (for resuming running after pre-emption in a cloud environment) exist in workdir/checkpoints-meta . When set to "eval", it can do an arbitrary combination of the following

    • Evaluate the loss function on the test / validation dataset.

    • Generate a fixed number of samples and compute its Inception score, FID, or KID. Prior to evaluation, stats files must have already been downloaded/computed and stored in assets/stats.

    • Compute the log-likelihood on the training or test dataset.

    These functionalities can be configured through config files, or more conveniently, through the command-line support of the ml_collections package. For example, to generate samples and evaluate sample quality, supply the --config.eval.enable_sampling flag; to compute log-likelihoods, supply the --config.eval.enable_bpd flag, and specify --config.eval.dataset=train/test to indicate whether to compute the likelihoods on the training or test dataset.

How to extend the code

  • New SDEs: inherent the sde_lib.SDE abstract class and implement all abstract methods. The discretize() method is optional and the default is Euler-Maruyama discretization. Existing sampling methods and likelihood computation will automatically work for this new SDE.
  • New predictors: inherent the sampling.Predictor abstract class, implement the update_fn abstract method, and register its name with @register_predictor. The new predictor can be directly used in sampling.get_pc_sampler for Predictor-Corrector sampling, and all other controllable generation methods in controllable_generation.py.
  • New correctors: inherent the sampling.Corrector abstract class, implement the update_fn abstract method, and register its name with @register_corrector. The new corrector can be directly used in sampling.get_pc_sampler, and all other controllable generation methods in controllable_generation.py.

Pretrained checkpoints

All checkpoints are provided in this Google drive.

Instructions: You may find two checkpoints for some models. The first checkpoint (with a smaller number) is the one that we reported FID scores in our paper's Table 3 (also corresponding to the FID and IS columns in the table below). The second checkpoint (with a larger number) is the one that we reported likelihood values and FIDs of black-box ODE samplers in our paper's Table 2 (also FID(ODE) and NNL (bits/dim) columns in the table below). The former corresponds to the smallest FID during the course of training (every 50k iterations). The later is the last checkpoint during training.

Per Google's policy, we cannot release our original CelebA and CelebA-HQ checkpoints. That said, I have re-trained models on FFHQ 1024px, FFHQ 256px and CelebA-HQ 256px with personal resources, and they achieved similar performance to our internal checkpoints.

Here is a detailed list of checkpoints and their results reported in the paper. FID (ODE) corresponds to the sample quality of black-box ODE solver applied to the probability flow ODE.

Checkpoint path FID IS FID (ODE) NNL (bits/dim)
ve/cifar10_ncsnpp/ 2.45 9.73 - -
ve/cifar10_ncsnpp_continuous/ 2.38 9.83 - -
ve/cifar10_ncsnpp_deep_continuous/ 2.20 9.89 - -
vp/cifar10_ddpm/ 3.24 - 3.37 3.28
vp/cifar10_ddpm_continuous - - 3.69 3.21
vp/cifar10_ddpmpp 2.78 9.64 - -
vp/cifar10_ddpmpp_continuous 2.55 9.58 3.93 3.16
vp/cifar10_ddpmpp_deep_continuous 2.41 9.68 3.08 3.13
subvp/cifar10_ddpm_continuous - - 3.56 3.05
subvp/cifar10_ddpmpp_continuous 2.61 9.56 3.16 3.02
subvp/cifar10_ddpmpp_deep_continuous 2.41 9.57 2.92 2.99
Checkpoint path Samples
ve/bedroom_ncsnpp_continuous bedroom_samples
ve/church_ncsnpp_continuous church_samples
ve/ffhq_1024_ncsnpp_continuous ffhq_1024
ve/ffhq_256_ncsnpp_continuous ffhq_256_samples
ve/celebahq_256_ncsnpp_continuous celebahq_256_samples

Demonstrations and tutorials

Link Description
Open In Colab Load our pretrained checkpoints and play with sampling, likelihood computation, and controllable synthesis (JAX + FLAX)
Open In Colab Load our pretrained checkpoints and play with sampling, likelihood computation, and controllable synthesis (PyTorch)
Open In Colab Tutorial of score-based generative models in JAX + FLAX
Open In Colab Tutorial of score-based generative models in PyTorch

Tips

  • When using the JAX codebase, you can jit multiple training steps together to improve training speed at the cost of more memory usage. This can be set via config.training.n_jitted_steps. For CIFAR-10, we recommend using config.training.n_jitted_steps=5 when your GPU/TPU has sufficient memory; otherwise we recommend using config.training.n_jitted_steps=1. Our current implementation requires config.training.log_freq to be dividable by n_jitted_steps for logging and checkpointing to work normally.
  • The snr (signal-to-noise ratio) parameter of LangevinCorrector somewhat behaves like a temperature parameter. Larger snr typically results in smoother samples, while smaller snr gives more diverse but lower quality samples. Typical values of snr is 0.05 - 0.2, and it requires tuning to strike the sweet spot.
  • For VE SDEs, we recommend choosing config.model.sigma_max to be the maximum pairwise distance between data samples in the training dataset.

References

If you find the code useful for your research, please consider citing

@inproceedings{
  song2021scorebased,
  title={Score-Based Generative Modeling through Stochastic Differential Equations},
  author={Yang Song and Jascha Sohl-Dickstein and Diederik P Kingma and Abhishek Kumar and Stefano Ermon and Ben Poole},
  booktitle={International Conference on Learning Representations},
  year={2021},
  url={https://openreview.net/forum?id=PxTIG12RRHS}
}

This work is built upon some previous papers which might also interest you:

  • Song, Yang, and Stefano Ermon. "Generative Modeling by Estimating Gradients of the Data Distribution." Proceedings of the 33rd Annual Conference on Neural Information Processing Systems. 2019.
  • Song, Yang, and Stefano Ermon. "Improved techniques for training score-based generative models." Proceedings of the 34th Annual Conference on Neural Information Processing Systems. 2020.
  • Ho, Jonathan, Ajay Jain, and Pieter Abbeel. "Denoising diffusion probabilistic models." Proceedings of the 34th Annual Conference on Neural Information Processing Systems. 2020.

score_sde's People

Contributors

yang-song avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

score_sde's Issues

the jax-based code on multi-host tpu

Hi Yang,

That's a great work. I would like to ask whether this code can run on the multi-host tpu (such as v3-32). And could you give me some advice on how to change this code for it.

Thank you very much!

Yong

Questions about Alg. 5 in the paper

I wondered why the step size $\epsilon$ of VP SDE in corrector need to multiply a more $\alpha_i$ compared with $\epsilon$ in VE SDE setting.

[Huawei] 2012 Lab-Technical Exchange & Project Cooperation & Talented Youth Program Invitation

Hi, Yang Song, I have the honor to read your published paper and get your contact information; I am Han Lu from the Huawei 2012 Lab. We are currently doing audio/autonomous driving perception/CG (such as Character Animation) /Rendering / 3D reconstruction / motion synthesis / role link interaction, etc.) / CV (multi-modal learning algorithm) / ML/NLP and other related topics research and technical exploration, while introducing talents in related fields (internships, full-time, consultants, etc.) ); I look forward to an open and in-depth communication with you; this is my contact information, Email: [email protected]; Tel: 17710876257; WeChat: 1274634225 (I am Xiaobao); thanks ;
The 2012 Lab Central Media Technology Institute is the media technology innovation and engineering competence center of Huawei.
It is responsible for technical research, innovation and breakthrough tasks in the fields of the company's mobile phone camera, video, audio, and audio and video standards to ensure that Huawei's media product technology continues to lead the industry. At present, the Central Media Technology Institute has established R&D centers and professional laboratories in Japan, North America, Europe and other overseas countries, as well as in Shenzhen, Hangzhou, Beijing, Shanghai and other domestic cities.
Hope to be able to establish contact with you and look forward to your reply!

Tips to train a model on a custom dataset

Hey,
I really like your work and I wanted to compare the results of an NCSNPP model to a GAN. I have a custom dataset of 31k images and I am using Google Colab (so 1x V100).

Do you have any helpful comments on how I should train an NCSNPP model with that specification? Best case, the resolution would be between 64 and 256.

Thanks!

Attention Blocks in Denoiser

Hi!

I have noticed that the U-Net in ncsnpp.py handles attention blocks differently for the contracting and expanding path.

When downsampling, the number of attention blocks matches the no. of ResNet blocks for each dimension mult.

for i_level in range(num_resolutions):
    # Residual blocks for this resolution
    for i_block in range(num_res_blocks):
        h = ResnetBlock(out_ch=nf * ch_mult[i_level])(hs[-1], temb, train)
        if h.shape[1] in attn_resolutions:
            h = AttnBlock()(h)
        hs.append(h)

In the expanding path, however, there is only a single attention block.

# Upsampling block
for i_level in reversed(range(num_resolutions)):
    for i_block in range(num_res_blocks + 1):
        h = ResnetBlock(out_ch=nf * ch_mult[i_level])(
            jnp.concatenate([h, hs.pop()], axis=-1), temb, train
        )

    if h.shape[1] in attn_resolutions:
        h = AttnBlock()(h)

Why is this beneficial?

Checkpoint for CelebA-HQ

Hi guys,

first of all thank for the code!
I wonder if you could upload a pretrained checkpoint for CelebA-HQ ?
How long does it take to train on CelebA-HQ and how many GPUs are required?

Thanks,
Artsiom

about "JIT multiple training steps together"

Hello, Dr. Song

Thank you for sharing this excellent work.
I saw that a parameter "n_jitted_steps" was used in the training, and the comment of the code said: "JIT multiple training steps together for faster training." Can you explain why and how to conduct this "JIT multiple training steps together"? Does this "n_jitted_steps" affect performance, that is, if I don't use this "JIT multiple training steps together", will the performance be the same?
Thank you in advance.

Checkpoint on CelebA dataset

Thank you for your dedication on score matching models.
Recent days, it was particularly interesting to follow your recent paper on SDE+DSM

I would like to what happens to CelebA in terms of performance.
You have compared CelebA with CIFAR10 on the Appendix, but there is no checkpoint for CelebA.
Do you have any plan to release the checkpoint of CelebA?

FID score of conditional sampling

Hi,

Thanks a lot for your amazing work.
I'm recently reproducing your work for conditional sampling. I found that the FID scores of the images sampled using the conditional sampling (a VE score model with a wide resnet classifier) is far higher than 2.20. Is there any suggestion for tuning the hyper-parameters to improve the performance? Or can you provide the FID score for this experimental setting for reference? It seems that the paper only provides some visualization example for this experiment.

Thanks.

Colorization Matrix

Hello.
What was the process for choosing the proposed 3x3 orthogonal matrix in the process of colorization?

The possibility of using maximum likelihood estimation for training

Hello,

We understand that using computed likelihood directly for maximum likelihood estimation can be computationally intensive. However, we are curious if this technique is theoretically and practically feasible. We would like to experiment with this approach on smaller models and datasets. Any advice or code assistance would be greatly appreciated.
:)
Once again, thank you for your outstanding work.

TypeError: can't multiply sequence by non-int of type 'BatchTracer'

Hi Yang,

I'm getting the following error:

File "/localscratch/jolicoea.62752629.0/1/ScoreSDEMore/score_sde/losses.py", line 111, in loss_fn
losses = jnp.square(batch_mul(score, std) + z)
File "/localscratch/jolicoea.62752629.0/1/ScoreSDEMore/score_sde/utils.py", line 42, in batch_mul
return jax.vmap(lambda a, b: a * b)(a, b)
TypeError: can't multiply sequence by non-int of type 'BatchTracer'

If I use chex.fake_pmap to be able to print inside the pmap, I see:

std:

Traced<ShapedArray(float32[8])>with<BatchTrace(level=1/0)>
with val = DeviceArray([[ 0.21264698, 0.77755433, 0.27918625, 0.9448618 ,
10.666621 , 0.24025024, 12.233008 , 3.626547 ]], dtype=float32)
batch_dim = 0

z:

Traced<ShapedArray(float32[8,32,32,3])>with<BatchTrace(level=1/0)>
with val = DeviceArray([[[[[-2.01293421e+00, -2.17641640e+00, -1.23569024e+00],
[ 6.13737464e-01, 1.50414258e-01, -2.59380966e-01],
....

score:

(Traced<ShapedArray(float32[8,32,32,3])>with<JVPTrace(level=3/0)>
with primal = Traced<ShapedArray(float32[8,32,32,3])>with<BatchTrace(level=1/0)>
with val = DeviceArray([[[[[ 8.45328856e-08, -1.25030866e-07, -8.07002252e-08],
...

I tried to match your versions of libraries as much as possible. Same jax, flax, jaxlib version. Tensorflow_gpu=2.4.1. I use one GPU and config="configs/ncsnpp/cifar10_continuous_ve.py".

Update: If I do it on cifar10_continuous_vp.py it breaks here:

File "/localscratch/jolicoea.62752629.0/1/ScoreSDEMore/score_sde/models/utils.py", line 197, in score_fn
score = batch_mul(-model, 1. / std)
TypeError: bad operand type for unary -: 'tuple'

Requirements are broken due to new dependency releases

The repository cannot be installed at the moment probably due to new releases of dependencies. I suspect the issue is that the version of tensorflow_io is not specified and some versioning around the tensorflow-probability library is incorrectly specified somewhere upstream.

This seems to work but the versions likely differ from what you have used

ml-collections==0.1.0
tensorflow-gan==2.0.0
tensorflow_io==0.17.1
tensorflow_datasets==3.1.0
tensorflow==2.4.0
tensorflow-addons==0.12.0
tensorboard==2.4.0
absl-py==0.10.0
flax==0.3.1
jax==0.2.8
jaxlib==0.1.59
tensorflow-probability==0.12.2

Would it be possible for you to report the version of python that you are using and run pip freeze > requirements.txt to help with reproducibility?

figure source code

Dear Song,

Many thanks for sharing this awesome idea and the source code! I am wondering whether you can point me to the source code of the following figure in Song2021?

image

Kind regards
Feng

Why does the approximate equality in Eq.(24) holds?

Hi! When I'm reading the proof that DDPM is a discretization of the VP-SDE in Appendix B of https://openreview.net/pdf?id=PxTIG12RRHS, I don't understand why Eq.(24) holds. I know that when $x\ll 1, \sqrt{1-x} \approx 1-x/2$. However, In Eq.(24), $\beta(t+\Delta t)\Delta t$ seems to not satisfy this condition, because $\beta(t)=N\beta_i$, and when $\Delta t\rightarrow 0, N\rightarrow \infty$. Could you explain why this approximate equality still holds?

Training NCSNPP on a custom 512x512 dataset

Hi!

Firstly thanks a lot for your amazing work and contribution!

I was really motivated to try your code in a real custom dataset of about 140K pictures in 512x512.
Following I would like to use NCSNPP, but I am having a really hard time understanding how I can load my own data.

Could you maybe describe programmatically what is the pipeline of doing so?
i.e.

  • what script should I run to create a dataset and how and
  • which config file should I edit and which lines/attributes should I change - at least to have a basic starting point?

Thanks in advance!
ysig

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.