Giter Club home page Giter Club logo

seed-encoder's Introduction

Less is More: Pre-train a Strong Text Encoder for Dense Retrieval Using a Weak Decoder

Shuqi Lu∗, Di He, Chenyan Xiong, Guolin Ke, Waleed Malik, Zhicheng Dou, Paul Bennett, Tie-Yan Liu, Arnold Overwijk

This repository provides the fine-tuning stage on Marco ranking task for SEED-Encoder and is based on ANCE (https://github.com/microsoft/ANCE).

Requirements and Installation

  • PyTorch version >=1.6
  • Python version >= 3.6
  • Please install Apex with CUDA and C++ extensions (apex github).

Fine-tuning for SEED-Encoder

Requirements

To install requirements, run the following commands:

git clone https://github.com/microsoft/SEED-Encoder
cd SEED-Encoder
pip install fairseq==0.10.2
pip install transformers==3.4

python setup.py install

Environment

You can refer to the dockerfile in docker/pytorch-1.6-itp/Dockerfile.

Data Download

To download all the needed data, run:

bash commands/data_download.sh 

Our Checkpoints

Pretrained SEED-Encoder with 3-layer decoder, attention span = 2

Pretrained SEED-Encoder with 1-layer decoder, attention span = 8

SEED-Encoder warmup checkpoint

ANCE finetuned SEED-Encoder checkpoint on passage ranking task

ANCE finetuned SEED-Encoder checkpoint on document ranking task

bpe file used in our tokenizer

DPR finetuned SEED-Encoder checkpoint on NQ task

ANCE finetuned SEED-Encoder checkpoint on NQ task

SEED-Encoder finetuned checkpoint on MIND

Data Preprocessing

The command to preprocess passage and document data is listed below:

python data/msmarco_data.py \
--data_dir $raw_data_dir \
--out_data_dir $preprocessed_data_dir \ 
--train_model_type {use rdot_nll_fairseq_fast for SEED-Encoder ANCE FirstP} \ 
--max_seq_length {use 512 for ANCE FirstP, 2048 for ANCE MaxP} \ 
--data_type {use 1 for passage, 0 for document}
--bpe_vocab_file $bpe_vocab_file

The data preprocessing command is included as the first step in the training command file commands/run_train.sh

Warmup for Training

    model_file=SEED-Encoder-3-decoder-2-attn.pt
    vocab=vocab.txt

    python3 -m torch.distributed.launch --nproc_per_node=8 ../drivers/run_warmup.py \
    --train_model_type rdot_nll_fairseq_fast --model_name_or_path $LOAD_DIR --model_file $model_file --task_name MSMarco --do_train \
    --evaluate_during_training --data_dir $DATA_DIR \
    --max_seq_length 128 --per_gpu_eval_batch_size=256  --per_gpu_train_batch_size=32 --learning_rate 2e-4 --logging_steps 100 --num_train_epochs 2.0 \
    --output_dir $SAVE_DIR --warmup_steps 1000 --overwrite_output_dir --save_steps 10000 --gradient_accumulation_steps 1 --expected_train_size 35000000 \
    --logging_steps_per_eval 100 --fp16 --optimizer lamb --log_dir $SAVE_DIR/log --bpe_vocab_file $vocab

ANCE Training (passage, you may first use the second command to generate the initial data)

    gpu_no=4
    seq_length=512
    tokenizer_type="roberta-base-fast"
    model_type=rdot_nll_fairseq_fast
    base_data_dir={}
    preprocessed_data_dir="${base_data_dir}ann_data_${tokenizer_type}_${seq_length}/"
    job_name=$exp_name
    pretrained_checkpoint_dir=SEED-Encoder-warmup-90000.pt
    data_type=1
    warmup_steps=5000
    per_gpu_train_batch_size=16
    gradient_accumulation_steps=1
    learning_rate=1e-6
    vocab=vocab.txt

    blob_model_dir="${base_data_dir}${job_name}/"
    blob_model_ann_data_dir="${blob_model_dir}ann_data/"

    model_dir="./${job_name}/"
    model_ann_data_dir="${model_dir}ann_data/"

    
    CUDA_VISIBLE_DEVICES=4,5,6,7 python -m torch.distributed.launch --nproc_per_node=$gpu_no ../drivers/run_ann_data_gen.py --training_dir $model_dir \
    --init_model_dir $pretrained_checkpoint_dir --train_model_type $model_type --output_dir $model_ann_data_dir \
    --cache_dir {} --data_dir $preprocessed_data_dir --max_seq_length $seq_length \
    --per_gpu_eval_batch_size 64 --topk_training 200 --negative_sample 20 --bpe_vocab_file $vocab




    CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=$gpu_no --master_addr 127.0.0.2 --master_port 35000 ../drivers/run_ann.py --train_model_type $model_type \
    --model_name_or_path $pretrained_checkpoint_dir --task_name MSMarco --triplet --data_dir $preprocessed_data_dir \
    --ann_dir $model_ann_data_dir --max_seq_length $seq_length --per_gpu_train_batch_size=$per_gpu_train_batch_size \
    --gradient_accumulation_steps $gradient_accumulation_steps --learning_rate $learning_rate --output_dir $model_dir \
    --warmup_steps $warmup_steps --logging_steps 100 --save_steps 10000 --optimizer lamb --single_warmup --bpe_vocab_file $vocab \
    --blob_ann_dir $blob_model_ann_data_dir --blob_output_dir $blob_model_dir

ANCE Training (document)

    gpu_no=4
    seq_length=512
    tokenizer_type="roberta-base-fast-docdev2"
    model_type=rdot_nll_fairseq_fast
    base_data_dir={}
    preprocessed_data_dir="${base_data_dir}ann_data_${tokenizer_type}_${seq_length}/"
    job_name=$exp_name
    pretrained_checkpoint_dir=SEED-Encoder-warmup-90000.pt
    data_type=0
    warmup_steps=3000
    per_gpu_train_batch_size=4
    gradient_accumulation_steps=4
    learning_rate=5e-6
    vocab=vocab.txt

    blob_model_dir="${base_data_dir}${job_name}/"
    blob_model_ann_data_dir="${blob_model_dir}ann_data/"

    model_dir="./${job_name}/"
    model_ann_data_dir="${model_dir}ann_data/"

    CUDA_VISIBLE_DEVICES=4,5,6,7 python -m torch.distributed.launch --nproc_per_node=$gpu_no ../drivers/run_ann_data_gen.py --training_dir $model_dir \
    --init_model_dir $pretrained_checkpoint_dir --train_model_type $model_type --output_dir $model_ann_data_dir \
    --cache_dir {} --data_dir $preprocessed_data_dir --max_seq_length $seq_length \
    --per_gpu_eval_batch_size 16 --topk_training 200 --negative_sample 20 --bpe_vocab_file $vocab



    CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=$gpu_no --master_addr 127.0.0.2 --master_port 35000 ../drivers/run_ann.py --train_model_type $model_type \
    --model_name_or_path $pretrained_checkpoint_dir --task_name MSMarco --triplet --data_dir $preprocessed_data_dir \
    --ann_dir $model_ann_data_dir --max_seq_length $seq_length --per_gpu_train_batch_size=$per_gpu_train_batch_size \
    --gradient_accumulation_steps $gradient_accumulation_steps --learning_rate $learning_rate --output_dir $model_dir \
    --warmup_steps $warmup_steps --logging_steps 100 --save_steps 10000 --optimizer lamb --single_warmup --bpe_vocab_file $vocab \
    --blob_ann_dir $blob_model_ann_data_dir --blob_output_dir $blob_model_dir --cache_dir {}

To reproduce our results you can use our checkpoints to generate the embeddings and then evaluate the results:

    python -m torch.distributed.launch --nproc_per_node=$gpu_no ../drivers/run_ann_data_gen.py --training_dir $model_dir \
    --init_model_dir $pretrained_checkpoint_dir --train_model_type $model_type --output_dir $blob_model_ann_data_dir \
    --cache_dir {} --data_dir $preprocessed_data_dir --max_seq_length $seq_length \
    --per_gpu_eval_batch_size 64 --topk_training 200 --negative_sample 20 --end_output_num 0 --inference --bpe_vocab_file $vocab
    
    
    python ../evaluation/eval.py

NQ scripts

The running script is in commands/run_ann_data_gen_dpr.sh and commands/run_tran_dpr.sh

Results of SEED-Encoder

MSMARCO Dev Passage Retrieval MRR@10 Recall@1k
BM25 warmup checkpoint 0.329 0.953
ANCE Passage checkpoint 0.334 0.961
MSMARCO Document Retrieval MRR@10 (Dev) MRR@10 (Eval)
ANCE Document (FirstP) checkpoint 0.394 0.362
NQ Task Top-1 Top-5 Top-20 Top-100 MRR@20 P@20
DPR checkpoint 46.1 68.8 80.4 87.1 56.2 20.1
ANCE NQ checkpoint 52.5 73.1 83.1 88.7 61.5 22.5

Our huggingface Checkpoints

Pretrained SEED-Encoder with 3-layer decoder, attention span = 2

Pretrained SEED-Encoder with 1-layer decoder, attention span = 8

SEED-Encoder warmup checkpoint

ANCE finetuned SEED-Encoder checkpoint on passage ranking task

ANCE finetuned SEED-Encoder checkpoint on document ranking task

Load the huggingface checkpoints and run

DATA_DIR=../../data/raw_data
SAVE_DIR=../../temp/
LOAD_DIR=$your_dir/SEED-Encoder-warmup-90000/

python3 -m torch.distributed.launch --nproc_per_node=8 ../drivers/run_warmup.py \
--train_model_type seeddot_nll --model_name_or_path $LOAD_DIR --task_name MSMarco --do_train \
--evaluate_during_training --data_dir $DATA_DIR \
--max_seq_length 128 --per_gpu_eval_batch_size=512  --per_gpu_train_batch_size=2 --learning_rate 2e-4 --logging_steps 1 --num_train_epochs 2.0 \
--output_dir $SAVE_DIR --warmup_steps 1000 --overwrite_output_dir --save_steps 1 --gradient_accumulation_steps 1 --expected_train_size 35000000 \
--logging_steps_per_eval 1 --fp16 --optimizer lamb --log_dir $SAVE_DIR/log --do_lower_case --fp16

seed-encoder's People

Contributors

microsoftopensource avatar shuqilu avatar soonhwan-kwon avatar xiongchenyan avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

seed-encoder's Issues

How about pre-training code and scripts?

Hi Guys,

Thanks for your good paper and sharing code. I wanna know do you have a plan to share pre-training code and when? For it is useful for others to follow your work, and the pre-training method is the key contribution of SEEDencoder.

Thanks in advance.

preprocessing error KeyError: 140227

I encountered KeyError: 140227 and below is the log of my script.
Error :
Traceback (most recent call last):
File "data/msmarco_data.py", line 518, in
main()
File "data/msmarco_data.py", line 514, in main
preprocess(args)
File "data/msmarco_data.py", line 213, in preprocess
"train-qrel.tsv")
File "data/msmarco_data.py", line 122, in write_query_rel
rel +
KeyError: 140227

