Giter Club home page Giter Club logo

rnn.wgan's Introduction

Language Generation with Recurrent Generative Adversarial Networks without Pre-training

Code for training and evaluation of the model from "Language Generation with Recurrent Generative Adversarial Networks without Pre-training".

A short summary of the paper is available here.

Sample outputs (32 chars)

" There has been to be a place w
On Friday , the stories in Kapac
From should be taken to make it 
He is conference for the first t
For a lost good talks to ever ti

Training

To start training the CL+VL+TH model, first download the dataset, available at http://www.statmt.org/lm-benchmark/, and extract it into the ./data directory.

Then use the following command:

python curriculum_training.py

The following packages are required:

  • Python 2.7
  • Tensorflow 1.1
  • Scipy
  • Matplotlib

The following parameters can be configured:

LOGS_DIR: Path to save model checkpoints and samples during training (defaults to './logs/')
DATA_DIR: Path to load the data from (defaults to './data/1-billion-word-language-modeling-benchmark-r13output/')
CKPT_PATH: Path to checkpoint file when restoring a saved model
BATCH_SIZE: Size of batch (defaults to 64)
CRITIC_ITERS: Number of iterations for the discriminator (defaults to 10)
GEN_ITERS: Number of iterations for the geneartor (defaults to 50)
MAX_N_EXAMPLES: Number of samples to load from dataset (defaults to 10000000)
GENERATOR_MODEL: Name of generator model (currently only 'Generator_GRU_CL_VL_TH' is available)
DISCRIMINATOR_MODEL: Name of discriminator model (currently only 'Discriminator_GRU' is available)
PICKLE_PATH: Path to PKL directory to hold cached pickle files (defaults to './pkl')
ITERATIONS_PER_SEQ_LENGTH: Number of iterations to run per each sequence length in the curriculum training (defaults to 15000)
NOISE_STDEV: Standard deviation for the noise vector (defaults to 10.0)
DISC_STATE_SIZE: Discriminator GRU state size (defaults to 512)
GEN_STATE_SIZE: Genarator GRU state size (defaults to 512)
TRAIN_FROM_CKPT: Boolean, set to True to restore from checkpoint (defaults to False)
GEN_GRU_LAYERS: Number of GRU layers for the genarator (defaults to 1)
DISC_GRU_LAYERS: Number of GRU layers for the discriminator (defaults to 1)
START_SEQ: Sequence length to start the curriculum learning with (defaults to 1)
END_SEQ: Sequence length to end the curriculum learning with (defaults to 32)
SAVE_CHECKPOINTS_EVERY: Save checkpoint every # steps (defaults to 25000)
LIMIT_BATCH: Boolean that indicates whether to limit the batch size  (defaults to true)

Parameters can be set by either changing their value in the config file or by passing them in the terminal:

python curriculum_training.py --START_SEQ=1 --END_SEQ=32

Generating text

The generate.py script will generate BATCH_SIZE samples using a saved model. It should be run using the parameters used to train the model (if they are different than the default values). For example:

python generate.py --CKPT_PATH=/path/to/checkpoint/seq-32/ckp --DISC_GRU_LAYERS=2 --GEN_GRU_LAYERS=2

(If your model has not reached stage 32 in the curriculum, make sure to change the '32' in the path above to the maximal stage in the curriculum that your model trained on.)

Evaluating text

To evaluate samples using our %-IN-TEST-n metrics, use the following command, linking to a txt file where each row is a sample:

python evaluate.py --INPUT_SAMPLE=/path/to/samples.txt

Reference

If you found this code useful, please cite the following paper:

@article{press2017language,
  title={Language Generation with Recurrent Generative Adversarial Networks without Pre-training},
  author={Press, Ofir and Bar, Amir and Bogin, Ben and Berant, Jonathan and Wolf, Lior},
  journal={arXiv preprint arXiv:1706.01399},
  year={2017}
}

Acknowledgments

This repository is based on the code published in Improved Training of Wasserstein GANs.

rnn.wgan's People

Contributors

amirbar avatar benbogin avatar nickshahml avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

rnn.wgan's Issues

generate errors

