The code location is here : https://github.com/GokuMohandas/the-neural-perspective/blob/master/recurrent-neural-networks/text_classification/model.py
in the file of model.py, line 85.
I think that the index should be 1 because when the index is 0 , it return the state of c.But when the index is 1, it return the state of h.And in the rnn network, we want to get the last state of h not the state of c.
A sample code is as follows:
import tensorflow as tf
import numpy as np
def rnn_cell(flags, unit_nums, layer_nums):
if flags == 'rnn':
rnn_cell_type = tf.nn.rnn_cell.BasicRNNCell
elif flags == 'gru':
rnn_cell_type = tf.nn.rnn_cell.GRUCell
elif flags == 'lstm':
rnn_cell_type = tf.nn.rnn_cell.BasicLSTMCell
else:
raise Exception('Choose a valid RNN unit type')
single_cell = rnn_cell_type(unit_nums)
stacked_cell = tf.nn.rnn_cell.MultiRNNCell(
[rnn_cell_type(unit_nums) for _ in range(2)],
state_is_tuple=True)
return stacked_cell
n_steps = 2
n_inputs = 3
n_neurons = 5
X = tf.placeholder(dtype=tf.float32, shape=[None, n_steps, n_inputs])
seq_length = tf.placeholder(tf.int32, [None])
# basic_cell = tf.nn.rnn_cell.GRUCell(num_units=n_neurons)
basic_cell = rnn_cell('gru', n_neurons, 3)
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, sequence_length=seq_length, dtype=tf.float32)
test_out = states[1]
X_batch = np.array([
# t = 0 t = 1
[[0, 1, 2], [9, 8, 7]], # instance 0
[[3, 4, 5], [0, 0, 0]], # instance 1
[[6, 7, 8], [6, 5, 4]], # instance 2
[[9, 0, 1], [3, 2, 1]], # instance 3
])
seq_length_batch = np.array([2, 1, 2, 2])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
outputs_val, states_val, test_out_val = sess.run([outputs, states, test_out],
feed_dict={X: X_batch, seq_length: seq_length_batch})
print('output_val:')
print(outputs_val)
print('test_out_val:')
print(test_out_val)
print('states output:')
print(states_val)
The output is as follow:
output_val:
[[[-0.04201178 0.05825585 -0.10475425 -0.0574191 -0.02787135]
[-0.19379006 0.3098646 -0.45543492 -0.12587619 0.0240692 ]]
[[-0.17244668 0.23054264 -0.35517666 -0.08038893 -0.03418409]
[ 0. 0. 0. 0. 0. ]]
[[-0.2036431 0.3009616 -0.4299594 -0.07195221 -0.01479668]
[-0.34811032 0.41498578 -0.609179 -0.13371195 -0.04363498]]
[[-0.07408727 0.08317428 -0.16693498 -0.05484099 0.03070312]
[-0.17517757 0.16493565 -0.3365708 -0.11573118 0.06112424]]]
test_out_val:
[[-0.19379006 0.3098646 -0.45543492 -0.12587619 0.0240692 ]
[-0.17244668 0.23054264 -0.35517666 -0.08038893 -0.03418409]
[-0.34811032 0.41498578 -0.609179 -0.13371195 -0.04363498]
[-0.17517757 0.16493565 -0.3365708 -0.11573118 0.06112424]]
states output:
(array([[ 0.03890764, -0.68590057, -0.9278017 , 0.9249147 , -0.9152604 ],
[ 0.02587428, -0.61686724, -0.60513765, 0.7774262 , -0.7873352 ],
[ 0.02819343, -0.7372245 , -0.61572564, 0.9879686 , -0.96763486],
[-0.0758315 , -0.05249108, -0.32755852, 0.62991834, -0.5825802 ]],
dtype=float32), array([[-0.19379006, 0.3098646 , -0.45543492, -0.12587619, 0.0240692 ],
[-0.17244668, 0.23054264, -0.35517666, -0.08038893, -0.03418409],
[-0.34811032, 0.41498578, -0.609179 , -0.13371195, -0.04363498],
[-0.17517757, 0.16493565, -0.3365708 , -0.11573118, 0.06112424]],
dtype=float32))
if we use the index 0, we got the first array like this:
[[ 0.03890764, -0.68590057, -0.9278017 , 0.9249147 , -0.9152604 ],
[ 0.02587428, -0.61686724, -0.60513765, 0.7774262 , -0.7873352 ],
[ 0.02819343, -0.7372245 , -0.61572564, 0.9879686 , -0.96763486],
[-0.0758315 , -0.05249108, -0.32755852, 0.62991834, -0.5825802 ]]
if we use index 1, we fot the second array:
[[-0.19379006 0.3098646 -0.45543492 -0.12587619 0.0240692 ]
[-0.17244668 0.23054264 -0.35517666 -0.08038893 -0.03418409]
[-0.34811032 0.41498578 -0.609179 -0.13371195 -0.04363498]
[-0.17517757 0.16493565 -0.3365708 -0.11573118 0.06112424]]
And I think that the second array is what we want.
looking forward to your replay.