Giter Club home page Giter Club logo

Comments (6)

hycis avatar hycis commented on August 23, 2024

for now, the default is to return the sequence of outputs. If input is x_1, x_2, .. x_k, the output will be y_1, y_2, .. y_k. But if just want to take the the last output y_k, then for the get_output you can just take the last element of the sequence y=forward+backward; return y[:,-1,:]

from bidirectional_rnn.

cjmcmurtrie avatar cjmcmurtrie commented on August 23, 2024

Like this?

def get_output(self, train):
    forward = self.get_forward_output(train)
    backward = self.get_backward_output(train)
    if self.output_mode is 'sum':
        output = forward + backward
    elif self.output_mode is 'concat':
        output = T.concatenate([forward, backward], axis=2)
    else:
        raise Exception('output mode is not sum or concat')
    if self.return_sequences==False:
        return output[:,-1,:]
    elif self.return_sequences==True:
        return output
    else:
        raise Exception('return sequences didnt work')

from bidirectional_rnn.

cjmcmurtrie avatar cjmcmurtrie commented on August 23, 2024

In the Keras LSTM, return_sequences is handled this way:

    if self.return_sequences:
        return outputs.dimshuffle((1, 0, 2))
    return outputs[-1]

Are you sure the tensor doesn't need to be transposed? To plug into the other Keras modules?

Edit: Sorry actually I just realized you handled the transposition in get_forward_output and get_backward_output.

from bidirectional_rnn.

hycis avatar hycis commented on August 23, 2024

what you did is correct, since after dimension shuffle in get_forward_output and get_backward_output, output = (num_examples, num_seq, seq_len), so using output[:,-1,:] will take the last sequence.

from bidirectional_rnn.

cjmcmurtrie avatar cjmcmurtrie commented on August 23, 2024

Ok, it seems to be working. I'm using two stacked LSTMs like this:

model = Sequential()
model.add(BiDirectionLSTM(embedding_size, hidden_size, init=initialize))
model.add(Dense(hidden_size, hidden_size, init=initialize))
model.add(Activation('relu'))
model.add(RepeatVector(maxlen))
model.add(BiDirectionLSTM(hidden_size, hidden_size, return_sequences=True, init=initialize))
model.add(TimeDistributedDense(hidden_size, output_size, activation='softmax', init=initialize))

And it seems to be training ok (I'm actually trying to decide if I need two bi-dir LSTMs or just one at the input layer side).

I've forked (sorry a bit new to Github) but I can add the changes here if you like.

from bidirectional_rnn.

hycis avatar hycis commented on August 23, 2024

sure feel free to add the changes and create a pull request if you can

from bidirectional_rnn.

Related Issues (11)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.