Giter Club home page Giter Club logo

segment-anything's Introduction

Segment Anything

Meta AI Research, FAIR

Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick

[Paper] [Project] [Demo] [Dataset] [Blog] [BibTeX]

SAM design

The Segment Anything Model (SAM) produces high quality object masks from input prompts such as points or boxes, and it can be used to generate masks for all objects in an image. It has been trained on a dataset of 11 million images and 1.1 billion masks, and has strong zero-shot performance on a variety of segmentation tasks.

Installation

The code requires python>=3.8, as well as pytorch>=1.7 and torchvision>=0.8. Please follow the instructions here to install both PyTorch and TorchVision dependencies. Installing both PyTorch and TorchVision with CUDA support is strongly recommended.

Install Segment Anything:

pip install git+https://github.com/facebookresearch/segment-anything.git

or clone the repository locally and install with

git clone [email protected]:facebookresearch/segment-anything.git
cd segment-anything; pip install -e .

The following optional dependencies are necessary for mask post-processing, saving masks in COCO format, the example notebooks, and exporting the model in ONNX format. jupyter is also required to run the example notebooks.

pip install opencv-python pycocotools matplotlib onnxruntime onnx

Getting Started

First download a model checkpoint. Then the model can be used in just a few lines to get masks from a given prompt:

from segment_anything import SamPredictor, sam_model_registry
sam = sam_model_registry["<model_type>"](checkpoint="<path/to/checkpoint>")
predictor = SamPredictor(sam)
predictor.set_image(<your_image>)
masks, _, _ = predictor.predict(<input_prompts>)

or generate masks for an entire image:

from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
sam = sam_model_registry["<model_type>"](checkpoint="<path/to/checkpoint>")
mask_generator = SamAutomaticMaskGenerator(sam)
masks = mask_generator.generate(<your_image>)

Additionally, masks can be generated for images from the command line:

python scripts/amg.py --checkpoint <path/to/checkpoint> --model-type <model_type> --input <image_or_folder> --output <path/to/output>

See the examples notebooks on using SAM with prompts and automatically generating masks for more details.

ONNX Export

SAM's lightweight mask decoder can be exported to ONNX format so that it can be run in any environment that supports ONNX runtime, such as in-browser as showcased in the demo. Export the model with

python scripts/export_onnx_model.py --checkpoint <path/to/checkpoint> --model-type <model_type> --output <path/to/output>

See the example notebook for details on how to combine image preprocessing via SAM's backbone with mask prediction using the ONNX model. It is recommended to use the latest stable version of PyTorch for ONNX export.

Web demo

The demo/ folder has a simple one page React app which shows how to run mask prediction with the exported ONNX model in a web browser with multithreading. Please see demo/README.md for more details.

Model Checkpoints

Three model versions of the model are available with different backbone sizes. These models can be instantiated by running

from segment_anything import sam_model_registry
sam = sam_model_registry["<model_type>"](checkpoint="<path/to/checkpoint>")

Click the links below to download the checkpoint for the corresponding model type.

Dataset

See here for an overview of the datastet. The dataset can be downloaded here. By downloading the datasets you agree that you have read and accepted the terms of the SA-1B Dataset Research License.

We save masks per image as a json file. It can be loaded as a dictionary in python in the below format.

{
    "image"                 : image_info,
    "annotations"           : [annotation],
}

image_info {
    "image_id"              : int,              # Image id
    "width"                 : int,              # Image width
    "height"                : int,              # Image height
    "file_name"             : str,              # Image filename
}

annotation {
    "id"                    : int,              # Annotation id
    "segmentation"          : dict,             # Mask saved in COCO RLE format.
    "bbox"                  : [x, y, w, h],     # The box around the mask, in XYWH format
    "area"                  : int,              # The area in pixels of the mask
    "predicted_iou"         : float,            # The model's own prediction of the mask's quality
    "stability_score"       : float,            # A measure of the mask's quality
    "crop_box"              : [x, y, w, h],     # The crop of the image used to generate the mask, in XYWH format
    "point_coords"          : [[x, y]],         # The point coordinates input to the model to generate the mask
}

Image ids can be found in sa_images_ids.txt which can be downloaded using the above link as well.

To decode a mask in COCO RLE format into binary:

from pycocotools import mask as mask_utils
mask = mask_utils.decode(annotation["segmentation"])

See here for more instructions to manipulate masks stored in RLE format.

License

The model is licensed under the Apache 2.0 license.

Contributing

See contributing and the code of conduct.

Contributors

The Segment Anything project was made possible with the help of many contributors (alphabetical):

Aaron Adcock, Vaibhav Aggarwal, Morteza Behrooz, Cheng-Yang Fu, Ashley Gabriel, Ahuva Goldstand, Allen Goodman, Sumanth Gurram, Jiabo Hu, Somya Jain, Devansh Kukreja, Robert Kuo, Joshua Lane, Yanghao Li, Lilian Luong, Jitendra Malik, Mallika Malhotra, William Ngan, Omkar Parkhi, Nikhil Raina, Dirk Rowe, Neil Sejoor, Vanessa Stark, Bala Varadarajan, Bram Wasti, Zachary Winstrom

Citing Segment Anything

If you use SAM or SA-1B in your research, please use the following BibTeX entry.

