Giter Club home page Giter Club logo

trainable-joint-bilateral-filter-source's Introduction

Trainable Joint Bilateral Filter Layer (PyTorch)

This repository implements a GPU-accelerated trainable joint bilateral filter layer (guidance image + three spatial and one range filter dimension) that can be directly included in any Pytorch graph, just as any conventional layer (FCL, CNN, ...). By calculating the analytical derivative of the joint bilateral filter with respect to its parameters, the guidance image, and the input, the (so far) hyperparameters can be automatically optimized via backpropagation for a calculated loss.

Our associated paper Trainable Joint Bilateral Filters for Enhanced Prediction Stability in Low-dose CT can be found in Scientific Reports (open access) or on arXiv (pre-print).

Setup:

The C++/CUDA implemented forward and backward functions are compiled via the setup.py script using setuptools:

  1. Create and activate a python environment (python>=3.7).
  2. Install Torch (tested versions: 1.7.1, 1.9.0).
  3. Install the joint bilateral filter layer via pip:
pip install jointbilateralfilter_torch

In case you encounter problems with 3. install the layer directly from our GitHub repository:

  1. Download the repository.
  2. Navigate into the extracted repo.
  3. Compile/install the joint bilateral filter layer by calling
python setup.py install

Example scripts:

  • Can be found in our GitHub repository
  • Try out the forward pass by running the example_filter.py (requires Matplotlib and scikit-image).
  • Run the gradcheck.py script to verify the correct gradient implementation.
  • Run example_optimization.py to optimize the parameters of a joint bilateral filter layer to automatically denoise an image.

Optimized joint bilateral filter prediction:

https://github.com/faebstn96/trainable-joint-bilateral-filter-source/blob/main/out/example_optimization.png?raw=true

Citation:

If you find our code useful, please cite our work

@article{wagner2022trainable,
  title={Trainable Joint Bilateral Filters for Enhanced Prediction Stability in Low-dose CT},
  author={Wagner, Fabian and Thies, Mareike and Denzinger, Felix and Gu, Mingxuan and Patwari, Mayank and Ploner, Stefan and Maul, Noah and Pfaff, Laura and Huang, Yixing and Maier, Andreas},
  journal={Scientific Reports},
  volume={12},
  number={1},
  pages={1--9},
  year={2022},
  publisher={Nature Publishing Group},
  doi={https://doi.org/10.1038/s41598-022-22530-4}
 }

Implementation:

The general structure of the implementation follows the PyTorch documentation for creating custom C++ and CUDA extensions. The forward pass implementation of the layer is based on code from the Project MONAI framework, originally published under the Apache License, Version 2.0. The correct implementation of the analytical forward and backward pass can be verified by running the gradcheck.py script, comparing numerical gradients with the derived analytical gradient using the PyTorch built-in gradcheck function.

Troubleshooting

nvcc-related errors:

  1. Compiling the filter layers requires the Nvidia CUDA toolkit. Check its version

    nvcc --version

    or install it via, e.g.,

    sudo apt update
    sudo apt install nvidia-cuda-toolkit
  2. The NVIDIA CUDA toolkit 11.6 made some problems on a Windows machine in combination with pybind. Downgrading the toolkit to version 11.3 fixed the problem (see this discussion).

Windows-related problems:

  1. Make sure the cl.exe environment variable is correctly set.

trainable-joint-bilateral-filter-source's People

Contributors

faebstn96 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

Forkers

rodneyperhaps

trainable-joint-bilateral-filter-source's Issues

inference of JBF

Hello! I am using this trainable JBF for my project and it is working wonderfully, so thank you for sharing!
I have a question regarding the usage of the JBF layer on unseen data: does the standard load_state_dict method load the final weights (sigmas) learned during training (if saved with torch.save), or should I insert them manually when defining the layer?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.