Giter Club home page Giter Club logo

arel's Introduction

No Metrics Are Perfect: Adversarial REward Learning for Visual Storytelling

This repo is the implementation of our paper "No Metrics Are Perfect: Adversarial Reward Learning for Visual Storytelling", which also provides a codebase for the task of visual storytelling.

In the AREL paper, we not only introduce a novel adversarial reward learning algorithm to generate more human-like stories given image sequences, but also empirically analyze the limitations of the automatic metrics for story evaluation. For more details, please check the latest version of the paper: https://arxiv.org/abs/1804.09160.

Prerequisites

  • Python 2.7
  • PyTorch 0.3
  • TensorFlow (optional, only using the fantastic tensorboard)
  • cuda & cudnn

Usage

1. Setup

Clone this github repository recursively:

git clone --recursive https://github.com/eric-xw/AREL.git ./

Download the preprocessed ResNet-152 features here and unzip it into DATADIR/resnet_features.

2. Supervised Learning

We use cross entropy loss to warm start the model first:

python train.py --id XE --data_dir DATADIR --start_rl -1

Check the file opt.py for more options, where you can play with some other settings.

3. AREL Learning

To train an AREL model, run

python train_AREL.py --id AREL --start_from_model PRETRAINED_MODEL

Note that PRETRAINED_MODEL can be data/save/XE/model.pth or some other saved models. Check opt.py for more information.

4. Monitor your training

TensorBoard is used to monitor the training process. Suppose you set the option checkpoint_path as data/save, then run

tensorboard --logdir data/save/tensorboard

And then open your browser and go to [IP address]:6006 (the default port for tensorboard is 6006).

5. Testing

To test the model's performance, run

python train.py --option test --beam_size 3 --start_from_model data/save/XE/model.pth

or

python train_AREL.py --option test --beam_size 3 --start_from_model data/save/AREL/model.pth

Reproducing our results

We uploaded our checkpoints and meta files to the IRL-ini-iter100-*. Please load the model from these folders by running

python train.py --option test --beam_size 3 --start_from_model [best_model_path]

If you find this code useful, please cite the paper

@InProceedings{xwang-2018-AREL,
  author = 	"Wang, Xin and Chen, Wenhu and Wang, Yuan-Fang and Wang, William Yang",
  title = 	"No Metrics Are Perfect: Adversarial Reward Learning for Visual Storytelling",
  booktitle = 	"Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)",
  year = 	"2018",
  publisher = 	"Association for Computational Linguistics",
  pages = 	"899--909",
  location = 	"Melbourne, Australia",
  url = 	"http://aclweb.org/anthology/P18-1083"
}

Acknowledgement

arel's People

Contributors

eric-xw avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

arel's Issues

The pretrained resnet152 from torchvision doesn't work

Sorry to bother you, I train the model as you did. And i try to test the model using some images from the network. I use the pretrained resnet152 model from torchvision, but the model doesn't work. I would be appreciate if you could tell me more details about the resnet152 you used. thanks for your generosity to share the code, it really helps. With best wishes!

Resuming doesn't work

Hello,

For train_AREL.py the resume_from option doesn't work. This is because the disc_optimizer is not saved as checkpoint during the training.

Though I guess this is trivial to fix!

KeyError: "tensor(159, device='cuda:0')"

Hi,

In step 2 -- 2. Supervised Learning, I warm start by running:
python train.py --id XE --data_dir DATADIR --start_rl -1

Then I met an error:

INFO     Epoch 1 - Iter 369 / 627, loss = 4.03909, time used = 0.135s
INFO     Epoch 1 - Iter 370 / 627, loss = 4.01009, time used = 0.135s
INFO     Epoch 1 - Iter 371 / 627, loss = 3.95547, time used = 0.135s
INFO     Epoch 1 - Iter 372 / 627, loss = 4.04058, time used = 0.135s
INFO     Evaluating...
Traceback (most recent call last):
  File "train.py", line 193, in <module>
    train(opt)
  File "train.py", line 152, in train
    val_loss, predictions, metrics = evaluator.eval_story(model, crit, dataset, val_loader, opt)
  File "/data/lihaozheng/AREL/eval_utils.py", line 151, in eval_story
    stories = utils.decode_story(dataset.get_vocab(), results)
  File "/data/lihaozheng/AREL/misc/utils.py", line 59, in decode_story
    txt = txt + ' ' + id2word[str(vocab_id)]
KeyError: "tensor(159, device='cuda:0')"

What can I do to fix it ?

Thanks a lot !

About the detailed options

I followed the instructions and trained the model successfully, but got a score a bit lower than that is reproted in the paper. Everything is set as default in 'opts.py'.
I got Cider:7.8 B-4:11.7 ROUGE:29.6 METEOR:34.5
Could you tell me how to adjust the parameters to achieve the performance reproted in the paper ?
Thank you.

