Giter Club home page Giter Club logo

dni-tensorflow's Introduction

Disclaimer: right now the code contains a little bug, and i'll fix it as soon as possible.

Image classification with synthetic gradient in tensorflow

I implement the Decoupled Neural Interfaces using Synthetic Gradients in tensorflow. The paper use synthetic gradient to decouple the layers in the network. This is pretty interesting since we won't suffer from update lock anymore. I test my model in cifar10 and archieve similar result as the paper claimed.

Requirement

TODO

  • use multi-threading on gpu to analyze the speed
  • apply to some more complicated network to see if it's general

What's synthetic gradients?

We ofter optimize NN by backpropogation, which is usually implemented in some well-known framework. However, is there another way for the layers in NN to communicate with other layers? Here comes the synthetic gradients! It gives us a way to allow neural networks to communicate, to learn to send messages between themselves, in a decoupled, scalable manner paving the way for multiple neural networks to communicate with each other or improving the long term temporal dependency of recurrent networks.
The neuron in each layer will automatically produces an error signal(δa_head) from synthetic-layers and do the optimzation. And how did the error signal generated? Actually, the network still does the backpropogation. While the error signal(δa) from the objective function is not used to optimize the neuron in the network, it is used to optimize the error signal(δa_head) produced by the synthetic-layer. The following is the illustration from the paper:

Usage

Right now I just implement the FCN version, which is set as the default network structure
You can define some variable in command line: ex: python main.py -- max_step 100000 --checkpoint_dir ./model

max_step = 50000
model_name = mlp                  # the ckpt will save in $checkpoint_path/$model_name/checkpoint-*
checkpoint_dir = './checkpoint'   # the checkpint directory
gpu_fraction = 1/2 # you can define the gpu memory usage
batch_size = 256
hidden_size = 1000             	  # hidden size of the mlp
test_per_iter = 50
optim_type = adam
synthetic = False                 # ues synthetic gradient or not	

Experiment Result

DNI-mlp test on cifar10

cls loss synthetic_grad loss test acc

DNI-cnn test on cifar10

cls loss synthetic_grad loss test acc

unknown problem: the increase of synthetic gradient loss in CNN model

Something Beautiful in Tensorflow

Tensorflow is known for the convenience of auto-gradient, while at the same time many people don't know how it do the backprop or calculate the backprop. Compared to Torch, there's no obvious way to access the gradOutput, gradInput. Actually, Tensorflow contains some beautiful function that makes it easier and more flexible.
Sometimes, you might want to calculate gradient dy/dx:
Use tf.gradients(y,x). It's very simple If you want to calculate the gradientm given the gradient backprop from the loss, or sth you've defined (dy/dx = dy/du*du/dx, given dy/du):
Use tf.gradients(y,x,dy/du).

Reference

  • Deepmind's post on Decoupled Neural Interfaces Using Synthetic Gradients

dni-tensorflow's People

Contributors

andrewliao11 avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.