Giter Club home page Giter Club logo

pytorch_wavelets's Introduction

2D Wavelet Transforms in Pytorch

build status Documentation Status doi

The full documentation is also available here.

This package provides support for computing the 2D discrete wavelet and the 2d dual-tree complex wavelet transforms, their inverses, and passing gradients through both using pytorch.

The implementation is designed to be used with batches of multichannel images. We use the standard pytorch implementation of having 'NCHW' data format.

We also have added layers to do the 2-D DTCWT based scatternet. This is similar to the Morlet based scatternet in KymatIO, but is roughly 10 times faster.

If you use this repo, please cite my PhD thesis, chapter 3: https://doi.org/10.17863/CAM.53748.

New in version 1.3.0

  • Added 1D DWT support
import torch
from pytorch_wavelets import DWT1DForward, DWT1DInverse  # or simply DWT1D, IDWT1D
dwt = DWT1DForward(wave='db6', J=3)
X = torch.randn(10, 5, 100)
yl, yh = dwt(X)
print(yl.shape)
>>> torch.Size([10, 5, 22])
print(yh[0].shape)
>>> torch.Size([10, 5, 55])
print(yh[1].shape)
>>> torch.Size([10, 5, 33])
print(yh[2].shape)
>>> torch.Size([10, 5, 22])
idwt = DWT1DInverse(wave='db6')
x = idwt((yl, yh))

New in version 1.2.0

  • Added a DTCWT based ScatterNet
import torch
from pytorch_wavelets import ScatLayer
scat = ScatLayer()
X = torch.randn(10,5,64,64)
# A first order scatternet with 6 orientations and one lowpass channels
# gives 7 times the input channel dimension
Z = scat(X)
print(Z.shape)
>>> torch.Size([10, 35, 32, 32])
# A second order scatternet with 6 orientations and one lowpass channels
# gives 7^2 times the input channel dimension
scat2 = torch.nn.Sequential(ScatLayer(), ScatLayer())
Z = scat2(X)
print(Z.shape)
>>> torch.Size([10, 245, 16, 16])
# We also have a slightly more specialized, but slower, second order scatternet
from pytorch_wavelets import ScatLayerj2
scat2a = ScatLayerj2()
Z = scat2a(X)
print(Z.shape)
>>> torch.Size([10, 245, 16, 16])
# These all of course work with cuda
scat2a.cuda()
Z = scat2a(X.cuda())

New in version 1.1.0

  • Fixed memory problem with dwt
  • Fixed the backend code for the dtcwt calculation - much cleaner now but similar performance
  • Both dtcwt and dwt should be more memory efficient/aware now.
  • Removed need to specify number of scales for DTCWTInverse

New in version 1.0.0

Version 1.0.0 has now added support for separable DWT calculation, and more padding schemes, such as symmetric, zero and periodization.

Also, no longer need to specify the number of channels when creating the wavelet transform classes.

Speed Tests

We compare doing the dtcwt with the python package and doing the dwt with PyWavelets to doing both in pytorch_wavelets, using a GTX1080. The numpy methods were run on a 14 core Xeon Phi machine using intel's parallel python. For the dtwcwt we use the near_sym_a filters for the first scale and the qshift_a filters for subsequent scales. For the dwt we use the db4 filters.

For a fixed input size, but varying the number of scales (from 1 to 4) we have the following speeds (averaged over 5 runs):

For an input size with height and width 512 by 512, we also vary the batch size for a 3 scale transform. The resulting speeds were:

Installation

The easiest way to install pytorch_wavelets is to clone the repo and pip install it. Later versions will be released on PyPi but the docs need to updated first:

$ git clone https://github.com/fbcotter/pytorch_wavelets
$ cd pytorch_wavelets
$ pip install .

(Although the develop command may be more useful if you intend to perform any significant modification to the library.) A test suite is provided so that you may verify the code works on your system:

$ pip install -r tests/requirements.txt
$ pytest tests/

Example Use

For the DWT - note that the highpass output has an extra dimension, in which we stack the (lh, hl, hh) coefficients. Also note that the Yh output has the finest detail coefficients first, and the coarsest last (the opposite to PyWavelets).