OSError: [Errno 22] Invalid argument

Hello,
When I run train.py or train_AERL.py, I always encounter an issue at the step of evaluating when setting up scorers. Could you give me any advice?

setting up scorers...
INFO Evaluate iter 75/78 96.15%. Time used: 0.435925960541
INFO Evaluate iter 76/78 97.44%. Time used: 0.435518026352
INFO Evaluate iter 77/78 98.72%. Time used: 0.423863887787
Traceback (most recent call last):
File "/Users/ray/Desktop/VIST_Coding/AREL/train.py", line 193, in
train(opt)
File "/Users/ray/Desktop/VIST_Coding/AREL/train.py", line 152, in train
val_loss, predictions, metrics = evaluator.eval_story(model, crit, dataset, val_loader, opt)
File "/Users/ray/Desktop/VIST_Coding/AREL/eval_utils.py", line 164, in eval_story
metrics = self.measure() # compute all the language metrics
File "/Users/ray/Desktop/VIST_Coding/AREL/eval_utils.py", line 91, in measure
self.eval.evaluate(self.reference, predictions)
File "/Users/ray/Desktop/VIST_Coding/AREL/vist_eval/album_eval.py", line 31, in evaluate
(Meteor(), "METEOR"),
File "/Users/ray/Desktop/VIST_Coding/AREL/vist_eval/meteor/meteor.py", line 27, in init
stderr=subprocess.PIPE)
File "/Users/ray/anaconda2/envs/pytorch_gpu_0.3/lib/python2.7/subprocess.py", line 390, in init
errread, errwrite)
File "/Users/ray/anaconda2/envs/pytorch_gpu_0.3/lib/python2.7/subprocess.py", line 1000, in _execute_child
data = _eintr_retry_call(os.read, errpipe_read, 1048576)
File "/Users/ray/anaconda2/envs/pytorch_gpu_0.3/lib/python2.7/subprocess.py", line 121, in _eintr_retry_call
return func(*args)
OSError: [Errno 22] Invalid argument

The results of AREL (best) in Table 2

Thank you for sharing the code.
The results of AREL (best) in Table 2 are: 63.8, 39.1,23.2, 14.1, 35.0, 29.5 and 9.4. But the results I reproduced are: 60.0, 36.7, 22.0, 13.5, 35.2, 29.7 and 8.2. Can you tell me how can I get the results shown in Table 2? Thank you very much!

About the RL training option

Hello,
When I use the XE-ss-init you uploaded as base model to train RL models (CIDEr and METEOR), I can’t achieve the scores reported in this paper. So could you tell me the detail settings you use when you train RL models (not AREL), such as learning rate and rl_weight?
One more question, as shown in your paper, why using CIDEr as reward doesn’t help for improve any metrics performance, even on CIDEr metric.

Checkpoint after AREL training

Could you please release the checkpoint after the RL finetuning? From what I understand, it seems like you have only released the checkpoint after the CE training.

Some question about "sample_results" files

(1) It seems that you only provided a single human rating for each Turing Test and pairwise comparison, or did you combine the results of the 5 workers?

(2) What do the 0 and 1 in the CSV files represent? You have provided three ranks(Win, Tie, Lose) for comparison in your paper, but only 0 and 1 are available here. Is there a mistake here?

(3) Can you provide me with your full human evaluation data, I would like to do further analysis.

Could you share the trained Reward Model?

Hello,
Thanks for your work. But when I train using train_AREL, I can't achieve the scores reported in your paper. So could you share the trained reward model, then I can use the reward model to try to train the policy model on my own?

I use the "IRL-XE-ss-init" you share as PRETRAINED_MODEL for AREL training.
The setting is just like the default setting in opts.py. When after training 50 epochs, I run testing with beam size=3, I get these scores, even worse than IRL-XE-ss-init.

computing Bleu score ... {'testlen': 42401, 'reflen': 43988, 'guess': [42401, 41391, 40381, 39371], 'correct': [27334, 9991, 3414, 1354]} ratio: 0.9639219787214477 Bleu_1: 0.621 Bleu_2: 0.380 Bleu_3: 0.227 Bleu_4: 0.140 computing METEOR score ... METEOR: 0.348 computing Rouge score ... ROUGE_L: 0.292 computing CIDEr score ... CIDEr: 0.083 Test finished. Time used: 504.9425232410431

What is the model-best.pth in the repo?

Hello,

Firstly, thank you for open-sourcing the code!

I was wondering what is the model-best.pth included under data/save/IRL-XE-ss-init/. In the publication, you mention the following:

strong baseline model (XE-ss)

and

use the XE-ss model to initialize our policy model and further train it with AREL

So I am presuming it is the already pre-trained policy model architecture based model? If so, I can skip to Step-3 in the Usage guide of README?

Thank you!

GAN vs AREL?

Hello,

I am trying to contemplate the bases on which AREL loss calculation differs from that of GANs?

