Giter Club home page Giter Club logo

yang-song / score_sde Goto Github PK

View Code? Open in Web Editor NEW
1.3K 17.0 180.0 4.45 MB

Official code for Score-Based Generative Modeling through Stochastic Differential Equations (ICLR 2021, Oral)

Home Page: https://arxiv.org/abs/2011.13456

License: Apache License 2.0

Python 5.65% Jupyter Notebook 94.35%
score-matching stochastic-differential-equations generative-models score-based-generative-modeling controllable-generation inverse-problems jax flax iclr-2021 diffusion-models

score_sde's People

Contributors

yang-song avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

score_sde's Issues

TypeError: can't multiply sequence by non-int of type 'BatchTracer'

Hi Yang,

I'm getting the following error:

File "/localscratch/jolicoea.62752629.0/1/ScoreSDEMore/score_sde/losses.py", line 111, in loss_fn
losses = jnp.square(batch_mul(score, std) + z)
File "/localscratch/jolicoea.62752629.0/1/ScoreSDEMore/score_sde/utils.py", line 42, in batch_mul
return jax.vmap(lambda a, b: a * b)(a, b)
TypeError: can't multiply sequence by non-int of type 'BatchTracer'

If I use chex.fake_pmap to be able to print inside the pmap, I see:

std:

Traced<ShapedArray(float32[8])>with<BatchTrace(level=1/0)>
with val = DeviceArray([[ 0.21264698, 0.77755433, 0.27918625, 0.9448618 ,
10.666621 , 0.24025024, 12.233008 , 3.626547 ]], dtype=float32)
batch_dim = 0

z:

