Giter Club home page Giter Club logo

rewritenat's Introduction

RewriteNAT

This repo provides the code for reproducing our proposed RewriteNAT in EMNLP 2021 paper entitled "Learning to Rewrite for Non-Autoregressive Neural Machine Translation". RewriteNAT is a iterative NAT model which utilizes a locator component to explicitly learn to rewrite the erroneous translation pieces during iterative decoding.

Dependencies

Preprocessing

All the datasets are tokenized using the scripts from Moses except for Chinese with Jieba tokenizer, and splitted into subword units using BPE. The tokenized datasets are binaried using the script binaried.sh as follows:

python preprocess.py \
    --source-lang ${src} --target-lang ${tgt} \
    --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
    --destdir data-bin/${dataset} --thresholdtgt 0 --thresholdsrc 0 \ 
    --workers 64 --joined-dictionary

Train

All the models are run on 8 Tesla V100 GPUs for 300,000 updates with an effective batch size of 128,000 tokens apart from En→Fr where we make 500,000 updates to account for the data size. The training scripts train.rewrite.nat.sh is configured as follows:

python train.py \
    data-bin/${dataset} \
    --source-lang ${src} --target-lang ${tgt} \
    --save-dir ${save_dir} \
    --ddp-backend=no_c10d \
    --task translation_lev \
    --criterion rewrite_nat_loss \
    --arch rewrite_nonautoregressive_transformer \
    --noise full_mask \
    ${share_all_embeddings} \
    --optimizer adam --adam-betas '(0.9,0.98)' \
    --lr 0.0005 --lr-scheduler inverse_sqrt \
    --min-lr '1e-09' --warmup-updates 10000 \
    --warmup-init-lr '1e-07' --label-smoothing 0.1 \
    --dropout 0.3 --weight-decay 0.01 \
    --decoder-learned-pos \
    --encoder-learned-pos \
    --length-loss-factor 0.1 \
    --apply-bert-init \
    --log-format 'simple' --log-interval 100 \
    --fixed-validation-seed 7 \ 
    --max-tokens 4000 \
    --save-interval-updates 10000 \
    --max-update ${step} \
    --update-freq 4 \ 
    --fp16 \
    --save-interval ${save_interval} \
    --discriminator-layers 6 \ 
    --train-max-iter ${max_iter} \
    --roll-in-g sample \
    --roll-in-d oracle \
    --imitation-g \
    --imitation-d \
    --discriminator-loss-factor ${discriminator_weight} \
    --no-share-discriminator \
    --generator-scale ${generator_scale} \
    --discriminator-scale ${discriminator_scale} \

Evaluation

We evaluate performance with BLEU for all language pairs, except for En->Zh, where we use SacreBLEU. The testing scripts test.rewrite.nat.sh is utilized to generate the translations, as follows:

python generate.py \                                            
    data-bin/${dataset} \                                          
    --source-lang ${src} --target-lang ${tgt} \                    
    --gen-subset ${subset} \                                       
    --task translation_lev \                                       
    --path ${save_dir}/${dataset}/checkpoint_average_${suffix}.pt \
    --iter-decode-max-iter ${max_iter} \                           
    --iter-decode-with-beam ${beam} \                              
    --iter-decode-p ${iter_p} \                                    
    --beam 1 --remove-bpe \                                        
    --batch-size 50\                                               
    --print-step \                                                 
    --quiet 

Citation

Please cite as:

@inproceedings{geng-etal-2021-learning,
    title = "Learning to Rewrite for Non-Autoregressive Neural Machine Translation",
    author = "Geng, Xinwei and Feng, Xiaocheng and Qin, Bing",
    booktitle = "Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing",
    month = nov,
    year = "2021",
    address = "Online and Punta Cana, Dominican Republic",
    publisher = "Association for Computational Linguistics",
    url = "https://aclanthology.org/2021.emnlp-main.265",
    pages = "3297--3308",
}

rewritenat's People

Contributors

xwgeng avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

rewritenat's Issues

How to install dependencies?

Hi all,

Thanks for your awesome work and codes!

I tried to run the code and used the following commands to build the environment:

conda create -n RewirteNAT python=3.6
conda activate RewirteNAT
conda install pytorch==1.2.0 torchvision==0.4.0 cudatoolkit=10.0 -c pytorch
cd RewirteNAT
pip install --editable ./

When pip install --editable ./ ran, I got errors like this:

    ERROR: Command errored out with exit status 1:
     command: /home/azureuser/miniconda3/envs/rewrite/bin/python -c 'import io, os, sys, setuptools, tokenize; sys.argv[0] = '"'"'/home/azureuser/rewrite/setup.py'"'"'; __file__='"'"'/home/azureuser/rewrite/setup.py'"'"';f = getattr(tokenize, '"'"'open'"'"', open)(__file__) if os.path.exists(__file__) else io.StringIO('"'"'from setuptools import setup; setup()'"'"');code = f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, __file__, '"'"'exec'"'"'))' develop --no-deps
         cwd: /home/azureuser/rewrite/
    Complete output (14 lines):
    running develop
    running egg_info
    creating fairseq.egg-info
    writing fairseq.egg-info/PKG-INFO
    writing dependency_links to fairseq.egg-info/dependency_links.txt
    writing entry points to fairseq.egg-info/entry_points.txt
    writing requirements to fairseq.egg-info/requires.txt
    writing top-level names to fairseq.egg-info/top_level.txt
    writing manifest file 'fairseq.egg-info/SOURCES.txt'
    reading manifest file 'fairseq.egg-info/SOURCES.txt'
    writing manifest file 'fairseq.egg-info/SOURCES.txt'
    running build_ext
    cythoning fairseq/data/data_utils_fast.pyx to fairseq/data/data_utils_fast.cpp
    error: /home/azureuser/rewrite/fairseq/data/data_utils_fast.pyx
    ----------------------------------------
