Giter Club home page Giter Club logo

tf-layer-norm's Introduction

Tensorflow Layer Normalization and Hyper Networks

================================= Tensorflow implementation of Layer Normalization and Hyper Networks.

This implementation contains:

  1. Layer Normalization for GRU

  2. Layer Normalization for LSTM

    • Currently normalizing c causes lot of nan's in the model, thus commenting it out for now.
  3. Hyper Networks for LSTM

  4. Layer Normalization and Hyper Networks (combined) for LSTM

model_demo

Prerequisites

MNIST

To evaluate the new model, we train it on MNIST. Here is the model and results using Layer Normalized GRU

histogram

scalar

Usage

To train a mnist model with different cell_types:

$ python mnist.py --hidden 128 summaries_dir log/ --cell_type LNGRU

To train a mnist model with HyperNetworks:

$ python mnist.py --hidden 128 summaries_dir log/ --cell_type HyperLnLSTMCell --layer_norm 0

To train a mnist model with HyperNetworks and Layer Normalization:

$ python mnist.py --hidden 128 summaries_dir log/ --cell_type HyperLnLSTMCell --layer_norm 1

cell_type = [LNGRU, LNLSTM, LSTM , GRU, BasicRNN, HyperLnLSTMCell]

To view graph:

$ tensorboard --logdir log/train/

Todo

  1. Add attention based models ( in progress ).

tf-layer-norm's People

Contributors

pbhatia243 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

tf-layer-norm's Issues

regarding the model parameters specified by Nin, Nout and fsize

Hi Parminder,

I have a question regarding one statement made in the paper of HyperNetworks. In section 3.1, authors state

In this section we will describe how we construct a hypernetwork for the purpose of generating the
weights of a feedforward convolutional network (similar to Figure 2). In a typical deep convolutional
network, the majority of model parameters resides in the kernels within the convolutional layers.
Each kernel contain Nin × Nout filters and each filter has dimensions fsize × fsize.

I am not very clear about the definition of Nin, Nout, and fsize. Normally, for each layer of CNN, we can have M filters, each filter has a receptive field X. We can have M*X parameters for a given layer. How should I map this with the Nin, Nout and fsize mentioned in the paper.

Thanks,
wenouyang

Different gain and bias at each time step

Maybe I'm misunderstanding something, but in the paper, the same gain and bias are used for all time steps; otherwise this creates a lot of parameters for long sequences.

There's a problem while save the session

Thanks for you job!

