Giter Club home page Giter Club logo

tcl's Introduction

TCL: Text-grounded Contrastive Learning (CVPR'23)

Official PyTorch implementation of Learning to Generate Text-grounded Mask for Open-world Semantic Segmentation from Only Image-Text Pairs, Junbum Cha, Jonghwan Mun, Byungseok Roh, CVPR 2023.

Text-grounded Contrastive Learning (TCL) is an open-world semantic segmentation framework using only image-text pairs. TCL enables a model to learn region-text alignment without train-test discrepancy.

Demo page is available. Since this demo runs on a free HuggingFace CPU space, inference times may take around 5-10 seconds.

Results

TCL can perform segmentation on both (a, c) existing segmentation benchmarks and (b) arbitrary concepts, such as proper nouns and free-form text, in the wild images.


Additional examples in PASCAL VOC

Additional examples in the wild

Dependencies

We used pytorch 1.12.1 and torchvision 0.13.1.

pip install -U openmim
mim install mmcv-full==1.6.2 mmsegmentation==0.27.0
pip install -r requirements.txt

Note that the order of requirements roughly represents the importance of the version. We recommend using the same version for at least webdataset, mmsegmentation, and timm.

Datasets

Note that much of this section is adapted from the data preparation section of GroupViT README.

We use webdataset as scalable data format in training and mmsegmentation for semantic segmentation evaluation.

The overall file structure is as follows:

TCL
├── data
│   ├── gcc3m
│   │   ├── gcc-train-000000.tar
│   │   ├── ...
│   ├── gcc12m
│   │   ├── cc-000000.tar
│   │   ├── ...
│   ├── cityscapes
│   │   ├── leftImg8bit
│   │   │   ├── train
│   │   │   ├── val
│   │   ├── gtFine
│   │   │   ├── train
│   │   │   ├── val
│   ├── VOCdevkit
│   │   ├── VOC2012
│   │   │   ├── JPEGImages
│   │   │   ├── SegmentationClass
│   │   │   ├── ImageSets
│   │   │   │   ├── Segmentation
│   │   ├── VOC2010
│   │   │   ├── JPEGImages
│   │   │   ├── SegmentationClassContext
│   │   │   ├── ImageSets
│   │   │   │   ├── SegmentationContext
│   │   │   │   │   ├── train.txt
│   │   │   │   │   ├── val.txt
│   │   │   ├── trainval_merged.json
│   │   ├── VOCaug
│   │   │   ├── dataset
│   │   │   │   ├── cls
│   ├── ade
│   │   ├── ADEChallengeData2016
│   │   │   ├── annotations
│   │   │   │   ├── training
│   │   │   │   ├── validation
│   │   │   ├── images
│   │   │   │   ├── training
│   │   │   │   ├── validation
│   ├── coco_stuff164k
│   │   ├── images
│   │   │   ├── train2017
│   │   │   ├── val2017
│   │   ├── annotations
│   │   │   ├── train2017
│   │   │   ├── val2017

The instructions for preparing each dataset are as follows.

Training datasets

In training, we use Conceptual Caption 3m and 12m. We use img2dataset tool to download and preprocess the datasets.

GCC3M

Please download the training split annotation file from Conceptual Caption 3M and name it as gcc3m.tsv.

Then run img2dataset to download the image text pairs and save them in the webdataset format.

sed -i '1s/^/caption\turl\n/' gcc3m.tsv
img2dataset --url_list gcc3m.tsv --input_format "tsv" \
            --url_col "url" --caption_col "caption" --output_format webdataset \
            --output_folder data/gcc3m \
            --processes_count 16 --thread_count 64 \
            --image_size 512 --resize_mode keep_ratio --resize_only_if_bigger True \
            --enable_wandb True --save_metadata False --oom_shard_count 6
rename -d 's/^/gcc-train-/' data/gcc3m/*

Please refer to img2dataset CC3M tutorial for more details.

GCC12M

Please download the annotation file from Conceptual Caption 12M and name it as gcc12m.tsv.

Then run img2dataset to download the image text pairs and save them in the webdataset format.

sed -i '1s/^/caption\turl\n/' gcc12m.tsv
img2dataset --url_list gcc12m.tsv --input_format "tsv" \
            --url_col "url" --caption_col "caption" --output_format webdataset \
            --output_folder data/gcc12m \
            --processes_count 16 --thread_count 64 \
            --image_size 512 --resize_mode keep_ratio --resize_only_if_bigger True \
            --enable_wandb True --save_metadata False --oom_shard_count 6
rename -d 's/^/cc-/' data/gcc12m/*

Please refer to img2dataset CC12M tutorial for more details.

Evaluation datasets

In the paper, we use 8 benchmarks; (i) w/ background: PASCAL VOC20, PASCAL Context59, and COCO-Object, and (ii) w/o background: PASCAL VOC, PASCAL Context, COCO-Stuff, Cityscapes, and ADE20k. Since some benchmarks share the data sources (e.g., VOC20 and VOC), we need to prepare 5 datasets: PASCAL VOC, PASCAL Context, COCO-Stuff164k, Cityscapes, and ADE20k.

Please download and setup PASCAL VOC, PASCAL Context, COCO-Stuff164k, Cityscapes, and ADE20k datasets following MMSegmentation data preparation document.

COCO Object

COCO-Object dataset uses only object classes from COCO-Stuff164k dataset by collecting instance semgentation annotations. Run the following command to convert instance segmentation annotations to semantic segmentation annotations:

python convert_dataset/convert_coco.py data/coco_stuff164k/ -o data/coco_stuff164k/

Training

We use 16 and 8 NVIDIA V100 GPUs for the main and ablation experiments, respectively.

Single node

torchrun --rdzv_endpoint=localhost:5 --nproc_per_node=auto main.py --cfg ./configs/tcl.yml

Multi node

torchrun --rdzv_endpoint=$HOST:$PORT --nproc_per_node=auto --nnodes=$NNODES --node_rank=$RANK main.py --cfg ./configs/tcl.yml

Evaluation

We provide an official checkpoint to reproduce the main results of our paper.

  • Zero-shot transfer to semantic segmentation (Table 2):
torchrun --rdzv_endpoint=localhost:5 --nproc_per_node=auto main.py --resume checkpoints/tcl.pth --eval
  • Evaluation without PAMR (Table 3 in Appendix):
torchrun --rdzv_endpoint=localhost:5 --nproc_per_node=auto main.py --resume checkpoints/tcl.pth --eval \
    --opts evaluate.pamr=false evaluate.bg_thresh=0.5

Note that we use bg_threshold of 0.4 with PAMR and 0.5 without PAMR, since we observed that PAMR tends to reduce the foreground area.

Citation

@inproceedings{cha2022tcl,
  title={Learning to Generate Text-grounded Mask for Open-world Semantic Segmentation from Only Image-Text Pairs},
  author={Cha, Junbum and Mun, Jonghwan and Roh, Byungseok},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
  year={2023}
}

License

This project is released under MIT license.

tcl's People

Contributors

kentjang avatar khanrc avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

tcl's Issues

Doubts about the `ignore_last_attn` parameter

Thank you so much for releasing the code base and your wonderful work. I was going through the code base and I am struggling a bit with the parameter ignore_last_attn. Can you please explain what is its purpose? I am a bit confused since if all entries except the diagonal are masked away, how does it work?

training problem

Dear author, I encountered a problem while training the model. No matter how I modified my configuration parameters, I was unable to obtain data through the dataloader when using GCC12m for training, and the algorithm got stuck endlessly at this stage. If you could give me some guidance on this issue, I would greatly appreciate it.

About training cost.

Hi, thank you for your wonderful work!
And I was wondering the training cost of TCL. In the paper it writes 12 hours, but what GPUs are you using?

ValueError: need at most 63 handles, got a sequence of length 66

Hello, thanks for your codes and readme file. I have a problem when I try to download GCC3M dataset using img2dataset. The error shows like that:

Exception in thread Thread-1:
Traceback (most recent call last):
  File "D:\Anaconda\envs\opseg\lib\threading.py", line 980, in _bootstrap_inner
# writing /data/gcc3m\000000.tar 0 0.0 GB 0
    self.run()
  File "D:\Anaconda\envs\opseg\lib\threading.py", line 917, in run
    self._target(*self._args, **self._kwargs)
  File "D:\Anaconda\envs\opseg\lib\multiprocessing\pool.py", line 519, in _handle_workers
0it [00:00, ?it/s]    cls._wait_for_updates(current_sentinels, change_notifier)
  File "D:\Anaconda\envs\opseg\lib\multiprocessing\pool.py", line 499, in _wait_for_updates
0it [00:00, ?it/s]
    wait(sentinels, timeout=timeout)
  File "D:\Anaconda\envs\opseg\lib\multiprocessing\connection.py", line 879, in wait
    ready_handles = _exhaustive_wait(waithandle_to_obj.keys(), timeout)
  File "D:\Anaconda\envs\opseg\lib\multiprocessing\connection.py", line 811, in _exhaustive_wait
    res = _winapi.WaitForMultipleObjects(L, False, timeout)
ValueError: need at most 63 handles, got a sequence of length 66

Could you please help me solve the problem?

CXR-CLIP code is not found

I was unable to find the relevant code through given link in paper "CXR-CLIP: Toward Large Scale Chest X-ray Language-Image Pre-training" and your repository list. I badly want the code, where can I find it?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.