import torch
from pytorch_wavelets import DWTForward, DWTInverse
xfm = DWTForward(J=3, wave='db3', mode='zero')
X = torch.randn(10,5,64,64)
Yl, Yh = xfm(X) 
print(Yl.shape)
>>> torch.Size([10, 5, 12, 12])
print(Yh[0].shape) 
>>> torch.Size([10, 5, 3, 34, 34])
print(Yh[1].shape)
>>> torch.Size([10, 5, 3, 19, 19])
print(Yh[2].shape)
>>> torch.Size([10, 5, 3, 12, 12])
ifm = DWTInverse(wave='db3', mode='zero')
Y = ifm((Yl, Yh))

For the DTCWT:

import torch
from pytorch_wavelets import DTCWTForward, DTCWTInverse
xfm = DTCWTForward(J=3, biort='near_sym_b', qshift='qshift_b')
X = torch.randn(10,5,64,64)
Yl, Yh = xfm(X) 
print(Yl.shape)
>>> torch.Size([10, 5, 16, 16])
print(Yh[0].shape) 
>>> torch.Size([10, 5, 6, 32, 32, 2])
print(Yh[1].shape)
>>> torch.Size([10, 5, 6, 16, 16, 2])
print(Yh[2].shape)
>>> torch.Size([10, 5, 6, 8, 8, 2])
ifm = DTCWTInverse(biort='near_sym_b', qshift='qshift_b')
Y = ifm((Yl, Yh))

Some initial notes:

  • Yh returned is a tuple. There are 2 extra dimensions - the first comes between the channel dimension of the input and the row dimension. This is the 6 orientations of the DTCWT. The second is the final dimension, which is the real an imaginary parts (complex numbers are not native to pytorch)

Running on the GPU

This should come as no surprise to pytorch users. The DWT and DTCWT transforms support cuda calling:

import torch
from pytorch_wavelets import DTCWTForward, DTCWTInverse
xfm = DTCWTForward(J=3, biort='near_sym_b', qshift='qshift_b').cuda()
X = torch.randn(10,5,64,64).cuda()
Yl, Yh = xfm(X) 
ifm = DTCWTInverse(biort='near_sym_b', qshift='qshift_b').cuda()
Y = ifm((Yl, Yh))

The automated tests cannot test the gpu functionality, but do check cpu running. To test whether the repo is working on your gpu, you can download the repo, ensure you have pytorch with cuda enabled (the tests will check to see if torch.cuda.is_available() returns true), and run:

pip install -r tests/requirements.txt
pytest tests/

From the base of the repo.

Backpropagation

It is possible to pass gradients through the forward and backward transforms. All you need to do is ensure that the input to each has the required_grad attribute set to true.

Provenance

Based on the Dual-Tree Complex Wavelet Transform Pack for MATLAB by Nick Kingsbury, Cambridge University. The original README can be found in ORIGINAL_README.txt. This file outlines the conditions of use of the original MATLAB toolbox.

Further information on the DT CWT can be obtained from papers downloadable from my website (given below). The best tutorial is in the 1999 Royal Society Paper. In particular this explains the conversion between 'real' quad-number subimages and pairs of complex subimages. The Q-shift filters are explained in the ICIP 2000 paper and in more detail in the May 2001 paper for the Journal on Applied and Computational Harmonic Analysis.

This code is copyright and is supplied free of charge for research purposes only. In return for supplying the code, all I ask is that, if you use the algorithms, you give due reference to this work in any papers that you write and that you let me know if you find any good applications for the DT CWT. If the applications are good, I would be very interested in collaboration. I accept no liability arising from use of these algorithms.

Nick Kingsbury, Cambridge University, June 2003.

Dr N G Kingsbury, Dept. of Engineering, University of Cambridge, Trumpington St., Cambridge CB2 1PZ, UK., or Trinity College, Cambridge CB2 1TQ, UK. Phone: (0 or +44) 1223 338514 / 332647; Home: 1954 211152; Fax: 1223 338564 / 332662; E-mail: [email protected] Web home page: http://www.eng.cam.ac.uk/~ngk/

pytorch_wavelets's People

Contributors

cfinlay avatar fbcotter avatar mclaughlin6464 avatar mike9251 avatar voletiv avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

pytorch_wavelets's Issues

ValueError: step must be greater than zero

Hi, great work! Recently, I encountered the error as stated in the issue title when I use the method DTCWTInverse. And the error is located in pytorch_wavelets/dtcwt/transform2d.py line 224.
issue

I found that Pytorch Tensor object does not support slice operation with step=-1. I wonder why nobody issued such an issue. Does it come along with the new version of Pytorch?

How to change default tensor dtype in pytorch_wavelets

I want to use torch.float16 data type when DWT, but i got a error:

