Comments (9)
Do you mean why I take the negative strides? It's because pytorch conv2d does cross-correlation rather than true convolution, so to prepare for this later, I flip the filters
from pytorch_wavelets.
Do you mean why I take the negative strides? It's because pytorch conv2d does cross-correlation rather than true convolution, so to prepare for this later, I flip the filters
Thanks for your answer,in DWTForward ,you flip the filters,but in DWTInverse you don't flip the filters
from pytorch_wavelets.
import numpy as np
import torch.nn as nn
import torch
import torch.nn.functional as F
def dwt(x):
ll = np.array([[0.5, 0.5], [0.5, 0.5]])
lh = np.array([[-0.5, -0.5], [0.5, 0.5]])
hl = np.array([[-0.5, 0.5], [-0.5, 0.5]])
hh = np.array([[0.5, -0.5], [-0.5, 0.5]])
filts = np.stack([ll[None,], lh[None,],
hl[None,], hh[None,]],
axis=0)
filts = np.copy(filts)
weight = nn.Parameter(
torch.tensor(filts).to(torch.get_default_dtype()),
requires_grad=False)
C = x.shape[1]
filters = torch.cat([weight, ] * C, dim=0)
xs = torch.from_numpy(x).to(torch.float)
y = F.conv2d(xs, filters, groups=C, stride=2)
return y.numpy()
def idwt(y):
ll = np.array([[0.5, 0.5], [0.5, 0.5]])
lh = np.array([[-0.5, -0.5], [0.5, 0.5]])
hl = np.array([[-0.5, 0.5], [-0.5, 0.5]])
hh = np.array([[0.5, -0.5], [-0.5, 0.5]])
filts = np.stack([ll[None,], lh[None,],
hl[None,], hh[None,]],
axis=0)
filts = np.copy(filts)
weight = nn.Parameter(
torch.tensor(filts).to(torch.get_default_dtype()),
requires_grad=False)
C = int(y.shape[1] / 4)
filters = torch.cat([weight, ] * C, dim=0)
ys = torch.from_numpy(y).to(torch.float)
b = F.conv_transpose2d(ys, filters, groups=C, stride=2)
return b.numpy()
x = [[1,2,3,4] , [5,6,7,8] , [9,10,11,12] , [13,14,15,16]]
x = np.array(x)
x = np.expand_dims(x, 0)
x = np.expand_dims(x, 0)
d1 = dwt(x)
d2 = dwt(d1)
i2 = idwt(d2)
i1 = idwt(i2)
print(i1)
>>>[[[[ 1. 2. 3. 4.]
>>> [ 5. 6. 7. 8.]
>>> [ 9. 10. 11. 12.]
>>> [13. 14. 15. 16.]]]]
i change your code to test ,i use haar Wavelet
w = pywt.Wavelet('haar')
ll = np.outer(w.dec_lo, w.dec_lo)
lh = np.outer(w.dec_hi, w.dec_lo)
hl = np.outer(w.dec_lo, w.dec_hi)
hh = np.outer(w.dec_hi, w.dec_hi)
if both DWT and IDWT all flip the filters or all don't flip the filters we can get the output same as input, Can you tell me why,thank you
from pytorch_wavelets.
I don't flip the filters in the inverse because I use conv_transpose2d, which does true convolution. The end result being that both the forward and inverse do proper convolution with non-flipped filters.
Your code example works because you've effectively swapped the analysis and synthesis filters. If you run dwt(x)
and compare the output to pywt.dwt2(x, 'haar')
you will see that your wavelet coefficients are wrong.
Here's an extra bit of info which might explain it a bit more:
Note that for the haar, like all orthogonal wavelets, analysis = flipped(synthesis).
Your dwt: filters are correct, but you use correlation rather than convolution = convolution with flipped analysis filters = convolution with synthesis filters.
Your iwt: you use analysis with true convolution
Result: analysis and synthesis are swapped
from pytorch_wavelets.
Thanks for your reply
I want to build a nerwork DWT-CNN-IDWT,that is to say the cnn fit in the haar wavele domain
have i make some mistakes?
class DWTForward(nn.Module):
def __init__(self):
super().__init__()
ll = np.array([[0.5, 0.5], [0.5, 0.5]])
lh = np.array([[-0.5, -0.5], [0.5, 0.5]])
hl = np.array([[-0.5, 0.5], [-0.5, 0.5]])
hh = np.array([[0.5, -0.5], [-0.5, 0.5]])
filts = np.stack([ll[None,::-1,::-1], lh[None,::-1,::-1],
hl[None,::-1,::-1], hh[None,::-1,::-1]],
axis=0)
self.weight = nn.Parameter(
torch.tensor(filts).to(torch.get_default_dtype()),
requires_grad=False)
def forward(self, x):
C = x.shape[1]
filters = torch.cat([self.weight,] * C, dim=0)
y = F.conv2d(x, filters, groups=C, stride=2)
return y
class DWTInverse(nn.Module):
def __init__(self):
super().__init__()
ll = np.array([[0.5, 0.5], [0.5, 0.5]])
lh = np.array([[-0.5, -0.5], [0.5, 0.5]])
hl = np.array([[-0.5, 0.5], [-0.5, 0.5]])
hh = np.array([[0.5, -0.5], [-0.5, 0.5]])
filts = np.stack([ll[None, ::-1, ::-1], lh[None, ::-1, ::-1],
hl[None, ::-1, ::-1], hh[None, ::-1, ::-1]],
axis=0)
self.weight = nn.Parameter(
torch.tensor(filts).to(torch.get_default_dtype()),
requires_grad=False)
def forward(self, x):
C = int(x.shape[1] / 4)
filters = torch.cat([self.weight, ] * C, dim=0)
y = F.conv_transpose2d(x, filters, groups=C, stride=2)
return y
from pytorch_wavelets.
That's great to hear. So you don't need to rewrite the DWTforward and DWTInverse functions. If you wanted to do something like that you could try something like the below:
from pytorch_wavelets import DWT, IDWT
import torch.nn as nn
class Layer(nn.Module):
def __init__(self, C, F):
self.dwt = DWT(J=1, wave='haar')
self.ll_gain = nn.Conv2d(C, F, 3, padding=1)
self.lh_gain = nn.Conv2d(C, F, 3, padding=1)
self.hl_gain = nn.Conv2d(C, F, 3, padding=1)
self.hh_gain = nn.Conv2d(C, F, 3, padding=1)
self.idwt = IDWT(wave='haar')
def forward(self, x):
yl, yh = self.dwt(x)
yl = self.ll_gain(yl)
lh = self.lh_gain(yh[0][:,:,0])
hl = self.hl_gain(yh[0][:,:,1])
hh = self.hh_gain(yh[0][:,:,2])
yh = (torch.stack((lh, hl, hh), dim=2), )
y = self.idwt((yl, yh))
return y
I wrote a paper recently about learning in the wavelet space, although I used a Dual Tree Complex Wavelet transform rather than the DWT with a Haar wavelet. You can see the paper behind it and the code here.
from pytorch_wavelets.
That's great to hear. So you don't need to rewrite the DWTforward and DWTInverse functions. If you wanted to do something like that you could try something like the below:
from pytorch_wavelets import DWT, IDWT import torch.nn as nn class Layer(nn.Module): def __init__(self, C, F): self.dwt = DWT(J=1, wave='haar') self.ll_gain = nn.Conv2d(C, F, 3, padding=1) self.lh_gain = nn.Conv2d(C, F, 3, padding=1) self.hl_gain = nn.Conv2d(C, F, 3, padding=1) self.hh_gain = nn.Conv2d(C, F, 3, padding=1) self.idwt = IDWT(wave='haar') def forward(self, x): yl, yh = self.dwt(x) yl = self.ll_gain(yl) lh = self.lh_gain(yh[0][:,:,0]) hl = self.hl_gain(yh[0][:,:,1]) hh = self.hh_gain(yh[0][:,:,2]) yh = (torch.stack((lh, hl, hh), dim=2), ) y = self.idwt((yl, yh)) return yI wrote a paper recently about learning in the wavelet space, although I used a Dual Tree Complex Wavelet transform rather than the DWT with a Haar wavelet. You can see the paper behind it and the code here.
Thanks for your reply. I use deep learning to solve level-vision problem in wavelet domain. your work give me a great many idea.Thank you
from pytorch_wavelets.
@fbcotter to confirm, the DWT & IDWT modules are differentiable and can use .backward etc? (seems to be true when I ran a simple test)
from pytorch_wavelets.
Okay, oops, the docs state this already.
Is there a functional API available for the wavelet transforms?
from pytorch_wavelets.
Related Issues (20)
- A bug when importing DWTForward HOT 1
- The order of the input parameters of the AFB2D function HOT 2
- Support for DoubleFloat HOT 1
- How to implement DTCWT to 1-D signals? HOT 10
- Padding for powers of 2 HOT 1
- Reconstructed wavelet HOT 1
- Export DWTForward to onnx HOT 1
- Not preserving spatial dimensions HOT 1
- Make pypi release with 1D DWT
- using dtcwt as a high freq feature extractor HOT 1
- DWTInverse is not the inverse of DWTForward for small image sizes.
- ValueError: step must be greater than zero HOT 1
- Compatiable with torch AMP?
- How does output yl reflect the two approximate subbands?
- How to change default tensor dtype in pytorch_wavelets HOT 2
- Why must kernels be fixed?
- Hints on avoiding the scaled ouput
- Installation error HOT 1
- How to implement DTCWT for 1-D signals? Looking forward to your answer.
- py37_cu102/fused/fused.so
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch_wavelets.