Giter Club home page Giter Club logo

Comments (9)

fbcotter avatar fbcotter commented on May 18, 2024

Do you mean why I take the negative strides? It's because pytorch conv2d does cross-correlation rather than true convolution, so to prepare for this later, I flip the filters

from pytorch_wavelets.

happycaoyue avatar happycaoyue commented on May 18, 2024

Do you mean why I take the negative strides? It's because pytorch conv2d does cross-correlation rather than true convolution, so to prepare for this later, I flip the filters

Thanks for your answer,in DWTForward ,you flip the filters,but in DWTInverse you don't flip the filters

from pytorch_wavelets.

happycaoyue avatar happycaoyue commented on May 18, 2024
import numpy as np
import torch.nn as nn
import torch
import torch.nn.functional as F


def dwt(x):
    ll = np.array([[0.5, 0.5], [0.5, 0.5]])
    lh = np.array([[-0.5, -0.5], [0.5, 0.5]])
    hl = np.array([[-0.5, 0.5], [-0.5, 0.5]])
    hh = np.array([[0.5, -0.5], [-0.5, 0.5]])

    filts = np.stack([ll[None,], lh[None,],
                      hl[None,], hh[None,]],
                     axis=0)
    filts = np.copy(filts)
    weight = nn.Parameter(
        torch.tensor(filts).to(torch.get_default_dtype()),
        requires_grad=False)
    C = x.shape[1]
    filters = torch.cat([weight, ] * C, dim=0)
    xs = torch.from_numpy(x).to(torch.float)
    y = F.conv2d(xs, filters, groups=C, stride=2)

    return y.numpy()

def idwt(y):
    ll = np.array([[0.5, 0.5], [0.5, 0.5]])
    lh = np.array([[-0.5, -0.5], [0.5, 0.5]])
    hl = np.array([[-0.5, 0.5], [-0.5, 0.5]])
    hh = np.array([[0.5, -0.5], [-0.5, 0.5]])
    filts = np.stack([ll[None,], lh[None,],
                      hl[None,], hh[None,]],
                     axis=0)

    filts = np.copy(filts)
    weight = nn.Parameter(
        torch.tensor(filts).to(torch.get_default_dtype()),
        requires_grad=False)

    C = int(y.shape[1] / 4)
    filters = torch.cat([weight, ] * C, dim=0)
    ys = torch.from_numpy(y).to(torch.float)
    b = F.conv_transpose2d(ys, filters, groups=C, stride=2)
    return b.numpy()

x = [[1,2,3,4] , [5,6,7,8] , [9,10,11,12] , [13,14,15,16]]
x = np.array(x)
x = np.expand_dims(x, 0)
x = np.expand_dims(x, 0)


d1 = dwt(x)
d2 = dwt(d1)
i2 = idwt(d2)
i1 = idwt(i2)
print(i1)

>>>[[[[ 1.  2.  3.  4.]
>>>   [ 5.  6.  7.  8.]
>>>   [ 9. 10. 11. 12.]
>>>   [13. 14. 15. 16.]]]]

i change your code to test ,i use haar Wavelet

w = pywt.Wavelet('haar')
ll = np.outer(w.dec_lo, w.dec_lo)
lh = np.outer(w.dec_hi, w.dec_lo)
hl = np.outer(w.dec_lo, w.dec_hi)
hh = np.outer(w.dec_hi, w.dec_hi)

if both DWT and IDWT all flip the filters or all don't flip the filters we can get the output same as input, Can you tell me why,thank you

from pytorch_wavelets.

fbcotter avatar fbcotter commented on May 18, 2024

I don't flip the filters in the inverse because I use conv_transpose2d, which does true convolution. The end result being that both the forward and inverse do proper convolution with non-flipped filters.

Your code example works because you've effectively swapped the analysis and synthesis filters. If you run dwt(x) and compare the output to pywt.dwt2(x, 'haar') you will see that your wavelet coefficients are wrong.

Here's an extra bit of info which might explain it a bit more:
Note that for the haar, like all orthogonal wavelets, analysis = flipped(synthesis).

Your dwt: filters are correct, but you use correlation rather than convolution = convolution with flipped analysis filters = convolution with synthesis filters.
Your iwt: you use analysis with true convolution
Result: analysis and synthesis are swapped

from pytorch_wavelets.

happycaoyue avatar happycaoyue commented on May 18, 2024

Thanks for your reply
I want to build a nerwork DWT-CNN-IDWT,that is to say the cnn fit in the haar wavele domain
have i make some mistakes?

