Giter Club home page Giter Club logo

seqgan-tensorflow's Introduction

SeqGAN in Tensorflow

As part of the implementation series of Joseph Lim's group at USC, our motivation is to accelerate (or sometimes delay) research in the AI community by promoting open-source projects. To this end, we implement state-of-the-art research papers, and publicly share them with concise reports. Please visit our group github site for other projects.

This project is implemented by Shaofan Lai and reviewed by Shao-Hua Sun.

Descriptions

This project includes a [Tensorflow] implementation of SeqGAN proposed in the paper [SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient] by Lantao Yu et al. at Shanghai Jiao Tong University and University College London.

SeqGAN adapts GAN for sequential generation. It regards the generator as a policy in reinforcement learning and the discriminator is trained to provide the reward. To evaluate unfinished sequences, Monto-Carlo search is also applied to sample the complete sequences.

We use the advanced (decoder) APIs provided by the tensorflow.contribs.seq2seq module to implement SeqGAN. Notice that it's extremly hard to select the hyper-parameters of SeqGAN as in GAN. And it is possible that the SeqGAN performs much more poorly than the supervised learning (MLE) method in other tasks if the hyper-parameters are randomly chosen.

Prerequisites

Usage

Datasets

  • A randomly initialized LSTM is used to simulate a specific distribution.
  • A music dataset contains multiple Nottingham Songs.

Training LSTM

  • Run
python2 main.py --pretrain_g_epochs 2000 --total_epochs 0 --log_dir logs/train/pure_pretrain --eval_log_dir logs/eval/pure_pretrain

to have a baseline that is trained with pure MLE loss for 2000 iterations.

  • Run
python2 main.py --pretrain_g_epochs 1000 --total_epochs 1000 --log_dir logs/train/pretrain_n_seqgan  --eval_log_dir logs/eval/with_seqgan

to train the model with first pretraining loss and then SeqGAN's loss.

  • Run
tensorboard --logdir logs/eval/

and open your browser to check the improvement that SeqGAN provided.

Music generation

  • Run bash train_nottingham.sh to train the model. Check data/Nottingham/*.mid for generations. The songs will be updated every 100 epochs.

Results

In this figure, the blue line is Negative Log Likelihood(NLL) of purely using supervised learning (MLE loss) to train the generator, while the orange one is first using MLE to pretrain and then optimizing the adversarial loss. Two curves overlap with each other at the beginning since the same random seed is used. After using SeqGAN's loss, the NLL drops and converges to a smaller loss, which indicates that the generated sequences match the distribution of the randomly intialized LSTM better.

Related works

Yu, Lantao, et al. "SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient." AAAI. 2017.

Author

Shaofan Lai / @shaofanl @ Joseph Lim's research lab @ USC

seqgan-tensorflow's People

Contributors

shaofanl avatar

Stargazers

Murat GUZEL avatar  avatar Rey Pocius avatar Bohui Zhang avatar  avatar Kevin Jordan avatar  avatar  avatar Wentao Shi avatar hz avatar Niall Walsh avatar Vedant Choudhary avatar Wei Li avatar Yunxuan Xiao avatar Shijun avatar bookcold avatar felix-wang avatar Yaoming avatar

Watchers

 avatar paper2code - bot avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.