Giter Club home page Giter Club logo

scan's Introduction

Open-Vocabulary Segmentation with Semantic-Assisted Calibration [CVPR 2024]

Yong Liu*, Sule Bai*, Guanbin Li, Yitong Wang, Yansong Tang (*equal contribution)

The repository contains the official implementation of "Open-Vocabulary Segmentation with Semantic-Assisted Calibration"

Paper


📖 Pipeline & Results

Tab of Content

If you find any bugs due to carelessness on our part in organizing the code, feel free to contact us and point that!

Installation

Please see installation guide.

Data Preparation

Please follow the instruction of ov-seg to prepare the training and test data. The data should be organized like:

$DETECTRON2_DATASETS/
  coco/                 # COCOStuff-171
  ADEChallengeData2016/ # ADE20K-150
  ADE20K_2021_17_01/    # ADE20K-847
  VOCdevkit/
    VOC2012/            # PASCALVOC-20
    VOC2010/            # PASCALContext-59, PASCALContext-459

Usage

  • Pretrained Weight

    We have provided the pretrained SCAN-VitL weights and the finetuned Contextual-shifted CLIP weights. Please download them from here.

Evaluation

python train_net.py --eval-only --config-file <CONFIG_FILE> --num-gpus <NUM_GPU> OUTPUT_DIR <OUTPUT_PATH> MODEL.WEIGHTS <TRAINED_MODEL_PATH>
  • Here is an example:
python train_net.py --num-gpu 8 --eval-only --config-file configs/scan_vitL.yaml MODEL.WEIGHTS ./SCAN.pth DATASETS.TEST \(\"ade20k_sem_seg_val\",\) MODEL.CLIP_ADAPTER.REPLACE_RATIO 0.05 MODEL.CLIP_ADAPTER.CLIP_ENSEMBLE_WEIGHT 0.75 MODEL.CLIP_ADAPTER.MASK_THR 0.55

Training

  1. Train the segmentation model:
python train_net.py  --config-file <CONFIG_FILE> --num-gpus <NUM_GPU>
  • Here is an example:
python train_net.py  --num-gpu 8 --config-file configs/scan_vitL.yaml
  1. Fuse segmentation model with finetuned CLIP.

We have provided the finetuned CLIP weights. You can directly fuse the pretrained weights with the segmentation model to get the final model. The fuse command is:

cd tools
python replace_clip.py

You need to specify the "clip_ckpt" and "ovseg_model" in the file according to your CLIP path and segmentation model path.

(Optional) If you want to finetune the CLIP model from scratch, please follow ov-seg to prepare the corresponding data. The finetued command is:

cd open_clip_training
cd src
bash scripts/finetune_VitL_with_mask.sh

Cite

If you find our work helpful, we'd appreciate it if you could cite our paper in your work.

@article{liu2023open,
  title={Open-Vocabulary Segmentation with Semantic-Assisted Calibration},
  author={Liu, Yong and Bai, Sule and Li, Guanbin and Wang, Yitong and Tang, Yansong},
  journal={arXiv preprint arXiv:2312.04089},
  year={2023}
}

scan's People

Contributors

sulebai avatar yongliu20 avatar

Stargazers

 avatar Haoran Wang avatar Haoji Zhang avatar  avatar  avatar  avatar Dongjun Hwang avatar Cameltr avatar  avatar Jiajun Chen avatar  avatar Lawrence avatar  avatar Yunhan Yang avatar  avatar  avatar Joe Nevaeh avatar Luo avatar lg(x) avatar  avatar  avatar Yan Wang avatar TeaQwQTea avatar 帅约科维奇 avatar Robert Luo avatar Wang-MMM-Lab avatar yahooo avatar  avatar Nick Imanzi avatar Jeff Carpenter avatar Debug_Yann avatar Jixuan Fan avatar Xiaoke Huang avatar  avatar  avatar Xubing Ye avatar  avatar

Watchers

Yixiong Liang avatar  avatar

scan's Issues

Issue with Loading Converted Weights Due to Shape Mismatch

Thank you for your amazing work on this project.

I'm currently experiencing issues when trying to load weights into the model. Specifically, I used the convert-pretrained-swin-model-to-d2.py script to convert swin_base_patch4_window12_384_22k.pth to swin_base_patch4_window12_384_22k.pkl.

However, when loading the converted weights, I encounter several warnings related to shape mismatches between the checkpoint and the model parameters. Here are some examples of the warnings:

d2.checkpoint.c2_model_loading WARNING: Shape of norm.bias in checkpoint is torch.Size([1024]), while shape of sem_seg_head.pixel_decoder.adapter_1.norm.bias in model is torch.Size([256]).
d2.checkpoint.c2_model_loading WARNING: norm.bias will not be loaded. Please double check and see if this is desired.
d2.checkpoint.c2_model_loading WARNING: Shape of norm.weight in checkpoint is torch.Size([1024]), while shape of sem_seg_head.pixel_decoder.adapter_1.norm.weight in model is torch.Size([256]).
d2.checkpoint.c2_model_loading WARNING: norm.weight will not be loaded. Please double check and see if this is desired.
d2.checkpoint.c2_model_loading WARNING: Shape of norm.bias in checkpoint is torch.Size([1024]), while shape of sem_seg_head.pixel_decoder.layer_1.norm.bias in model is torch.Size([256]).
d2.checkpoint.c2_model_loading WARNING: norm.bias will not be loaded. Please double check and see if this is desired.
d2.checkpoint.c2_model_loading WARNING: Shape of norm.weight in checkpoint is torch.Size([1024]), while shape of sem_seg_head.pixel_decoder.layer_1.norm.weight in model is torch.Size([256]).
d2.checkpoint.c2_model_loading WARNING: norm.weight will not be loaded. Please double check and see if this is desired.
d2.checkpoint.c2_model_loading WARNING: Shape of norm.bias in checkpoint is torch.Size([1024]), while shape of sem_seg_head.predictor.transformer_cross_attention_layers.0.norm.bias in model is torch.Size([256]).

Due to these warnings, it seems that some of the model parameters are not being correctly loaded, which is impacting the performance of the model.

Could you please provide some guidance on how to resolve these shape mismatches? Is there a recommended way to adjust the model or the checkpoint to ensure compatibility? Any insights or suggestions would be greatly appreciated.

Thank you in advance for your help!

Details

Thank you for your work.
What graphics card was used for the training and how much VRAM was used when the batch size was set to 32?

模型未完全使用的问题

你好,我在测试时发现如下提醒:请帮我看一看,谢谢
WARNING [07/15 08:50:58 fvcore.common.checkpoint]: The checkpoint state_dict contains keys that are not used by the model:
clip_adapter.clip_model.visual.learnable_weight
clip_adapter.clip_model.visual.cxt_decoder.layers.0.self_attn.{in_proj_bias, in_proj_weight}
clip_adapter.clip_model.visual.cxt_decoder.layers.0.self_attn.out_proj.{bias, weight}
clip_adapter.clip_model.visual.cxt_decoder.layers.0.multihead_attn.{in_proj_bias, in_proj_weight}
clip_adapter.clip_model.visual.cxt_decoder.layers.0.multihead_attn.out_proj.{bias, weight}
clip_adapter.clip_model.visual.cxt_decoder.layers.0.linear1.{bias, weight}
clip_adapter.clip_model.visual.cxt_decoder.layers.0.linear2.{bias, weight}
clip_adapter.clip_model.visual.cxt_decoder.layers.0.norm1.{bias, weight}
clip_adapter.clip_model.visual.cxt_decoder.layers.0.norm2.{bias, weight}
clip_adapter.clip_model.visual.cxt_decoder.layers.0.norm3.{bias, weight}
clip_adapter.clip_model.visual.cxt_decoder.layers.1.self_attn.{in_proj_bias, in_proj_weight}
clip_adapter.clip_model.visual.cxt_decoder.layers.1.self_attn.out_proj.{bias, weight}
clip_adapter.clip_model.visual.cxt_decoder.layers.1.multihead_attn.{in_proj_bias, in_proj_weight}
clip_adapter.clip_model.visual.cxt_decoder.layers.1.multihead_attn.out_proj.{bias, weight}
clip_adapter.clip_model.visual.cxt_decoder.layers.1.linear1.{bias, weight}
clip_adapter.clip_model.visual.cxt_decoder.layers.1.linear2.{bias, weight}
clip_adapter.clip_model.visual.cxt_decoder.layers.1.norm1.{bias, weight}
clip_adapter.clip_model.visual.cxt_decoder.layers.1.norm2.{bias, weight}
clip_adapter.clip_model.visual.cxt_decoder.layers.1.norm3.{bias, weight}
clip_adapter.clip_model.visual.cxt_decoder.layers.2.self_attn.{in_proj_bias, in_proj_weight}
clip_adapter.clip_model.visual.cxt_decoder.layers.2.self_attn.out_proj.{bias, weight}
clip_adapter.clip_model.visual.cxt_decoder.layers.2.multihead_attn.{in_proj_bias, in_proj_weight}
clip_adapter.clip_model.visual.cxt_decoder.layers.2.multihead_attn.out_proj.{bias, weight}
clip_adapter.clip_model.visual.cxt_decoder.layers.2.linear1.{bias, weight}
clip_adapter.clip_model.visual.cxt_decoder.layers.2.linear2.{bias, weight}
clip_adapter.clip_model.visual.cxt_decoder.layers.2.norm1.{bias, weight}
clip_adapter.clip_model.visual.cxt_decoder.layers.2.norm2.{bias, weight}
clip_adapter.clip_model.visual.cxt_decoder.layers.2.norm3.{bias, weight}
clip_adapter.clip_model.visual.cxt_decoder.layers.3.self_attn.{in_proj_bias, in_proj_weight}
clip_adapter.clip_model.visual.cxt_decoder.layers.3.self_attn.out_proj.{bias, weight}
clip_adapter.clip_model.visual.cxt_decoder.layers.3.multihead_attn.{in_proj_bias, in_proj_weight}
clip_adapter.clip_model.visual.cxt_decoder.layers.3.multihead_attn.out_proj.{bias, weight}
clip_adapter.clip_model.visual.cxt_decoder.layers.3.linear1.{bias, weight}
clip_adapter.clip_model.visual.cxt_decoder.layers.3.linear2.{bias, weight}
clip_adapter.clip_model.visual.cxt_decoder.layers.3.norm1.{bias, weight}
clip_adapter.clip_model.visual.cxt_decoder.layers.3.norm2.{bias, weight}
clip_adapter.clip_model.visual.cxt_decoder.layers.3.norm3.{bias, weight}
clip_adapter.original_clip.learnable_weight
clip_adapter.original_clip.cxt_decoder.layers.0.self_attn.{in_proj_bias, in_proj_weight}
clip_adapter.original_clip.cxt_decoder.layers.0.self_attn.out_proj.{bias, weight}
clip_adapter.original_clip.cxt_decoder.layers.0.multihead_attn.{in_proj_bias, in_proj_weight}
clip_adapter.original_clip.cxt_decoder.layers.0.multihead_attn.out_proj.{bias, weight}
clip_adapter.original_clip.cxt_decoder.layers.0.linear1.{bias, weight}
clip_adapter.original_clip.cxt_decoder.layers.0.linear2.{bias, weight}
clip_adapter.original_clip.cxt_decoder.layers.0.norm1.{bias, weight}
clip_adapter.original_clip.cxt_decoder.layers.0.norm2.{bias, weight}
clip_adapter.original_clip.cxt_decoder.layers.0.norm3.{bias, weight}
clip_adapter.original_clip.cxt_decoder.layers.1.self_attn.{in_proj_bias, in_proj_weight}
clip_adapter.original_clip.cxt_decoder.layers.1.self_attn.out_proj.{bias, weight}
clip_adapter.original_clip.cxt_decoder.layers.1.multihead_attn.{in_proj_bias, in_proj_weight}
clip_adapter.original_clip.cxt_decoder.layers.1.multihead_attn.out_proj.{bias, weight}
clip_adapter.original_clip.cxt_decoder.layers.1.linear1.{bias, weight}
clip_adapter.original_clip.cxt_decoder.layers.1.linear2.{bias, weight}
clip_adapter.original_clip.cxt_decoder.layers.1.norm1.{bias, weight}
clip_adapter.original_clip.cxt_decoder.layers.1.norm2.{bias, weight}
clip_adapter.original_clip.cxt_decoder.layers.1.norm3.{bias, weight}
clip_adapter.original_clip.cxt_decoder.layers.2.self_attn.{in_proj_bias, in_proj_weight}
clip_adapter.original_clip.cxt_decoder.layers.2.self_attn.out_proj.{bias, weight}
clip_adapter.original_clip.cxt_decoder.layers.2.multihead_attn.{in_proj_bias, in_proj_weight}
clip_adapter.original_clip.cxt_decoder.layers.2.multihead_attn.out_proj.{bias, weight}
clip_adapter.original_clip.cxt_decoder.layers.2.linear1.{bias, weight}
clip_adapter.original_clip.cxt_decoder.layers.2.linear2.{bias, weight}
clip_adapter.original_clip.cxt_decoder.layers.2.norm1.{bias, weight}
clip_adapter.original_clip.cxt_decoder.layers.2.norm2.{bias, weight}
clip_adapter.original_clip.cxt_decoder.layers.2.norm3.{bias, weight}
clip_adapter.original_clip.cxt_decoder.layers.3.self_attn.{in_proj_bias, in_proj_weight}
clip_adapter.original_clip.cxt_decoder.layers.3.self_attn.out_proj.{bias, weight}
clip_adapter.original_clip.cxt_decoder.layers.3.multihead_attn.{in_proj_bias, in_proj_weight}
clip_adapter.original_clip.cxt_decoder.layers.3.multihead_attn.out_proj.{bias, weight}
clip_adapter.original_clip.cxt_decoder.layers.3.linear1.{bias, weight}
clip_adapter.original_clip.cxt_decoder.layers.3.linear2.{bias, weight}
clip_adapter.original_clip.cxt_decoder.layers.3.norm1.{bias, weight}
clip_adapter.original_clip.cxt_decoder.layers.3.norm2.{bias, weight}
clip_adapter.original_clip.cxt_decoder.layers.3.norm3.{bias, weight}
0%| | 0/1 [00:00<?, ?it/s]G
:\program\SCAN-main\scan\frequency.py:42: UserWarning: Casting complex values to real discards the imaginary part (Triggere
d internally at ..\aten\src\ATen\native\Copy.cpp:244.)
y = torch.fft.ifft2(y, s=(h, w)).float()
G:\program\SCAN-main\scan\modeling\transformer_decoder\position_encoding.py:41: UserWarning: floordiv is deprecated, an
d its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'flo
or'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mo
de='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
G:\program\SCAN-main\scan\modeling\transformer_decoder\position_encoding.py:41: UserWarning: floordiv is deprecated, an
d its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'flo
or'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mo
de='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
这是模型参数未被完全使用?请问你知道这个问题的解决方法,这些形状不匹配的问题该如何解决呢?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.