Script :
python data/msmarco_data.py --data_dir /raid/bae_sh/codes/ance/raw_data --out_data_dir ./preprocessed_data --train_model_type rdot_nll_fairseq_fast --max_seq_length 512 --data_type 0 --bpe_vocab_file checkpoints_msmarco/vocab.txt

Log :
start passage file split processing
start merging splits
0 1555982
1 3257085
2 898719
Total lines written: 1085721
First line
(512, array([ 0, 19100, 1036, 30, 19, 19, 9519, 18, 28131,
18, 1745, 19, 3814, 19, 8741, 35, 24792, 1045,
33, 3292, 17346, 31158, 11400, 1072, 12714, 10327, 1083,
1048, 1044, 1051, 1030, 2, 1653, 3324, 14373, 14489,
1666, 4976, 32324, 4056, 1663, 1653, 2011, 1666, 21866,
11856, 18, 35, 2, 3863, 10, 8937, 8767, 1653,
3324, 14373, 14489, 1666, 4976, 32324, 4056, 1663, 1653,
2011, 1666, 21866, 11856, 18, 35, 1706, 1705, 41,
2343, 27441, 1674, 10673, 1719, 1653, 21384, 2194, 6271,
45, 1705, 6927, 1674, 21, 1695, 2205, 14489, 18,
2632, 1653, 17393, 1666, 1653, 2483, 7972, 1680, 16,
1653, 5219, 3672, 2483, 1663, 1653, 21269, 21351, 1719,
11301, 2059, 4056, 1722, 41, 5269, 1666, 22, 18,
27, 64, 2476, 39, 5904, 63, 1670, 1891, 41,
4894, 7278, 1666, 2797, 16, 26305, 51, 18, 10673,
1719, 1653, 2483, 1705, 24444, 18, 2507, 242, 33,
18, 18, 18, 2280, 1991, 2167, 23, 9519, 9519,
25690, 8604, 20565, 6959, 2579, 4947, 30, 15675, 17,
14202, 1052, 6573, 2546, 2438, 1719, 1653, 4056, 22333,
1734, 11856, 1705, 18816, 1674, 1653, 7037, 2680, 1666,
1653, 7278, 30, 57, 33, 228, 87, 242, 87,
60, 39, 24, 1653, 3521, 4056, 22333, 1722, 41,
24444, 4894, 1666, 17393, 58, 1705, 57, 33, 57,
87, 239, 87, 3, 33, 228, 87, 242, 87,
60, 39, 24, 87, 239, 87, 3, 8716, 1653,
17393, 1705, 58, 33, 3, 12, 57, 19, 12,
228, 87, 242, 87, 60, 39, 24, 87, 239,
13, 13, 33, 3, 12, 22, 18, 27, 1049,
7927, 15, 5904, 63, 19, 12, 21, 87, 25,
18, 10211, 1049, 7927, 17, 28, 1083, 19, 3,
39, 24, 87, 12, 30821, 1047, 13, 39, 24,
87, 239, 13, 13, 33, 23, 18, 3623, 1049,
7927, 15, 2890, 53, 4973, 12, 59, 13, 30,
19100, 30, 19, 19, 1812, 18, 22904, 18, 13878,
19, 16424, 1032, 19, 15675, 40, 14202, 18, 18,
18, 14663, 28269, 87, 21, 8069, 4722, 1069, 1971,
5123, 14663, 28269, 16, 1736, 12905, 41, 24, 1663,
2102, 4947, 18, 2102, 5660, 2219, 3748, 1706, 30,
52, 33, 24, 16807, 12, 58, 39, 22, 13,
22250, 12, 60, 39, 24, 13, 18, 3220, 52,
16, 19030, 16091, 16, 1702, 1653, 4056, 1663, 1803,
4418, 16, 1736, 1878, 2632, 1653, 17393, 58, 1734,
3643, 19839, 22202, 12, 52, 19, 12, 24, 16807,
14422, 27240, 1029, 12, 60, 39, 24, 13, 13,
18, 4168, 1803, 9795, 4031, 18, 12929, 87, 24,
2154, 4722, 1086, 21, 5123, 12, 15675, 17, 14202,
1052, 6573, 2546, 13, 52, 33, 24, 16807, 14,
58, 39, 22, 14, 22250, 14, 60, 39, 24,
18140, 1695, 58, 1827, 2131, 30, 33, 34, 58,
33, 12, 21, 19, 12, 22, 1031, 39, 22,
13, 13, 14, 19839, 22202, 12, 52, 19, 12,
12650, 14, 22250, 13, 13, 16116, 3858, 1663, 2102,
7563, 1736, 2405, 2131, 30, 33, 34, 58, 33,
12, 21, 19, 12, 22, 12, 2797, 16, 26305,
1047, 13, 39, 22, 13, 13, 14, 19839, 22202,
12, 12, 22, 18, 27, 14, 2476, 39, 5904,
1083, 13, 19, 12, 12650, 14, 12, 25],
dtype=int32))
done saving pid2offset
Writing query files train-query and train-qrel.tsv
Loading query_2_pos_docid
start query file split processing
start merging splits
1 1185869
2 295446
done saving qid2offset
Total lines written: 367013
First line
(16, array([ 0, 13, 1959, 1686, 1653, 8786, 5474, 1666, 1653,
2834, 1666, 1653, 10881, 3012, 35, 2, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1,
1], dtype=int32))
Writing qrels
Traceback (most recent call last):
File "data/msmarco_data.py", line 518, in
main()
File "data/msmarco_data.py", line 514, in main
preprocess(args)
File "data/msmarco_data.py", line 213, in preprocess
"train-qrel.tsv")
File "data/msmarco_data.py", line 122, in write_query_rel
rel +
KeyError: 140227

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.