Giter Club home page Giter Club logo

disco's Introduction

DisCo Transformer for Non-autoregressive MT

Download trained DisCo transformers

All models are trained with distillation. See our paper.

Pretrained Models
WMT14 English-German WMT14 German-English WMT16 English-Romanian WMT16 Romanian-English
WMT17 English-Chinese WMT17 Chinese-English WMT14 English-French

We also provide our knowledge distillation data for WMT14 EN-DE.

Preprocess

text=PATH_YOUR_DATA

output_dir=PATH_YOUR_OUTPUT

src=source_language

tgt=target_language

model_path=PATH_TO_MASKPREDICT_MODEL_DIR

python preprocess.py --source-lang ${src} --target-lang ${tgt} --trainpref $text/train \
    --validpref $text/valid --testpref $text/test  --destdir ${output_dir}/data-bin \
    --workers 60 --srcdict ${model_path}/maskPredict_${src}_${tgt}/dict.${src}.txt \
    --tgtdict ${model_path}/maskPredict_${src}_${tgt}/dict.${tgt}.txt

Train

model_dir=PLACE_TO_SAVE_YOUR_MODEL

python train.py ${output_dir}/data-bin --arch disco_transformer \
    --criterion label_smoothed_length_cross_entropy --label-smoothing 0.1 --lr 5e-4 \
    --warmup-init-lr 1e-7 --min-lr 1e-9 --lr-scheduler inverse_sqrt --warmup-updates 10000 \
    --optimizer adam --adam-betas '(0.9, 0.999)' --adam-eps 1e-6 --task translation_self \
    --max-tokens 8192 --weight-decay 0.01 --dropout 0.2 --encoder-layers 6 --encoder-embed-dim 512 \
    --decoder-layers 6 --decoder-embed-dim 512 --fp16 --max-source-positions 10000 \
    --max-target-positions 10000 --max-update 300000 --seed 1 \
    --save-dir ${model_dir} --dynamic-masking  --ignore-eos-loss --share-all-embeddings

Evaluation

We provide two inference methods:

  1. Parallel Easy-First (Kasai et al., 2020)
python generate_disco.py ${output_dir}/data-bin --path ${model_dir}/checkpoint_best.pt \
    --task translation_self --remove-bpe --max-sentences 20 --decoding-iterations 10 \
    --decoding-strategy easy_first --length-beam 5
  1. Mask-Predict (Ghazvininejad et al., 2019)
python generate_disco.py ${output_dir}/data-bin --path ${model_dir}/checkpoint_best.pt \
    --task translation_self --remove-bpe --max-sentences 20 --decoding-iterations 10 \
    --decoding-strategy mask_predict --length-beam 5

License

DisCo is CC-BY-NC 4.0. The license applies to the trained models as well.

Citation

Please cite as:

@inproceedings{Kasai2020DisCo,
  title = {Non-autoregressive Machine Translation with Disentangled Context Transformer},
  author = {Jungo Kasai and James Cross and Marjan Ghazvininejad and Jiatao Gu},
  booktitle = {Proc. of ICML},
  year = {2020},
  url = {https://arxiv.org/abs/2001.05136},
}

Note

We based this code heavily on the original mask-predict implementation.

disco's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

disco's Issues

need suggestions for ende training

Hi, Kasai ~ Great work!

For EnDe, I can got ~27.3 with your pretrained model, but when I trained from scratch with your provided distilled data, I can get merely ~24 BLEU with weird loss curve. Can you give me some advices to properly reproduce your results?

I trained it on 8 V100 GPUs with following script:

python train.py ./wmt16.en-de.disco.dist/ --arch disco_transformer \
--criterion label_smoothed_length_cross_entropy \
--label-smoothing 0.1 \
--lr 5e-4 --warmup-init-lr 1e-7 --min-lr 1e-9 --lr-scheduler inverse_sqrt \
--warmup-updates 10000 \
--optimizer adam --adam-betas '(0.9, 0.999)' --adam-eps 1e-6 \
--task translation_self \
--max-tokens 16000 \
--weight-decay 0.01 \
--dropout 0.2 \
--encoder-layers 6 --encoder-embed-dim 512 --decoder-layers 6 --decoder-embed-dim 512 \
--max-source-positions 10000  --max-target-positions 10000 \
--max-update 100000 --seed 1 \
--save-dir checkpoints/wmt16.en-de.nat_authordata \
--dynamic-masking --ignore-eos-loss --share-all-embeddings \
--keep-last-epochs 20  \
--no-progress-bar --log-format simple --log-interval 100 --save-interval-updates 2000 \
--fp16 --ddp-backend=c10d --update-freq 4

The valid loss during training looks weird:

Namespace(adam_betas='(0.9, 0.999)', adam_eps=1e-06, adaptive_softmax_cutoff=None, adaptive_softmax_dropout=0, arch='disco_transformer', at_only=False, at_rm=False, attention_dropout=0.0, best_checkpoint_metric='loss', bilm_add_bos=False, bilm_attention_dropout=0.0, bilm_mask_last_state=False, bilm_model_dropout=0.1, bilm_relu_dropout=0.0, bucket_cap_mb=25, clip_norm=25, cpu=False, criterion='label_smoothed_length_cross_entropy', curriculum=0, data=['wmt16.en-de.disco.dist'], dataset_impl=None, ddp_backend='c10d', decoder_attention_heads=8, decoder_embed_dim=512, decoder_embed_path=None, decoder_embed_scale=None, decoder_ffn_embed_dim=2048, decoder_input_dim=512, decoder_layers=6, decoder_learned_pos=False, decoder_normalize_before=False, decoder_output_dim=512, device_id=0, disable_validation=False, distributed_backend='nccl', distributed_init_method='tcp://localhost:12097', distributed_no_spawn=False, distributed_port=-1, distributed_rank=0, distributed_world_size=8, dropout=0.2, dynamic_length=False, dynamic_masking=True, embedding_only=False, encoder_attention_heads=8, encoder_embed_dim=512, encoder_embed_path=None, encoder_embed_scale=None, encoder_ffn_embed_dim=2048, encoder_layers=6, encoder_learned_pos=False, encoder_normalize_before=False, find_unused_parameters=False, fix_batches_to_gpus=False, fp16=True, fp16_init_scale=128, fp16_scale_tolerance=0.0, fp16_scale_window=None, full_masking=False, ignore_eos_loss=True, keep_interval_updates=-1, keep_last_epochs=20, label_smoothing=0.1, left_pad_source='True', left_pad_target='False', log_format='simple', log_interval=100, lr=[0.0005], lr_scheduler='inverse_sqrt', mask_range=False, maskp=False, max_epoch=0, max_sentences=None, max_sentences_valid=None, max_source_positions=10000, max_target_positions=10000, max_tokens=16000, max_tokens_valid=16000, max_update=300000, maximize_best_checkpoint_metric=False, memory_efficient_fp16=False, min_loss_scale=0.0001, min_lr=1e-09, mix_masking=False, no_dec_token_positional_embeddings=False, no_enc_token_positional_embeddings=False, no_epoch_checkpoints=False, no_last_checkpoints=False, no_progress_bar=True, no_save=False, no_save_optimizer_state=False, num_workers=0, optimizer='adam', optimizer_overrides='{}', perm_only=False, raw_text=False, relu_dropout=0.0, required_batch_size_multiple=8, reset_dataloader=False, reset_lr_scheduler=False, reset_meters=False, reset_optimizer=False, restore_file='checkpoint_last.pt', save_dir='checkpoints/wmt16.en-de.nat_authordata', save_interval=1, save_interval_updates=2000, seed=1, self_target=False, sentence_avg=False, share_all_embeddings=True, share_decoder_input_output_embed=False, share_layers=False, skip_eos=False, skip_invalid_size_inputs_valid_test=False, source_lang=None, target_lang=None, task='translation_self', tbmf_wrapper=False, tensorboard_logdir='', threshold_loss_scale=None, train_subset='train', update_freq=[4], upsample_primary=1, use_bmuf=False, user_dir=None, valid_subset='valid', validate_interval=1, warmup_init_lr=1e-07, warmup_updates=10000, weight_decay=0.01)
| $path/wmt16.en-de.disco.dist/ valid 3000 examples
| epoch 001 | valid on 'valid' subset | loss 13.232 | nll_loss 12.570 | ppl 6079.24 | num_updates 275 | length_loss 11.3722
| epoch 002 | valid on 'valid' subset | loss 12.160 | nll_loss 11.408 | ppl 2717.66 | num_updates 555 | best_loss 12.1599 | length_loss 8.86187
| epoch 003 | valid on 'valid' subset | loss 11.816 | nll_loss 11.056 | ppl 2129.57 | num_updates 834 | best_loss 11.8161 | length_loss 8.1737
| epoch 004 | valid on 'valid' subset | loss 11.327 | nll_loss 10.515 | ppl 1463.34 | num_updates 1113 | best_loss 11.3275 | length_loss 7.40185
| epoch 005 | valid on 'valid' subset | loss 10.940 | nll_loss 10.061 | ppl 1068.39 | num_updates 1393 | best_loss 10.9401 | length_loss 6.81688
| epoch 006 | valid on 'valid' subset | loss 10.546 | nll_loss 9.588 | ppl 769.47 | num_updates 1671 | best_loss 10.5464 | length_loss 6.52779
| epoch 007 | valid on 'valid' subset | loss 10.192 | nll_loss 9.159 | ppl 571.58 | num_updates 1951 | best_loss 10.1923 | length_loss 6.09935
| epoch 008 | valid on 'valid' subset | loss 10.177 | nll_loss 9.126 | ppl 558.89 | num_updates 2000 | best_loss 10.177 | length_loss 6.11036
| epoch 008 | valid on 'valid' subset | loss 9.937 | nll_loss 8.814 | ppl 450.07 | num_updates 2231 | best_loss 9.93728 | length_loss 5.94912
| epoch 009 | valid on 'valid' subset | loss 9.726 | nll_loss 8.553 | ppl 375.56 | num_updates 2510 | best_loss 9.72613 | length_loss 5.94019
| epoch 010 | valid on 'valid' subset | loss 9.608 | nll_loss 8.381 | ppl 333.39 | num_updates 2790 | best_loss 9.60826 | length_loss 5.63923
| epoch 011 | valid on 'valid' subset | loss 9.316 | nll_loss 8.035 | ppl 262.27 | num_updates 3069 | best_loss 9.31551 | length_loss 5.4725
| epoch 012 | valid on 'valid' subset | loss 8.915 | nll_loss 7.551 | ppl 187.55 | num_updates 3349 | best_loss 8.91503 | length_loss 5.32048
| epoch 013 | valid on 'valid' subset | loss 8.565 | nll_loss 7.135 | ppl 140.53 | num_updates 3629 | best_loss 8.56478 | length_loss 5.32925
| epoch 014 | valid on 'valid' subset | loss 8.234 | nll_loss 6.740 | ppl 106.93 | num_updates 3909 | best_loss 8.23377 | length_loss 5.26823
| epoch 015 | valid on 'valid' subset | loss 8.124 | nll_loss 6.609 | ppl 97.63 | num_updates 4000 | best_loss 8.1238 | length_loss 5.40804
| epoch 015 | valid on 'valid' subset | loss 7.963 | nll_loss 6.422 | ppl 85.73 | num_updates 4187 | best_loss 7.96343 | length_loss 5.74557
| epoch 016 | valid on 'valid' subset | loss 7.689 | nll_loss 6.106 | ppl 68.87 | num_updates 4467 | best_loss 7.68925 | length_loss 5.02578
| epoch 017 | valid on 'valid' subset | loss 7.498 | nll_loss 5.862 | ppl 58.16 | num_updates 4747 | best_loss 7.49752 | length_loss 5.48479
| epoch 018 | valid on 'valid' subset | loss 7.295 | nll_loss 5.646 | ppl 50.07 | num_updates 5027 | best_loss 7.2946 | length_loss 5.25694
| epoch 019 | valid on 'valid' subset | loss 7.158 | nll_loss 5.503 | ppl 45.35 | num_updates 5307 | best_loss 7.1578 | length_loss 5.05596
| epoch 020 | valid on 'valid' subset | loss 7.089 | nll_loss 5.404 | ppl 42.33 | num_updates 5586 | best_loss 7.08919 | length_loss 5.22389
| epoch 021 | valid on 'valid' subset | loss 7.011 | nll_loss 5.330 | ppl 40.23 | num_updates 5866 | best_loss 7.01102 | length_loss 5.30964
| epoch 022 | valid on 'valid' subset | loss 6.965 | nll_loss 5.272 | ppl 38.65 | num_updates 6000 | best_loss 6.96542 | length_loss 5.11045
| epoch 022 | valid on 'valid' subset | loss 6.877 | nll_loss 5.184 | ppl 36.34 | num_updates 6144 | best_loss 6.87745 | length_loss 5.0194
| epoch 023 | valid on 'valid' subset | loss 6.811 | nll_loss 5.097 | ppl 34.22 | num_updates 6424 | best_loss 6.81102 | length_loss 5.7039
| epoch 024 | valid on 'valid' subset | loss 6.763 | nll_loss 5.062 | ppl 33.41 | num_updates 6704 | best_loss 6.76268 | length_loss 5.10046
| epoch 025 | valid on 'valid' subset | loss 6.677 | nll_loss 4.964 | ppl 31.22 | num_updates 6983 | best_loss 6.67652 | length_loss 5.17359
| epoch 026 | valid on 'valid' subset | loss 6.647 | nll_loss 4.923 | ppl 30.34 | num_updates 7263 | best_loss 6.64697 | length_loss 5.66601
| epoch 027 | valid on 'valid' subset | loss 6.578 | nll_loss 4.859 | ppl 29.01 | num_updates 7542 | best_loss 6.57783 | length_loss 5.3032
| epoch 028 | valid on 'valid' subset | loss 6.451 | nll_loss 4.706 | ppl 26.10 | num_updates 7821 | best_loss 6.45056 | length_loss 5.37859
| epoch 029 | valid on 'valid' subset | loss 6.489 | nll_loss 4.750 | ppl 26.92 | num_updates 8000 | best_loss 6.45056 | length_loss 5.34189
| epoch 029 | valid on 'valid' subset | loss 6.447 | nll_loss 4.718 | ppl 26.32 | num_updates 8101 | best_loss 6.44725 | length_loss 5.08036
| epoch 030 | valid on 'valid' subset | loss 6.417 | nll_loss 4.684 | ppl 25.71 | num_updates 8381 | best_loss 6.41685 | length_loss 4.99973
| epoch 031 | valid on 'valid' subset | loss 6.345 | nll_loss 4.602 | ppl 24.28 | num_updates 8659 | best_loss 6.34513 | length_loss 5.10561
| epoch 032 | valid on 'valid' subset | loss 6.365 | nll_loss 4.612 | ppl 24.45 | num_updates 8939 | best_loss 6.34513 | length_loss 5.76396
| epoch 033 | valid on 'valid' subset | loss 6.295 | nll_loss 4.563 | ppl 23.63 | num_updates 9219 | best_loss 6.29542 | length_loss 4.95261
| epoch 034 | valid on 'valid' subset | loss 6.218 | nll_loss 4.461 | ppl 22.03 | num_updates 9499 | best_loss 6.21752 | length_loss 5.3258
| epoch 035 | valid on 'valid' subset | loss 6.199 | nll_loss 4.415 | ppl 21.33 | num_updates 9779 | best_loss 6.19946 | length_loss 6.13781
| epoch 036 | valid on 'valid' subset | loss 6.169 | nll_loss 4.426 | ppl 21.49 | num_updates 10000 | best_loss 6.16876 | length_loss 5.08362
| epoch 036 | valid on 'valid' subset | loss 6.192 | nll_loss 4.444 | ppl 21.77 | num_updates 10059 | best_loss 6.16876 | length_loss 4.96137
| epoch 037 | valid on 'valid' subset | loss 6.142 | nll_loss 4.395 | ppl 21.05 | num_updates 10338 | best_loss 6.14181 | length_loss 5.11
| epoch 038 | valid on 'valid' subset | loss 6.100 | nll_loss 4.360 | ppl 20.53 | num_updates 10617 | best_loss 6.10049 | length_loss 5.06722
| epoch 039 | valid on 'valid' subset | loss 6.117 | nll_loss 4.377 | ppl 20.77 | num_updates 10897 | best_loss 6.10049 | length_loss 5.19192
| epoch 040 | valid on 'valid' subset | loss 6.057 | nll_loss 4.315 | ppl 19.90 | num_updates 11177 | best_loss 6.05656 | length_loss 5.11511
| epoch 041 | valid on 'valid' subset | loss 6.034 | nll_loss 4.293 | ppl 19.61 | num_updates 11456 | best_loss 6.03448 | length_loss 4.86498
| epoch 042 | valid on 'valid' subset | loss 5.988 | nll_loss 4.254 | ppl 19.08 | num_updates 11736 | best_loss 5.98752 | length_loss 4.73533
| epoch 043 | valid on 'valid' subset | loss 5.989 | nll_loss 4.234 | ppl 18.82 | num_updates 12000 | best_loss 5.98752 | length_loss 5.12521
| epoch 043 | valid on 'valid' subset | loss 6.040 | nll_loss 4.292 | ppl 19.59 | num_updates 12016 | best_loss 5.98752 | length_loss 4.97918
| epoch 044 | valid on 'valid' subset | loss 6.018 | nll_loss 4.270 | ppl 19.30 | num_updates 12295 | best_loss 5.98752 | length_loss 4.90497
| epoch 045 | valid on 'valid' subset | loss 5.954 | nll_loss 4.196 | ppl 18.33 | num_updates 12575 | best_loss 5.95385 | length_loss 5.30288
| epoch 046 | valid on 'valid' subset | loss 5.990 | nll_loss 4.250 | ppl 19.03 | num_updates 12854 | best_loss 5.95385 | length_loss 4.72608
| epoch 047 | valid on 'valid' subset | loss 12.413 | nll_loss 11.636 | ppl 3183.69 | num_updates 13133 | best_loss 5.95385 | length_loss 11.388
| epoch 048 | valid on 'valid' subset | loss 11.669 | nll_loss 10.782 | ppl 1761.11 | num_updates 13413 | best_loss 5.95385 | length_loss 9.24891
| epoch 049 | valid on 'valid' subset | loss 9.240 | nll_loss 7.688 | ppl 206.19 | num_updates 13693 | best_loss 5.95385 | length_loss 11.8834
| epoch 050 | valid on 'valid' subset | loss 8.942 | nll_loss 7.535 | ppl 185.45 | num_updates 13973 | best_loss 5.95385 | length_loss 8.7949
| epoch 051 | valid on 'valid' subset | loss 8.918 | nll_loss 7.521 | ppl 183.67 | num_updates 14000 | best_loss 5.95385 | length_loss 8.71842
| epoch 051 | valid on 'valid' subset | loss 8.832 | nll_loss 7.443 | ppl 174.05 | num_updates 14253 | best_loss 5.95385 | length_loss 8.87974
| epoch 052 | valid on 'valid' subset | loss 8.751 | nll_loss 7.362 | ppl 164.50 | num_updates 14533 | best_loss 5.95385 | length_loss 8.94963
| epoch 053 | valid on 'valid' subset | loss 11.789 | nll_loss 11.018 | ppl 2074.01 | num_updates 14813 | best_loss 5.95385 | length_loss 9.54743
| epoch 054 | valid on 'valid' subset | loss 11.798 | nll_loss 11.021 | ppl 2078.72 | num_updates 15093 | best_loss 5.95385 | length_loss 9.2394
| epoch 055 | valid on 'valid' subset | loss 11.751 | nll_loss 10.963 | ppl 1995.98 | num_updates 15373 | best_loss 5.95385 | length_loss 9.27329
| epoch 056 | valid on 'valid' subset | loss 11.731 | nll_loss 10.935 | ppl 1957.32 | num_updates 15653 | best_loss 5.95385 | length_loss 9.23337
| epoch 057 | valid on 'valid' subset | loss 11.747 | nll_loss 10.945 | ppl 1971.58 | num_updates 15931 | best_loss 5.95385 | length_loss 9.15427
| epoch 058 | valid on 'valid' subset | loss 11.756 | nll_loss 10.948 | ppl 1975.56 | num_updates 16000 | best_loss 5.95385 | length_loss 9.1642
| epoch 058 | valid on 'valid' subset | loss 11.666 | nll_loss 10.847 | ppl 1841.53 | num_updates 16211 | best_loss 5.95385 | length_loss 9.10518
| epoch 059 | valid on 'valid' subset | loss 11.636 | nll_loss 10.811 | ppl 1796.58 | num_updates 16490 | best_loss 5.95385 | length_loss 9.11344
| epoch 060 | valid on 'valid' subset | loss 11.525 | nll_loss 10.682 | ppl 1643.20 | num_updates 16770 | best_loss 5.95385 | length_loss 9.14741
| epoch 061 | valid on 'valid' subset | loss 11.321 | nll_loss 10.452 | ppl 1400.71 | num_updates 17049 | best_loss 5.95385 | length_loss 9.16941
| epoch 062 | valid on 'valid' subset | loss 10.823 | nll_loss 9.900 | ppl 955.17 | num_updates 17328 | best_loss 5.95385 | length_loss 8.71178
| epoch 063 | valid on 'valid' subset | loss 10.626 | nll_loss 9.679 | ppl 819.49 | num_updates 17608 | best_loss 5.95385 | length_loss 8.95637
| epoch 064 | valid on 'valid' subset | loss 10.565 | nll_loss 9.603 | ppl 777.90 | num_updates 17888 | best_loss 5.95385 | length_loss 9.092
| epoch 065 | valid on 'valid' subset | loss 10.546 | nll_loss 9.586 | ppl 768.39 | num_updates 18000 | best_loss 5.95385 | length_loss 8.9947
| epoch 065 | valid on 'valid' subset | loss 10.622 | nll_loss 9.670 | ppl 814.79 | num_updates 18168 | best_loss 5.95385 | length_loss 9.0684
| epoch 066 | valid on 'valid' subset | loss 10.550 | nll_loss 9.624 | ppl 788.86 | num_updates 18447 | best_loss 5.95385 | length_loss 9.13353
| epoch 067 | valid on 'valid' subset | loss 10.585 | nll_loss 9.679 | ppl 819.54 | num_updates 18726 | best_loss 5.95385 | length_loss 8.87808
| epoch 068 | valid on 'valid' subset | loss 10.725 | nll_loss 9.841 | ppl 916.99 | num_updates 19006 | best_loss 5.95385 | length_loss 9.05298
| epoch 069 | valid on 'valid' subset | loss 10.739 | nll_loss 9.871 | ppl 936.56 | num_updates 19286 | best_loss 5.95385 | length_loss 8.99715
| epoch 070 | valid on 'valid' subset | loss 10.854 | nll_loss 10.007 | ppl 1029.08 | num_updates 19566 | best_loss 5.95385 | length_loss 9.01016
| epoch 071 | valid on 'valid' subset | loss 10.864 | nll_loss 10.032 | ppl 1046.96 | num_updates 19845 | best_loss 5.95385 | length_loss 8.86554
| epoch 072 | valid on 'valid' subset | loss 8.418 | nll_loss 7.003 | ppl 128.30 | num_updates 20000 | best_loss 5.95385 | length_loss 8.79381
| epoch 072 | valid on 'valid' subset | loss 8.412 | nll_loss 6.999 | ppl 127.91 | num_updates 20123 | best_loss 5.95385 | length_loss 8.84561
| epoch 073 | valid on 'valid' subset | loss 10.533 | nll_loss 9.672 | ppl 815.50 | num_updates 20403 | best_loss 5.95385 | length_loss 9.18969
| epoch 074 | valid on 'valid' subset | loss 10.665 | nll_loss 9.823 | ppl 905.86 | num_updates 20683 | best_loss 5.95385 | length_loss 9.22083
| epoch 075 | valid on 'valid' subset | loss 10.880 | nll_loss 10.068 | ppl 1073.60 | num_updates 20963 | best_loss 5.95385 | length_loss 9.157
| epoch 076 | valid on 'valid' subset | loss 11.071 | nll_loss 10.286 | ppl 1248.62 | num_updates 21243 | best_loss 5.95385 | length_loss 8.96544
| epoch 077 | valid on 'valid' subset | loss 11.182 | nll_loss 10.408 | ppl 1359.12 | num_updates 21523 | best_loss 5.95385 | length_loss 9.07412
| epoch 078 | valid on 'valid' subset | loss 11.236 | nll_loss 10.477 | ppl 1425.37 | num_updates 21803 | best_loss 5.95385 | length_loss 8.85913
| epoch 079 | valid on 'valid' subset | loss 10.301 | nll_loss 9.398 | ppl 674.81 | num_updates 22000 | best_loss 5.95385 | length_loss 8.96468
| epoch 079 | valid on 'valid' subset | loss 9.806 | nll_loss 8.855 | ppl 462.96 | num_updates 22082 | best_loss 5.95385 | length_loss 8.93285
| epoch 080 | valid on 'valid' subset | loss 10.276 | nll_loss 9.401 | ppl 676.18 | num_updates 22362 | best_loss 5.95385 | length_loss 9.12087
| epoch 081 | valid on 'valid' subset | loss 10.474 | nll_loss 9.630 | ppl 792.14 | num_updates 22641 | best_loss 5.95385 | length_loss 9.18552
| epoch 082 | valid on 'valid' subset | loss 10.584 | nll_loss 9.755 | ppl 863.88 | num_updates 22921 | best_loss 5.95385 | length_loss 8.87157
| epoch 083 | valid on 'valid' subset | loss 10.820 | nll_loss 10.015 | ppl 1034.79 | num_updates 23200 | best_loss 5.95385 | length_loss 8.92633
| epoch 084 | valid on 'valid' subset | loss 11.004 | nll_loss 10.226 | ppl 1198.01 | num_updates 23480 | best_loss 5.95385 | length_loss 9.01677
| epoch 085 | valid on 'valid' subset | loss 11.027 | nll_loss 10.249 | ppl 1217.05 | num_updates 23758 | best_loss 5.95385 | length_loss 8.90608
| epoch 086 | valid on 'valid' subset | loss 11.099 | nll_loss 10.324 | ppl 1282.21 | num_updates 24000 | best_loss 5.95385 | length_loss 9.01892
| epoch 086 | valid on 'valid' subset | loss 11.159 | nll_loss 10.400 | ppl 1351.02 | num_updates 24037 | best_loss 5.95385 | length_loss 8.86138
| epoch 087 | valid on 'valid' subset | loss 10.857 | nll_loss 10.056 | ppl 1064.75 | num_updates 24317 | best_loss 5.95385 | length_loss 8.76016
| epoch 088 | valid on 'valid' subset | loss 11.008 | nll_loss 10.226 | ppl 1197.49 | num_updates 24597 | best_loss 5.95385 | length_loss 8.96673
| epoch 089 | valid on 'valid' subset | loss 11.241 | nll_loss 10.479 | ppl 1426.87 | num_updates 24877 | best_loss 5.95385 | length_loss 8.93496
| epoch 090 | valid on 'valid' subset | loss 11.131 | nll_loss 10.350 | ppl 1305.47 | num_updates 25157 | best_loss 5.95385 | length_loss 9.13093
| epoch 091 | valid on 'valid' subset | loss 11.027 | nll_loss 10.242 | ppl 1211.21 | num_updates 25437 | best_loss 5.95385 | length_loss 8.79867
| epoch 092 | valid on 'valid' subset | loss 11.170 | nll_loss 10.394 | ppl 1346.02 | num_updates 25716 | best_loss 5.95385 | length_loss 8.89253
| epoch 093 | valid on 'valid' subset | loss 11.423 | nll_loss 10.670 | ppl 1629.57 | num_updates 25996 | best_loss 5.95385 | length_loss 8.90915
| epoch 094 | valid on 'valid' subset | loss 11.373 | nll_loss 10.618 | ppl 1571.63 | num_updates 26000 | best_loss 5.95385 | length_loss 8.93119
| epoch 094 | valid on 'valid' subset | loss 11.429 | nll_loss 10.679 | ppl 1639.33 | num_updates 26274 | best_loss 5.95385 | length_loss 8.89474
| epoch 095 | valid on 'valid' subset | loss 11.142 | nll_loss 10.341 | ppl 1296.98 | num_updates 26554 | best_loss 5.95385 | length_loss 8.87344
| epoch 096 | valid on 'valid' subset | loss 11.532 | nll_loss 10.778 | ppl 1756.30 | num_updates 26834 | best_loss 5.95385 | length_loss 8.97212
| epoch 097 | valid on 'valid' subset | loss 11.188 | nll_loss 10.393 | ppl 1344.82 | num_updates 27114 | best_loss 5.95385 | length_loss 8.93483
| epoch 098 | valid on 'valid' subset | loss 10.440 | nll_loss 9.561 | ppl 755.15 | num_updates 27393 | best_loss 5.95385 | length_loss 9.11795
| epoch 099 | valid on 'valid' subset | loss 8.302 | nll_loss 6.880 | ppl 117.75 | num_updates 27672 | best_loss 5.95385 | length_loss 8.83351
| epoch 100 | valid on 'valid' subset | loss 9.792 | nll_loss 8.838 | ppl 457.54 | num_updates 27952 | best_loss 5.95385 | length_loss 8.92019
| epoch 101 | valid on 'valid' subset | loss 9.602 | nll_loss 8.621 | ppl 393.71 | num_updates 28000 | best_loss 5.95385 | length_loss 8.90594
| epoch 101 | valid on 'valid' subset | loss 8.294 | nll_loss 6.871 | ppl 117.07 | num_updates 28229 | best_loss 5.95385 | length_loss 8.7386
| epoch 102 | valid on 'valid' subset | loss 8.288 | nll_loss 6.862 | ppl 116.29 | num_updates 28509 | best_loss 5.95385 | length_loss 8.86053
| epoch 103 | valid on 'valid' subset | loss 8.282 | nll_loss 6.858 | ppl 116.03 | num_updates 28789 | best_loss 5.95385 | length_loss 8.84397
| epoch 104 | valid on 'valid' subset | loss 8.285 | nll_loss 6.862 | ppl 116.30 | num_updates 29069 | best_loss 5.95385 | length_loss 8.85538
| epoch 105 | valid on 'valid' subset | loss 8.285 | nll_loss 6.858 | ppl 116.02 | num_updates 29349 | best_loss 5.95385 | length_loss 8.87207
| epoch 106 | valid on 'valid' subset | loss 8.287 | nll_loss 6.858 | ppl 115.97 | num_updates 29629 | best_loss 5.95385 | length_loss 8.8873
| epoch 107 | valid on 'valid' subset | loss 8.287 | nll_loss 6.858 | ppl 116.03 | num_updates 29909 | best_loss 5.95385 | length_loss 8.86653
| epoch 108 | valid on 'valid' subset | loss 8.278 | nll_loss 6.853 | ppl 115.63 | num_updates 30000 | best_loss 5.95385 | length_loss 8.87427
| epoch 108 | valid on 'valid' subset | loss 8.263 | nll_loss 6.846 | ppl 115.03 | num_updates 30189 | best_loss 5.95385 | length_loss 8.84151
| epoch 109 | valid on 'valid' subset | loss 8.265 | nll_loss 6.843 | ppl 114.83 | num_updates 30468 | best_loss 5.95385 | length_loss 8.85339
| epoch 110 | valid on 'valid' subset | loss 8.272 | nll_loss 6.845 | ppl 114.99 | num_updates 30747 | best_loss 5.95385 | length_loss 8.89397
| epoch 111 | valid on 'valid' subset | loss 8.263 | nll_loss 6.839 | ppl 114.51 | num_updates 31026 | best_loss 5.95385 | length_loss 8.84809
| epoch 112 | valid on 'valid' subset | loss 8.263 | nll_loss 6.838 | ppl 114.44 | num_updates 31306 | best_loss 5.95385 | length_loss 8.83498
| epoch 113 | valid on 'valid' subset | loss 8.268 | nll_loss 6.843 | ppl 114.76 | num_updates 31586 | best_loss 5.95385 | length_loss 8.85753
| epoch 114 | valid on 'valid' subset | loss 8.258 | nll_loss 6.833 | ppl 114.01 | num_updates 31866 | best_loss 5.95385 | length_loss 8.81941
| epoch 115 | valid on 'valid' subset | loss 8.274 | nll_loss 6.843 | ppl 114.83 | num_updates 32000 | best_loss 5.95385 | length_loss 8.78144
| epoch 115 | valid on 'valid' subset | loss 8.266 | nll_loss 6.839 | ppl 114.49 | num_updates 32146 | best_loss 5.95385 | length_loss 8.88354
| epoch 116 | valid on 'valid' subset | loss 8.265 | nll_loss 6.838 | ppl 114.38 | num_updates 32426 | best_loss 5.95385 | length_loss 8.84121
| epoch 117 | valid on 'valid' subset | loss 8.259 | nll_loss 6.831 | ppl 113.89 | num_updates 32705 | best_loss 5.95385 | length_loss 8.84998
| epoch 118 | valid on 'valid' subset | loss 8.255 | nll_loss 6.832 | ppl 113.95 | num_updates 32985 | best_loss 5.95385 | length_loss 8.82119
| epoch 119 | valid on 'valid' subset | loss 8.267 | nll_loss 6.836 | ppl 114.27 | num_updates 33264 | best_loss 5.95385 | length_loss 8.83913
| epoch 120 | valid on 'valid' subset | loss 8.254 | nll_loss 6.826 | ppl 113.43 | num_updates 33544 | best_loss 5.95385 | length_loss 8.86863
| epoch 121 | valid on 'valid' subset | loss 8.256 | nll_loss 6.826 | ppl 113.48 | num_updates 33823 | best_loss 5.95385 | length_loss 8.84371
| epoch 122 | valid on 'valid' subset | loss 8.248 | nll_loss 6.817 | ppl 112.79 | num_updates 34000 | best_loss 5.95385 | length_loss 8.85724
| epoch 122 | valid on 'valid' subset | loss 8.261 | nll_loss 6.831 | ppl 113.83 | num_updates 34103 | best_loss 5.95385 | length_loss 8.83933
| epoch 123 | valid on 'valid' subset | loss 8.251 | nll_loss 6.822 | ppl 113.17 | num_updates 34382 | best_loss 5.95385 | length_loss 8.82543
| epoch 124 | valid on 'valid' subset | loss 8.251 | nll_loss 6.823 | ppl 113.23 | num_updates 34661 | best_loss 5.95385 | length_loss 8.83107
| epoch 125 | valid on 'valid' subset | loss 8.239 | nll_loss 6.814 | ppl 112.52 | num_updates 34940 | best_loss 5.95385 | length_loss 8.84941
| epoch 126 | valid on 'valid' subset | loss 8.243 | nll_loss 6.812 | ppl 112.37 | num_updates 35219 | best_loss 5.95385 | length_loss 8.83323
| epoch 127 | valid on 'valid' subset | loss 8.240 | nll_loss 6.815 | ppl 112.57 | num_updates 35499 | best_loss 5.95385 | length_loss 8.81804
| epoch 128 | valid on 'valid' subset | loss 8.240 | nll_loss 6.811 | ppl 112.25 | num_updates 35779 | best_loss 5.95385 | length_loss 8.83056
| epoch 129 | valid on 'valid' subset | loss 8.248 | nll_loss 6.816 | ppl 112.70 | num_updates 36000 | best_loss 5.95385 | length_loss 8.8526
| epoch 129 | valid on 'valid' subset | loss 8.244 | nll_loss 6.816 | ppl 112.69 | num_updates 36059 | best_loss 5.95385 | length_loss 8.81925
| epoch 130 | valid on 'valid' subset | loss 8.232 | nll_loss 6.807 | ppl 111.98 | num_updates 36339 | best_loss 5.95385 | length_loss 8.83944
| epoch 131 | valid on 'valid' subset | loss 8.239 | nll_loss 6.809 | ppl 112.13 | num_updates 36619 | best_loss 5.95385 | length_loss 8.87711
| epoch 132 | valid on 'valid' subset | loss 8.241 | nll_loss 6.810 | ppl 112.24 | num_updates 36899 | best_loss 5.95385 | length_loss 8.87148
| epoch 133 | valid on 'valid' subset | loss 8.238 | nll_loss 6.808 | ppl 112.09 | num_updates 37177 | best_loss 5.95385 | length_loss 8.87225
| epoch 134 | valid on 'valid' subset | loss 8.240 | nll_loss 6.810 | ppl 112.22 | num_updates 37457 | best_loss 5.95385 | length_loss 8.84351
| epoch 135 | valid on 'valid' subset | loss 8.244 | nll_loss 6.813 | ppl 112.42 | num_updates 37737 | best_loss 5.95385 | length_loss 8.84976
| epoch 136 | valid on 'valid' subset | loss 8.236 | nll_loss 6.806 | ppl 111.92 | num_updates 38000 | best_loss 5.95385 | length_loss 8.85678
| epoch 136 | valid on 'valid' subset | loss 8.235 | nll_loss 6.804 | ppl 111.73 | num_updates 38016 | best_loss 5.95385 | length_loss 8.86879
| epoch 137 | valid on 'valid' subset | loss 8.237 | nll_loss 6.807 | ppl 111.95 | num_updates 38296 | best_loss 5.95385 | length_loss 8.85905
| epoch 138 | valid on 'valid' subset | loss 8.234 | nll_loss 6.808 | ppl 112.03 | num_updates 38575 | best_loss 5.95385 | length_loss 8.88407
| epoch 139 | valid on 'valid' subset | loss 8.231 | nll_loss 6.800 | ppl 111.43 | num_updates 38855 | best_loss 5.95385 | length_loss 8.86341
| epoch 140 | valid on 'valid' subset | loss 8.220 | nll_loss 6.794 | ppl 111.01 | num_updates 39135 | best_loss 5.95385 | length_loss 8.84081
| epoch 141 | valid on 'valid' subset | loss 8.236 | nll_loss 6.804 | ppl 111.73 | num_updates 39414 | best_loss 5.95385 | length_loss 8.82266
| epoch 142 | valid on 'valid' subset | loss 8.222 | nll_loss 6.795 | ppl 111.04 | num_updates 39694 | best_loss 5.95385 | length_loss 8.84459
| epoch 143 | valid on 'valid' subset | loss 8.226 | nll_loss 6.797 | ppl 111.16 | num_updates 39974 | best_loss 5.95385 | length_loss 8.88616
| epoch 144 | valid on 'valid' subset | loss 8.218 | nll_loss 6.792 | ppl 110.79 | num_updates 40000 | best_loss 5.95385 | length_loss 8.88367
| epoch 144 | valid on 'valid' subset | loss 8.230 | nll_loss 6.800 | ppl 111.44 | num_updates 40253 | best_loss 5.95385 | length_loss 8.8395
| epoch 145 | valid on 'valid' subset | loss 8.247 | nll_loss 6.809 | ppl 112.09 | num_updates 40533 | best_loss 5.95385 | length_loss 8.86908
| epoch 146 | valid on 'valid' subset | loss 8.219 | nll_loss 6.790 | ppl 110.67 | num_updates 40813 | best_loss 5.95385 | length_loss 8.86673
| epoch 147 | valid on 'valid' subset | loss 8.226 | nll_loss 6.790 | ppl 110.68 | num_updates 41093 | best_loss 5.95385 | length_loss 8.83123
| epoch 148 | valid on 'valid' subset | loss 8.222 | nll_loss 6.794 | ppl 110.98 | num_updates 41372 | best_loss 5.95385 | length_loss 8.85367
| epoch 149 | valid on 'valid' subset | loss 8.221 | nll_loss 6.789 | ppl 110.55 | num_updates 41652 | best_loss 5.95385 | length_loss 8.84598
| epoch 150 | valid on 'valid' subset | loss 8.226 | nll_loss 6.794 | ppl 110.99 | num_updates 41932 | best_loss 5.95385 | length_loss 8.835
| epoch 151 | valid on 'valid' subset | loss 8.218 | nll_loss 6.784 | ppl 110.21 | num_updates 42000 | best_loss 5.95385 | length_loss 8.82967
| epoch 151 | valid on 'valid' subset | loss 8.215 | nll_loss 6.783 | ppl 110.12 | num_updates 42211 | best_loss 5.95385 | length_loss 8.83456
| epoch 152 | valid on 'valid' subset | loss 9.292 | nll_loss 8.233 | ppl 300.93 | num_updates 42491 | best_loss 5.95385 | length_loss 8.91676
| epoch 153 | valid on 'valid' subset | loss 9.855 | nll_loss 8.861 | ppl 465.13 | num_updates 42770 | best_loss 5.95385 | length_loss 8.94109
| epoch 154 | valid on 'valid' subset | loss 10.298 | nll_loss 9.339 | ppl 647.77 | num_updates 43049 | best_loss 5.95385 | length_loss 8.89363
| epoch 155 | valid on 'valid' subset | loss 10.358 | nll_loss 9.405 | ppl 677.85 | num_updates 43329 | best_loss 5.95385 | length_loss 8.80955
| epoch 156 | valid on 'valid' subset | loss 10.896 | nll_loss 9.991 | ppl 1017.96 | num_updates 43609 | best_loss 5.95385 | length_loss 9.08136
| epoch 157 | valid on 'valid' subset | loss 10.711 | nll_loss 9.779 | ppl 878.42 | num_updates 43889 | best_loss 5.95385 | length_loss 8.95622
| epoch 158 | valid on 'valid' subset | loss 10.923 | nll_loss 9.994 | ppl 1019.84 | num_updates 44000 | best_loss 5.95385 | length_loss 8.91165
| epoch 158 | valid on 'valid' subset | loss 10.706 | nll_loss 9.774 | ppl 875.50 | num_updates 44169 | best_loss 5.95385 | length_loss 8.89074
| epoch 159 | valid on 'valid' subset | loss 10.983 | nll_loss 10.055 | ppl 1063.43 | num_updates 44448 | best_loss 5.95385 | length_loss 8.99308
| epoch 160 | valid on 'valid' subset | loss 11.061 | nll_loss 10.145 | ppl 1132.42 | num_updates 44728 | best_loss 5.95385 | length_loss 8.83523
| epoch 161 | valid on 'valid' subset | loss 11.176 | nll_loss 10.264 | ppl 1229.85 | num_updates 45007 | best_loss 5.95385 | length_loss 8.98247
| epoch 162 | valid on 'valid' subset | loss 11.354 | nll_loss 10.453 | ppl 1401.46 | num_updates 45287 | best_loss 5.95385 | length_loss 9.01076
| epoch 163 | valid on 'valid' subset | loss 11.283 | nll_loss 10.376 | ppl 1328.46 | num_updates 45567 | best_loss 5.95385 | length_loss 9.13525
| epoch 164 | valid on 'valid' subset | loss 11.396 | nll_loss 10.496 | ppl 1444.11 | num_updates 45847 | best_loss 5.95385 | length_loss 8.98389
| epoch 165 | valid on 'valid' subset | loss 11.258 | nll_loss 10.361 | ppl 1315.40 | num_updates 46000 | best_loss 5.95385 | length_loss 8.96401
| epoch 165 | valid on 'valid' subset | loss 11.380 | nll_loss 10.480 | ppl 1428.03 | num_updates 46125 | best_loss 5.95385 | length_loss 9.03602
| epoch 166 | valid on 'valid' subset | loss 11.378 | nll_loss 10.473 | ppl 1421.48 | num_updates 46405 | best_loss 5.95385 | length_loss 8.99594
| epoch 167 | valid on 'valid' subset | loss 11.364 | nll_loss 10.456 | ppl 1404.80 | num_updates 46684 | best_loss 5.95385 | length_loss 8.91066
| epoch 168 | valid on 'valid' subset | loss 11.408 | nll_loss 10.515 | ppl 1463.64 | num_updates 46964 | best_loss 5.95385 | length_loss 8.9694
| epoch 169 | valid on 'valid' subset | loss 11.339 | nll_loss 10.450 | ppl 1399.20 | num_updates 47243 | best_loss 5.95385 | length_loss 8.95181
| epoch 170 | valid on 'valid' subset | loss 11.419 | nll_loss 10.521 | ppl 1469.72 | num_updates 47523 | best_loss 5.95385 | length_loss 8.93635
| epoch 171 | valid on 'valid' subset | loss 11.667 | nll_loss 10.783 | ppl 1761.77 | num_updates 47803 | best_loss 5.95385 | length_loss 8.98119
| epoch 172 | valid on 'valid' subset | loss 11.565 | nll_loss 10.672 | ppl 1631.77 | num_updates 48000 | best_loss 5.95385 | length_loss 9.00285
| epoch 172 | valid on 'valid' subset | loss 11.728 | nll_loss 10.840 | ppl 1832.38 | num_updates 48083 | best_loss 5.95385 | length_loss 8.91423
| epoch 173 | valid on 'valid' subset | loss 11.612 | nll_loss 10.728 | ppl 1695.63 | num_updates 48363 | best_loss 5.95385 | length_loss 8.96803
| epoch 174 | valid on 'valid' subset | loss 11.691 | nll_loss 10.800 | ppl 1782.33 | num_updates 48641 | best_loss 5.95385 | length_loss 8.94515
| epoch 175 | valid on 'valid' subset | loss 11.562 | nll_loss 10.678 | ppl 1638.23 | num_updates 48921 | best_loss 5.95385 | length_loss 9.14294
| epoch 176 | valid on 'valid' subset | loss 11.579 | nll_loss 10.696 | ppl 1658.43 | num_updates 49200 | best_loss 5.95385 | length_loss 8.99614
| epoch 177 | valid on 'valid' subset | loss 8.192 | nll_loss 6.761 | ppl 108.43 | num_updates 49480 | best_loss 5.95385 | length_loss 8.85377
| epoch 178 | valid on 'valid' subset | loss 11.600 | nll_loss 10.707 | ppl 1671.80 | num_updates 49760 | best_loss 5.95385 | length_loss 8.9394
| epoch 179 | valid on 'valid' subset | loss 11.534 | nll_loss 10.642 | ppl 1597.87 | num_updates 50000 | best_loss 5.95385 | length_loss 8.74317
| epoch 179 | valid on 'valid' subset | loss 8.202 | nll_loss 6.770 | ppl 109.13 | num_updates 50039 | best_loss 5.95385 | length_loss 8.85002
| epoch 180 | valid on 'valid' subset | loss 11.481 | nll_loss 10.586 | ppl 1536.68 | num_updates 50319 | best_loss 5.95385 | length_loss 8.89714
| epoch 181 | valid on 'valid' subset | loss 8.193 | nll_loss 6.759 | ppl 108.33 | num_updates 50599 | best_loss 5.95385 | length_loss 8.85359
| epoch 182 | valid on 'valid' subset | loss 11.492 | nll_loss 10.597 | ppl 1548.90 | num_updates 50878 | best_loss 5.95385 | length_loss 8.95791
| epoch 183 | valid on 'valid' subset | loss 11.591 | nll_loss 10.708 | ppl 1672.66 | num_updates 51158 | best_loss 5.95385 | length_loss 9.00685
| epoch 184 | valid on 'valid' subset | loss 8.188 | nll_loss 6.755 | ppl 108.01 | num_updates 51437 | best_loss 5.95385 | length_loss 8.85051
| epoch 185 | valid on 'valid' subset | loss 8.201 | nll_loss 6.766 | ppl 108.81 | num_updates 51717 | best_loss 5.95385 | length_loss 8.85741
| epoch 186 | valid on 'valid' subset | loss 11.735 | nll_loss 10.846 | ppl 1841.29 | num_updates 51996 | best_loss 5.95385 | length_loss 8.93106
| epoch 187 | valid on 'valid' subset | loss 11.512 | nll_loss 10.628 | ppl 1582.03 | num_updates 52000 | best_loss 5.95385 | length_loss 8.8988
| epoch 187 | valid on 'valid' subset | loss 11.522 | nll_loss 10.627 | ppl 1581.80 | num_updates 52276 | best_loss 5.95385 | length_loss 8.9317

seek the wmt14 en_de data and the bpe codes

it seems that the wmt14_en_de data link in get_data.sh is disabled. The pretrained model also doesn't contain the bpe codes. So I can't get access to the bped data. It will be helpful if updating the data link or providing the bpe codes.

Would you like to share trained CMLM transformers ?

The CMLM models you implemented generally got better performance than the original models. I'm doing some tests and eager to do some comparisons between models. Would you like to share your trained CMLM transformers ? Thanks!

How to evaluate the wmt17 English-Chinese model?

Hi! Thanks for your nice code!
When I evaluated the wmt17 English-Chinese, I only get 19.14(DisCo + Mask-Predict, step4)

My script:
CUDA_VISIBLE_DEVICES=0 python generate_disco.py ${data_path} --path ${model_dir}/checkpoint_top5_average.pt \ --task translation_self --max-sentences 10 --remove-bpe --decoding-iterations 4 \ --decoding-strategy mask_predict --length-beam 5 --sacrebleu

use disco to training autoregressive model

Hi, Jungo. Thanks for your nice code!

I wanna use your disco model to train an autoregressive model as you said in your paper (sec 5.1 : AT with Contextless KVs). I saw there is one args called at-only in the disco_transformer.py, but I am confused how it used to train the AT model.

Additionally, the q_mask[:, :, eos] is always set to True. Does that means even for AT model we need to predict eos and the corresponded position first?

seek suggestions on alignment analysis

Hi, Jungo. I am trying to analyze the alignment when the refinement iteration goes by force decoding.

The basic idea is to replace the golden target tokens with MASK from the L2R fashion, like the BERT language model. However, this can not analyze the dynamics for different refine iterations. Can you recommend some effective approaches to analyze the alignment?

distilled validation data

Hi, thanks for providing the distillation data. I found the validate set in the data is raw data. Do you distill the validation set? Thanks!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.