RuntimeError: expected scalar type Half but found Float

I think the reason is that the data type of the DWT function is torch.float32, but input data type is torch.float16. Is there any way to change the default torch.float32 type of the DWT function to torch.float16.

How to implement DTCWT to 1-D signals?

Thanks for such a convenient framework! Is there anyway or method to implement DTCWT to 1-D signals? Also, does DTCWT has advantage over DWT in signal processing?

no module

wave = pywt.Wavelet(wave)
AttributeError: module 'pywt' has no attribute 'Wavelet'

Hints on avoiding the scaled ouput

Hi,

First of all, congrats on the nice work.

I wanted to ask how would it be possible to avoid that the height and width at the output of ScatLayer is halved every scale?

Best

The order of the input parameters of the AFB2D function

Thanks for your great work.
I have a question about the order of the input parameters of the AFB2D function. In the transform2d.py line 70-71, the function take input parameters as
ll, high = lowlevel.AFB2D.apply(ll, self.h0_col, self.h1_col, self.h0_row, self.h1_row, mode)

ll, high = lowlevel.AFB2D.apply(

however, the function AFB2D forwardd in lowlevel.py line 336 take the order of inputs as h0_row, h1_row, h0_col, h1_col

def forward(ctx, x, h0_row, h1_row, h0_col, h1_col, mode):

so, the order of the input "col col row row" seems to be inconsistent with the function's definition "row row col col". I wonder why it is entered in this way?

Stationary wavelet transform

Hello,

First of all, thank you for the nice implementation of wavelet transform in PyTorch.

I tried to run your implementation of 2d stationary wavelet transform (or undecimated wavelet transform) but it throws an error. Is there any way to fix it?

Thank you!

Screen Shot 2021-01-19 at 1 15 45 PM

Up/Downsampling Questions

Hello,

I don't have too much of a signal processing background, so apologies for being uninformed here. I'm trying to replicate the recent Adaptive Discriminator Augmentation code from the StyleGAN team in Pytorch, and it seems like they're using the sym6 wavelet to upsample an image before applying a projective matrix to an image and then downsampling using the same wavelet. They use tensorflows tf.nn.depthwise_conv2d_backprop_input function to achieve this. I can't seem to find much literature on up and downsampling via wavelets, but the most I could gather was that the DWT could be used to accomplish this.

I tried using the DWT in this library to upsample by first initializing a matrix of zeros with the upsample dimensions and filling in the upper left quadrant with the original image and then applying the DWT (according to the instructions from one paper I came across by Microsoft), but I'm not sure how to use the output to reconstruct the upsampled image.

The input image dimensions are [1, 3, 128, 128], and I want the upsample dimensions to be [1, 3, 256, 256]. After applying the DWT with J=3 and wave='sym6', I get yl with dimensions [1, 3, 41, 41], and yh with dimensions [1, 3, 3, 133, 133], [1, 3, 3, 72, 72], [1, 3, 3, 41, 41]. Is there a way to assemble that output into an upsampled image or am I doing something fundamentally wrong here? And how might I downsample the resulting upsampled image after applying a transformation to it? Would I just pass (yl2, yh) to the inverse transform?

Any help here is much appreciated. Thanks.

A problem about multiple GPUs training.

Hi,fbcotter. I'm very interested in your dissertation and this github repo.

I met a problem when I called DTCWTForward().

Can DTCWTForward() support training on multiple GPUs. I want to import pytorch_wavelets , and then train CNN on 8 GPUS .

Thx!!!

Padding for powers of 2

wtpadding
I have noticed that DWT always applies padding, even in the case of perfect powers of 2 (see above).
Also, when the input dimensions are odd the inverse has incorrect dimensions as shown in #22 (though it seems this can be fixed by just cropping the result).
I think it would be nice if all of the padding and cropping was handled internally.

Installation error

Hi!
I can't install this package correctly.
When I execute "pip install .", I always get error messages "ERROR: Directory '.' is not installable. Neither 'setup.py' nor 'pyproject.toml' found.". Do you know the reason? Thanks!

Extension to 3D Wavelet Transform?

Hi, Is there any future work planned toward allowing 3D DWT transforms? E.g. on an array of shape (batch, channels, height, width, depth) .

I'd be very interested in working on this myself if you know of any good theoretical literature etc. Cheers!

Export DWTForward to onnx

HI, nice work!

I try to export a pytorch model which contains DWTForward module, just as logs shows, an error occurs, after go deep into source code, I find that [AFB2D] is a python function which is consit of a set of pytorch op, which can not be exported to onnx?
Do you have any idea how to export this [AFB2D] to onnx model? thank you!

> RuntimeError: ONNX export failed: Couldn't export Python operator AFB2D

Is SWT avaliable now?

Hi, @fbcotter, thanks for your wonderful work. I am now trying to use SWT and I found it in transform2d.py. Is it avaliable now?

Reconstructed wavelet

hello, how can i display 6 independent images after reconstruction of 6 subbands after n times of decomposition?

about image size

Forgive me for being a newbie, I would like to ask, if my input image size is odd, after DWT, IDWT, how can I get the same size. Thank you very much, if you can give me any answer.

Compatiable with torch AMP?

It seems like this module is not trained with mixed precision.

Do you plan on adding this feature?

Thank you,

Alternating the filters

Y = torch.stack((X[:, ch:], X[:, :ch]), dim=-1).view(batch, ch, r, c2)

Hi Fergal,

first of all: Amazing implementation you have done here.

Regarding the references line, shouldnt the highpass outputs remain the same but the lowpass outputs be interchanged so that the other filters apply on the next iteration? Maybe I am missing something here

Also, ist the start from index 2 and 3 in this line

X = torch.cat((X[:,:,:,xe[2::2]], X[:,:,:,xe[3::2]]), dim=1)

due to the fact that

The delay differences must not be swapped, eben when the filters are swapped, so an extra delay of one sample must be included as required ...

given in Selesnicks reference paper?

thank you in advance

precision problem of DWTForward&DWTInverse

I find that there is obvious difference between before pytorch_wavelets.DWTForward and after pytorch_wavelets.DWTInverse.

import torch
import pytorch_wavelets

DWTForward = pytorch_wavelets.DWTForward()
DWTInverse = pytorch_wavelets.DWTInverse()

x = torch.abs(torch.randn(17, 7, 32, 32))
yl, yh = DWTForward(x)
y = DWTInverse((yl, yh))
print(torch.sum(torch.abs(y - x)))

and the difference is about 0.0080

Separable option gone from DWT/IDWT?

Hello,

Thanks a lot for this library, I was wondering if you had removed the separable flag from DWT/IDWT?

separable (bool): whether to do the filtering separably or not (the

It does not seem like it is allowed to use the separable flag in the constructor
def __init__(self, J=1, wave='db1', mode='zero'):

You mentioned it could sometimes be faster to not use the separable implementation so I was curious to try.
Thanks in advance!

Not preserving spatial dimensions

Good day!
Thanks a lot for your efforts with this lib!
Recently I encountered some problems with preserving spatial dimensions of tensor.

j = 3  
wave = 'db1'  
mode = 'symmetric'  
layer0 = DWTForward(J=j, wave=wave, mode=mode)  
layer1 = DWTInverse(wave=wave, mode=mode)  
test_input = torch.arange(27).reshape(1, 3, 3, 3).to(torch.float32)  
low, high = layer0(test_input)  
test_output = layer1((low, high))  
print(test_input.shape, test_output.shape)  

Expected to get (1, 3, 3, 3) but got (1, 3, 4, 4)

Installation not working

Hello:-)

I installed the package using pip install . in an Anaconda Environment, but when I try to import pytorch_wavelets in a jupyter notebook running with the same environment I get an error.
ModuleNotFoundError Traceback (most recent call last)
in
17 import tqdm
18 #from PyTorchWavelets import *
---> 19 import pytorch_wavelets
20 from pytorch_wavelets import DWTForward
21

ModuleNotFoundError: No module named 'pytorch_wavelets'

Help would be appreciated:-)

