Giter Club home page Giter Club logo

the-neural-perspective's Issues

some problem about the code

The code location is here : https://github.com/GokuMohandas/the-neural-perspective/blob/master/recurrent-neural-networks/text_classification/model.py

in the file of model.py, line 85.

I think that the index should be 1 because when the index is 0 , it return the state of c.But when the index is 1, it return the state of h.And in the rnn network, we want to get the last state of h not the state of c.

A sample code is as follows:

import tensorflow as tf
import numpy as np

def rnn_cell(flags, unit_nums, layer_nums):
    if flags == 'rnn':
        rnn_cell_type = tf.nn.rnn_cell.BasicRNNCell
    elif flags == 'gru':
        rnn_cell_type = tf.nn.rnn_cell.GRUCell
    elif flags == 'lstm':
        rnn_cell_type = tf.nn.rnn_cell.BasicLSTMCell
    else:
        raise Exception('Choose a valid RNN unit type')

    single_cell = rnn_cell_type(unit_nums)

    stacked_cell = tf.nn.rnn_cell.MultiRNNCell(
                                 [rnn_cell_type(unit_nums) for _ in range(2)], 
                                 state_is_tuple=True)

    return stacked_cell

n_steps = 2
n_inputs = 3
n_neurons = 5

X = tf.placeholder(dtype=tf.float32, shape=[None, n_steps, n_inputs])
seq_length = tf.placeholder(tf.int32, [None])

# basic_cell = tf.nn.rnn_cell.GRUCell(num_units=n_neurons)
basic_cell = rnn_cell('gru', n_neurons, 3)
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, sequence_length=seq_length, dtype=tf.float32)

test_out = states[1]

X_batch = np.array([
  # t = 0      t = 1
  [[0, 1, 2], [9, 8, 7]], # instance 0
  [[3, 4, 5], [0, 0, 0]], # instance 1
  [[6, 7, 8], [6, 5, 4]], # instance 2
  [[9, 0, 1], [3, 2, 1]], # instance 3
])

seq_length_batch = np.array([2, 1, 2, 2])

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  outputs_val, states_val, test_out_val = sess.run([outputs, states, test_out],
                                     feed_dict={X: X_batch, seq_length: seq_length_batch})
  print('output_val:')
  print(outputs_val)
  print('test_out_val:')
  print(test_out_val)
  print('states output:')
  print(states_val)

The output is as follow:

output_val:
[[[-0.04201178  0.05825585 -0.10475425 -0.0574191  -0.02787135]
  [-0.19379006  0.3098646  -0.45543492 -0.12587619  0.0240692 ]]

 [[-0.17244668  0.23054264 -0.35517666 -0.08038893 -0.03418409]
  [ 0.          0.          0.          0.          0.        ]]

 [[-0.2036431   0.3009616  -0.4299594  -0.07195221 -0.01479668]
  [-0.34811032  0.41498578 -0.609179   -0.13371195 -0.04363498]]

 [[-0.07408727  0.08317428 -0.16693498 -0.05484099  0.03070312]
  [-0.17517757  0.16493565 -0.3365708  -0.11573118  0.06112424]]]
test_out_val:
[[-0.19379006  0.3098646  -0.45543492 -0.12587619  0.0240692 ]
 [-0.17244668  0.23054264 -0.35517666 -0.08038893 -0.03418409]
 [-0.34811032  0.41498578 -0.609179   -0.13371195 -0.04363498]
 [-0.17517757  0.16493565 -0.3365708  -0.11573118  0.06112424]]
states output:
(array([[ 0.03890764, -0.68590057, -0.9278017 ,  0.9249147 , -0.9152604 ],
       [ 0.02587428, -0.61686724, -0.60513765,  0.7774262 , -0.7873352 ],
       [ 0.02819343, -0.7372245 , -0.61572564,  0.9879686 , -0.96763486],
       [-0.0758315 , -0.05249108, -0.32755852,  0.62991834, -0.5825802 ]],
      dtype=float32), array([[-0.19379006,  0.3098646 , -0.45543492, -0.12587619,  0.0240692 ],
       [-0.17244668,  0.23054264, -0.35517666, -0.08038893, -0.03418409],
       [-0.34811032,  0.41498578, -0.609179  , -0.13371195, -0.04363498],
       [-0.17517757,  0.16493565, -0.3365708 , -0.11573118,  0.06112424]],
      dtype=float32))

if we use the index 0, we got the first array like this:

[[ 0.03890764, -0.68590057, -0.9278017 ,  0.9249147 , -0.9152604 ],
  [ 0.02587428, -0.61686724, -0.60513765,  0.7774262 , -0.7873352 ],
  [ 0.02819343, -0.7372245 , -0.61572564,  0.9879686 , -0.96763486],
  [-0.0758315 , -0.05249108, -0.32755852,  0.62991834, -0.5825802 ]]

if we use index 1, we fot the second array:

[[-0.19379006  0.3098646  -0.45543492 -0.12587619  0.0240692 ]
 [-0.17244668  0.23054264 -0.35517666 -0.08038893 -0.03418409]
 [-0.34811032  0.41498578 -0.609179   -0.13371195 -0.04363498]
 [-0.17517757  0.16493565 -0.3365708  -0.11573118  0.06112424]]

And I think that the second array is what we want.

looking forward to your replay.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.