Giter Club home page Giter Club logo

pytorch_wavelets's Introduction

2D Wavelet Transforms in Pytorch

build status Documentation Status

The full documentation is also available here.

This package provides support for computing the 2D discrete wavelet and the 2d dual-tree complex wavelet transforms, their inverses, and passing gradients through both using pytorch.

The implementation is designed to be used with batches of multichannel images. We use the standard pytorch implementation of having 'NCHW' data format.

New in version 1.0.0

Version 1.0.0 has now added support for separable DWT calculation, and more padding schemes, such as symmetric, zero and periodization.

Also, no longer need to specify the number of channels when creating the wavelet transform classes.

Speed Tests

We compare doing the dtcwt with the python package and doing the dwt with PyWavelets to doing both in pytorch_wavelets, using a GTX1080. The numpy methods were run on a 14 core Xeon Phi machine using intel's parallel python. For the dtwcwt we use the near_sym_a filters for the first scale and the qshift_a filters for subsequent scales. For the dwt we use the db4 filters.

For a fixed input size, but varying the number of scales (from 1 to 4) we have the following speeds (averaged over 5 runs):

For an input size with height and width 512 by 512, we also vary the batch size for a 3 scale transform. The resulting speeds were:

Installation

The easiest way to install pytorch_wavelets is to clone the repo and pip install it. Later versions will be released on PyPi but the docs need to updated first:

$ git clone https://github.com/fbcotter/pytorch_wavelets
$ cd pytorch_wavelets
$ pip install .

(Although the develop command may be more useful if you intend to perform any significant modification to the library.) A test suite is provided so that you may verify the code works on your system:

$ pip install -r tests/requirements.txt
$ pytest tests/

Example Use

For the DWT - note that the highpass output has an extra dimension, in which we stack the (lh, hl, hh) coefficients. Also note that the Yh output has the finest detail coefficients first, and the coarsest last (the opposite to PyWavelets).

import torch
from pytorch_wavelets import DWTForward, DWTInverse
xfm = DWTForward(J=3, wave='db3', mode='zero')
X = torch.randn(10,5,64,64)
Yl, Yh = xfm(X) 
print(Yl.shape)
>>> torch.Size([10, 5, 12, 12])
print(Yh[0].shape) 
>>> torch.Size([10, 5, 3, 34, 34])
print(Yh[1].shape)
>>> torch.Size([10, 5, 3, 19, 19])
print(Yh[2].shape)
>>> torch.Size([10, 5, 3, 12, 12])
ifm = DWTInverse(wave='db3', mode='zero')
Y = ifm((Yl, Yh))

For the DTCWT:

import torch
from pytorch_wavelets import DTCWTForward, DTCWTInverse
xfm = DTCWTForward(J=3, biort='near_sym_b', qshift='qshift_b')
X = torch.randn(10,5,64,64)
Yl, Yh = xfm(X) 
print(Yl.shape)
>>> torch.Size([10, 5, 16, 16])
print(Yh[0].shape) 
>>> torch.Size([10, 5, 6, 32, 32, 2])
print(Yh[1].shape)
>>> torch.Size([10, 5, 6, 16, 16, 2])
print(Yh[2].shape)
>>> torch.Size([10, 5, 6, 8, 8, 2])
ifm = DTCWTInverse(J=3, biort='near_sym_b', qshift='qshift_b')
Y = ifm((Yl, Yh))

Some initial notes:

  • Yh returned is a tuple. There are 2 extra dimensions - the first comes between the channel dimension of the input and the row dimension. This is the 6 orientations of the DTCWT. The second is the final dimension, which is the real an imaginary parts (complex numbers are not native to pytorch)

Running on the GPU

This should come as no surprise to pytorch users. The DWT and DTCWT transforms support cuda calling:

import torch
from pytorch_wavelets import DTCWTForward, DTCWTInverse
xfm = DTCWTForward(J=3, biort='near_sym_b', qshift='qshift_b').cuda()
X = torch.randn(10,5,64,64).cuda()
Yl, Yh = xfm(X) 
ifm = DTCWTInverse(J=3, biort='near_sym_b', qshift='qshift_b').cuda()
Y = ifm((Yl, Yh))

The automated tests cannot test the gpu functionality, but do check cpu running. To test whether the repo is working on your gpu, you can download the repo, ensure you have pytorch with cuda enabled (the tests will check to see if torch.cuda.is_available() returns true), and run:

pip install -r tests/requirements.txt
pytest tests/

From the base of the repo.

Backpropagation

It is possible to pass gradients through the forward and backward transforms. All you need to do is ensure that the input to each has the required_grad attribute set to true.

Provenance

Based on the Dual-Tree Complex Wavelet Transform Pack for MATLAB by Nick Kingsbury, Cambridge University. The original README can be found in ORIGINAL_README.txt. This file outlines the conditions of use of the original MATLAB toolbox.

pytorch_wavelets's People

Contributors

fbcotter avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.