Giter Club home page Giter Club logo

redunet's Introduction

Deep Networks from the Principle of Rate Reduction

This repository is the official PyTorch implementation of the paper ReduNet: A White-box Deep Network from the Principle of Maximizing Rate Reduction (2021)

by Kwan Ho Ryan Chan* (UC Berkeley), Yaodong Yu* (UC Berkeley), Chong You* (UC Berkeley), Haozhi Qi (UC Berkeley), John Wright (Columbia University), and Yi Ma (UC Berkeley).

What is ReduNet?

ReduNet is a deep neural network construcuted naturally by deriving the gradients of the Maximal Coding Rate Reduction (MCR2) [1] objective. Every layer of this network can be interpreted based on its mathematical operations and the network collectively is trained in a feed-forward manner only. In addition, by imposing shift invariant properties to our network, the convolutional operator can be derived using only the data and MCR2 objective function, hence making our network design principled and interpretable.


Figure: Weights and operations for one layer of ReduNet

[1] Yu, Yaodong, Kwan Ho Ryan Chan, Chong You, Chaobing Song, and Yi Ma. "Learning diverse and discriminative representations via the principle of maximal coding rate reduction" Advances in Neural Information Processing Systems 33 (2020).

Requirements

This codebase is written for python3. To install necessary python packages, run conda create --name redunet_official --file requirements.txt.

Demo

For a quick demonstration of ReduNet on Gaussian 2D or 3D cases, please visit the notebook by running one of the two commands:

$ jupyter notebook ./examples/gaussian2d.ipynb
$ jupyter notebook ./examples/gaussian3d.ipynb

Core Usage and Design

The design of this repository aims to be easy-to-use and easy-to-intergrate to the current framework of your experiment, as long as it uses PyTorch. The ReduNet object inherents from nn.Sequential, and layers ReduLayers, such as Vector, Fourier1D and Fourier2D inherent from nn.Module. Loss functions are implemented in loss.py. Architectures and Dataset options are located in load.py file. Data objects and pre-set architectures are loaded in folders dataset and architectures. Feel free to add more based on the experiments you want to run. We have provided basic experiment setups, located in train_<mode>.py and evaluate_<mode>.py, where <mode> is the type of experiment. For utility functions, please check out functional.py or utils.py. Feel free to email us if there are any issues or suggestions.

Example: Forward Construction

To train a ReduNet using forward construction, please checkout train_forward.py. For evaluating, please checkout evaluate_forward.py. For example, to train on 40-layer ReduNet on MNIST using 1000 samples per class, run:

$ python3 train_forward.py --data mnistvector --arch layers50 --samples 1000

After training, you can evaluate the trained model using evaluate_forward.py, by running:

$ python3 evaluate_forward.py --model_dir ./saved_models/forward/mnistvector+layers50/samples1000 

, which will evaluate using all available training samples and testing samples. For more training and testing options, please checkout the file train_forward.py and evaluate_forward.py.

Experiments in Paper

For code used to generate experimental empirical results listed in our paper, please visit our other repository: https://github.com/ryanchankh/redunet_paper

Reference

For technical details and full experimental results, please check the paper. Please consider citing our work if you find it helpful to yours:

@article{chan2021redunet,
  title={ReduNet: A White-box Deep Network from the Principle of Maximizing Rate Reduction},
  author={Chan, Kwan Ho Ryan and Yu, Yaodong and You, Chong and Qi, Haozhi and Wright, John and Ma, Yi},
  journal={arXiv preprint arXiv:2105.10446},
  year={2021}
}

License and Contributing

  • This README is formatted based on paperswithcode.
  • Feel free to post issues via Github.

Contact

Please contact [email protected] and [email protected] if you have any question on the codes.

redunet's People

Contributors

ryanchankh avatar yaodongyu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

redunet's Issues

memory will be very large

If the data format is slightly larger, the memory will be very large. Do you have any suggestions for optimization? Any suggestions are welcome, thank you!

Uncertainty Estimation

Hi,

Very impressed by your work. And I have been wondering since the ReduNet is a white-box, one should be able to write down what is the uncertainty of the ReduNet's prediction analytically. Say in the test phase, I feed an image of half apple half orange to the ReduNet (which is trained to classify apple and orange), I should be able to get the prediction uncertainty for free? And in theory, I should also be able to track back through every layer to see how the uncertainty propagate, right? Is uncertainty estimation in your roadmap?

Backprop training

Hi,