Thank you,
Maria

Support for DoubleFloat

Thanks for such a handy toolbox! It would be great if the package supports double precision data (not too much practical use, but can avoid some errors.)

Also, I had a very(!) ugly implementation of the 'array_to_coeffs' here
Hope this would help.

I have a question to ask you

in pytorch_wavelets.dwt.transform2d.py you use filts = np.stack([ll[None,::-1,::-1], lh[None,::-1,::-1], hl[None,::-1,::-1], hh[None,::-1,::-1]], axis=0)

I don't konw why make ll lh hl hh change ?

something wrong when we use encoder to process the high_frequency component

Thanks for sharing your great work!

After decomposing the image (feature map) using the wavelet transform method you provided. We input the low-frequency and high-frequency components into an encoder network separately. After getting the encoded information, we found more detailed information in the high frequencies. Then, we have the following problem after re-inverting the encoded processed low and high-frequency components back to the time domain. A very distinct grid appears in the image (feature map). We have repeatedly experimented and found that coding operations on the low-frequency components do not cause this phenomenon. At the same time, this issue occurs only when coding operations are performed on high frequencies.

The processing of our work:
image

Could you please help us analyze the reason for this "grid effect"? Is it because the encoder changes the value of the high-frequency information so that the phase information in the high-frequency information changes? Does it end up causing the grid effect when inverted? If so, is there any good way to solve this problem?