HI, I use the default setting for curriculum_training, than use
--GENERATOR_MODEL=Generator_RNN_CL_VL_TH --DISCRIMINATOR_MODEL=Discriminator_RNN --ckpt_path=F:\NLP\rnn-wgan\logs\Generator_RNN_CL_VL_TH-Discriminator_RNN-50-10-512-512-1510731157.353556-\checkpoint
for generate script parameters. While error happens in model.py.
File "F:\NLP\rnn-wgan\model.py", line 52, in Generator_RNN_CL_VL_TH
cells.append(rnn_cell(num_neurons))
TypeError: 'NoneType' object is not callable
Could you plz give me some suggestions?

Is there a way to generate variable length text?

It seems current code can only work with fixed length text generation, then generator and discriminator are trained for fixed length input/output. I'm wondering is there a way to generate variable length text?

NoneType object Error when generating

Hi, I'm interested in your RNN-GANs model but I'm still new to this domain. I'm trying to run your code but I get some problems.
The first one is when I run the generating command:
python generate.py --CKPT_PATH=/logs/Generator_RNN_CL_VL_TH-Discriminator_RNN-50-10-512-512-1509799522.05-/checkpoint --DISC_GRU_LAYERS=2 --GEN_GRU_LAYERS=2
The system outputs this error:

File "generate.py", line 20, in <module>
    _, inference_op = Generator(BATCH_SIZE, charmap_len, seq_len=SEQ_LEN)
  File "/.../rnn.wgan/model.py", line 51, in Generator_RNN_CL_VL_TH
    cells.append(rnn_cell(num_neurons))
TypeError: 'NoneType' object is not callable

The second one is when I train the model by running: python curriculum_training.py. It takes a lot of time, 3 days for training seq-1 to seq-4. I run this code on a server with 256CPU , 1536Core, Intel Xeon E5-4655v3 [6Core, 30M Cache, 2.9GHz], 32TB of RAM (DDR4).

Can you help me about that? Thank you in advance.

Fisher GAN and wider GRU's Increase performance

@amirbar thanks for releasing this code. I've been doing some experimentation, and I have found the following results:

  1. Fisher GAN seems to keep approximately the same performance, but speeds up computation. I personally think it helps the discriminator more because the gradients don't vanish as much compared to WGAN-GP.

  2. Simply increasing the GRU's state size to 1024 seems to improve performance as well.

  3. As confirmed by @ofirpress on reddit, I can't get any temporal convolution discriminator's to work well.

I don't know if you're team is still pursuing this direction of research but just wanted to notify you in case you were.

Adapt code for Lyrics Generation

Hi,
thanks for the repo, I think it's an original idea to apply GANs.

I would ask for some advice for adapting this code for lyrics. I have large availability of lyrics, but I don't have an idea to adapt the code for my task.
Trying to be more clear: I would create a GAN that, at the end of the training steps, is able to generate lyrics (general format lyrics, not specified on particular artist) and at the same time use the trained discriminator (after training) to recognize lyrics that doesn't match format ( sort of spam detection).

Do you think it's a good idea? Could you suggest something in this way?
Thanks

Simone

WGAN-GP vs. FisherGAN on langauge generation in practice

Could you share your experience of WGAN-GP vs. FisherGAN on learning language generator?
Conceptually, I think 'sampling along the line between real sequence vs. generated sequence' of WGAN-GP sounds not natural for sequence because they can be different length. But I want to know which method works well in practice.

cf) If this is not appropriate topic for Issues, I will delete this post.

alpha_optimizer_op not actually run?

In the fisher GAN implementation, alpha_optimizer_op is not returned and is not run anywhere. Isn't it necessary to add the op as a control_dependency?

RNN with GANs vs Independent RNN Model

So I am researching in language models that can generate words.

Your results showed that RNN + GANs improves the quality of generated sequences compared to CNN+GANs.
But do you think that the combination of RNN and GANs performs better than an independent RNN model?(LSTM, GRU)
Because the resulting sentences you got from RNN+GANs are not coherent, and a well-trained RNN model can do the same job. Also, when I was training your model, I feel like using CL+VL+TH is very time-consuming. So is it really worth to train an RNN with GANs? Or the purpose of this project is just to prove that RNN could work well with GANs?

Thanks,
Sida Sun

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.