Traced<ShapedArray(float32[8,32,32,3])>with<BatchTrace(level=1/0)>
with val = DeviceArray([[[[[-2.01293421e+00, -2.17641640e+00, -1.23569024e+00],
[ 6.13737464e-01, 1.50414258e-01, -2.59380966e-01],
....

score:

(Traced<ShapedArray(float32[8,32,32,3])>with<JVPTrace(level=3/0)>
with primal = Traced<ShapedArray(float32[8,32,32,3])>with<BatchTrace(level=1/0)>
with val = DeviceArray([[[[[ 8.45328856e-08, -1.25030866e-07, -8.07002252e-08],
...

I tried to match your versions of libraries as much as possible. Same jax, flax, jaxlib version. Tensorflow_gpu=2.4.1. I use one GPU and config="configs/ncsnpp/cifar10_continuous_ve.py".

Update: If I do it on cifar10_continuous_vp.py it breaks here:

File "/localscratch/jolicoea.62752629.0/1/ScoreSDEMore/score_sde/models/utils.py", line 197, in score_fn
score = batch_mul(-model, 1. / std)
TypeError: bad operand type for unary -: 'tuple'

Why does the approximate equality in Eq.(24) holds?

Hi! When I'm reading the proof that DDPM is a discretization of the VP-SDE in Appendix B of https://openreview.net/pdf?id=PxTIG12RRHS, I don't understand why Eq.(24) holds. I know that when $x\ll 1, \sqrt{1-x} \approx 1-x/2$. However, In Eq.(24), $\beta(t+\Delta t)\Delta t$ seems to not satisfy this condition, because $\beta(t)=N\beta_i$, and when $\Delta t\rightarrow 0, N\rightarrow \infty$. Could you explain why this approximate equality still holds?

Questions about Alg. 5 in the paper

I wondered why the step size $\epsilon$ of VP SDE in corrector need to multiply a more $\alpha_i$ compared with $\epsilon$ in VE SDE setting.

Checkpoint on CelebA dataset

Thank you for your dedication on score matching models.
Recent days, it was particularly interesting to follow your recent paper on SDE+DSM

I would like to what happens to CelebA in terms of performance.
You have compared CelebA with CIFAR10 on the Appendix, but there is no checkpoint for CelebA.
Do you have any plan to release the checkpoint of CelebA?

Training NCSNPP on a custom 512x512 dataset

Hi!

Firstly thanks a lot for your amazing work and contribution!

I was really motivated to try your code in a real custom dataset of about 140K pictures in 512x512.
Following I would like to use NCSNPP, but I am having a really hard time understanding how I can load my own data.

Could you maybe describe programmatically what is the pipeline of doing so?
i.e.

  • what script should I run to create a dataset and how and
  • which config file should I edit and which lines/attributes should I change - at least to have a basic starting point?

Thanks in advance!
ysig

Colorization Matrix

Hello.
What was the process for choosing the proposed 3x3 orthogonal matrix in the process of colorization?

the jax-based code on multi-host tpu

Hi Yang,

That's a great work. I would like to ask whether this code can run on the multi-host tpu (such as v3-32). And could you give me some advice on how to change this code for it.

Thank you very much!

Yong

Checkpoint for CelebA-HQ

Hi guys,

first of all thank for the code!
I wonder if you could upload a pretrained checkpoint for CelebA-HQ ?
How long does it take to train on CelebA-HQ and how many GPUs are required?

Thanks,
Artsiom

Tips to train a model on a custom dataset

Hey,
I really like your work and I wanted to compare the results of an NCSNPP model to a GAN. I have a custom dataset of 31k images and I am using Google Colab (so 1x V100).

Do you have any helpful comments on how I should train an NCSNPP model with that specification? Best case, the resolution would be between 64 and 256.

Thanks!

figure source code

Dear Song,

Many thanks for sharing this awesome idea and the source code! I am wondering whether you can point me to the source code of the following figure in Song2021?

image

Kind regards
Feng

about "JIT multiple training steps together"

Hello, Dr. Song

Thank you for sharing this excellent work.
I saw that a parameter "n_jitted_steps" was used in the training, and the comment of the code said: "JIT multiple training steps together for faster training." Can you explain why and how to conduct this "JIT multiple training steps together"? Does this "n_jitted_steps" affect performance, that is, if I don't use this "JIT multiple training steps together", will the performance be the same?
Thank you in advance.

[Huawei] 2012 Lab-Technical Exchange & Project Cooperation & Talented Youth Program Invitation

Hi, Yang Song, I have the honor to read your published paper and get your contact information; I am Han Lu from the Huawei 2012 Lab. We are currently doing audio/autonomous driving perception/CG (such as Character Animation) /Rendering / 3D reconstruction / motion synthesis / role link interaction, etc.) / CV (multi-modal learning algorithm) / ML/NLP and other related topics research and technical exploration, while introducing talents in related fields (internships, full-time, consultants, etc.) ); I look forward to an open and in-depth communication with you; this is my contact information, Email: [email protected]; Tel: 17710876257; WeChat: 1274634225 (I am Xiaobao); thanks ;
The 2012 Lab Central Media Technology Institute is the media technology innovation and engineering competence center of Huawei.
It is responsible for technical research, innovation and breakthrough tasks in the fields of the company's mobile phone camera, video, audio, and audio and video standards to ensure that Huawei's media product technology continues to lead the industry. At present, the Central Media Technology Institute has established R&D centers and professional laboratories in Japan, North America, Europe and other overseas countries, as well as in Shenzhen, Hangzhou, Beijing, Shanghai and other domestic cities.
Hope to be able to establish contact with you and look forward to your reply!

Requirements are broken due to new dependency releases

The repository cannot be installed at the moment probably due to new releases of dependencies. I suspect the issue is that the version of tensorflow_io is not specified and some versioning around the tensorflow-probability library is incorrectly specified somewhere upstream.

This seems to work but the versions likely differ from what you have used

ml-collections==0.1.0
tensorflow-gan==2.0.0
tensorflow_io==0.17.1
tensorflow_datasets==3.1.0
tensorflow==2.4.0
tensorflow-addons==0.12.0
tensorboard==2.4.0
absl-py==0.10.0
flax==0.3.1
jax==0.2.8
jaxlib==0.1.59
tensorflow-probability==0.12.2

Would it be possible for you to report the version of python that you are using and run pip freeze > requirements.txt to help with reproducibility?

FID score of conditional sampling

Hi,

Thanks a lot for your amazing work.
I'm recently reproducing your work for conditional sampling. I found that the FID scores of the images sampled using the conditional sampling (a VE score model with a wide resnet classifier) is far higher than 2.20. Is there any suggestion for tuning the hyper-parameters to improve the performance? Or can you provide the FID score for this experimental setting for reference? It seems that the paper only provides some visualization example for this experiment.

Thanks.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.