@article{kirillov2023segany,
  title={Segment Anything},
  author={Kirillov, Alexander and Mintun, Eric and Ravi, Nikhila and Mao, Hanzi and Rolland, Chloe and Gustafson, Laura and Xiao, Tete and Whitehead, Spencer and Berg, Alexander C. and Lo, Wan-Yen and Doll{\'a}r, Piotr and Girshick, Ross},
  journal={arXiv:2304.02643},
  year={2023}
}

segment-anything's People

Contributors

advaybot avatar anh-vunguyen avatar calebrob6 avatar derekray311511 avatar elm-forest avatar eltociear avatar endingcredits avatar ericmintun avatar hannamao avatar jp-x-g avatar lmmx avatar nikhilaravi avatar pierizvi avatar spencerwhitehead avatar triple-mu avatar xrenya avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

segment-anything's Issues

Pretrained models

I notice that the first stage of data engine trains a vit-base model, and ablations also train a vit-base model, so is the released model pretrained on final SA-1B dataset? or just from the first stage of data engine? or some other cases?

RuntimeError: Error(s) in loading state_dict for Sam

I tried to run the model based on the given code but got below error:
`---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
~\AppData\Local\Temp\ipykernel_13280\2123822288.py in
----> 1 mask_generator = SamAutomaticMaskGenerator(build_sam(checkpoint="segment_anything/models/sam_vit_b_01ec64.pth"))

~\Desktop\My Projects\segment-anything\segment_anything\build_sam.py in build_sam_vit_h(checkpoint)
13
14 def build_sam_vit_h(checkpoint=None):
---> 15 return _build_sam(
16 encoder_embed_dim=1280,
17 encoder_depth=32,

~\Desktop\My Projects\segment-anything\segment_anything\build_sam.py in _build_sam(encoder_embed_dim, encoder_depth, encoder_num_heads, encoder_global_attn_indexes, checkpoint)
104 with open(checkpoint, "rb") as f:
105 state_dict = torch.load(f)
--> 106 sam.load_state_dict(state_dict)
107 return sam

~\anaconda3\lib\site-packages\torch\nn\modules\module.py in load_state_dict(self, state_dict, strict)
1669
1670 if len(error_msgs) > 0:
-> 1671 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
1672 self.class.name, "\n\t".join(error_msgs)))
1673 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for Sam:
Missing key(s) in state_dict: "image_encoder.blocks.12.norm1.weight", "image_encoder.blocks.12.norm1.bias", "image_encoder.blocks.12.attn.rel_pos_h", "image_encoder.blocks.12.attn.rel_pos_w", "image_encoder.blocks.12.attn.qkv.weight", "image_encoder.blocks.12.attn.qkv.bias", "image_encoder.blocks.12.attn.proj.weight", "image_encoder.blocks.12.attn.proj.bias", "image_encoder.blocks.12.norm2.weight", "image_encoder.blocks.12.norm2.bias", "image_encoder.blocks.12.mlp.lin1.weight", "image_encoder.blocks.12.mlp.lin1.bias", "image_encoder.blocks.12.mlp.lin2.weight", "image_encoder.blocks.12.mlp.lin2.bias", "image_encoder.blocks.13.norm1.weight", "image_encoder.blocks.13.norm1.bias", "image_encoder.blocks.13.attn.rel_pos_h", "image_encoder.blocks.13.attn.rel_pos_w", "image_encoder.blocks.13.attn.qkv.weight", "image_encoder.blocks.13.attn.qkv.bias", "image_encoder.blocks.13.attn.proj.weight", "image_encoder.blocks.13.attn.proj.bias", "image_encoder.blocks.13.norm2.weight", "image_encoder.blocks.13.norm2.bias", "image_encoder.blocks.13.mlp.lin1.weight", "image_encoder.blocks.13.mlp.lin1.bias", "image_encoder.blocks.13.mlp.lin2.weight", "image_encoder.blocks.13.mlp.lin2.bias", "image_encoder.blocks.14.norm1.weight", "image_encoder.blocks.14.norm1.bias", "image_encoder.blocks.14.attn.rel_pos_h", "image_encoder.blocks.14.attn.rel_pos_w", "image_encoder.blocks.14.attn.qkv.weight", "image_encoder.blocks.14.attn.qkv.bias", "image_encoder.blocks.14.attn.proj.weight", "image_encoder.blocks.14.attn.proj.bias", "image_encoder.blocks.14.norm2.weight", "image_encoder.blocks.14.norm2.bias", "image_encoder.blocks.14.mlp.lin1.weight", "image_encoder.blocks.14.mlp.lin1.bias", "image_encoder.blocks.14.mlp.lin2.weight", "image_encoder.blocks.14.mlp.lin2.bias", "image_encoder.blocks.15.norm1.weight", "image_encoder.blocks.15.norm1.bias", "image_encoder.blocks.15.attn.rel_pos_h", "image_encoder.blocks.15.attn.rel_pos_w", "image_encoder.blocks.15.attn.qkv.weight", "image_encoder.blocks.15.attn.qkv.bias", "image_encoder.blocks.15.attn.proj.weight", "image_encoder.blocks.15.attn.proj.bias", "image_encoder.blocks.15.norm2.weight", "image_encoder.blocks.15.norm2.bias", "image_encoder.blocks.15.mlp.lin1.weight", "image_encoder.blocks.15.mlp.lin1.bias", "image_encoder.blocks.15.mlp.lin2.weight", "image_encoder.blocks.15.mlp.lin2.bias", "image_encoder.blocks.16.norm1.weight", "image_encoder.blocks.16.norm1.bias", "image_encoder.blocks.16.attn.rel_pos_h", "image_encoder.blocks.16.attn.rel_pos_w", "image_encoder.blocks.16.attn.qkv.weight", "image_encoder.blocks.16.attn.qkv.bias", "image_encoder.blocks.16.attn.proj.weight", "image_encoder.blocks.16.attn.proj.bias", "image_encoder.blocks.16.norm2.weight", "image_encoder.blocks.16.norm2.bias", "image_encoder.blocks.16.mlp.lin1.weight", "image_encoder.blocks.16.mlp.lin1.bias", "image_encoder.blocks.16.mlp.lin2.weight", "image_encoder.blocks.16.mlp.lin2.bias", "image_encoder.blocks.17.norm1.weight", "image_encoder.blocks.17.norm1.bias", "image_encoder.blocks.17.attn.rel_pos_h", "image_encoder.blocks.17.attn.rel_pos_w", "image_encoder.blocks.17.attn.qkv.weight", "image_encoder.blocks.17.attn.qkv.bias", "image_encoder.blocks.17.attn.proj.weight", "image_encoder.blocks.17.attn.proj.bias", "image_encoder.blocks.17.norm2.weight", "image_encoder.blocks.17.norm2.bias", "image_encoder.blocks.17.mlp.lin1.weight", "image_encoder.blocks.17.mlp.lin1.bias", "image_encoder.blocks.17.mlp.lin2.weight", "image_encoder.blocks.17.mlp.lin2.bias", "image_encoder.blocks.18.norm1.weight", "image_encoder.blocks.18.norm1.bias", "image_encoder.blocks.18.attn.rel_pos_h", "image_encoder.blocks.18.attn.rel_pos_w", "image_encoder.blocks.18.attn.qkv.weight", "image_encoder.blocks.18.attn.qkv.bias", "image_encoder.blocks.18.attn.proj.weight", "image_encoder.blocks.18.attn.proj.bias", "image_encoder.blocks.18.norm2.weight", "image_encoder.blocks.18.norm2.bias", "image_encoder.blocks.18.mlp.lin1.weight", "image_encoder.blocks.18.mlp.lin1.bias", "image_encoder.blocks.18.mlp.lin2.weight", "image_encoder.blocks.18.mlp.lin2.bias", "image_encoder.blocks.19.norm1.weight", "image_encoder.blocks.19.norm1.bias", "image_encoder.blocks.19.attn.rel_pos_h", "image_encoder.blocks.19.attn.rel_pos_w", "image_encoder.blocks.19.attn.qkv.weight", "image_encoder.blocks.19.attn.qkv.bias", "image_encoder.blocks.19.attn.proj.weight", "image_encoder.blocks.19.attn.proj.bias", "image_encoder.blocks.19.norm2.weight", "image_encoder.blocks.19.norm2.bias", "image_encoder.blocks.19.mlp.lin1.weight", "image_encoder.blocks.19.mlp.lin1.bias", "image_encoder.blocks.19.mlp.lin2.weight", "image_encoder.blocks.19.mlp.lin2.bias", "image_encoder.blocks.20.norm1.weight", "image_encoder.blocks.20.norm1.bias", "image_encoder.blocks.20.attn.rel_pos_h", "image_encoder.blocks.20.attn.rel_pos_w", "image_encoder.blocks.20.attn.qkv.weight", "image_encoder.blocks.20.attn.qkv.bias", "image_encoder.blocks.20.attn.proj.weight", "image_encoder.blocks.20.attn.proj.bias", "image_encoder.blocks.20.norm2.weight", "image_encoder.blocks.20.norm2.bias", "image_encoder.blocks.20.mlp.lin1.weight", "image_encoder.blocks.20.mlp.lin1.bias", "image_encoder.blocks.20.mlp.lin2.weight", "image_encoder.blocks.20.mlp.lin2.bias", "image_encoder.blocks.21.norm1.weight", "image_encoder.blocks.21.norm1.bias", "image_encoder.blocks.21.attn.rel_pos_h", "image_encoder.blocks.21.attn.rel_pos_w", "image_encoder.blocks.21.attn.qkv.weight", "image_encoder.blocks.21.attn.qkv.bias", "image_encoder.blocks.21.attn.proj.weight", "image_encoder.blocks.21.attn.proj.bias", "image_encoder.blocks.21.norm2.weight", "image_encoder.blocks.21.norm2.bias", "image_encoder.blocks.21.mlp.lin1.weight", "image_encoder.blocks.21.mlp.lin1.bias", "image_encoder.blocks.21.mlp.lin2.weight", "image_encoder.blocks.21.mlp.lin2.bias", "image_encoder.blocks.22.norm1.weight", "image_encoder.blocks.22.norm1.bias", "image_encoder.blocks.22.attn.rel_pos_h", "image_encoder.blocks.22.attn.rel_pos_w", "image_encoder.blocks.22.attn.qkv.weight", "image_encoder.blocks.22.attn.qkv.bias", "image_encoder.blocks.22.attn.proj.weight", "image_encoder.blocks.22.attn.proj.bias", "image_encoder.blocks.22.norm2.weight", "image_encoder.blocks.22.norm2.bias", "image_encoder.blocks.22.mlp.lin1.weight", "image_encoder.blocks.22.mlp.lin1.bias", "image_encoder.blocks.22.mlp.lin2.weight", "image_encoder.blocks.22.mlp.lin2.bias", "image_encoder.blocks.23.norm1.weight", "image_encoder.blocks.23.norm1.bias", "image_encoder.blocks.23.attn.rel_pos_h", "image_encoder.blocks.23.attn.rel_pos_w", "image_encoder.blocks.23.attn.qkv.weight", "image_encoder.blocks.23.attn.qkv.bias", "image_encoder.blocks.23.attn.proj.weight", "image_encoder.blocks.23.attn.proj.bias", "image_encoder.blocks.23.norm2.weight", "image_encoder.blocks.23.norm2.bias", "image_encoder.blocks.23.mlp.lin1.weight", "image_encoder.blocks.23.mlp.lin1.bias", "image_encoder.blocks.23.mlp.lin2.weight", "image_encoder.blocks.23.mlp.lin2.bias", "image_encoder.blocks.24.norm1.weight", "image_encoder.blocks.24.norm1.bias", "image_encoder.blocks.24.attn.rel_pos_h", "image_encoder.blocks.24.attn.rel_pos_w", "image_encoder.blocks.24.attn.qkv.weight", "image_encoder.blocks.24.attn.qkv.bias", "image_encoder.blocks.24.attn.proj.weight", "image_encoder.blocks.24.attn.proj.bias", "image_encoder.blocks.24.norm2.weight", "image_encoder.blocks.24.norm2.bias", "image_encoder.blocks.24.mlp.lin1.weight", "image_encoder.blocks.24.mlp.lin1.bias", "image_encoder.blocks.24.mlp.lin2.weight", "image_encoder.blocks.24.mlp.lin2.bias", "image_encoder.blocks.25.norm1.weight", "image_encoder.blocks.25.norm1.bias", "image_encoder.blocks.25.attn.rel_pos_h", "image_encoder.blocks.25.attn.rel_pos_w", "image_encoder.blocks.25.attn.qkv.weight", "image_encoder.blocks.25.attn.qkv.bias", "image_encoder.blocks.25.attn.proj.weight", "image_encoder.blocks.25.attn.proj.bias", "image_encoder.blocks.25.norm2.weight", "image_encoder.blocks.25.norm2.bias", "image_encoder.blocks.25.mlp.lin1.weight", "image_encoder.blocks.25.mlp.lin1.bias", "image_encoder.blocks.25.mlp.lin2.weight", "image_encoder.blocks.25.mlp.lin2.bias", "image_encoder.blocks.26.norm1.weight", "image_encoder.blocks.26.norm1.bias", "image_encoder.blocks.26.attn.rel_pos_h", "image_encoder.blocks.26.attn.rel_pos_w", "image_encoder.blocks.26.attn.qkv.weight", "image_encoder.blocks.26.attn.qkv.bias", "image_encoder.blocks.26.attn.proj.weight", "image_encoder.blocks.26.attn.proj.bias", "image_encoder.blocks.26.norm2.weight", "image_encoder.blocks.26.norm2.bias", "image_encoder.blocks.26.mlp.lin1.weight", "image_encoder.blocks.26.mlp.lin1.bias", "image_encoder.blocks.26.mlp.lin2.weight", "image_encoder.blocks.26.mlp.lin2.bias", "image_encoder.blocks.27.norm1.weight", "image_encoder.blocks.27.norm1.bias", "image_encoder.blocks.27.attn.rel_pos_h", "image_encoder.blocks.27.attn.rel_pos_w", "image_encoder.blocks.27.attn.qkv.weight", "image_encoder.blocks.27.attn.qkv.bias", "image_encoder.blocks.27.attn.proj.weight", "image_encoder.blocks.27.attn.proj.bias", "image_encoder.blocks.27.norm2.weight", "image_encoder.blocks.27.norm2.bias", "image_encoder.blocks.27.mlp.lin1.weight", "image_encoder.blocks.27.mlp.lin1.bias", "image_encoder.blocks.27.mlp.lin2.weight", "image_encoder.blocks.27.mlp.lin2.bias", "image_encoder.blocks.28.norm1.weight", "image_encoder.blocks.28.norm1.bias", "image_encoder.blocks.28.attn.rel_pos_h", "image_encoder.blocks.28.attn.rel_pos_w", "image_encoder.blocks.28.attn.qkv.weight", "image_encoder.blocks.28.attn.qkv.bias", "image_encoder.blocks.28.attn.proj.weight", "image_encoder.blocks.28.attn.proj.bias", "image_encoder.blocks.28.norm2.weight", "image_encoder.blocks.28.norm2.bias", "image_encoder.blocks.28.mlp.lin1.weight", "image_encoder.blocks.28.mlp.lin1.bias", "image_encoder.blocks.28.mlp.lin2.weight", "image_encoder.blocks.28.mlp.lin2.bias", "image_encoder.blocks.29.norm1.weight", "image_encoder.blocks.29.norm1.bias", "image_encoder.blocks.29.attn.rel_pos_h", "image_encoder.blocks.29.attn.rel_pos_w", "image_encoder.blocks.29.attn.qkv.weight", "image_encoder.blocks.29.attn.qkv.bias", "image_encoder.blocks.29.attn.proj.weight", "image_encoder.blocks.29.attn.proj.bias", "image_encoder.blocks.29.norm2.weight", "image_encoder.blocks.29.norm2.bias", "image_encoder.blocks.29.mlp.lin1.weight", "image_encoder.blocks.29.mlp.lin1.bias", "image_encoder.blocks.29.mlp.lin2.weight", "image_encoder.blocks.29.mlp.lin2.bias", "image_encoder.blocks.30.norm1.weight", "image_encoder.blocks.30.norm1.bias", "image_encoder.blocks.30.attn.rel_pos_h", "image_encoder.blocks.30.attn.rel_pos_w", "image_encoder.blocks.30.attn.qkv.weight", "image_encoder.blocks.30.attn.qkv.bias", "image_encoder.blocks.30.attn.proj.weight", "image_encoder.blocks.30.attn.proj.bias", "image_encoder.blocks.30.norm2.weight", "image_encoder.blocks.30.norm2.bias", "image_encoder.blocks.30.mlp.lin1.weight", "image_encoder.blocks.30.mlp.lin1.bias", "image_encoder.blocks.30.mlp.lin2.weight", "image_encoder.blocks.30.mlp.lin2.bias", "image_encoder.blocks.31.norm1.weight", "image_encoder.blocks.31.norm1.bias", "image_encoder.blocks.31.attn.rel_pos_h", "image_encoder.blocks.31.attn.rel_pos_w", "image_encoder.blocks.31.attn.qkv.weight", "image_encoder.blocks.31.attn.qkv.bias", "image_encoder.blocks.31.attn.proj.weight", "image_encoder.blocks.31.attn.proj.bias", "image_encoder.blocks.31.norm2.weight", "image_encoder.blocks.31.norm2.bias", "image_encoder.blocks.31.mlp.lin1.weight", "image_encoder.blocks.31.mlp.lin1.bias", "image_encoder.blocks.31.mlp.lin2.weight", "image_encoder.blocks.31.mlp.lin2.bias".
size mismatch for image_encoder.pos_embed: copying a param with shape torch.Size([1, 64, 64, 768]) from checkpoint, the shape in current model is torch.Size([1, 64, 64, 1280]).
size mismatch for image_encoder.patch_embed.proj.weight: copying a param with shape torch.Size([768, 3, 16, 16]) from checkpoint, the shape in current model is torch.Size([1280, 3, 16, 16]).
size mismatch for image_encoder.patch_embed.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.0.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.0.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.0.attn.rel_pos_h: copying a param with shape torch.Size([27, 64]) from checkpoint, the shape in current model is torch.Size([27, 80]).
size mismatch for image_encoder.blocks.0.attn.rel_pos_w: copying a param with shape torch.Size([27, 64]) from checkpoint, the shape in current model is torch.Size([27, 80]).
size mismatch for image_encoder.blocks.0.attn.qkv.weight: copying a param with shape torch.Size([2304, 768]) from checkpoint, the shape in current model is torch.Size([3840, 1280]).
size mismatch for image_encoder.blocks.0.attn.qkv.bias: copying a param with shape torch.Size([2304]) from checkpoint, the shape in current model is torch.Size([3840]).
size mismatch for image_encoder.blocks.0.attn.proj.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([1280, 1280]).
size mismatch for image_encoder.blocks.0.attn.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.0.norm2.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.0.norm2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.0.mlp.lin1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([5120, 1280]).
size mismatch for image_encoder.blocks.0.mlp.lin1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([5120]).
size mismatch for image_encoder.blocks.0.mlp.lin2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([1280, 5120]).
size mismatch for image_encoder.blocks.0.mlp.lin2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.1.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.1.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.1.attn.rel_pos_h: copying a param with shape torch.Size([27, 64]) from checkpoint, the shape in current model is torch.Size([27, 80]).
size mismatch for image_encoder.blocks.1.attn.rel_pos_w: copying a param with shape torch.Size([27, 64]) from checkpoint, the shape in current model is torch.Size([27, 80]).
size mismatch for image_encoder.blocks.1.attn.qkv.weight: copying a param with shape torch.Size([2304, 768]) from checkpoint, the shape in current model is torch.Size([3840, 1280]).
size mismatch for image_encoder.blocks.1.attn.qkv.bias: copying a param with shape torch.Size([2304]) from checkpoint, the shape in current model is torch.Size([3840]).
size mismatch for image_encoder.blocks.1.attn.proj.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([1280, 1280]).
size mismatch for image_encoder.blocks.1.attn.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.1.norm2.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.1.norm2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.1.mlp.lin1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([5120, 1280]).
size mismatch for image_encoder.blocks.1.mlp.lin1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([5120]).
size mismatch for image_encoder.blocks.1.mlp.lin2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([1280, 5120]).
size mismatch for image_encoder.blocks.1.mlp.lin2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.2.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.2.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.2.attn.rel_pos_h: copying a param with shape torch.Size([127, 64]) from checkpoint, the shape in current model is torch.Size([27, 80]).
size mismatch for image_encoder.blocks.2.attn.rel_pos_w: copying a param with shape torch.Size([127, 64]) from checkpoint, the shape in current model is torch.Size([27, 80]).
size mismatch for image_encoder.blocks.2.attn.qkv.weight: copying a param with shape torch.Size([2304, 768]) from checkpoint, the shape in current model is torch.Size([3840, 1280]).
size mismatch for image_encoder.blocks.2.attn.qkv.bias: copying a param with shape torch.Size([2304]) from checkpoint, the shape in current model is torch.Size([3840]).
size mismatch for image_encoder.blocks.2.attn.proj.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([1280, 1280]).
size mismatch for image_encoder.blocks.2.attn.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.2.norm2.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.2.norm2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.2.mlp.lin1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([5120, 1280]).
size mismatch for image_encoder.blocks.2.mlp.lin1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([5120]).
size mismatch for image_encoder.blocks.2.mlp.lin2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([1280, 5120]).
size mismatch for image_encoder.blocks.2.mlp.lin2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.3.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.3.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.3.attn.rel_pos_h: copying a param with shape torch.Size([27, 64]) from checkpoint, the shape in current model is torch.Size([27, 80]).
size mismatch for image_encoder.blocks.3.attn.rel_pos_w: copying a param with shape torch.Size([27, 64]) from checkpoint, the shape in current model is torch.Size([27, 80]).
size mismatch for image_encoder.blocks.3.attn.qkv.weight: copying a param with shape torch.Size([2304, 768]) from checkpoint, the shape in current model is torch.Size([3840, 1280]).
size mismatch for image_encoder.blocks.3.attn.qkv.bias: copying a param with shape torch.Size([2304]) from checkpoint, the shape in current model is torch.Size([3840]).
size mismatch for image_encoder.blocks.3.attn.proj.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([1280, 1280]).
size mismatch for image_encoder.blocks.3.attn.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.3.norm2.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.3.norm2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.3.mlp.lin1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([5120, 1280]).
size mismatch for image_encoder.blocks.3.mlp.lin1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([5120]).
size mismatch for image_encoder.blocks.3.mlp.lin2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([1280, 5120]).
size mismatch for image_encoder.blocks.3.mlp.lin2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.4.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.4.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.4.attn.rel_pos_h: copying a param with shape torch.Size([27, 64]) from checkpoint, the shape in current model is torch.Size([27, 80]).
size mismatch for image_encoder.blocks.4.attn.rel_pos_w: copying a param with shape torch.Size([27, 64]) from checkpoint, the shape in current model is torch.Size([27, 80]).
size mismatch for image_encoder.blocks.4.attn.qkv.weight: copying a param with shape torch.Size([2304, 768]) from checkpoint, the shape in current model is torch.Size([3840, 1280]).
size mismatch for image_encoder.blocks.4.attn.qkv.bias: copying a param with shape torch.Size([2304]) from checkpoint, the shape in current model is torch.Size([3840]).
size mismatch for image_encoder.blocks.4.attn.proj.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([1280, 1280]).
size mismatch for image_encoder.blocks.4.attn.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.4.norm2.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.4.norm2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.4.mlp.lin1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([5120, 1280]).
size mismatch for image_encoder.blocks.4.mlp.lin1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([5120]).
size mismatch for image_encoder.blocks.4.mlp.lin2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([1280, 5120]).
size mismatch for image_encoder.blocks.4.mlp.lin2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.5.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.5.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.5.attn.rel_pos_h: copying a param with shape torch.Size([127, 64]) from checkpoint, the shape in current model is torch.Size([27, 80]).
size mismatch for image_encoder.blocks.5.attn.rel_pos_w: copying a param with shape torch.Size([127, 64]) from checkpoint, the shape in current model is torch.Size([27, 80]).
size mismatch for image_encoder.blocks.5.attn.qkv.weight: copying a param with shape torch.Size([2304, 768]) from checkpoint, the shape in current model is torch.Size([3840, 1280]).
size mismatch for image_encoder.blocks.5.attn.qkv.bias: copying a param with shape torch.Size([2304]) from checkpoint, the shape in current model is torch.Size([3840]).
size mismatch for image_encoder.blocks.5.attn.proj.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([1280, 1280]).
size mismatch for image_encoder.blocks.5.attn.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.5.norm2.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.5.norm2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.5.mlp.lin1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([5120, 1280]).
size mismatch for image_encoder.blocks.5.mlp.lin1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([5120]).
size mismatch for image_encoder.blocks.5.mlp.lin2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([1280, 5120]).
size mismatch for image_encoder.blocks.5.mlp.lin2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.6.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.6.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.6.attn.rel_pos_h: copying a param with shape torch.Size([27, 64]) from checkpoint, the shape in current model is torch.Size([27, 80]).
size mismatch for image_encoder.blocks.6.attn.rel_pos_w: copying a param with shape torch.Size([27, 64]) from checkpoint, the shape in current model is torch.Size([27, 80]).
size mismatch for image_encoder.blocks.6.attn.qkv.weight: copying a param with shape torch.Size([2304, 768]) from checkpoint, the shape in current model is torch.Size([3840, 1280]).
size mismatch for image_encoder.blocks.6.attn.qkv.bias: copying a param with shape torch.Size([2304]) from checkpoint, the shape in current model is torch.Size([3840]).
size mismatch for image_encoder.blocks.6.attn.proj.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([1280, 1280]).
size mismatch for image_encoder.blocks.6.attn.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.6.norm2.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.6.norm2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.6.mlp.lin1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([5120, 1280]).
size mismatch for image_encoder.blocks.6.mlp.lin1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([5120]).
size mismatch for image_encoder.blocks.6.mlp.lin2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([1280, 5120]).
size mismatch for image_encoder.blocks.6.mlp.lin2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.7.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.7.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.7.attn.rel_pos_h: copying a param with shape torch.Size([27, 64]) from checkpoint, the shape in current model is torch.Size([127, 80]).
size mismatch for image_encoder.blocks.7.attn.rel_pos_w: copying a param with shape torch.Size([27, 64]) from checkpoint, the shape in current model is torch.Size([127, 80]).
size mismatch for image_encoder.blocks.7.attn.qkv.weight: copying a param with shape torch.Size([2304, 768]) from checkpoint, the shape in current model is torch.Size([3840, 1280]).
size mismatch for image_encoder.blocks.7.attn.qkv.bias: copying a param with shape torch.Size([2304]) from checkpoint, the shape in current model is torch.Size([3840]).
size mismatch for image_encoder.blocks.7.attn.proj.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([1280, 1280]).
size mismatch for image_encoder.blocks.7.attn.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.7.norm2.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.7.norm2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.7.mlp.lin1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([5120, 1280]).
size mismatch for image_encoder.blocks.7.mlp.lin1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([5120]).
size mismatch for image_encoder.blocks.7.mlp.lin2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([1280, 5120]).
size mismatch for image_encoder.blocks.7.mlp.lin2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.8.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.8.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.8.attn.rel_pos_h: copying a param with shape torch.Size([127, 64]) from checkpoint, the shape in current model is torch.Size([27, 80]).
size mismatch for image_encoder.blocks.8.attn.rel_pos_w: copying a param with shape torch.Size([127, 64]) from checkpoint, the shape in current model is torch.Size([27, 80]).
size mismatch for image_encoder.blocks.8.attn.qkv.weight: copying a param with shape torch.Size([2304, 768]) from checkpoint, the shape in current model is torch.Size([3840, 1280]).
size mismatch for image_encoder.blocks.8.attn.qkv.bias: copying a param with shape torch.Size([2304]) from checkpoint, the shape in current model is torch.Size([3840]).
size mismatch for image_encoder.blocks.8.attn.proj.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([1280, 1280]).
size mismatch for image_encoder.blocks.8.attn.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.8.norm2.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.8.norm2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.8.mlp.lin1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([5120, 1280]).
size mismatch for image_encoder.blocks.8.mlp.lin1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([5120]).
size mismatch for image_encoder.blocks.8.mlp.lin2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([1280, 5120]).
size mismatch for image_encoder.blocks.8.mlp.lin2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.9.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.9.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.9.attn.rel_pos_h: copying a param with shape torch.Size([27, 64]) from checkpoint, the shape in current model is torch.Size([27, 80]).
size mismatch for image_encoder.blocks.9.attn.rel_pos_w: copying a param with shape torch.Size([27, 64]) from checkpoint, the shape in current model is torch.Size([27, 80]).
size mismatch for image_encoder.blocks.9.attn.qkv.weight: copying a param with shape torch.Size([2304, 768]) from checkpoint, the shape in current model is torch.Size([3840, 1280]).
size mismatch for image_encoder.blocks.9.attn.qkv.bias: copying a param with shape torch.Size([2304]) from checkpoint, the shape in current model is torch.Size([3840]).
size mismatch for image_encoder.blocks.9.attn.proj.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([1280, 1280]).
size mismatch for image_encoder.blocks.9.attn.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.9.norm2.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.9.norm2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.9.mlp.lin1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([5120, 1280]).
size mismatch for image_encoder.blocks.9.mlp.lin1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([5120]).
size mismatch for image_encoder.blocks.9.mlp.lin2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([1280, 5120]).
size mismatch for image_encoder.blocks.9.mlp.lin2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.10.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.10.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.10.attn.rel_pos_h: copying a param with shape torch.Size([27, 64]) from checkpoint, the shape in current model is torch.Size([27, 80]).
size mismatch for image_encoder.blocks.10.attn.rel_pos_w: copying a param with shape torch.Size([27, 64]) from checkpoint, the shape in current model is torch.Size([27, 80]).
size mismatch for image_encoder.blocks.10.attn.qkv.weight: copying a param with shape torch.Size([2304, 768]) from checkpoint, the shape in current model is torch.Size([3840, 1280]).
size mismatch for image_encoder.blocks.10.attn.qkv.bias: copying a param with shape torch.Size([2304]) from checkpoint, the shape in current model is torch.Size([3840]).
size mismatch for image_encoder.blocks.10.attn.proj.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([1280, 1280]).
size mismatch for image_encoder.blocks.10.attn.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.10.norm2.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.10.norm2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.10.mlp.lin1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([5120, 1280]).
size mismatch for image_encoder.blocks.10.mlp.lin1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([5120]).
size mismatch for image_encoder.blocks.10.mlp.lin2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([1280, 5120]).
size mismatch for image_encoder.blocks.10.mlp.lin2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.11.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.11.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.11.attn.rel_pos_h: copying a param with shape torch.Size([127, 64]) from checkpoint, the shape in current model is torch.Size([27, 80]).
size mismatch for image_encoder.blocks.11.attn.rel_pos_w: copying a param with shape torch.Size([127, 64]) from checkpoint, the shape in current model is torch.Size([27, 80]).
size mismatch for image_encoder.blocks.11.attn.qkv.weight: copying a param with shape torch.Size([2304, 768]) from checkpoint, the shape in current model is torch.Size([3840, 1280]).
size mismatch for image_encoder.blocks.11.attn.qkv.bias: copying a param with shape torch.Size([2304]) from checkpoint, the shape in current model is torch.Size([3840]).
size mismatch for image_encoder.blocks.11.attn.proj.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([1280, 1280]).
size mismatch for image_encoder.blocks.11.attn.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.11.norm2.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.11.norm2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.blocks.11.mlp.lin1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([5120, 1280]).
size mismatch for image_encoder.blocks.11.mlp.lin1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([5120]).
size mismatch for image_encoder.blocks.11.mlp.lin2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([1280, 5120]).
size mismatch for image_encoder.blocks.11.mlp.lin2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1280]).
size mismatch for image_encoder.neck.0.weight: copying a param with shape torch.Size([256, 768, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 1280, 1, 1]).`

Can anyone help debug this? I have tried all the 3 given checkpoints.

Size mismatch errors in running automatic_mask_generation_example.ipynb

When I ran blocks in the automatic_mask_generation_example.ipynb, I got the size mismatch for image_encoder.blocks errors.
Here is the code:

import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
mask_generator = SamAutomaticMaskGenerator(sam)

Here is the errors:

RuntimeError                              Traceback (most recent call last)
Cell In[24], line 5
      2 sys.path.append("..")
      3 from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
----> 5 sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
      6 sam.to(device=device)
      8 mask_generator = SamAutomaticMaskGenerator(sam)

File ~/segment-anything/notebooks/../segment_anything/build_sam.py:15, in build_sam_vit_h(checkpoint)
     14 def build_sam_vit_h(checkpoint=None):
---> 15     return _build_sam(
     16         encoder_embed_dim=1280,
     17         encoder_depth=32,
     18         encoder_num_heads=16,
     19         encoder_global_attn_indexes=[7, 15, 23, 31],
     20         checkpoint=checkpoint,
     21     )

File ~/segment-anything/notebooks/../segment_anything/build_sam.py:106, in _build_sam(encoder_embed_dim, encoder_depth, encoder_num_heads, encoder_global_attn_indexes, checkpoint)
    104     with open(checkpoint, "rb") as f:
    105         state_dict = torch.load(f)
--> 106     sam.load_state_dict(state_dict)
    107 return sam
...
	size mismatch for image_encoder.blocks.23.mlp.lin1.weight: copying a param with shape torch.Size([4096, 1024]) from checkpoint, the shape in current model is torch.Size([5120, 1280]).
	size mismatch for image_encoder.blocks.23.mlp.lin1.bias: copying a param with shape torch.Size([4096]) from checkpoint, the shape in current model is torch.Size([5120]).
	size mismatch for image_encoder.blocks.23.mlp.lin2.weight: copying a param with shape torch.Size([1024, 4096]) from checkpoint, the shape in current model is torch.Size([1280, 5120]).
	size mismatch for image_encoder.blocks.23.mlp.lin2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for image_encoder.neck.0.weight: copying a param with shape torch.Size([256, 1024, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 1280, 1, 1]).
My environment:
OS: Ubuntu 20.04.6 LTS x86_64
Kernel: 5.15.0-58-generic
CPU: 12th Gen Intel i5-12400
GPU: NVIDIA 3070 laptop
Driver Version: 515.86.01    
CUDA Version: 11.7 
Conda env:
Python Version: 3.8.16 
pytorch Version: 1.8.0 
torchaudio Version: 0.8.0 
cudatoolkit Version: 10.2

How to fix these errors?

Finetuning

Is there any plans to release scripts for finetuning the model?

Also you did such a great work! Thank you very much!

Release

Hi all,

congrats all to the release announced on Twitter!

Would you mind creating a release here on the github repository, too? That would be great for reproducible research. It enables people to run and compare specific versions of the code.

Big thanks!

Best,
Robert

Mismatch in ckpt for vit_b and vit_l. vit_h works fine.

checkpoints for vit_b (sam_vit_b_01ec64.pth) and vit_l (sam_vit_l_0b3195.pth) do not work with segment_anything.build_sam().
Following error message is received:

Missing key(s) in state_dict: "image_encoder.blocks.12.norm1.weight", "image_encoder.blocks.12.norm1.bias", "image_encoder.blocks.12.attn.rel_pos_h", "image_encoder.blocks.12.attn.rel_pos_w", "image_encoder.blocks.12.attn.qkv.weight", "image_encoder.blocks.12.attn.qkv.bias", "image_encoder.blocks.12.attn.proj.weight", "image_encoder.blocks.12.attn.proj.bias", "image_encoder.blocks.12.norm2.weight", "image_encoder.blocks.12.norm2.bias", "image_encoder.blocks.12.mlp.lin1.weight", "image_encoder.blocks.12.mlp.lin1.bias", "image_encoder.blocks.12.mlp.lin2.weight", "image_encoder.blocks.12.mlp.lin2.bias", "image_encoder.blocks.13.norm1.weight", "image_encoder.blocks.13.norm1.bias", "image_encoder.blocks.13.attn.rel_pos_h", "image_encoder.blocks.13.attn.rel_pos_w", "image_encoder.blocks.13.attn.qkv.weight", "image_encoder.blocks.13.attn.qkv.bias", "image_encoder.blocks.13.attn.proj.weight", "image_encoder.blocks.13.attn.proj.bias", "image_encoder.blocks.13.norm2.weight", "image_encoder.blocks.13.norm2.bias", "image_encoder.blocks.13.mlp.lin1.weight", "image_encoder.blocks.13.mlp.lin1.bias", "image_encoder.blocks.13.mlp.lin2.weight", "image_encoder.blocks.13.mlp.lin2.bias", "image_encoder.blocks.14.norm1.weight", "image_encoder.blocks.14.norm1.bias", "image_encoder.blocks.14.attn.rel_pos_h", "image_encoder.blocks.14.attn.rel_pos_w", "image_encoder.blocks.14.attn.qkv.weight", "image_encoder.blocks.14.attn.qkv.bias", "image_encoder.blocks.14.attn.proj.weight", "image_encoder.blocks.14.attn.proj.bias", "image_encoder.blocks.14.norm2.weight", "image_encoder.blocks.14.norm2.bias", "image_encoder.blocks.14.mlp.lin1.weight", "image_encoder.blocks.14.mlp.lin1.bias", "image_encoder.blocks.14.mlp.lin2.weight", "image_encoder.blocks.14.mlp.lin2.bias", "image_encoder.blocks.15.norm1.weight", "image_encoder.blocks.15.norm1.bias", "image_encoder.blocks.15.attn.rel_pos_h", "image_encoder.blocks.15.attn.rel_pos_w", "image_encoder.blocks.15.attn.qkv.weight", "image_encoder.blocks.15.attn.qkv.bias", "image_encoder.blocks.15.attn.proj.weight", "image_encoder.blocks.15.attn.proj.bias", "image_encoder.blocks.15.norm2.weight", "image_encoder.blocks.15.norm2.bias", "image_encoder.blocks.15.mlp.lin1.weight", "image_encoder.blocks.15.mlp.lin1.bias", "image_encoder.blocks.15.mlp.lin2.weight", "image_encoder.blocks.15.mlp.lin2.bias", "image_encoder.blocks.16.norm1.weight", "image_encoder.blocks.16.norm1.bias", "image_encoder.blocks.16.attn.rel_pos_h", "image_encoder.blocks.16.attn.rel_pos_w", "image_encoder.blocks.16.attn.qkv.weight", "image_encoder.blocks.16.attn.qkv.bias", "image_encoder.blocks.16.attn.proj.weight", "image_encoder.blocks.16.attn.proj.bias", "image_encoder.blocks.16.norm2.weight", "image_encoder.blocks.16.norm2.bias", "image_encoder.blocks.16.mlp.lin1.weight", "image_encoder.blocks.16.mlp.lin1.bias", "image_encoder.blocks.16.mlp.lin2.weight", "image_encoder.blocks.16.mlp.lin2.bias", "image_encoder.blocks.17.norm1.weight", "image_encoder.blocks.17.norm1.bias", "image_encoder.blocks.17.attn.rel_pos_h", "image_encoder.blocks.17.attn.rel_pos_w", "image_encoder.blocks.17.attn.qkv.weight", "image_encoder.blocks.17.attn.qkv.bias", "image_encoder.blocks.17.attn.proj.weight", "image_encoder.blocks.17.attn.proj.bias", "image_encoder.blocks.17.norm2.weight", "image_encoder.blocks.17.norm2.bias", "image_encoder.blocks.17.mlp.lin1.weight", "image_encoder.blocks.17.mlp.lin1.bias", "image_encoder.blocks.17.mlp.lin2.weight", "image_encoder.blocks.17.mlp.lin2.bias", "image_encoder.blocks.18.norm1.weight", "image_encoder.blocks.18.norm1.bias", "image_encoder.blocks.18.attn.rel_pos_h", "image_encoder.blocks.18.attn.rel_pos_w", "image_encoder.blocks.18.attn.qkv.weight", "image_encoder.blocks.18.attn.qkv.bias", "image_encoder.blocks.18.attn.proj.weight", "image_encoder.blocks.18.attn.proj.bias", "image_encoder.blocks.18.norm2.weight", "image_encoder.blocks.18.norm2.bias", "image_encoder.blocks.18.mlp.lin1.weight", "image_encoder.blocks.18.mlp.lin1.bias", "image_encoder.blocks.18.mlp.lin2.weight", "image_encoder.blocks.18.mlp.lin2.bias", "image_encoder.blocks.19.norm1.weight", "image_encoder.blocks.19.norm1.bias", "image_encoder.blocks.19.attn.rel_pos_h", "image_encoder.blocks.19.attn.rel_pos_w", "image_encoder.blocks.19.attn.qkv.weight", "image_encoder.blocks.19.attn.qkv.bias", "image_encoder.blocks.19.attn.proj.weight", "image_encoder.blocks.19.attn.proj.bias", "image_encoder.blocks.19.norm2.weight", "image_encoder.blocks.19.norm2.bias", "image_encoder.blocks.19.mlp.lin1.weight", "image_encoder.blocks.19.mlp.lin1.bias", "image_encoder.blocks.19.mlp.lin2.weight", "image_encoder.blocks.19.mlp.lin2.bias", "image_encoder.blocks.20.norm1.weight", "image_encoder.blocks.20.norm1.bias", "image_encoder.blocks.20.attn.rel_pos_h", "image_encoder.blocks.20.attn.rel_pos_w", "image_encoder.blocks.20.attn.qkv.weight", "image_encoder.blocks.20.attn.qkv.bias", "image_encoder.blocks.20.attn.proj.weight", "image_encoder.blocks.20.attn.proj.bias", "image_encoder.blocks.20.norm2.weight", "image_encoder.blocks.20.norm2.bias", "image_encoder.blocks.20.mlp.lin1.weight", "image_encoder.blocks.20.mlp.lin1.bias", "image_encoder.blocks.20.mlp.lin2.weight", "image_encoder.blocks.20.mlp.lin2.bias", "image_encoder.blocks.21.norm1.weight", "image_encoder.blocks.21.norm1.bias", "image_encoder.blocks.21.attn.rel_pos_h", "image_encoder.blocks.21.attn.rel_pos_w", "image_encoder.blocks.21.attn.qkv.weight", "image_encoder.blocks.21.attn.qkv.bias", "image_encoder.blocks.21.attn.proj.weight", "image_encoder.blocks.21.attn.proj.bias", "image_encoder.blocks.21.norm2.weight", "image_encoder.blocks.21.norm2.bias", "image_encoder.blocks.21.mlp.lin1.weight", "image_encoder.blocks.21.mlp.lin1.bias", "image_encoder.blocks.21.mlp.lin2.weight", "image_encoder.blocks.21.mlp.lin2.bias", "image_encoder.blocks.22.norm1.weight", "image_encoder.blocks.22.norm1.bias", "image_encoder.blocks.22.attn.rel_pos_h", "image_encoder.blocks.22.attn.rel_pos_w", "image_encoder.blocks.22.attn.qkv.weight", "image_encoder.blocks.22.attn.qkv.bias", "image_encoder.blocks.22.attn.proj.weight", "image_encoder.blocks.22.attn.proj.bias", "image_encoder.blocks.22.norm2.weight", "image_encoder.blocks.22.norm2.bias", "image_encoder.blocks.22.mlp.lin1.weight", "image_encoder.blocks.22.mlp.lin1.bias", "image_encoder.blocks.22.mlp.lin2.weight", "image_encoder.blocks.22.mlp.lin2.bias", "image_encoder.blocks.23.norm1.weight", "image_encoder.blocks.23.norm1.bias", "image_encoder.blocks.23.attn.rel_pos_h", "image_encoder.blocks.23.attn.rel_pos_w", "image_encoder.blocks.23.attn.qkv.weight", "image_encoder.blocks.23.attn.qkv.bias", "image_encoder.blocks.23.attn.proj.weight", "image_encoder.blocks.23.attn.proj.bias", "image_encoder.blocks.23.norm2.weight", "image_encoder.blocks.23.norm2.bias", "image_encoder.blocks.23.mlp.lin1.weight", "image_encoder.blocks.23.mlp.lin1.bias", "image_encoder.blocks.23.mlp.lin2.weight", "image_encoder.blocks.23.mlp.lin2.bias", "image_encoder.blocks.24.norm1.weight", "image_encoder.blocks.24.norm1.bias", "image_encoder.blocks.24.attn.rel_pos_h", "image_encoder.blocks.24.attn.rel_pos_w", "image_encoder.blocks.24.attn.qkv.weight", "image_encoder.blocks.24.attn.qkv.bias", "image_encoder.blocks.24.attn.proj.weight", "image_encoder.blocks.24.attn.proj.bias", "image_encoder.blocks.24.norm2.weight", "image_encoder.blocks.24.norm2.bias", "image_encoder.blocks.24.mlp.lin1.weight", "image_encoder.blocks.24.mlp.lin1.bias", "image_encoder.blocks.24.mlp.lin2.weight", "image_encoder.blocks.24.mlp.lin2.bias", "image_encoder.blocks.25.norm1.weight", "image_encoder.blocks.25.norm1.bias", "image_encoder.blocks.25.attn.rel_pos_h", "image_encoder.blocks.25.attn.rel_pos_w", "image_encoder.blocks.25.attn.qkv.weight", "image_encoder.blocks.25.attn.qkv.bias", "image_encoder.blocks.25.attn.proj.weight", "image_encoder.blocks.25.attn.proj.bias", "image_encoder.blocks.25.norm2.weight", "image_encoder.blocks.25.norm2.bias", "image_encoder.blocks.25.mlp.lin1.weight", "image_encoder.blocks.25.mlp.lin1.bias", "image_encoder.blocks.25.mlp.lin2.weight", "image_encoder.blocks.25.mlp.lin2.bias", "image_encoder.blocks.26.norm1.weight", "image_encoder.blocks.26.norm1.bias", "image_encoder.blocks.26.attn.rel_pos_h", "image_encoder.blocks.26.attn.rel_pos_w", "image_encoder.blocks.26.attn.qkv.weight", "image_encoder.blocks.26.attn.qkv.bias", "image_encoder.blocks.26.attn.proj.weight", "image_encoder.blocks.26.attn.proj.bias", "image_encoder.blocks.26.norm2.weight", "image_encoder.blocks.26.norm2.bias", "image_encoder.blocks.26.mlp.lin1.weight", "image_encoder.blocks.26.mlp.lin1.bias", "image_encoder.blocks.26.mlp.lin2.weight", "image_encoder.blocks.26.mlp.lin2.bias", "image_encoder.blocks.27.norm1.weight", "image_encoder.blocks.27.norm1.bias", "image_encoder.blocks.27.attn.rel_pos_h", "image_encoder.blocks.27.attn.rel_pos_w", "image_encoder.blocks.27.attn.qkv.weight", "image_encoder.blocks.27.attn.qkv.bias", "image_encoder.blocks.27.attn.proj.weight", "image_encoder.blocks.27.attn.proj.bias", "image_encoder.blocks.27.norm2.weight", "image_encoder.blocks.27.norm2.bias", "image_encoder.blocks.27.mlp.lin1.weight", "image_encoder.blocks.27.mlp.lin1.bias", "image_encoder.blocks.27.mlp.lin2.weight", "image_encoder.blocks.27.mlp.lin2.bias", "image_encoder.blocks.28.norm1.weight", "image_encoder.blocks.28.norm1.bias", "image_encoder.blocks.28.attn.rel_pos_h", "image_encoder.blocks.28.attn.rel_pos_w", "image_encoder.blocks.28.attn.qkv.weight", "image_encoder.blocks.28.attn.qkv.bias", "image_encoder.blocks.28.attn.proj.weight", "image_encoder.blocks.28.attn.proj.bias", "image_encoder.blocks.28.norm2.weight", "image_encoder.blocks.28.norm2.bias", "image_encoder.blocks.28.mlp.lin1.weight", "image_encoder.blocks.28.mlp.lin1.bias", "image_encoder.blocks.28.mlp.lin2.weight", "image_encoder.blocks.28.mlp.lin2.bias", "image_encoder.blocks.29.norm1.weight", "image_encoder.blocks.29.norm1.bias", "image_encoder.blocks.29.attn.rel_pos_h", "image_encoder.blocks.29.attn.rel_pos_w", "image_encoder.blocks.29.attn.qkv.weight", "image_encoder.blocks.29.attn.qkv.bias", "image_encoder.blocks.29.attn.proj.weight", "image_encoder.blocks.29.attn.proj.bias", "image_encoder.blocks.29.norm2.weight", "image_encoder.blocks.29.norm2.bias", "image_encoder.blocks.29.mlp.lin1.weight", "image_encoder.blocks.29.mlp.lin1.bias", "image_encoder.blocks.29.mlp.lin2.weight", "image_encoder.blocks.29.mlp.lin2.bias", "image_encoder.blocks.30.norm1.weight", "image_encoder.blocks.30.norm1.bias", "image_encoder.blocks.30.attn.rel_pos_h", "image_encoder.blocks.30.attn.rel_pos_w", "image_encoder.blocks.30.attn.qkv.weight", "image_encoder.blocks.30.attn.qkv.bias", "image_encoder.blocks.30.attn.proj.weight", "image_encoder.blocks.30.attn.proj.bias", "image_encoder.blocks.30.norm2.weight", "image_encoder.blocks.30.norm2.bias", "image_encoder.blocks.30.mlp.lin1.weight", "image_encoder.blocks.30.mlp.lin1.bias", "image_encoder.blocks.30.mlp.lin2.weight", "image_encoder.blocks.30.mlp.lin2.bias", "image_encoder.blocks.31.norm1.weight", "image_encoder.blocks.31.norm1.bias", "image_encoder.blocks.31.attn.rel_pos_h", "image_encoder.blocks.31.attn.rel_pos_w", "image_encoder.blocks.31.attn.qkv.weight", "image_encoder.blocks.31.attn.qkv.bias", "image_encoder.blocks.31.attn.proj.weight", "image_encoder.blocks.31.attn.proj.bias", "image_encoder.blocks.31.norm2.weight", "image_encoder.blocks.31.norm2.bias", "image_encoder.blocks.31.mlp.lin1.weight", "image_encoder.blocks.31.mlp.lin1.bias", "image_encoder.blocks.31.mlp.lin2.weight", "image_encoder.blocks.31.mlp.lin2.bias".

Indices not moved to correct device in batched NMS

I'm trying to increase the points per side but when I go over 80 I get the following

Traceback (most recent call last):
  File "/home/louis/dev/cv/segment-anything/scripts/amg.py", line 267, in <module>
    main(args)
  File "/home/louis/dev/cv/segment-anything/scripts/amg.py", line 244, in main
    masks = generator.generate(image)
  File "/home/louis/miniconda3/envs/sam/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/louis/dev/cv/segment-anything/segment_anything/automatic_mask_generator.py", line 163, in generate
    mask_data = self._generate_masks(image)
  File "/home/louis/dev/cv/segment-anything/segment_anything/automatic_mask_generator.py", line 206, in _generate_masks
    crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
  File "/home/louis/dev/cv/segment-anything/segment_anything/automatic_mask_generator.py", line 251, in _process_crop
    keep_by_nms = batched_nms(
  File "/home/louis/miniconda3/envs/sam/lib/python3.10/site-packages/torchvision/ops/boxes.py", line 73, in batched_nms
    return _batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
  File "/home/louis/miniconda3/envs/sam/lib/python3.10/site-packages/torch/jit/_trace.py", line 1220, in wrapper
    return fn(*args, **kwargs)
  File "/home/louis/miniconda3/envs/sam/lib/python3.10/site-packages/torchvision/ops/boxes.py", line 110, in _batched_nms_vanilla
    keep_mask[curr_indices[curr_keep_indices]] = True
RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)

I expect this means the curr_keep_indices need to be moved to the same device as the indexed keep_mask tensor.

If we breakpoint at line 251 in _process_crop of automatic_mask_generator.py

        # Remove duplicates within this crop.
        breakpoint()
        keep_by_nms = batched_nms(
            data["boxes"].float(),
            data["iou_preds"],
            torch.zeros(len(data["boxes"])),  # categories
            iou_threshold=self.box_nms_thresh,
        )
        data.filter(keep_by_nms)
  • Note torch.zeros(len(data["boxes"])), # categories does not specify a device

and enter the debugger

> /home/louis/dev/cv/segment-anything/segment_anything/automatic_mask_generator.py(253)_process_crop()
-> data["boxes"].float(),
(Pdb) n
> /home/louis/dev/cv/segment-anything/segment_anything/automatic_mask_generator.py(254)_process_crop()
-> data["iou_preds"],
(Pdb) n
> /home/louis/dev/cv/segment-anything/segment_anything/automatic_mask_generator.py(255)_process_crop()
-> torch.zeros(len(data["boxes"])),  # categories
(Pdb) n
> /home/louis/dev/cv/segment-anything/segment_anything/automatic_mask_generator.py(256)_process_crop()
-> iou_threshold=self.box_nms_thresh,
(Pdb) n
> /home/louis/dev/cv/segment-anything/segment_anything/automatic_mask_generator.py(252)_process_crop()
-> keep_by_nms = batched_nms(
(Pdb) s
--Call--
> /home/louis/miniconda3/envs/sam/lib/python3.10/site-packages/torchvision/ops/boxes.py(44)batched_nms()
-> def batched_nms(
(Pdb) n
> /home/louis/miniconda3/envs/sam/lib/python3.10/site-packages/torchvision/ops/boxes.py(68)batched_nms()
-> if not torch.jit.is_scripting() and not torch.jit.is_tracing():
(Pdb) n
> /home/louis/miniconda3/envs/sam/lib/python3.10/site-packages/torchvision/ops/boxes.py(69)batched_nms()
-> _log_api_usage_once(batched_nms)
(Pdb) n
> /home/louis/miniconda3/envs/sam/lib/python3.10/site-packages/torchvision/ops/boxes.py(72)batched_nms()
-> if boxes.numel() > (4000 if boxes.device.type == "cpu" else 20000) and not torchvision._is_tracing():
(Pdb) n
> /home/louis/miniconda3/envs/sam/lib/python3.10/site-packages/torchvision/ops/boxes.py(73)batched_nms()
-> return _batched_nms_vanilla(boxes, scores, idxs, iou_threshold)

and print the devices of the variables being passed through batched_nms() to _batched_nms_vanilla():

> /home/louis/miniconda3/envs/sam/lib/python3.10/site-packages/torchvision/ops/boxes.py(73)batched_nms()
-> return _batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
(Pdb) p boxes
tensor([[   0.,    0.,  968.,   45.],
        [   0.,    0.,  968.,   45.],
        [   0.,    0.,  968.,   45.],
        ...,
        [   3.,    0.,  995., 1332.],
        [   3.,    0.,  995., 1332.],
        [   0.,    4.,  995., 1332.]], device='cuda:0')
(Pdb) p scores
tensor([0.9919, 0.9900, 0.9894,  ..., 1.0067, 1.0069, 0.9800], device='cuda:0')
(Pdb) p idxs
tensor([0., 0., 0.,  ..., 0., 0., 0.])
(Pdb) p iou_threshold
0.7

We can confirm that the source of the error is the zeroes which we created without specifying a device, the boxes and scores are on CUDA.

To remedy it

torch.zeros_like(data["boxes"][:,0])

or

torch.zeros(len(data["boxes"]), device=data["boxes"].device)

but I think the first way is more elegant.

I expect the same edit should also be applied on line 360.

            torch.zeros(len(boxes)),  # categories

to

            torch.zeros_like(boxes[:,0]),  # categories

Satellite Imagery(tiny objects) Generalization

Thank you for the incredible work & congratulations!

SAM does not seem to generalize as well on satellite imagery(tiny objects). This was the result of the "segment everything" option on the image .
However SAM works better on the same image if I manually prompt the model to the object of interest (such as the tiny aircrafts on the LHS corner) that it may have missed in the "segment everything" option.
Satellite Imagery

Couple of more examples with the "segment everything" option:

image

image

Any insights on the same would be most helpful !

Dataset

Hi 👋

Thanks for the amazing work. I was wondering if there is a way to download the dataset.

Thanks.

Best regards,

Fra

Misspell in README.md

Getting Started

or generate masks for an entire image:
from segment_anything import build_sam, SamAutomaticMaskGenerator
mask_generator = SamAutomaticMaskGenerator(build_sam(checkpoint="</path/to/model.pth>"))
masks = mask_generator_generate(<your_image>)

should be corrected to
masks = mask_generator.generate(<your_image>)

error occurs in python scritps

1st :masks,_, _ = predictor.predict('bed')

AssertionError: point_labels must be supplied if point_coords is supplied.

2ed: from segment_anything import build_sam, SamAutomaticMaskGenerator
mask_generator = SamAutomaticMaskGenerator(build_sam(checkpoint=pth_path))
masks = mask_generator_generate(img)
NameError: name 'mask_generator_generate' is not defined

how to deal with that ques?

Text prompt?

Amazing work! However, I didn't find the text prompt support , is there any plan to release it?

the result is not as good as the result in demo

What are the parameters in the demo? Why can't I get the good result showed in the demo even though I used the same image? Does using different models make a big difference to the results?

How to effectively download the SA-1B dataset

Thank you very much for such an outstanding contribution. I tried to download the dataset on the official website, and the link in the download text provided could not be downloaded with wget, and it was very time-consuming if I tried to download one by one. How can I download the dataset efficiently?

training on custom dataset

thank you for your great work and contribution. I would like to train this model on my own custom dataset. Could you please explain how is it done. I could not find any training scrips for that. Your comments is really appreciated.

use mask as prompt

Thanks for sharing this great work!
I wonder if you could provide a demo code that uses masks (coordinates of polygon vertex) as prompt.

Best,
Dongyi

AttributeError: 'str' object has no attribute 'shape'

successfully installed and already download ViT-H SAM model, but error occurred when running the code. Can someone help?

code:

from segment_anything import build_sam, SamPredictor
predictor = SamPredictor(build_sam(checkpoint="sam_vit_h_4b8939.pth"))
predictor.set_image("123.jpg")
masks, _, _ = predictor.predict("dog")

#I put .jpg and .pth under folder segment-anything#

error:

(py-1) PS D:\Code\segment-anything> d:; cd 'd:\Code\segment-anything'; & 'C:\Users\toutouge.conda\envs\py-1\python.exe' 'c:\Users\toutouge.vscode\extensions\ms-python.python-2023.6.0\pythonFiles\lib\python\debugpy\adapter/../..\debugpy\launcher' '52773' '--' 'd:\Code\segment-anything\test1.py'
Traceback (most recent call last):
File "d:\Code\segment-anything\test1.py", line 3, in
predictor.set_image("123.jpg")
File "d:\Code\segment-anything\segment_anything\predictor.py", line 56, in set_image
input_image = self.transform.apply_image(image)
File "d:\Code\segment-anything\segment_anything\utils\transforms.py", line 30, in apply_image
target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
AttributeError: 'str' object has no attribute 'shape'

What is post-processing masks?

README mentions needing optional dependencies for stuff like "post-processing". What is post-processing masks? It's mentioned no where else in the README or examples, so hard to tell if it's something I would need.

How can I express it??

result :
[{'segmentation': array([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., False, False, False],
[ True, True, True, ..., False, False, False],
[False, False, False, ..., False, False, False]]), 'area': 76377, 'bbox': [0, 223, 422, 309], 'predicted_iou': 1.0341997146606445, 'point_coords': [[362.5, 258.65625]], 'stability_score': 0.9911314249038696, 'crop_box': [0, 0, 800, 534]}, {'segmentation': array([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,........................

How can I express it??

How long does it take to generate masks?

For a 1702x1276 RGB image, it takes 6+ seconds to generate the masks using the default parameters.

  • OS: Win 11
  • GPU: NVDIA 3070 8GB

Is there a way to speed up?

Great Project

Really great project. I tried it on a wide range of images and am super impressed.

Training on one class only

Is there any "tiny" or "nano" version that is used only for one class segmentation? I am currently struggling with memory on my CPU. I would like to train your model on one class and to be able to run it on CPU in real time. Is that possible?

typo

great works!

from segment_anything import build_sam, SamAutomaticMaskGenerator
mask_generator = SamAutomaticMaskGenerator(build_sam(checkpoint="</path/to/model.pth>"))
masks = mask_generator_generate(<your_image>)

should be:
mask_generator.generate

Text prompt embedding size

Thank you for releasing the model.
The paper mentions that text prompts are encoded using a pretrained ViT-L/14@336px CLIP model. CLIP embedding from this model are of size 768 while SegmentAnything prompt embedding is of size 256. Are there any extra steps for converting CLIP embeddings before feeding into the model?

Onnx opset missmatch (opset 17 only supported in pytorch>=2)

The README says the code requires pytorch>=1.7 but the onnx_model_example.ipynb expected opset=17. Is far as I can tell opset 17 is only supported in pytorch>=2.

  1. Should the readme be updated? Or
  2. Can I just set opset_version=16 in the notebook?
    with open(onnx_model_path, "wb") as f:
        torch.onnx.export(
            onnx_model,
            tuple(dummy_inputs.values()),
            f,
            export_params=True,
            verbose=False,
            opset_version=17,     #  <----  :-(
            do_constant_folding=True,
            input_names=list(dummy_inputs.keys()),
            output_names=output_names,
            dynamic_axes=dynamic_axes,
        )    

Happy to make the changes and make a PR with any desired solution.

Release the web demo please

The web demo is a useful tool to do much thing. Release the web demo or open source can make the project more helpful and let more people take part in the project.

ONNX model in example notebook is "prebaked" to specific aspect ratio.

As noted in the title, when using the ONNX model generated by the example notebook OR the script, it appears to be restricted to the aspect ratio given in the orig_im_size dummy inputs.

I am not sure if this is intended behaviour (i.e. an unavoidable limitation of the ONNX system), or a bug.

This can be observed by changing the dummy input value to e.g. [1024, 1024] when creating the ONNX model, OR by using a different example image with different aspect ratio. When the next section of the ONNX example notebook (onnx_model_example.ipynb) is run, the projected mask will not match up with the underliying image, but will be deformed by a certain amount. This is rather annoying as, since the output mask its automatically "resized" to the target image dimensions, it is a pain to get the inverse transform to get both to match up again, not to mention part of the mask is permanently clipped off.

This doesn't affect the usual predict function, or inputs, and only matters for aspect ratio, scaling both dimensions identically does not affect the result.

I'm not very familiar with the ONNX export process, but I'm going to have a dig around the docs to see if I can figure this one out, but I'll leve this here in the meantime for anyone having similar issues knows what might be the culprit.

ONNX example is not a really full ONNX example

As I notice in example we use exported onnx model only to do prediction, but generation of embeddings still done using pytorch and python.
Which is very limited approach if for example I want to deliver to the system that uses onnx runtime only.
Is there a way to export onnx for calculating embeddings as well?

Web version support for labeling

Thanks for the fantastic research and the code. It is very beneficial as it supports zero shot segmentation

Can we use the pre-trained models for custom images?
Researchers can also use this for labeling their images if we can use the web tool and perform labeling and create a JSON/COCO format file.
Any ideas?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.