for policy model, the loss would still be:
rl_weight * loss_rl(generated story) + (1 - rl_weight) * loss_ce(generated story)

for reward model, the loss would be:
loss = -true_story_score + generated_story_score

how is it different from the AREL objectives?

Thank you!

how to get a new pretrained embedding.npy

Hello, i want to know if i expand the id2words and the words2id, how can l get the pretrained embedding.npy。Because i have more words, the VIST/embedding.npy can support the code to run. Thanks!

Unzip resnet features error

Hi, @eric-xw . When I unziped the resnet features file on my mac, it showed a error, which is error2: no such file or directory. I suppose there may be a damage in your resnet features file. Can you help me with this? Thank you!

About the operation in updating policy net

Can you explain the goal in computing value loss and action loss when you update the policy net? I don't think that the way to updata net is consistent with the formula in your paper.

Or what should I understand?

Reward calculated for training Generator?

In the train_AREL.py. When calculate the Reward for training the generator:

rewards = Variable(gen_score.data - 0 * normed_seq_log_probs.data)

why you minus the 0 * normed_seq_log_probs.data? in the commit history, i notice you use the 0.0001 * normed_seq_log_probs.data.

In the original paper, i think it corresponding to the Eq(9) and the normed_seq_log_probs might be the log π(W), so the coefficient should be 1. Could you tell me your reason?

NameError: global name 'gt_score' is not defined ?

When I run AREL Learning, I encounter an issue in train_AREL.py :

if flag.flag == "Disc":
                gt_prob = disc(target.view(-1, target.size(2)), feature_fc.view(-1, feature_fc.size(2)))
                loss = -torch.sum(gt_prob) + torch.sum(gen_score)

                avg_pos_prob = torch.mean(gt_score)
                avg_neg_prob = torch.mean(gen_score)

Is gt_prob instead of gt_score ? Or there is other file allocate gt_score ?

Number of epochs?

Hello,

Is it possible to disclose the number of epochs of AREL training done?

I could not find that info in the paper - https://arxiv.org/pdf/1804.09160.pdf

Though I can assume that it is 100 (max_epochs), it would be more credible if it comes from you.

the purpose of mask in train_AREL.py

Hello, Could you tell me the the purpose of this function "mask = to_contiguous( torch.cat([Variable(mask.data.new(mask.size(0), mask.size(1), 1).fill_(1)), mask[:, :, :-1]], 2))" in train_AREL.py

EOFException happened when unzip the ResNet-152 features file

Hello, @eric-xw
when I downloaded the ResNet-152 features file(VIST_resnet_features.zip) to my Mac, I found that the file size has reached 14.37Gb, and it cannot be unzipped by "unzip" command, so I tried to use the command "jar xvf VIST_resnet_features.zip" to unzip the feature file, but an error occurred:

java.io.EOFException: Unexpected end of ZLIB input stream
at java.util.zip.InflaterInputStream.fill(InflaterInputStream.java:240)
at java.util.zip.InflaterInputStream.read(InflaterInputStream.java:158)
at java.util.zip.ZipInputStream.read(ZipInputStream.java:194)
at java.util.zip.ZipInputStream.closeEntry(ZipInputStream.java:140)
at sun.tools.jar.Main.extractFile(Main.java:1072)
at sun.tools.jar.Main.extract(Main.java:981)
at sun.tools.jar.Main.run(Main.java:311)
at sun.tools.jar.Main.main(Main.java:1288)

I still got a folder "resnet_features" containing 6.69Gb file.

I tried to download the VIST_resnet_features.zip again, BUT this time I got a 5.56Gb zip file. I used the jar to unzip again, and the java.io.EOFException happened again and the extracted folder also containing 6.69Gb file.

I've got two question about this error. Is the java.io.EOFException normal? What are the real size of VIST_resnet_features.zip and extracted folder?

BTW, the java.io.EOFException happened on my Ubuntu server again.
This question had been haunting me for a long time. I would be appreciated it if someone could answer it.

KeyError: '\n 3156\n[torch.cuda.LongTensor of size () (GPU 0)]\n'

Hello,
When I run train.py and train_AREL.py, I have a problem. Could you please help me solve it?

INFO Initialize the parameters of the model
INFO pos reward -0.0206863172352 neg reward -0.0734111517668
Traceback (most recent call last):
File "/Users/ray/Downloads/AREL/train_AREL.py", line 230, in
train(opt)
File "/Users/ray/Downloads/AREL/train_AREL.py", line 145, in train
print("PREDICTION: ", utils.decode_story(dataset.get_vocab(), seq[:1].data)[0])
File "/Users/ray/Downloads/AREL/misc/utils.py", line 58, in decode_story
txt = txt + ' ' + id2word[str(vocab_id)]
KeyError: '\n 3156\n[torch.cuda.LongTensor of size () (GPU 0)]\n'

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.