Giter Club home page Giter Club logo

tcl's Introduction

Vision-Language Pre-Training with Triple Contrastive Learning, CVPR 2022

News

(03/16/2022) upload retrieval checkpoints finetuned on COCO and Flickr


This is the official PyTorch implementation of TCL

image

Requirements:

conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=11.0 -c pytorch
pip install transformers==4.8.1
pip install timm==0.4.9
conda install ruamel_yaml
pip install opencv-python
pip install --upgrade Pillow
pip install einops

Pre-training Datasets:

Downstream-task Datasets:

Json Files from Pre-training and Downstream Tasks:

  • refer to Download in ALBEF
  • you need to change the image path in json files according to your downloaded images

Pre-trained checkpoint:

Pre-training:

python -m torch.distributed.launch --nproc_per_node=8 \
--use_env Pretrain.py \
--config ./configs/Pretrain.yaml \
--output_dir output/pretrain

Downstream Tasks:

Image-Text Retrieval

# zero-shot coco 
python -m torch.distributed.launch --nproc_per_node=8 \
--use_env Retrieval.py \
--config ./configs/Retrieval_coco.yaml \
--output_dir output/pretrain_e30_Retrieval_coco_zeroshot \
--checkpoint output/pretrain/checkpoint_29.pth \
--evaluate

# fine-tune flickr
python -m torch.distributed.launch --nproc_per_node=8 \
--use_env Retrieval.py \
--config ./configs/Retrieval_flickr.yaml \
--output_dir output/pretrain_e30_Retrieval_flickr \
--checkpoint output/pretrain/checkpoint_29.pth

# fine-tune coco
python -m torch.distributed.launch --nproc_per_node=8 \
--use_env Retrieval.py \
--config ./configs/Retrieval_coco.yaml \
--output_dir output/pretrain_e30_Retrieval_coco \
--checkpoint output/pretrain/checkpoint_29.pth

# zero-shot flickr 
python -m torch.distributed.launch --nproc_per_node=8 \
--use_env Retrieval.py \
--config ./configs/Retrieval_flickr.yaml \
--output_dir output/pretrain_e30_Retrieval_flickr_zeroshot \
--checkpoint output/pretrain_e30_Retrieval_coco/checkpoint_best.pth \
--evaluate

VQA

python -m torch.distributed.launch --nproc_per_node=8 \
--use_env VQA.py \
--config ./configs/VQA.yaml \
--output_dir output/pretrain_e30_vqa \
--checkpoint output/pretrain/checkpoint_29.pth

Visual Entailment

python -m torch.distributed.launch --nproc_per_node=8 \
--use_env VE.py \
--config ./configs/VE.yaml \
--output_dir output/pretrain_e30_VE \
--checkpoint output/pretrain/checkpoint_29.pth

NLVR2

# pre-train nlvr
python -m torch.distributed.launch --nproc_per_node=8 \
--use_env Pretrain_nlvr.py \
--config ./configs/NLVR_pretrain.yaml \
--output_dir output/pretrain_e30_NLVR_pretrain \
--checkpoint output/pretrain/checkpoint_29.pth

# fine-tune nlvr
python -m torch.distributed.launch --nproc_per_node=8 \
--use_env NLVR.py \
--config ./configs/NLVR.yaml \
--output_dir output/pretrain_e30_NLVR \
--checkpoint output/pretrain_e30_NLVR_pretrain/checkpoint_00.pth

Citation:

@article{yang2022vision,
  title={Vision-Language Pre-Training with Triple Contrastive Learning},
  author={Yang, Jinyu and Duan, Jiali and Tran, Son and Xu, Yi and Chanda, Sampath and Chen, Liqun and Zeng, Belinda and Chilimbi, Trishul and Huang, Junzhou},
  booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
  year={2022}
}

Our code is largely borrowed from ALBEF

tcl's People

Contributors

viyjy avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

tcl's Issues

Release finetuned checkpoint

Hi,
First of al, thanks for your sharing code and excellent paper.
I wonder if you could provide the checkpoints finetuned on flickr and coco?

How to obtain `T+`?

Hi!

Thanks for releasing the codes!
Sorry to bother in this busy CVPR week, but here's one minor question:

It's said in Sec. 3.2 that two sets of textual inputs, T and T+, are to be fed to h(.) and h_hat(.), respectively. Could you point me anywhere that how to obtain T+ exactly?

