Comments (7)
hi @ryan-caesar-ramos , I really appreciate your willingness to share the insights with me. I later came across a few papers that all rely on the deterministic reconstruction of DDIM. Specifically they are DDIB (from the same author of DDIM), Diffusion-CLIP, and DirectInversion. Among them Diffusion-CLIP gives the direct form of forward ODE in a clear and concise formula in its paper and it looks like this.
Again thanks for your kind response and hope this might help future readers of this post : )
from ddim.
Yes. That is exactly how you can do it. Although be noted that the seq should generally start at 0 and end at around 999.
from ddim.
Sorry, just to be clear, the trick is to switch seq
with seq_next
in the function, like this?
def generalized_steps(x, seq, model, b, **kwargs):
with torch.no_grad():
n = x.size(0)
seq_next = [-1] + list(seq[:-1])
x0_preds = []
xs = [x]
# SWAP
for i, j in zip(reversed(seq_next), reversed(seq)):
t = (torch.ones(n) * i).to(x.device)
next_t = (torch.ones(n) * j).to(x.device)
at = compute_alpha(b, t.long())
at_next = compute_alpha(b, next_t.long())
xt = xs[-1].to('cuda')
et = model(xt, t)
x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
x0_preds.append(x0_t.to('cpu'))
c1 = (
kwargs.get("eta", 0) * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt()
)
c2 = ((1 - at_next) - c1 ** 2).sqrt()
xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et
xs.append(xt_next.to('cpu'))
return xs, x0_preds
If I'm not mistaken, this would lead c1
and subsequently c2
to be nan
? Because (1 - at / at_next)
is negative when j
> i
leading to the square root of a negative term? I understand that under deterministic sampling, eta
would be 0 so I'm assuming we could technically let c1
be 0. But is that the right way to implement reverse sampling?
from ddim.
@RyanCaesarRamos @winnechan Hi guys, I am also trying to reproduce the reconstruction section but also found myself deeply lost. What I tried was using Eq.13 in DDIM in forward diffusion steps to get an expression of xt using xt-1. By recursively applying such expression with DDIM specific parameters, we can get a white noise in the end. (Though in the forward steps the noise is sampled from gaussian and the steps here should be DDIM steps such as 1, 21, 41, 61 etc if ddim_steps is 50 for 1000 ddpm steps) . Now that we have a deterministic latent code of image x, we can now use equation 13 (where sigma is 0 in every t) to reconstruct the original image.
However, this method completely failed to reconstruct the original image. Did you guys got any luck on this ?
from ddim.
Hi @Randolph-zeng I'm so sorry it took me this long to get back to you. Iirc I got it to somewhat work in v0.3.0 of the diffusers library by adding a function to their DDIM sampler class (results weren't that amazing but they at least made sense), but my code would break in later versions.
There are other people out there working on inversion for diffusion models. I haven't looked into it myself, but the diffusers library apparently has an implmentation of the CycleDiffusion paper which, if I'm not mistaken, proposes a DPM-Encoder for inversion. Here's a link to the docs.
Again, super sorry I didn't reply sooner. Please let me know if there's anything else I can help with!
from ddim.
Adding on top of it, DDIB has a formal implementation for the reverse process here: https://github.com/suxuann/ddib/blob/main/guided_diffusion/gaussian_diffusion.py#L670-L741
from ddim.
For those who also want to implement this:
def reverse_generalized_steps(x, seq, model, b, **kwargs):
with torch.no_grad():
n = x.size(0)
seq_next = list(seq[1:]) + [999]
x0_preds = []
xs = [x]
for i, j in zip(seq, seq_next):
t = (torch.ones(n) * i).to(x.device)
next_t = (torch.ones(n) * j).to(x.device)
at = compute_alpha(b, t.long())
at_next = compute_alpha(b, next_t.long())
print(at, at_next)
xt = xs[-1].to('cuda')
et = model(xt, t)
x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
x0_preds.append(x0_t.to('cpu'))
c2 = (1 - at_next).sqrt()
xt_next = at_next.sqrt() * x0_t + c2 * et
xs.append(xt_next.to('cpu'))
return xs, x0_preds
from ddim.
Related Issues (20)
- Does this repo support multi-GPUs?
- Transferability to transformers HOT 2
- FID of DDPM on CIFAR-10 HOT 3
- got an unexpected keyword argument `eta` HOT 5
- How about the training setting of CelebA model HOT 3
- Question about Lemma 1 in the paper
- test data
- DDIM inversion HOT 4
- why the activation functions formed like this in the code? HOT 1
- train loss goes very large HOT 5
- using DistributedDataParallel HOT 2
- run sampling process with cifar10 dataset HOT 1
- Generation Process not producing quality images. HOT 2
- Questions about noise distributions during training and sampling HOT 1
- Could you please provide guidance on the calculation of FID? HOT 1
- Regarding the issue of excessive FID HOT 2
- without "train" function!!!!
- Wrong fid in cifar10
- Loss is not going down HOT 2
- Why use asymmetric padding in downsample?
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from ddim.