Giter Club home page Giter Club logo

Comments (5)

fbcotter avatar fbcotter commented on May 18, 2024

I'm not quite sure what you want to replicate, perhaps you can link me to the tf code that does the upsampling?

Let's start from a different angle. Say if you have a target size, [1, 3, 256, 256], and use the DWT to downsample it to get a pyramid of signals:

import torch
from pytorch_wavelets import DWT, IDWT

x = torch.randn(1, 3, 256, 256)
dwt = DWT(J=1, wave='sym6', mode='zero')
yl, yh = dwt(x)
yl.shape
>>> torch.Size([1, 3, 133, 133])
yh[0].shape
>>> torch.Size([1, 3, 3, 133, 133])

What's happened here is yl is obtained by convolving x with the lowpass analysis filters h0. For a sym6 wavelet, these have length 12 (try printing dwt.h0_col or dwt.h0_row). When convolving a 256 length signal with a 12 length filter, we get an output of (256 + 12 - 1) = 267. This is then downsampled by 2, giving us the 133 we see in yl. The same is done for the 3 bandpass filters, just with combinations of h0 and h1 for rows and columns.

If you want to upsample a signal by 2 but don't have access to the bandpass coefficients, all we are doing is taking this yl, adding zeros in every other sample, and smoothing with a lowpass filter. It's not too much different to just using bilinear interpolation (although the lowpass filter will have a different frequency response curve).

To do this in this case, you would drop your signal into the yl coefficients. Note that after the initial convolution at the higher sample rate, the 256 length signal gained 5 samples before and 6 samples after. This became 2 samples before and 3 samples after when downsampled. Then:

idwt = IDWT(wayve='sym6', mode='zero')
yh = torch.zeros(1, 3, 3, 133, 133) # Should be the correct shape
yl = torch.zeros(1, 3, 133, 133)
yl[:, :, 2:-3, 2:-3] = your_signal
upsampled = idwt((yl, yh))

from pytorch_wavelets.

kierkegaard13 avatar kierkegaard13 commented on May 18, 2024

The code I'm specifically trying to replicate is here: https://github.com/NVlabs/stylegan2-ada/blob/main/training/augment.py#L424. Here's the paper if you're interested: https://arxiv.org/abs/2006.06676.

Thanks for that explanation. I think that makes sense for the most part. I'm not sure I get the padding scheme though and why the 256 length signal gains pixels after the convolution. I would expect the signal to be padded and then lose 11 pixels after the convolution with the sym6 signal. The upsampling procedure makes complete sense, but if I wanted to downsample to exactly half the pixels I'm not 100% sure how to do that.

from pytorch_wavelets.

fbcotter avatar fbcotter commented on May 18, 2024

Ok I see what they're doing in their implementation. Re padding: this is a property of convolutions, check out the gif here where two square pulses are convolved: https://en.wikipedia.org/wiki/Convolution#/media/File:Convolution_of_box_signal_with_itself2.gif, the output triangle signal has support before and after the input square.

In typical conv networks, we just throw away the extra info before and after. This is usually not a problem as the filters h are usually very small (3x3) compared to the signal x, so it's just easier to discard this.

To handle the padding in the DWT setting, it's a little tricky, but not too complex. What you want to do is:

x # shape [1, 3, 128, 128]
idwt = IDWT(wave='sym6', mode='zero')
dwt = DWT(J=1, wave='sym6', mode='zero')
# pad x to make it the expected 133x133 shape before giving to the IDWT - Use zeros to do this
upsampled = idwt((torch.nn.functional.pad(x, (2, 3, 2, 3)), [None, ]))  # shape [1, 3, 256, 256]
...
# The decimated version will have output 133x133 - discard the extra regions. These will not be zero, a bit like the
# triangle output from the gif.
downsampled = dwt(upsampled)[0][..., 2:-3, 2:-3]

If you then compare downsampled and x, they will be identical, except some errors at the borders, again a complication of the padding.

from pytorch_wavelets.

kierkegaard13 avatar kierkegaard13 commented on May 18, 2024

Great, thanks. That makes perfect sense. Out of curiosity, after looking at their implementation, do you think there are differences in their approach with what you've explained here at all? The thing I was most confused was the function tf.nn.depthwise_conv2d_backprop_input, because I couldn't find any example usage and it seemed like it should return gradient information from the description in the docs.

from pytorch_wavelets.

kierkegaard13 avatar kierkegaard13 commented on May 18, 2024

I tested out the up/downsampling method you mentioned and it seems to work perfectly. I'm using pytorch-geometry to apply the projective transformation and it seems to replicate everything in the paper as far as I can tell. Thanks again for the help.

from pytorch_wavelets.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.