Memory Overuse in DWT

Currently in the calculation of the DWT forward pass for an input that requires gradient, too much memory is being used. I believe it's because pytorch is saving all the intermediate activations.

How to make the input and output consistent?

I found that if the size is even, the input and output are the same, if it is odd, it is different.

import torch
from pytorch_wavelets import DTCWTForward, DTCWTInverse
xfm = DTCWTForward(J=3, biort='near_sym_b', qshift='qshift_b').cuda()
X = torch.randn(2, 2048, 26, 32)
Yl, Yh = xfm(X.cuda())
ifm = DTCWTInverse().cuda()
Y = ifm((Yl, Yh))
print(Y.shape)

I got:
torch.Size([2, 2048, 26, 32])

import torch
from pytorch_wavelets import DTCWTForward, DTCWTInverse
xfm = DTCWTForward(J=3, biort='near_sym_b', qshift='qshift_b').cuda()
X = torch.randn(2, 2048, 25, 31)
Yl, Yh = xfm(X.cuda())
ifm = DTCWTInverse().cuda()
Y = ifm((Yl, Yh))
print(Y.shape)

I got:
torch.Size([2, 2048, 26, 32])

DWTInverse is not the inverse of DWTForward for small image sizes.

For image sizes smaller than the wavelet support, DWTInverse is not the inverse operation of DWTForward (probably due to border effects).

How to reproduce

With pytorch_wavelets 1.3.0 and python 3.8.11:

from pytorch_wavelets import DWTForward, DWTInverse

def wavelet_check(L):
    n = lambda x: torch.norm(torch.flatten(x))
    rel_err = lambda true, other: (n(true - other) / n(true)).item()

    forward = DWTForward(wave="db4", mode="periodization")
    inverse = DWTInverse(wave="db4", mode="periodization")

    N = 10
    x = torch.rand((N, 4, L, L))
    y = inverse((x[:, :1], [x[:, None, 1:4]]))
    l, h = forward(y)
    x_rec = torch.concat((l, h[0][:, 0, ...]), dim=1)
    print(f"L={L}, relative error {rel_err(x, x_rec):.2}")

for j in range(5):
    wavelet_check(2 ** j)

produces the following output:

L=1, relative error 1.0
L=2, relative error 0.88
L=4, relative error 1.2e-07
L=8, relative error 1.2e-07
L=16, relative error 1.2e-07

License

What is the license of this repo?
Thanks

using dtcwt as a high freq feature extractor

I would like to extract high frequency components of an image and I am wondering how I can extract this using dtcwt. So, I looked at the examples, and I presume I can do something like this:

import torch, sys
from pytorch_wavelets import DTCWTForward
xfm1 = DTCWTForward(J=3)
x = torch.randn(1, 3, 64, 64) # 3 channel 64x64 img example
yl, yh = xfm1(x)
print(yl.shape)
print(len(yh),yh[0].shape)

