Giter Club home page Giter Club logo

tensorflow-keras-multidimensional-rnn's Introduction

Multidimensional Recurrent Neural Networks

Note: This code is offered without any warranty and was developed as a way to learn the new version of tensorflow 2.0. Finally, any contributions are welcome

TL; DR: What is currently implemented/working?

  • two-dimensional loop with 3 states (left, top, diagonal)
  • Spatial-GRU or 2D-GRU (kinda works, need more testing) (article: arXiv:1604.04378)

Table of contents

Introduction

This repository aims to offer a multidimensional recurrent function, implemented in tensorflow 2.0 with Keras API, that can be used by multiple recurrent cells (RNN/GRU/LSTM).

As best of my knowledge this is the first publicly available repository that tries to implement this type of function in tensorflow 2. Furthermore, I was able to find ONLY one repository that tries to implement a multidimensional-lstm in tensorlflow 1.7.

It is worth to mention the RetuRNN framework, that also offers a (GPU-only) multidimensional LSTM.

In terms of literature, these are some works that used/proposed this type of recurrency for "text" and "image" tasks:

How it works

In theory

A multidimensional RNN is similar to a one-dimensional recurrent neural network, but instead of using only one state (the output of the previous step), it uses multiples states, normally one per dimension.

The following image shows an example applied to two-dimensional data, where each entry has access to three previous states (left in blue, top in green and diagonal in red).However, in some works, only the left and top states are used.

Basic MDRNN IMAGE

So each state in a mdrnn is computed by a recursive function of its previous states and input.

equation

In practice

The current implementation follows a naive approach that iterates sequentially over every 2D entry (first column dimension and then row dimension), feeding the previous computed states (left, up, diagonal).

Basic MDRNN 2Dto1D

A GPU/CPU optimization could be achieved by computing the opposed diagonals in parellel since each entry in an opposed diagonal is independent, as presented in the following image by the black lines. However, note that there still is a sequential dependency between the black lines that must be respected.

GPU MDRNN IMAGE

Installation

(Working in progress)

  1. For now, clone the repository
  • (optional install as a python package)
  1. python setup.py sdist
  2. pip install dist/tfmd-0.0.1.tar.gz

Usage

# normal tensorflow keras imports
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense

# multidimensional rnn imports
from tfmd.mdrnn import MultiDimensionalRNN
from tfmd.mdcells import MultiDimensinalGRUCell

gru_units = 4

model = Sequential()
model.add(MultiDimensionalRNN(MultiDimensinalGRUCell(gru_units, activation='tanh'), input_shape=(5,5,1)))
model.add(Dense(1))

model.summary()

# normal keras model :D

Tests

Currently missing

Future improvements

Contributions are welcome!!

  • More gates (LSTM)
  • CPU/GPU improvement using the idea of opposed diagonal
  • Multidirictional recurrency

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.