ERROR: Command errored out with exit status 1: /home/azureuser/miniconda3/envs/rewrite/bin/python -c 'import io, os, sys, setuptools, tokenize; sys.argv[0] = '"'"'/home/azureuser/rewrite/setup.py'"'"'; __file__='"'"'/home/azureuser/rewrite/setup.py'"'"';f = getattr(tokenize, '"'"'open'"'"', open)(__file__) if os.path.exists(__file__) else io.StringIO('"'"'from setuptools import setup; setup()'"'"');code = f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, __file__, '"'"'exec'"'"'))' develop --no-deps Check the logs for full command output.

Could you give me some instructions to build your environment?

Thanks,
hemingkx

About training

Hi,
When I use the default hyperparameters to train on IWSLT 14 DE-EN distill datasets:
image
I got this
image
We both try train_max_iter as 2 or 4, but i always meet the above problem, i wonder if i have some errors or could you give some advice?

ModuleNotFoundError: No module named 'fairseq.data.append_token_dataset'

Hi~ thank you for sharing the codes.

I installed the Dependencies by the following commands:

conda install pytorch==1.2.0 torchvision==0.4.0 cudatoolkit=10.0 -c pytorch
cd RewirteNAT
pip install --editable ./

following the Issue 1, but I got this error when I try to run preprocess.py.

(rewrite_nat) [wbxu@cu10 RewriteNAT]$ bash preprocess.sh
Traceback (most recent call last):
  File "preprocess.py", line 7, in <module>
    from fairseq_cli.preprocess import cli_main
  File "/data/wbxu/RewriteNAT/fairseq_cli/preprocess.py", line 18, in <module>
    from fairseq import options, tasks, utils
  File "/data/wbxu/RewriteNAT/fairseq/__init__.py", line 9, in <module>
    import fairseq.criterions  # noqa
  File "/data/wbxu/RewriteNAT/fairseq/criterions/__init__.py", line 10, in <module>
    from fairseq.criterions.fairseq_criterion import FairseqCriterion
  File "/data/wbxu/RewriteNAT/fairseq/criterions/fairseq_criterion.py", line 10, in <module>
    from fairseq import metrics, utils
  File "/data/wbxu/RewriteNAT/fairseq/utils.py", line 21, in <module>
    from fairseq.modules import gelu, gelu_accurate
  File "/data/wbxu/RewriteNAT/fairseq/modules/__init__.py", line 9, in <module>
    from .character_token_embedder import CharacterTokenEmbedder
  File "/data/wbxu/RewriteNAT/fairseq/modules/character_token_embedder.py", line 13, in <module>
    from fairseq.data import Dictionary
  File "/data/wbxu/RewriteNAT/fairseq/data/__init__.py", line 12, in <module>
    from .append_token_dataset import AppendTokenDataset
ModuleNotFoundError: No module named 'fairseq.data.append_token_dataset'

Could you tell how to deal with this error? Thank you very much~

About Performance

Hi,
Thanks for your awesome work!
When I use the default hyperparameters to train on WMT2014'En-De Data, I got:

  • train-max-iter=2, 3 days on 8 * V100, SacreBleu 18.0 (1 iters), 24.1 (2 iters), 25.5 (2 iters) and 25.6 (10 iters).
  • train-max-iter=4, 6 days on 8 * V100, SacreBleu 17.0 (1 iters), 24.1 (2 iters), 25.9 (5 iters) and 26.0 (10 iters).

I wonder if there is something wrong with my training and testing scripts. Here are my scripts:

Training:

max_iter=4

src=en
tgt=de

step=300000

share_all_embeddings="--share-all-embeddings"

save_interval=1

python train.py \
    ${dataset} \
    --source-lang ${src} --target-lang ${tgt} \
    --save-dir /mnt/exp/project \
    --ddp-backend=no_c10d \
    --task translation_lev \
    --criterion rewrite_nat_loss \
    --arch rewrite_nonautoregressive_transformer \
    --noise full_mask \
    ${share_all_embeddings} \
    --optimizer adam --adam-betas '(0.9,0.98)' \
    --lr 0.0005 --lr-scheduler inverse_sqrt \
    --min-lr '1e-09' --warmup-updates 10000 \
    --warmup-init-lr '1e-07' --label-smoothing 0.1 \
    --dropout 0.3 --weight-decay 0.01 \
    --decoder-learned-pos \
    --encoder-learned-pos \
    --length-loss-factor 0.1 \
    --apply-bert-init \
    --log-format 'simple' --log-interval 100 \
    --fixed-validation-seed 7 \
    --max-tokens 4000 \
    --save-interval-updates 10000 \
    --max-update ${step} \
    --update-freq 4 \
    --fp16 \
    --discriminator-layers 6 \
    --train-max-iter ${max_iter} \
    --roll-in-g sample \
    --roll-in-d oracle \
    --imitation-g \
    --imitation-d \
    --no-share-discriminator \
    --reset-optimizer \
    --reset-meters \
    --reset-dataloader \
    --reset-lr-scheduler

Testing:

src=en
tgt=de

subset=test

max_iter=1

beam=1

iter_p=0.5

python generate.py \
    ${dataset} \
    --source-lang ${src} --target-lang ${tgt} \
    --gen-subset ${subset} \
    --task translation_lev \
    --criterion rewrite_nat_loss \
    --path ${save_dir} \
    --iter-decode-max-iter ${max_iter} \
    --iter-decode-with-beam ${beam} \
    --iter-decode-p ${iter_p} \
    --beam 1 --remove-bpe \
    --batch-size 25 \
    --print-step

Thanks very much!
hemingkx

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.