Thanks!

Is the dataset wrong?

My dataset uses "MSCOCO/COCO-2015/annotations/captions_train2014.json". When I run the code, the following error occurred. Is the dataset wrong?
#########################################################################
Not using distributed mode
Creating dataset
Creating model
reshape position embedding from 196 to 256
_IncompatibleKeys(missing_keys=[], unexpected_keys=['head.weight', 'head.bias'])
Downloading: 100%440M/440M [05:36<00:00, 1.31MB/s]
Traceback (most recent call last):
File "/home/whhhh/ZTTsar/TCL-main/Pretrain.py", line 204, in
main(args, config)
File "/home/whhhh/ZTTsar/TCL-main/Pretrain.py", line 121, in main
model = ALBEF(config=config, text_encoder=args.text_encoder, tokenizer=tokenizer, init_deit=True)
File "/home/whhhh/ZTTsar/TCL-main/models/model_pretrain.py", line 74, in init
[self.text_encoder,self.text_encoder_m],
File "/home/whhhh/anaconda3/envs/yolov7/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1265, in getattr
raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'ALBEF' object has no attribute 'text_encoder_m'

Process finished with exit code 1

About GPU Usage and Training Time

Hi, thanks for your great work and code sharing.

According to config/Pretrain.yaml, the batch_size is set to 64 (i.e., each GPU will process 64 image-text pairs) during pretraining. I would like to know how much GPU memory will be used and how much time will be taken per epoch (on 4M dataset) under this setting.

By the way, I have read your excellent paper but can not find the supplementary materials on the web. Could you share a link to download it? Thanks a lot.

why mlm work

thank you for excellent work and release the code.
After watching you model_pretrain.py. I have a question that you first encode the txt and image and get the features for contrastive loss, then you mask the features encoded by text_encoder and send them to fusion layer.
To the best of my knowledge, the text_features for every token have attend to other token after the text_encoder.So I think the mask stragey dont work.
So can you do me a favor to answer the question or maybe i lack some details.

Question about the MLM masking

Hi,

10% of the time, we replace masked input tokens with random word

    indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced

Here is the code you use to replace a token with a random word. Is it correct to use 0.5 as the parameter here?
Thank you for your answer.

Question about Data augmentation for MoCo.

Dear author, I feel thankful of your great masterpiece, and I really appreciate about your work these days.

Reading your paper with comprehending your code,

I got in my mind about data augmentation.

I found that you give data augmentation on Image, then, did you do same thing on text modalities?

If it is right, then can you check out that line in code?

Thx.

Inference + Hugging Face

Is there an easy way to inference the model on some new examples? Also, are there any plans to put the model on Hugging face?

Question about VQA fine-tuning

Hi Jinyu,
Thanks for sharing the code of the great work TCL. I have some questions about the code of model_vqa.py.
1. top k answers for each question, shouldn't the code be answer_ids[b] and answer_atts[b]?
2. use of text decoder, based on targets_ids = input_ids.masked_fill(input_ids == self.tokenizer.pad_token_id, -100), the input_ids are almost the same as targets_ids except the pad token id, so what's the point of calculating loss and generating the answer for the second time?

Thanks!

in_batch_g2l_loss

The line 249 in models
u_p = (temp_mask * u_p) + (10000. * (1-temp_mask)) may should be
u_p = (temp_mask * u_p) - (10000. * (1-temp_mask)) ?

About the XBert

Hi thanks for this wonderful work. I am confused about the CrossAttention Module, In the code of XBERT,when layer_num>=6, the text_encoder will turn into cross attention, however it will do self-attention on text_embeds and then do cross-attention between the text_embeds and image_embeds. I am confused why do self-attention on text_embeds and then do the cross-attention. Can it do self-attention on image_embeds first and then do cross-attention? or Can it only do the cross-attention? Please help me solve this problem when you are convenient. Thank you again!

json file problem

whether the follow json file belong to VQA v2 ?

train_file: ['../data/json_down/vqa_train.json',
'../data/json_down/vqa_val.json',
'../data/json_down/vg.json']

test_file: ['../data/json_down/vqa_test.json']
answer_list: '../data/json_down/answer_list.json'

vqa_root: '../data/VQA/Images/mscoco/' #train2014/
vg_root: '../data/VG/VG_100K/' #image/