class DWTForward(nn.Module):

    def __init__(self):
        super().__init__()
        ll = np.array([[0.5, 0.5], [0.5, 0.5]])
        lh = np.array([[-0.5, -0.5], [0.5, 0.5]])
        hl = np.array([[-0.5, 0.5], [-0.5, 0.5]])
        hh = np.array([[0.5, -0.5], [-0.5, 0.5]])
        filts = np.stack([ll[None,::-1,::-1], lh[None,::-1,::-1],
                          hl[None,::-1,::-1], hh[None,::-1,::-1]],
                         axis=0)
        self.weight = nn.Parameter(
            torch.tensor(filts).to(torch.get_default_dtype()),
            requires_grad=False)

    def forward(self, x):

        C = x.shape[1]
        filters = torch.cat([self.weight,] * C, dim=0)

        y = F.conv2d(x, filters, groups=C, stride=2)

        return y


class DWTInverse(nn.Module):
    def __init__(self):
        super().__init__()
        ll = np.array([[0.5, 0.5], [0.5, 0.5]])
        lh = np.array([[-0.5, -0.5], [0.5, 0.5]])
        hl = np.array([[-0.5, 0.5], [-0.5, 0.5]])
        hh = np.array([[0.5, -0.5], [-0.5, 0.5]])
        filts = np.stack([ll[None, ::-1, ::-1], lh[None, ::-1, ::-1],
                          hl[None, ::-1, ::-1], hh[None, ::-1, ::-1]],
                         axis=0)
        self.weight = nn.Parameter(
            torch.tensor(filts).to(torch.get_default_dtype()),
            requires_grad=False)

    def forward(self, x):
        C = int(x.shape[1] / 4)
        filters = torch.cat([self.weight, ] * C, dim=0)
        y = F.conv_transpose2d(x, filters, groups=C, stride=2)
        return y

from pytorch_wavelets.

fbcotter avatar fbcotter commented on May 18, 2024

That's great to hear. So you don't need to rewrite the DWTforward and DWTInverse functions. If you wanted to do something like that you could try something like the below:

from pytorch_wavelets import DWT, IDWT
import torch.nn as nn

class Layer(nn.Module):
    def __init__(self, C, F):
        self.dwt = DWT(J=1, wave='haar')
        self.ll_gain = nn.Conv2d(C, F, 3, padding=1)
        self.lh_gain = nn.Conv2d(C, F, 3, padding=1)
        self.hl_gain = nn.Conv2d(C, F, 3, padding=1)
        self.hh_gain = nn.Conv2d(C, F, 3, padding=1)
        self.idwt = IDWT(wave='haar')

    def forward(self, x):
        yl, yh = self.dwt(x)
        yl = self.ll_gain(yl)
        lh = self.lh_gain(yh[0][:,:,0])
        hl = self.hl_gain(yh[0][:,:,1])
        hh = self.hh_gain(yh[0][:,:,2])
        yh = (torch.stack((lh, hl, hh), dim=2), )
        y = self.idwt((yl, yh))
        return y

I wrote a paper recently about learning in the wavelet space, although I used a Dual Tree Complex Wavelet transform rather than the DWT with a Haar wavelet. You can see the paper behind it and the code here.

from pytorch_wavelets.

happycaoyue avatar happycaoyue commented on May 18, 2024

That's great to hear. So you don't need to rewrite the DWTforward and DWTInverse functions. If you wanted to do something like that you could try something like the below:

from pytorch_wavelets import DWT, IDWT
import torch.nn as nn

class Layer(nn.Module):
    def __init__(self, C, F):
        self.dwt = DWT(J=1, wave='haar')
        self.ll_gain = nn.Conv2d(C, F, 3, padding=1)
        self.lh_gain = nn.Conv2d(C, F, 3, padding=1)
        self.hl_gain = nn.Conv2d(C, F, 3, padding=1)
        self.hh_gain = nn.Conv2d(C, F, 3, padding=1)
        self.idwt = IDWT(wave='haar')

    def forward(self, x):
        yl, yh = self.dwt(x)
        yl = self.ll_gain(yl)
        lh = self.lh_gain(yh[0][:,:,0])
        hl = self.hl_gain(yh[0][:,:,1])
        hh = self.hh_gain(yh[0][:,:,2])
        yh = (torch.stack((lh, hl, hh), dim=2), )
        y = self.idwt((yl, yh))
        return y

I wrote a paper recently about learning in the wavelet space, although I used a Dual Tree Complex Wavelet transform rather than the DWT with a Haar wavelet. You can see the paper behind it and the code here.

Thanks for your reply. I use deep learning to solve level-vision problem in wavelet domain. your work give me a great many idea.Thank you

from pytorch_wavelets.

varun19299 avatar varun19299 commented on May 18, 2024

@fbcotter to confirm, the DWT & IDWT modules are differentiable and can use .backward etc? (seems to be true when I ran a simple test)

from pytorch_wavelets.

varun19299 avatar varun19299 commented on May 18, 2024

Okay, oops, the docs state this already.

Is there a functional API available for the wavelet transforms?

from pytorch_wavelets.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.