File "/home/zhukangkang/dzy_temp/anaconda3/envs/tf2/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 1075, in save
{self.saver_def.filename_tensor_name: checkpoint_file})
File "/home/zhukangkang/dzy_temp/anaconda3/envs/tf2/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 710, in run
run_metadata_ptr)
File "/home/zhukangkang/dzy_temp/anaconda3/envs/tf2/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 908, in _run
feed_dict_string, options, run_metadata)
File "/home/zhukangkang/dzy_temp/anaconda3/envs/tf2/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 958, in _do_run
target_list, options, run_metadata)
File "/home/zhukangkang/dzy_temp/anaconda3/envs/tf2/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 978, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors.UnimplementedError: File system scheme models-bak/Sig_RNN.py/model-inv3A128_200_2048_100_5_0.01_0.8_200_2_1000_1_9_0.01_0.5_1.0_0.1_0.5_r_squre_0.3_False/05-02 12 not implemented
[[Node: save/save = SaveSlices[T=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_save/Const_0, save/save/tensor_names, save/save/shapes_and_slices, OptimizeLoss/RNN/MultiRNNCell/Cell0/HyperLnLSTMCell/BasicLSTMCell/Linear/Bias/Adam/_119, OptimizeLoss/RNN/MultiRNNCell/Cell0/HyperLnLSTMCell/BasicLSTMCell/Linear/Bias/Adam_1/_121, OptimizeLoss/RNN/MultiRNNCell/Cell0/HyperLnLSTMCell/BasicLSTMCell/Linear/Matrix/Adam/_123, OptimizeLoss/RNN/MultiRNNCell/Cell0/HyperLnLSTMCell/BasicLSTMCell/Linear/Matrix/Adam_1/_125, OptimizeLoss/RNN/MultiRNNCell/Cell0/HyperLnLSTMCell/b1/Adam/_127, OptimizeLoss/RNN/MultiRNNCell/Cell0/HyperLnLSTMCell/b1/Adam_1/_129, OptimizeLoss/RNN/MultiRNNCell/Cell0/HyperLnLSTMCell/b2/Adam/_131, OptimizeLoss/RNN/MultiRNNCell/Cell0/HyperLnLSTMCell/b2/Adam_1/_133, OptimizeLoss/RNN/MultiRNNCell/Cell0/HyperLnLSTMCell/hyper_h/hyper_halpha/Matrix/Adam/_135, OptimizeLoss/RNN/MultiRNNCell/Cell0/HyperLnLSTMCell/hyper_h/hyper_halpha/Matrix/Adam_1/_137, OptimizeLoss/RNN/MultiRNNCell/Cell0/HyperLnLSTMCell/hyper_h/hyper_hz/Matrix/Adam/_139, OptimizeLoss/RNN/MultiRNNCell/Cell0/HyperLnLSTMCell/hyper_h/hyper_hz/Matrix/Adam_1/_141, OptimizeLoss/RNN/MultiRNNCell/Cell0/HyperLnLSTMCell/hyper_x/hyper_xalpha/Matrix/Adam/_143, OptimizeLoss/RNN/MultiRNNCell/Cell0/HyperLnLSTMCell/hyper_x/hyper_xalpha/Matrix/Adam_1/_145, OptimizeLoss/RNN/MultiRNNCell/Cell0/HyperLnLSTMCell/hyper_x/hyper_xz/Matrix/Adam/_147, OptimizeLoss/RNN/MultiRNNCell/Cell0/HyperLnLSTMCell/hyper_x/hyper_xz/Matrix/Adam_1/_149, OptimizeLoss/RNN/MultiRNNCell/Cell0/HyperLnLSTMCell/out_1/Matrix/Adam/_151, OptimizeLoss/RNN/MultiRNNCell/Cell0/HyperLnLSTMCell/out_1/Matrix/Adam_1/_153, OptimizeLoss/RNN/MultiRNNCell/Cell0/HyperLnLSTMCell/out_2/Matrix/Adam/_155, OptimizeLoss/RNN/MultiRNNCell/Cell0/HyperLnLSTMCell/out_2/Matrix/Adam_1/_157, OptimizeLoss/RNN/MultiRNNCell/Cell0/HyperLnLSTMCell/s1/Adam/_159, OptimizeLoss/RNN/MultiRNNCell/Cell0/HyperLnLSTMCell/s1/Adam_1/_161, OptimizeLoss/RNN/MultiRNNCell/Cell0/HyperLnLSTMCell/s2/Adam/_163, OptimizeLoss/RNN/MultiRNNCell/Cell0/HyperLnLSTMCell/s2/Adam_1/_165, OptimizeLoss/RNN/MultiRNNCell/Cell1/HyperLnLSTMCell/BasicLSTMCell/Linear/Bias/Adam/_167, OptimizeLoss/RNN/MultiRNNCell/Cell1/HyperLnLSTMCell/BasicLSTMCell/Linear/Bias/Adam_1/_169, OptimizeLoss/RNN/MultiRNNCell/Cell1/HyperLnLSTMCell/BasicLSTMCell/Linear/Matrix/Adam/_171, OptimizeLoss/RNN/MultiRNNCell/Cell1/HyperLnLSTMCell/BasicLSTMCell/Linear/Matrix/Adam_1/_173, OptimizeLoss/RNN/MultiRNNCell/Cell1/HyperLnLSTMCell/b1/Adam/_175, OptimizeLoss/RNN/MultiRNNCell/Cell1/HyperLnLSTMCell/b1/Adam_1/_177, OptimizeLoss/RNN/MultiRNNCell/Cell1/HyperLnLSTMCell/b2/Adam/_179, OptimizeLoss/RNN/MultiRNNCell/Cell1/HyperLnLSTMCell/b2/Adam_1/_181, OptimizeLoss/RNN/MultiRNNCell/Cell1/HyperLnLSTMCell/hyper_h/hyper_halpha/Matrix/Adam/_183, OptimizeLoss/RNN/MultiRNNCell/Cell1/HyperLnLSTMCell/hyper_h/hyper_halpha/Matrix/Adam_1/_185, OptimizeLoss/RNN/MultiRNNCell/Cell1/HyperLnLSTMCell/hyper_h/hyper_hz/Matrix/Adam/_187, OptimizeLoss/RNN/MultiRNNCell/Cell1/HyperLnLSTMCell/hyper_h/hyper_hz/Matrix/Adam_1/_189, OptimizeLoss/RNN/MultiRNNCell/Cell1/HyperLnLSTMCell/hyper_x/hyper_xalpha/Matrix/Adam/_191, OptimizeLoss/RNN/MultiRNNCell/Cell1/HyperLnLSTMCell/hyper_x/hyper_xalpha/Matrix/Adam_1/_193, OptimizeLoss/RNN/MultiRNNCell/Cell1/HyperLnLSTMCell/hyper_x/hyper_xz/Matrix/Adam/_195, OptimizeLoss/RNN/MultiRNNCell/Cell1/HyperLnLSTMCell/hyper_x/hyper_xz/Matrix/Adam_1/_197, OptimizeLoss/RNN/MultiRNNCell/Cell1/HyperLnLSTMCell/out_1/Matrix/Adam/_199, OptimizeLoss/RNN/MultiRNNCell/Cell1/HyperLnLSTMCell/out_1/Matrix/Adam_1/_201, OptimizeLoss/RNN/MultiRNNCell/Cell1/HyperLnLSTMCell/out_2/Matrix/Adam/_203, OptimizeLoss/RNN/MultiRNNCell/Cell1/HyperLnLSTMCell/out_2/Matrix/Adam_1/_205, OptimizeLoss/RNN/MultiRNNCell/Cell1/HyperLnLSTMCell/s1/Adam/_207, OptimizeLoss/RNN/MultiRNNCell/Cell1/HyperLnLSTMCell/s1/Adam_1/_209, OptimizeLoss/RNN/MultiRNNCell/Cell1/HyperLnLSTMCell/s2/Adam/_211, OptimizeLoss/RNN/MultiRNNCell/Cell1/HyperLnLSTMCell/s2/Adam_1/_213, OptimizeLoss/Variable/Adam/_215, OptimizeLoss/Variable/Adam_1/_217, OptimizeLoss/Variable_1/Adam/_219, OptimizeLoss/Variable_1/Adam_1/_221, OptimizeLoss/beta1_power/_223, OptimizeLoss/beta2_power/_225, RNN/MultiRNNCell/Cell0/HyperLnLSTMCell/BasicLSTMCell/Linear/Bias/_227, RNN/MultiRNNCell/Cell0/HyperLnLSTMCell/BasicLSTMCell/Linear/Matrix/_229, RNN/MultiRNNCell/Cell0/HyperLnLSTMCell/b1/_231, RNN/MultiRNNCell/Cell0/HyperLnLSTMCell/b2/_233, RNN/MultiRNNCell/Cell0/HyperLnLSTMCell/b3/_235, RNN/MultiRNNCell/Cell0/HyperLnLSTMCell/hyper_h/hyper_halpha/Matrix/_237, RNN/MultiRNNCell/Cell0/HyperLnLSTMCell/hyper_h/hyper_hz/Matrix/_239, RNN/MultiRNNCell/Cell0/HyperLnLSTMCell/hyper_x/hyper_xalpha/Matrix/_241, RNN/MultiRNNCell/Cell0/HyperLnLSTMCell/hyper_x/hyper_xz/Matrix/_243, RNN/MultiRNNCell/Cell0/HyperLnLSTMCell/out_1/Matrix/_245, RNN/MultiRNNCell/Cell0/HyperLnLSTMCell/out_2/Matrix/_247, RNN/MultiRNNCell/Cell0/HyperLnLSTMCell/s1/_249, RNN/MultiRNNCell/Cell0/HyperLnLSTMCell/s2/_251, RNN/MultiRNNCell/Cell0/HyperLnLSTMCell/s3/_253, RNN/MultiRNNCell/Cell1/HyperLnLSTMCell/BasicLSTMCell/Linear/Bias/_255, RNN/MultiRNNCell/Cell1/HyperLnLSTMCell/BasicLSTMCell/Linear/Matrix/_257, RNN/MultiRNNCell/Cell1/HyperLnLSTMCell/b1/_259, RNN/MultiRNNCell/Cell1/HyperLnLSTMCell/b2/_261, RNN/MultiRNNCell/Cell1/HyperLnLSTMCell/b3/_263, RNN/MultiRNNCell/Cell1/HyperLnLSTMCell/hyper_h/hyper_halpha/Matrix/_265, RNN/MultiRNNCell/Cell1/HyperLnLSTMCell/hyper_h/hyper_hz/Matrix/_267, RNN/MultiRNNCell/Cell1/HyperLnLSTMCell/hyper_x/hyper_xalpha/Matrix/_269, RNN/MultiRNNCell/Cell1/HyperLnLSTMCell/hyper_x/hyper_xz/Matrix/_271, RNN/MultiRNNCell/Cell1/HyperLnLSTMCell/out_1/Matrix/_273, RNN/MultiRNNCell/Cell1/HyperLnLSTMCell/out_2/Matrix/_275, RNN/MultiRNNCell/Cell1/HyperLnLSTMCell/s1/_277, RNN/MultiRNNCell/Cell1/HyperLnLSTMCell/s2/_279, RNN/MultiRNNCell/Cell1/HyperLnLSTMCell/s3/_281, Variable/_283, Variable_1/_285)]]
Caused by op u'save/save', defined at:
File "/home/zhukangkang/dzy_temp/code/supercell/tf-layer-norm/inv3Ft-arousal_rnn.py", line 67, in
train(data_train,args)
File "/home/zhukangkang/dzy_temp/code/supercell/tf-layer-norm/Sig_RNN.py", line 437, in train
mySaver = tf.train.Saver(max_to_keep=5)
File "/home/zhukangkang/dzy_temp/anaconda3/envs/tf2/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 861, in init
restore_sequentially=restore_sequentially)
File "/home/zhukangkang/dzy_temp/anaconda3/envs/tf2/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 517, in build
save_tensor = self._AddSaveOps(filename_tensor, vars_to_save)
File "/home/zhukangkang/dzy_temp/anaconda3/envs/tf2/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 213, in _AddSaveOps
save = self.save_op(filename_tensor, vars_to_save)
File "/home/zhukangkang/dzy_temp/anaconda3/envs/tf2/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 165, in save_op
tensor_slices=[vs.slice_spec for vs in vars_to_save])
File "/home/zhukangkang/dzy_temp/anaconda3/envs/tf2/lib/python2.7/site-packages/tensorflow/python/ops/io_ops.py", line 179, in _save
tensors, name=name)
File "/home/zhukangkang/dzy_temp/anaconda3/envs/tf2/lib/python2.7/site-packages/tensorflow/python/ops/gen_io_ops.py", line 438, in _save_slices
data=data, name=name)
File "/home/zhukangkang/dzy_temp/anaconda3/envs/tf2/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 710, in apply_op
op_def=op_def)
File "/home/zhukangkang/dzy_temp/anaconda3/envs/tf2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2317, in create_op
original_op=self._default_original_op, op_def=op_def)
File "/home/zhukangkang/dzy_temp/anaconda3/envs/tf2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1239, in init
self._traceback = _extract_stack()

the model is running well but could not apply tf.train.Saver, hv u ever met the problem?

please update to TensorFlow 1.0

The code as is doesn't work with TensorFlow 1.0. I tried fixing all the errors but got stuck when needing the _linear() operator which i can't find.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.