I presume here yl and yh are low and high frequency components respectively. However, I see that the HF components are complex :(

I guess my question is, how do I (in a sensible way) get a 1D flattened real vector (amplitude?) from the yh components.

Apologies if this is daft question and thank you ever so much for this awesome package.

export to onnx

I build a neural network using ScatLayer as one layer.
when I export the model from pytorch to onnx format, error occurs.

pytorch_wavelets\utils.py", line 162, in reflect
out = np.where(normed_mod >= rng, rng_by_2 - normed_mod, normed_mod) + minx
TypeError: '>=' not supported between instances of 'numpy.ndarray' and 'Tensor'

py37_cu102/fused/fused.so

Thanks for sharing your work to support research communities'.

I try to run cifar10 with creating new environment according to your instruction but it giving some error. It seems to be a cuda or pytorch problem, but i am same environment as your provided.
week4

DataParallel on GPUs not supported?

I've tried your implementations using multi-GPU and failed.

the code is as follows

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

xfm = DWTForward(J=3, wave='db3', mode='periodization')

if torch.cuda.device_count() > 1:
    xfm = nn.DataParallel(xfm)
xfm.to(device)

X = torch.randn(10,3,256,256).to(device)
Yl, Yh = xfm(X)

the error is reported as follows

Traceback (most recent call last):
  File "test.py", line 16, in <module>
    Yl, Yh = xfm(X)
  File "/home/opt/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/opt/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 123, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/home/opt/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 133, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/home/opt/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 77, in parallel_apply
    raise output
  File "/home/opt/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 53, in _worker
    output = module(*input, **kwargs)
  File "/home/opt/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/opt/lhp/DWT/pytorch_wavelets/dwt/transform2d.py", line 77, in forward
    y = lowlevel.afb2d(ll, self.h, self.mode)
  File "/home/opt/lhp/DWT/pytorch_wavelets/dwt/lowlevel.py", line 229, in afb2d
    lohi = afb1d(x, h0_row, h1_row, mode=mode, dim=3)
  File "/home/opt/lhp/DWT/pytorch_wavelets/dwt/lowlevel.py", line 107, in afb1d
    lohi = F.conv2d(x, h, padding=pad, stride=s, groups=C)
RuntimeError: Assertion `THCTensor_(checkGPU)(state, 3, input, output, weight)' failed. Some of weight/gradient/input tensors are located on different GPUs. Please move them to a single one. at /pytorch/aten/src/THCUNN/generic/SpatialDepthwiseConvolution.cu:16

KeyError: 'near_sym_a'

Hello, I am trying this code, but the strange thing is that one of my GPUs can be used, but the other one reports an error:

Traceback (most recent call last):
  File "/home/ubuntu/anaconda3/envs/d2/lib/python3.6/site-packages/pytorch_wavelets/dtcwt/coeffs.py", line 20, in _load_from_file
    mat = COEFF_CACHE[basename]
KeyError: 'near_sym_a'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/ubuntu/anaconda3/envs/d2/lib/python3.6/site-packages/pkg_resources/__init__.py", line 343, in get_provider
    module = sys.modules[moduleOrReq]
KeyError: 'pytorch_wavelets.dtcwt.data'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "t.py", line 3, in <module>
    scat = ScatLayer().cuda()
  File "/home/ubuntu/anaconda3/envs/d2/lib/python3.6/site-packages/pytorch_wavelets/scatternet/layers.py", line 47, in __init__
    h0o, _, h1o, _ = _biort(biort)
  File "/home/ubuntu/anaconda3/envs/d2/lib/python3.6/site-packages/pytorch_wavelets/dtcwt/coeffs.py", line 38, in biort
    return level1(name, compact=True)
  File "/home/ubuntu/anaconda3/envs/d2/lib/python3.6/site-packages/pytorch_wavelets/dtcwt/coeffs.py", line 74, in level1
    return _load_from_file(name, ('h0o', 'g0o', 'h1o', 'g1o'))
  File "/home/ubuntu/anaconda3/envs/d2/lib/python3.6/site-packages/pytorch_wavelets/dtcwt/coeffs.py", line 22, in _load_from_file
    with resource_stream('pytorch_wavelets.dtcwt.data', basename + '.npz') as f:
  File "/home/ubuntu/anaconda3/envs/d2/lib/python3.6/site-packages/pkg_resources/__init__.py", line 1134, in resource_stream
    return get_provider(package_or_requirement).get_resource_stream(
  File "/home/ubuntu/anaconda3/envs/d2/lib/python3.6/site-packages/pkg_resources/__init__.py", line 345, in get_provider
    __import__(moduleOrReq)
ModuleNotFoundError: No module named 'pytorch_wavelets.dtcwt.data'

Support for double precision

Thank you for publishing this package, I have found it very useful. In my application I use double precision floating point numbers, so I found that I had to change some of the dtypes in your code to make it work for me. Would you consider adding an optional dtype argument to your module constructors to accommodate using other dtypes more easily?

How to use wavelets for channel last

Hi,

I'm trying to use wavelets in my CNN network. I'm replacing max-pooling with wavelet for downsampling. I stumble upon your code. I input my image tensor in (H, W, channel) but your code is channel first. Could we do this with your code?

1d discrete wavelet transform

I want to forward 4d input tensor with 1d discrete wavelet transform through specific axis(e.g Channel axis). Is there any api for it or should i implement for it?

How to cite?

Is there any way to cite this library? Do you have any preference?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.