Giter Club home page Giter Club logo

my_seq2seq's Introduction

Seq2Seq models

This is a project to learn to implement different s2s model on tensorflow.

This project is only used for learning, which means it will contain many bugs. I suggest to use nmt project to do experiments and train seq2seq models. You can find it in the reference part.

Experiments

I am experimenting the copynet and pg on lcsts dataset, you can find the code in the lcsts branch.

Issues and suggestions are welcomed.

Models

The models I have implemented are as following:

  • Basic seq2seq model
    • A model with bi-direction RNN encdoer and attention mechanism
  • Seq2seq model
    • Same as basic model, but using tf.data pipeline to process input data
  • GNMT model
    • Residual conection and attention same as GNMT model to speed up training
    • refer to GNMT for more details
  • Pointer-Generator model
  • CopyNet model
    • A model also support copy mechanism
    • refer to CopyNet for more details.

For the implement details, refer to ReadMe in the model folder.

Structure

A typical sequence to sequence(seq2seq) model contains an encoder, an decoder and an attetion structure. Tensorflow provide many useful apis to implement a seq2seq model, usually you will need belowing apis:

  • tf.contrib.rnn
    • Different RNNs
  • tf.contrib.seq2seq
    • Provided different attention mechanism and also a good implementation of beam search
  • tf.data
    • data preproces pipeline apis
  • Other apis you need to build and train a model

Encoder

Use either:

  • Multi-layer rnn
    • use the last state of the last layer rnn as the initial decode state
  • Bi-direction rnn
    • use a Dense layer to convert the fw and bw state to the initial decode state
  • GNMT encoder
    • a bidirection rnn + serveral rnn with residual conection

Decoder

  • Use multi-layer rnn, and set the inital state of each layer to initial decode state
  • GNMT decoder
    • only apply attention to the bottom layer of decoder, so we can utilize multi gpus during training

Attention

  • Bahdanau
  • Luong

Metrics

Right now I only have cross entropy loss. Will add following metrics:

  • bleu
    • for translation problems
  • rouge
    • for summarization problems

Dependency

  • Using tf-1.4
  • Python 3

Run

Run the model on a toy dataset, ie. reverse the sequence

train:

python -m bin.toy_train

inference:

python -m bin.toy_inference

Also you can run on en-vi dataset, refer to en_vietnam_train.py in bin for more details.

You can find more training scripts in bin directory.

Reference

Thanks to following resources:

my_seq2seq's People

Contributors

xueyouluo avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

my_seq2seq's Issues

bidirectional encoder后面的decoder

如果用biRNN, decoder的hidden_size不应该是encoder的两倍吗?我看你的code里面好像没有这样的设定,关键是code还能跑。请指正?

另外,我发现当num_layers>2 时会出错。事实上,只有当encoder的layer等于1,decoder的layer等于2的时候,code是work的。从下面的code可以看出:
fw_cell = get_rnn_cell('gru',self.dim_size,num_layers= 1, train_phase=self.train_phase, keep_prob=self.keep_prob_config)
bw_cell = get_rnn_cell('gru',self.dim_size,num_layers= 1, train_phase=self.train_phase, keep_prob=self.keep_prob_config)
dec_cell = get_rnn_cell('gru', self.dim_size, num_layers=self.num_layers, train_phase=self.train_phase,
keep_prob=self.keep_prob_config)

attention mechanism on the top layer

From this repository, https://github.com/JayParks/tf-seq2seq/blob/master/seq2seq_model.py , I noticed a note saying:
# Note: We implement Attention mechanism only on the top decoder layer

Correspondingly, the code looks like:

    self.decoder_cell_list[-1] = attention_wrapper.AttentionWrapper(
        cell=self.decoder_cell_list[-1],
        attention_mechanism=self.attention_mechanism,
        attention_layer_size=self.hidden_units,
        cell_input_fn=attn_decoder_input_fn,
        initial_cell_state=encoder_last_state[-1],
        alignment_history=False,
        name='Attention_Wrapper')

Why do you add attention to the zeroth decoder layer ? I think the top layer version sounds more reasonable. Thanks!

beam_size=1?

Does anyone meet the problem that the parameter 'beam_size' can only be set to 1, or it will result in errors?

运行pointer模型,报以下错误

File "D:\App\novel\FunText\src\models\pointernet\PointerNetGenerator.py", line 108, in _build_decoder_cell
initial_state = cell.zero_state(self._batch_size, tf.float32).clone(cell_state=cell_state)
File "D:\App\novel\FunText\src\models\pointernet\PointerNetHelper.py", line 202, in zero_state
for _ in self._attention_mechanisms))
TypeError: new() missing 1 required positional argument: 'attention_state'

Thank you very much

This is not an actual issue, I just want to thank you very much, such a great repository, and I did fix some bugs for deprecated issue, but that just around less than 1% what you already done. Also I added Luong Attention for your pointer generator model.

I tested on real dataset, X = body of news, Y = title of news. Accuracy is based on sequential cross-entropy, never tested using rouge. All this after 10 epochs, 80% to train 20% to test.

On bahdanau attention, https://github.com/huseinzol05/NLP-Models-Tensorflow/blob/master/abstractive-summarization/7.xueyouluo-pointer-generator-bahdanau.ipynb

epoch: 9, avg loss: 6.818628, avg accuracy: 0.913240
epoch: 9, avg loss test: 16.073720, avg accuracy test: 0.870706

On luong attention, https://github.com/huseinzol05/NLP-Models-Tensorflow/blob/master/abstractive-summarization/9.xueyouluo-pointer-generator-luong.ipynb

epoch: 9, avg loss: 4.113721, avg accuracy: 0.988244
epoch: 9, avg loss test: 14.028113, avg accuracy test: 0.915375

Thank you so much!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.