Giter Club home page Giter Club logo

ccnn's Introduction

Modelling Long Range Dependencies in N-D: From Task-Specific to a General Purpose CNN

Code repository of the paper Modelling Long Range Dependencies in N-D: From Task-Specific to a General Purpose CNN.

Abstract

Performant Convolutional Neural Network (CNN) architectures must be tailored to specific tasks in order to consider the length, resolution, and dimensionality of the input data. In this work, we tackle the need for problem-specific CNN architectures.\break We present the \textit{Continuous Convolutional Neural Network} (CCNN): a single CNN able to process data of arbitrary resolution, dimensionality and length without any structural changes. Its key component are its \textit{continuous convolutional kernels} which model long-range dependencies at every layer, and thus remove the need of current CNN architectures for task-dependent downsampling and depths. We showcase the generality of our method by using the \emph{same architecture} for tasks on sequential (1D), visual (2D) and point-cloud (3D) data. Our CCNN matches and often outperforms the current state-of-the-art across all tasks considered.

Installation

conda

We provide an environment file; environment.yml containing the required dependencies. Clone the repo and run the following command in the root of this directory:

conda env create -f environment.yml

If you would like to install pytorch without cuda, instead run:

conda env create -f environment-nocuda.yml

Repository Structure

This repository is organized as follows:

  • ckconv contains the main PyTorch library of our model.
  • models contains the model architectures used in our experiments.
  • datamodules contains the pytorch lightning datamodules used in our experiments.
  • cfg contains the configuration file in which to specify default arguments to be passed to the script.

Using the code

All experiments are run with main.py. Flags are handled by Hydra. See cfg/config.yaml for all available flags. Flags can be passed as xxx.yyy=value.

Useful flags

  • net.* describes settings for the models (model definition models/resnet.py).
  • kernel.* describes settings for the MAGNet kernel generator networks.
  • mask.* describes settings for the FlexConv Gaussian mask.
  • conv.* describes settings for the convolution operation. It can be used to switch between FlexConv, CKConv, regular Conv, and their separable variants.
  • dataset.* specifies the dataset to be used, as well as variants, e.g., permuted, sequential.
  • train.* specifies the settings used for the Trainer of the models.
  • train.do=False: Only test the model. Useful in combination with pre-training.
  • optimizer.* specifies and configures the optimizer to be used.
  • debug=True: By default, all experiment scripts connect to Weights & Biases to log the experimental results. Use this flag to run without connecting to Weights & Biases.
  • pretrained.*: Use these to load checkpoints before training.

Reproducing experiments

Please see the experiments README for details on reproducing the paper's experiments.

Cite

If you found this work useful in your research, please consider citing:

@article{knigge2023modelling,
  title={Modelling Long Range Dependencies in N-D: From Task-Specific to a General Purpose CNN},
  author={Knigge, David M and Romero, David W and Gu, Albert and Bekkers, Erik J and Gavves, Efstratios and Tomczak, Jakub M and Hoogendoorn, Mark and Sonke, Jan-Jakob},
  journal={International Conference on Learning Representations},
  year={2023}
}

Acknowledgements

This work is supported by the Qualcomm Innovation Fellowship (2021) granted to David W. Romero. David W. Romero sincerely thanks Qualcomm for his support. David W. Romero is financed as part of the Efficient Deep Learning (EDL) programme (grant number P16-25), partly funded by the Dutch Research Council (NWO). David Knigge is partially funded by Elekta Oncology Systems AB and a RVO public-private partnership grant (PPS2102).

This work was carried out on the Dutch national infrastructure with the support of SURF Cooperative.

ccnn's People

Contributors

david-knigge avatar dwromero avatar davidknigge avatar rjbruin avatar trellixvulnteam avatar

Stargazers