Thank you for publishing your code with the paper. It's very nice work! In section 5.2 the authors discuss backpropagation training with redunet. Is code for training with backprop published in the repo?

Thanks,
Matt

fail to install requirement in win 10 or linux

run conda create --name redunet_official --file requirements.txt on win10 or github's codespace

Collecting package metadata (current_repodata.json): done
Solving environment: failed with repodata from current_repodata.json, will retry with next repodata source.
Collecting package metadata (repodata.json): done
Solving environment: failed

PackagesNotFoundError: The following packages are not available from current channels:

  - lz4-c==1.9.2=h79c402e_3
  - pysocks==1.7.1=py37hecd8cb5_0
  - openssl==1.1.1k=h9ed2024_0
  - tornado==6.0.4=py37h1de35cc_1
  - libopus==1.3.1=h1de35cc_0
  - nettle==3.4.1=h3018a27_0
  - brotlipy==0.7.0=py37h9ed2024_1003
  - torchaudio==0.8.0=py37
  - ninja==1.10.1=py37h879752b_0
  - libgfortran==3.0.1=h93005f0_2
  - mkl-service==2.3.0=py37hfbe908c_0
  - libtiff==4.1.0=hcb84e12_1
  - mkl_random==1.1.1=py37h959d312_0
  - pytorch==1.8.0=py3.7_0
  - libuv==1.40.0=haf1e3a3_0
  - opencv-python==4.4.0.44=pypi_0
  - lame==3.100=h1de35cc_0
  - scikit-learn==0.23.2=py37h959d312_0
  - llvm-openmp==10.0.0=h28b9765_0
  - gettext==0.19.8.1=hb0f4f8b_2
  - chardet==4.0.0=py37hecd8cb5_1003
  - intel-openmp==2019.4=233
  - lcms2==2.11=h92f6f08_0
  - bzip2==1.0.8=h1de35cc_0
  - libffi==3.3=hb1e8313_2
  - torchvision==0.9.0=py37_cpu
  - mkl==2019.4=233
  - ca-certificates==2021.1.19=hecd8cb5_1
  - x264==1!157.20191217=h1de35cc_0
  - libedit==3.1.20191231=h1de35cc_1
  - freetype==2.10.4=ha233b18_0
  - libiconv==1.16=h1de35cc_0
  - pillow==8.0.0=py37h1a82f1a_0
  - xz==5.2.5=h1de35cc_0
  - python==3.7.9=h26836e1_0
  - scipy==1.5.2=py37h912ce22_0
  - tk==8.6.10=hb0a8c7a_0
  - gnutls==3.6.5=h91ad68e_1002
  - pandas==1.1.3=py37hb1e8313_0
  - setuptools==50.3.0=py37h0dc7051_1
  - gmp==6.1.2=hb37e062_1
  - appnope==0.1.0=py37_0
  - ncurses==6.2=h0a44026_1
  - zeromq==4.3.3=hb1e8313_3
  - sqlite==3.33.0=hffcf06c_0
  - cffi==1.14.4=py37h2125817_0
  - zstd==1.4.5=h41d2c2f_0
  - numpy==1.19.1=py37h3b9f5b6_0
  - libvpx==1.7.0=h378b8a2_0
  - numpy-base==1.19.1=py37hcfb5961_0
  - zlib==1.2.11=h1de35cc_3
  - readline==8.0=h1de35cc_0
  - ffmpeg==4.2.2=h97e5cf8_0
  - pyzmq==19.0.2=py37hb1e8313_1
  - openh264==2.1.0=hd9629dc_0
  - kiwisolver==1.2.0=py37h04f5b5a_0
  - libpng==1.6.37=ha441bb4_0
  - jpeg==9b=he5867d9_2
  - matplotlib-base==3.3.2=py37h181983e_0
  - certifi==2020.12.5=py37hecd8cb5_0
  - mkl_fft==1.2.0=py37hc64f4ea_0
  - libsodium==1.0.18=h1de35cc_0
  - cryptography==3.3.1=py37hbcfaee0_0

Current channels:

  - https://repo.anaconda.com/pkgs/main/linux-64
  - https://repo.anaconda.com/pkgs/main/noarch
  - https://repo.anaconda.com/pkgs/r/linux-64
  - https://repo.anaconda.com/pkgs/r/noarch
  - https://conda.anaconda.org/conda-forge/linux-64
  - https://conda.anaconda.org/conda-forge/noarch

To search for alternate channels that may provide the conda package you're
looking for, navigate to

    https://anaconda.org

and use the search bar at the top of the page.```

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.