Giter Club home page Giter Club logo

cornn's Introduction

Coupled Oscillatory Recurrent Neural Network (coRNN)
[ICLR 2021 Oral]

This repository contains the implementation to reproduce the numerical experiments of the International Conference on Learning Representations (ICLR) 2021 [oral] paper Coupled Oscillatory Recurrent Neural Network (coRNN): An accurate and (gradient) stable architecture for learning long time dependencies

Requirements

pytorch 1.3+
torchvision 0.4+
torchtext 0.6+
numpy 1.17+
spacy v2.2+

If you want to run the experiments on a GPU, please make sure you have installed the corresponding cuda packages.

Example

The coRNN cell can be implemented in pytorch as easy as this:

from torch import nn
import torch

class coRNNCell(nn.Module):
    def __init__(self, n_inp, n_hid, dt, gamma=1., epsilon=1.):
        super(coRNNCell, self).__init__()
        self.dt = dt
        self.gamma = gamma
        self.epsilon = epsilon
        self.i2h = nn.Linear(n_inp + n_hid + n_hid, n_hid)

    def forward(self,x,hy,hz):
        hz = hz + self.dt * (torch.tanh(self.i2h(torch.cat((x, hz, hy),1)))
                                   - self.gamma * hy - self.epsilon * hz)
        hy = hy + self.dt * hz

        return hy, hz

Datasets

This repository contains the codes to reproduce the results of the following experiments for the proposed coRNN:

  • The Adding Problem
  • Sequential MNIST
  • Permuted Sequential MNIST
  • Noise padded CIFAR-10
  • HAR-2
  • IMDB

The data sets for the MNIST/CIFAR-10 task and the IMDB task are getting downloaded through torchvision and torchtext, respectively. The data set for the HAR-2 has to be downloaded and preprocessed according to the instructions mentioned in the paper.

Results

The results of the coRNN for each of the experiments are:

Experiment Result
sMNIST 99.4% test accuracy
psMNIST 97.3% test accuarcy
Noise padded CIFAR-10 59.0% test accuracy
HAR-2 97.2 test accuracy
IMDB 87.4% test accuracy

Citation

If you found this work useful, please consider citing

@inproceedings{rusch2021coupled,
  title={Coupled Oscillatory Recurrent Neural Network (coRNN): An accurate and (gradient) stable architecture for learning long time dependencies},
  author={Rusch, T. Konstantin and Mishra, Siddhartha},
  booktitle={International Conference on Learning Representations},
  year={2021}
}

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.