Comments (5)
I'm not quite sure what you want to replicate, perhaps you can link me to the tf code that does the upsampling?
Let's start from a different angle. Say if you have a target size, [1, 3, 256, 256]
, and use the DWT to downsample it to get a pyramid of signals:
import torch
from pytorch_wavelets import DWT, IDWT
x = torch.randn(1, 3, 256, 256)
dwt = DWT(J=1, wave='sym6', mode='zero')
yl, yh = dwt(x)
yl.shape
>>> torch.Size([1, 3, 133, 133])
yh[0].shape
>>> torch.Size([1, 3, 3, 133, 133])
What's happened here is yl is obtained by convolving x
with the lowpass analysis filters h0
. For a sym6
wavelet, these have length 12 (try printing dwt.h0_col
or dwt.h0_row
). When convolving a 256 length signal with a 12 length filter, we get an output of (256 + 12 - 1) = 267
. This is then downsampled by 2, giving us the 133
we see in yl. The same is done for the 3 bandpass filters, just with combinations of h0
and h1
for rows and columns.
If you want to upsample a signal by 2 but don't have access to the bandpass coefficients, all we are doing is taking this yl
, adding zeros in every other sample, and smoothing with a lowpass filter. It's not too much different to just using bilinear interpolation (although the lowpass filter will have a different frequency response curve).
To do this in this case, you would drop your signal into the yl
coefficients. Note that after the initial convolution at the higher sample rate, the 256 length signal gained 5 samples before and 6 samples after. This became 2 samples before and 3 samples after when downsampled. Then:
idwt = IDWT(wayve='sym6', mode='zero')
yh = torch.zeros(1, 3, 3, 133, 133) # Should be the correct shape
yl = torch.zeros(1, 3, 133, 133)
yl[:, :, 2:-3, 2:-3] = your_signal
upsampled = idwt((yl, yh))
from pytorch_wavelets.
The code I'm specifically trying to replicate is here: https://github.com/NVlabs/stylegan2-ada/blob/main/training/augment.py#L424. Here's the paper if you're interested: https://arxiv.org/abs/2006.06676.
Thanks for that explanation. I think that makes sense for the most part. I'm not sure I get the padding scheme though and why the 256 length signal gains pixels after the convolution. I would expect the signal to be padded and then lose 11 pixels after the convolution with the sym6 signal. The upsampling procedure makes complete sense, but if I wanted to downsample to exactly half the pixels I'm not 100% sure how to do that.
from pytorch_wavelets.
Ok I see what they're doing in their implementation. Re padding: this is a property of convolutions, check out the gif here where two square pulses are convolved: https://en.wikipedia.org/wiki/Convolution#/media/File:Convolution_of_box_signal_with_itself2.gif, the output triangle signal has support before and after the input square.
In typical conv networks, we just throw away the extra info before and after. This is usually not a problem as the filters h
are usually very small (3x3) compared to the signal x
, so it's just easier to discard this.
To handle the padding in the DWT setting, it's a little tricky, but not too complex. What you want to do is:
x # shape [1, 3, 128, 128]
idwt = IDWT(wave='sym6', mode='zero')
dwt = DWT(J=1, wave='sym6', mode='zero')
# pad x to make it the expected 133x133 shape before giving to the IDWT - Use zeros to do this
upsampled = idwt((torch.nn.functional.pad(x, (2, 3, 2, 3)), [None, ])) # shape [1, 3, 256, 256]
...
# The decimated version will have output 133x133 - discard the extra regions. These will not be zero, a bit like the
# triangle output from the gif.
downsampled = dwt(upsampled)[0][..., 2:-3, 2:-3]
If you then compare downsampled
and x
, they will be identical, except some errors at the borders, again a complication of the padding.
from pytorch_wavelets.
Great, thanks. That makes perfect sense. Out of curiosity, after looking at their implementation, do you think there are differences in their approach with what you've explained here at all? The thing I was most confused was the function tf.nn.depthwise_conv2d_backprop_input
, because I couldn't find any example usage and it seemed like it should return gradient information from the description in the docs.
from pytorch_wavelets.
I tested out the up/downsampling method you mentioned and it seems to work perfectly. I'm using pytorch-geometry to apply the projective transformation and it seems to replicate everything in the paper as far as I can tell. Thanks again for the help.
from pytorch_wavelets.
Related Issues (20)
- Support for DoubleFloat HOT 1
- How to implement DTCWT to 1-D signals? HOT 10
- Padding for powers of 2 HOT 1
- Reconstructed wavelet HOT 1
- Export DWTForward to onnx HOT 1
- Not preserving spatial dimensions HOT 1
- Make pypi release with 1D DWT
- using dtcwt as a high freq feature extractor HOT 1
- DWTInverse is not the inverse of DWTForward for small image sizes.
- ValueError: step must be greater than zero HOT 1
- Compatiable with torch AMP?
- How does output yl reflect the two approximate subbands?
- How to change default tensor dtype in pytorch_wavelets HOT 2
- Why must kernels be fixed?
- Hints on avoiding the scaled ouput
- Installation error HOT 1
- How to implement DTCWT for 1-D signals? Looking forward to your answer.
- py37_cu102/fused/fused.so
- How to calculate the params of DWT and IDWT?
- something wrong when we use encoder to process the high_frequency component
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch_wavelets.