Zengbin Wang avatar Flavio Maiorana avatar Tworan avatar Chieh Chang avatar cgh avatar James Ashford avatar  avatar Chengmou Li avatar  avatar Meng Wan avatar Shaikh Abdus Samad avatar  avatar Jeff Carpenter avatar XuechunBai avatar Pike渔市场 avatar  avatar Xufeng Huang avatar hui qin avatar Mark Vandergon avatar Jacob Goodale avatar Xiaoran Zhang avatar Hard-Song avatar Zefeng Zhu avatar Yuanhang Zhang avatar Yusin Chen avatar  avatar  avatar Shreyas Jaiswal avatar  avatar Nathan avatar Zhihao Duan avatar Schrodinger-E avatar 龙天涯 avatar  avatar  avatar  avatar Kevin Lim avatar  avatar  avatar  avatar  avatar  avatar eipi10 avatar  avatar  avatar  avatar Mohammad Mohaiminul Islam  avatar Farrukh Nauman avatar Orizuru avatar  avatar Q avatar Mahmud ElHuseyni avatar zhaozh10 avatar Mingyuan Luo avatar  avatar  avatar  avatar  avatar Peng Sun avatar  avatar ZHENG avatar Rui Xu avatar  avatar  avatar  avatar angleboy avatar i-MaTh avatar  avatar 蔡徐坤 avatar jbr97 avatar  avatar Xiang Zhang avatar  avatar Gang Wu avatar  avatar 0XAA55 avatar  avatar  avatar  avatar zhuolin li avatar  avatar An-zhi WANG avatar  avatar Yuchong Yao avatar  avatar Wang Chenyue avatar rwx avatar Jiajun Xian avatar Caifa Zhou avatar  avatar  avatar Ge Wu avatar  avatar YangLLa avatar  avatar Tao Lu avatar Zhuojun Sun CV Student avatar Xinyu Liu avatar Yajie Yang avatar Richard Chen avatar

Watchers

James Cloos avatar Kostas Georgiou avatar  avatar  avatar

ccnn's Issues

how to input the irregular sampled data, such as sparse images

Thank you very much for sharing such interesting work. I am wondering if you are convenient to show me how to input sparse images into your network. For example, we input the sparse values according to the coordinates of existing points in images.

run error

from .causal_conv import conv1d, fftconv1d

File "/root/*****-tmp/ccnn-main/ckconv/nn/functional/causal_conv.py", line 11, in
) -> tuple[torch.Tensor, torch.Tensor]:
TypeError: 'type' object is not subscriptable

I don't know how to solve this. Could you send me an email plz([[email protected]]) or Give me a solution?Special thanks and i will help others

Initialized bug in MAGNets causing issues with model loading

Hello @david-knigge! I've found a bug in ckconv.py that will cause a trained model which is loaded via state dict and checkpoint to have some of it's weights re-initialized at inference time.

Specifically this line:

        self.register_buffer("initialized", torch.zeros(1).bool(), persistent=False)

And in the chang_initialization method:

        if not self.initialized[0] and self.chang_initialize:

Using register_buffer with persistent=False for this is incorrect, as this means self.initialized is not saved in the state dict. This will cause a freshly created model object with it's state dict loaded to have self.initialized[0] = [False], causing weights to be re-initialized when inference is done. Instead, this should be changed to persistent=True and then it should get saved in the state dict.

This bug was really hard to figure out (you don't expect your weights to change when running inference 😆). I think this might also be the reason why in main.py you have somewhat weird code for loading in models. Hopefully this helps anyone else building on this repository!

运行代码发现错误

\ccnn-main\dataset_constructor.py", line 31, in construct_datamodule
dataset = getattr(datamodules, dataset_name)
AttributeError: module 'datamodules' has no attribute 'DataModule'

how can i solve this problem?

运行代码发现错误

\ccnn-main\ckconv\nn\functional\causal_conv.py", line 11, in
) -> tuple[torch.Tensor, torch.Tensor]:
TypeError: 'type' object is not subscriptable

predicted results of test dataset

Through the code, I can only see test/acc and test/loss during testing. How can I see the predicted results of test dataset? And what kind of code should I add?

Equivalent responses across input resolutions

Hello, I have 2 questions because I want to implement ccnn in my research paper.

  1. According to the paper, it stated in Equation 1 that the model is resolution agnostic. But I have searched through the code and cannot find the implementation. Would it be possible to point out where that (r1/r2) at?
  2. If I set the kernel size at 11, and I input two images of size (56,56) and (28,28). The kernel size is still 11 for both image. Is that correct? Because from my understanding, if the resolution is changed by 2x, the kernel size should have increased by 2x because the relative positions have also been 2x.
    When presenting an input at a different resolution, e.g., higher resolution, it is sufficient to pass a finer grid of coordinates through the kernel generator network in order to construct the same kernel at the corresponding resolution.

Thank you so much.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.