momentum models

Dear authors, why using momentum models? Could you provide some papers or blogs for a newbie? Thanks a lot!

Can not reproduce zero-shot retrieval performance

Hi, I have downloaded the pre-trained checkpoint TCL_4m.pth you provided and prepared Flickr30k.

I run the following command:

python -m torch.distributed.launch \
--nproc_per_node=4 \
--use_env Retrieval.py \
--config ./configs/Retrieval_flickr.yaml \
--output_dir output/pretrain_e30_Retrieval_flickr_zeroshot \
--checkpoint ./data/TCL_4M.pth \
--evaluate

Here are the results I get:

{"val_txt_r1": 87.96844181459566, "val_txt_r5": 98.12623274161736, "val_txt_r10": 99.40828402366864, "val_txt_r_mean": 95.16765285996057, "val_img_r1": 72.07100591715977, "val_img_r5": 90.55226824457594, "val_img_r10": 94.5759368836292, "val_img_r_mean": 85.73307034845497, "val_r_mean": 90.45036160420777, "test_txt_r1": 89.4, "test_txt_r5": 98.6, "test_txt_r10": 99.6, "test_txt_r_mean": 95.86666666666667, "test_img_r1": 73.36, "test_img_r5": 92.16, "test_img_r10": 95.52, "test_img_r_mean": 87.01333333333332, "test_r_mean": 91.44, "epoch": 0}

According to the Table 2 in your paper, zero-shot R@1 performance on Flickr30K test set is 93.0 (text retrieval) and 79.6 (image retrieval). But what I get is test_txt_r1 = 89.4 and text_img_r1 = 73.36.

Do I make something wrong?

How about the intra-modal retrieval performance?

Hi, Thanks for sharing of your great work, TCL.
I have read the paper, and get some confusion. What about the performance of intra-modal retrieval? Like img2img or text2text, it should be better than that reported in the ALIGN. However, I do not find experiments about that. Am I miss something?

loss is nan when pretaining on my own dataset

hi, thanks for your excellent work firstly. when i train my own chinese dataset (so i change the bert-base-uncased to bert-base-chinese), loss becomes nan after several iterations. i have tried to decrease the lr and add grad_clip, but the problem still exists.
image
here is my training config:
image

can you give me some suggestion? thanks in advance.

About the queue

Thanks for this wonderful work, i am confused about the queue size. Why is the queue size 65536.

Sent from PPHub

About the loss_distill

Hi, thank you for the excellent work and the release of the code!

I am a little confused about the approach to calculating loss_distill in line 1429 of xbert.py as shown in

                  loss_distill = -torch.sum(F.log_softmax(prediction_scores, dim=1)*soft_labels,dim=-1)

I think the size of both prediction_scores and soft_labels would be (batch_size, seq_len, vocab_size). And F.softmax is used in the last dimension for soft_labels in line 237 of model_pretrain.py, as shown in

                  mlm_output = self.text_encoder(input_ids, 
                                                 attention_mask = text.attention_mask,
                                                 encoder_hidden_states = image_embeds,
                                                 encoder_attention_mask = image_atts,      
                                                 return_dict = True,
                                                 labels = labels,   
                                                 soft_labels = F.softmax(logits_m,dim=-1),
                                                 alpha = alpha
                                                )

Why is F.log_softmax used in the second dimension (dimension of seq_len) for prediction_scores?

No module named 'refTools'

Hi I am trying to reproduce the results from inside a docker container. After installing the dependencies I hit the following error:

  File "Pretrain.py", line 30, in <module>
    from dataset import create_dataset, create_sampler, create_loader
  File "/workspace/dataset/__init__.py", line 6, in <module>
    from dataset.caption_dataset import re_train_dataset, re_eval_dataset, pretrain_dataset
  File "/workspace/dataset/caption_dataset.py", line 12, in <module>
    from dataset.utils import pre_caption
  File "/workspace/dataset/utils.py", line 45, in <module>
    from refTools.evaluation.refEvaluation import RefEvaluation
ModuleNotFoundError: No module named 'refTools'

every time when running:

python -m torch.distributed.launch --nproc_per_node=8 \ --use_env Pretrain.py \ --config ./configs/Pretrain.yaml \ --output_dir output/pretrain

I have tried pip3 install reftools but it does not solve the issue. Have you run into this issue before?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.