Giter Club home page Giter Club logo

tinysam's Introduction

TinySAM

TinySAM: Pushing the Envelope for Efficient Segment Anything Model

Han Shu, Wenshuo Li, Yehui Tang, Yiman Zhang, Yihao Chen, Houqiang Li, Yunhe Wang, Xinghao Chen

arXiv 2023

Open in OpenXLab

compare        compare

Updates

  • 2024/01/06: Demo of TinySAM is now available in OpenXLab. Thanks for the GPU grant.
  • 2023/12/27: Models and demo of TinySAM are now available in Hugging Face. Thanks for merveenoyan.
  • 2023/12/27: Pre-trained models and codes of Q-TinySAM (quantized variant) are released.
  • 2023/12/27: Evaluation codes for zero-shot instance segmentation task on COCO are released.
  • 2023/12/22: Pre-trained models and codes of TinySAM are released both in Pytorch and Mindspore.

Overview

We propose a framework to obtain a tiny segment anything model (TinySAM) while maintaining the strong zero-shot performance. We first propose a full-stage knowledge distillation method with online hard prompt sampling strategy to distill a lightweight student model. We also adapt the post-training quantization to the promptable segmentation task and further reducing the computational cost. Moreover, a hierarchical segmenting everything strategy is proposed to accelerate the everything inference by with almost no performance degradation. With all these proposed methods, our TinySAM leads to orders of magnitude computational reduction and pushes the envelope for efficient segment anything task. Extensive experiments on various zero-shot transfer tasks demonstrate the significantly advantageous performance of our TinySAM against counterpart methods.

framework

Figure 1: Overall framework and zero-shot results of TinySAM.

everything

Figure 2: Our hierarchical strategy for everything mode.

vis

Figure 3: Visualization results of TinySAM.

Requirements

The code requires python>=3.7 and we use torch==1.10.2 and torchvision==0.11.3. To visualize the results, matplotlib>=3.5.1 is also required.

  • python 3.7
  • pytorch == 1.10.2
  • torchvision == 0.11.3
  • matplotlib==3.5.1

Usage

  1. Download checkpoints into the directory of weights.

  2. Run the demo code for single prompt of point or box.

python demo.py
  1. Run the demo code for hierarchical segment everything strategy.
python demo_hierachical_everything.py
  1. Run the demo code for quantization inference.
python demo_quant.py

Evaluation

We follow the setting of original SAM paper and evaluate the zero-shot instance segmentaion on COCO and LVIS dataset. The experiment results are described as followed.

Model FLOPs (G) COCO AP (%) LVIS AP (%)
SAM-H 2976 46.6/46.5* 44.7
SAM-L 1491 46.2/45.5* 43.5
SAM-B 487 43.4/41.0* 40.8
FastSAM 344 37.9 34.5
MobileSAM 42.0 41.0 37.0
TinySAM [ckpt] 42.0 41.9 38.6
Q-TinySAM [ckpt] 20.3 41.3 37.2

* Results of single output (multimask_output=False).

First download the detection boxes (coco_instances_results_vitdet.json) produced by ViTDet model, as well as the ground-truth instance segmentation labels(instances_val2017.json) and put them into eval/json_files. Related json files for LVIS dataset are available in lvis_instances_results_vitdet.json and lvis_v1_val.json.

Run the following code to perform evaluation for zero-shot instance segmentation on COCO dataset.

cd eval; sh eval_coco.sh

The results should be:

Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.419
Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.683
Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.436
Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.260
Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.456
Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.583
Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.325
Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.511
Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.532
Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.390
Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.577
Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.671

Acknowledgements

We thank the following projects: SAM, MobileSAM, TinyViT.

Citation

@article{tinysam,
  title={TinySAM: Pushing the Envelope for Efficient Segment Anything Model},
  author={Shu, Han and Li, Wenshuo and Tang, Yehui and Zhang, Yiman and Chen, Yihao and Li, Houqiang and Wang, Yunhe and Chen, Xinghao},
  journal={arXiv preprint arXiv:2312.13789},
  year={2023}
}

License

This project is licensed under Apache License 2.0. Redistribution and use should follow this license.

tinysam's People

Contributors

eltociear avatar xinghaochen avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

tinysam's Issues

LVIS related json file

Your paper is great. We found out that you provided the coco_instances_results_vitdet.json file on COCO. Can you provide the LVIS related json file generated by vitdet for evaluation. Thank you.

Questions about latency

2024-03-05_16-03

I have a question, does this time represent the runtime of a single image in SegEvery mode?

MLP block weights for mask_tokens 0 in mask_decoder are almost all zeros

Just noticed the mlp block weights for mask_tokens 0 in mask_decoder are almost all zeros. Is this intend to be?

import torch
ss = torch.load("tinysam.pth", map_location=torch.device('cpu'))

ww = ss["mask_decoder.output_hypernetworks_mlps.0.layers.0.weight"]
print(ww[torch.where(ww.abs() > 1e-6)])
# tensor([-0.1436, -0.0390,  0.3668,  0.2065,  0.1118, -0.0201,  0.1688])

ww = ss["mask_decoder.output_hypernetworks_mlps.0.layers.1.weight"]
print(ww[torch.where(ww.abs() > 1e-6)])
# tensor([0.1090, 0.0203, 0.8415, 0.0125, 0.2405, 0.1774])

ww = ss["mask_decoder.output_hypernetworks_mlps.0.layers.2.weight"]
print(ww[torch.where(ww.abs() > 1e-6)])
# tensor([])

Finetune TinySAM on custom dataset

Hello, Im trying to fine-tune the mask decoder of tiny sam on a custom dataset while freezing the weights of the image_encoder and prompt_encoder. Im having an issue in my training loop where the sam.forward() requires a "multimask_output" argument but the MaskDecoder.forward() doesn't accept a "multitask_output" argument.

Im not an ML Engineer so I don't know much about the underlying code. If anyone with more knowledge than me has some insight into how I can resolve this issue I would appreciate it, thanks!

here is how im freezing the image encoder and prompt encoder to maintain the original weights:

for name, param in sam_model.named_parameters():
  if name.startswith("image_encoder") or name.startswith("prompt_encoder"):
    param.requires_grad_(False)

I am also providing bounding box Prompts as the input. Here is my custom class for Dataset creation:

class SAMDataset(Dataset):
    """
    Dataset class for SAM model, serving images with associated bounding boxes and masks,
   
    """
    def __init__(self, dataset, bbox_mapping, sam_model, device='cuda'):
        self.dataset = dataset
        self.bbox_mapping = bbox_mapping
        self.sam_model = sam_model
        self.device = device
        self.target_size = (1024, 1024)  # Adjusted to the expected input size of the model

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        # Assuming dataset[idx] returns a dict with 'image' and 'label' keys
        pil_image = self.dataset[idx]['image']
        pil_mask = self.dataset[idx]['label']

        image_tensor = to_tensor(np.array(pil_image)).to(self.device)
        mask_tensor = to_tensor(np.array(pil_mask)).to(self.device)

        # Resize image and mask to target size
        image_tensor = resize(image_tensor, self.target_size)
        mask_tensor = resize(mask_tensor, self.target_size)

        # Fetch bounding boxes directly without padding
        bboxes = self.bbox_mapping.get(idx + 1, [])  # Adjust index if necessary
        bboxes_tensor = torch.tensor(bboxes, dtype=torch.float, device=self.device)

        return {
            'image': image_tensor,
            'bboxes': bboxes_tensor,
            'mask': mask_tensor
        }
        
### Create a DataLoader instance for the training dataset
from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_dataset,shuffle=True, drop_last=False)

image torch.Size([1, 3, 1024, 1024])
bboxes torch.Size([1, 1, 4])
mask torch.Size([1, 1, 1024, 1024])
`
### Training Loop

num_epochs = 1
device = "cuda"
sam_model.to(device)
sam_model.train()

for epoch in range(num_epochs):
    epoch_losses = []
    for batch in tqdm(train_dataloader):
        # Preparing the batched_input according to the Tiny sam_model's expected input format
        batched_input = [{
        'image': batch['image'].squeeze(0).to(device),
        'bboxes': batch['bboxes'].squeeze(0).to(device)
    }]
        # forward pass
        outputs_list = sam_model(batched_input, multimask_output = True)

        # Assuming the outputs_list is a list of dictionaries, and you need to aggregate the masks for loss computation
        # Here, you'd need to adapt the code to match the structure of your outputs
        predicted_masks = torch.stack([output['pred_mask'] for output in outputs_list]).squeeze(0)
        ground_truth_masks = batch["mask"].float().squeeze(1).to(device)

        loss = seg_loss(predicted_masks, ground_truth_masks)

        # backward pass (compute gradients of parameters)
        optimizer.zero_grad()
        loss.backward()

        # optimize
        optimizer.step()
        epoch_losses.append(loss.item())

    print(f'EPOCH: {epoch}')
    print(f'Mean loss: {mean(epoch_losses)}')


error when I DONT provide multitask_output:

TypeError                                 Traceback (most recent call last)
<ipython-input-108-f41ebba752d9> in <cell line: 12>()
     21 
     22         # forward pass
---> 23         outputs_list = sam_model(batched_input)
     24 
     25         # Assuming the outputs_list is a list of dictionaries, and you need to aggregate the masks for loss computation

2 frames
/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py in decorate_context(*args, **kwargs)
    113     def decorate_context(*args, **kwargs):
    114         with ctx_factory():
--> 115             return func(*args, **kwargs)
    116 
    117     return decorate_context

TypeError: Sam.forward() missing 1 required positional argument: 'multimask_output'

error when I do provide the multitask_output argument:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[<ipython-input-42-9d874c2eda3d>](https://bvvo9qsh5t-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240325-094203_RC00_618870756#) in <cell line: 12>()
     19     }]
     20         # forward pass
---> 21         outputs_list = sam_model(batched_input, multimask_output = True)
     22 
     23         # Assuming the outputs_list is a list of dictionaries, and you need to aggregate the masks for loss computation

5 frames
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://bvvo9qsh5t-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240325-094203_RC00_618870756#) in _call_impl(self, *args, **kwargs)
   1518                 or _global_backward_pre_hooks or _global_backward_hooks
   1519                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520             return forward_call(*args, **kwargs)
   1521 
   1522         try:

TypeError: MaskDecoder.forward() got an unexpected keyword argument 'multimask_output'
 

Have weights available

Hello 👋 Thanks a lot for open-sourcing the code!
I wanted to try TinySAM and in the demo and README there seems to be a weights folder which I couldn't find.

In any case, would you like to have the weights available on Hugging Face Hub and load them from there?

Which part is knowledge distillation

Hello, Author,

Your paper is very nice, thanks for sharing. I have a confusion is that I can't find which part of the codes shows the knowledge distillation.

TinySAM multimask_output parameter

Hello, I'm trying to use TinySAM with the output of RTMDet as bounding box prompts.

I came across the issue of the 'multimask_output' parameter in the 'predict' method of mask_predictor not being recognized. Is the name of this parameter different, or does it not need to be specified? I believe that it's by default set to True in the foundation model.

TypeError                              Traceback (most recent call last)
[<ipython-input-28-1a43b775f4cb>](https://localhost:8080/#) in <cell line: 50>()
     48 logging.basicConfig(level=logging.INFO)
     49 image_path = "/content/drive/MyDrive/YOLOv5_Images/4lqw42gnjq1c1.webp"  # Replace with your image path
---> 50 segment_and_display_objects(image_path, rtmdet_model, mask_predictor, threshold=0.5, mask_threshold=0.3)

[<ipython-input-28-1a43b775f4cb>](https://localhost:8080/#) in segment_and_display_objects(image_path, rtmdet_model, mask_predictor, threshold, mask_threshold)
     24         if score > threshold:
     25             mask_predictor.set_image(original_img_bgr)
---> 26             masks, _, _ = mask_predictor.predict(box=bbox, multimask_output=True)
     27 
     28             if len(masks) >= 3:

TypeError: SamPredictor.predict() got an unexpected keyword argument 'multimask_output'

Loading Q-TinySAM

Hello!
I came across two issues loading Q-TinySAM.
First one is running demo_quant.py looking for normal checkpoints due to model type being vit_t, is this the proper type for loading quantized weights? Because due to this, during model loading it doesn't look for quantization layer and fails.
Demo error:

Traceback (most recent call last):
  File "/content/TinySAM/./demo_quant.py", line 7, in <module>
    from demo import show_mask, show_points, show_box
  File "/content/TinySAM/demo.py", line 31, in <module>
    sam = sam_model_registry[model_type](checkpoint="./weights/tinysam.pth")
  File "/content/TinySAM/tinysam/build_sam.py", line 90, in build_sam_vit_t
    with open(checkpoint, "rb") as f:
FileNotFoundError: [Errno 2] No such file or directory: './weights/tinysam.pth'

When I try to infer myself with no demo, I get following error:

---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
[<ipython-input-7-04f0ec6a92d3>](https://localhost:8080/#) in <cell line: 19>()
     17 
     18 cpt_path = "./tinysam/tinysam_w8a8.pth"
---> 19 quant_sam = torch.load(cpt_path)
     20 
     21 device = "cuda" if torch.cuda.is_available() else "cpu"

2 frames
[/usr/local/lib/python3.10/dist-packages/torch/serialization.py](https://localhost:8080/#) in find_class(self, mod_name, name)
   1413                     pass
   1414             mod_name = load_module_mapping.get(mod_name, mod_name)
-> 1415             return super().find_class(mod_name, name)
   1416 
   1417     # Load the data (which may in turn use `persistent_load` to load tensors)

ModuleNotFoundError: No module named 'quantization_layer'

Benchmarking Batch Inference

Hello 🙌🏼 I'm running some benchmarks on TinySAM, trying to benchmark batch inference. However, I hit a wall during batch inference. All inputs are torch tensors with the shapes expected by docstrings in inference code:

point_prompt.shape # torch.Size([4, 1, 2]) BXNX2
input_label.shape # torch.Size([4, 1]) BXN
batched_image.shape # torch.Size([4, 3, 1024, 1024]), BCHW
predictor.set_torch_image(batched_image, original_image_size=batched_image[0, 0, :, :].shape) # goes well

# this fails
with torch.no_grad():
        _, _, _ = predictor.predict_torch(
        point_coords=point_prompt,
        point_labels=input_label)

I don't really have a lot of time to debug this as I already did couple of steps, I feel like I'm missing a step, can you let me know if so? I can post a full trace if you want but I really feel like I'm missing a step and hence it errors out.

Unexpected warning of register_model

UserWarning: Overwriting tiny_vit_5m_224 in registry with tinysam.modeling.tiny_vit_sam.tiny_vit_5m_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
return register_model(fn_wrapper)

I got several warings when running demo.py and this is my outcome. I want to know if the result is normal and if this warning can be ignored since my pytorch version is higher.

test

Tiny-Sam vs HQ-Sam

with four points, the tiny-sam does'nt segment correctly,while HQ-Sam works well even for pictures with great noises
test
20231226114538

how to quantize the lightweight SAM model?

Hi, nice work it is. I'm tring your method to do some application and have some questions about the quantization.

I have carefully looked at the code in demo_quan.py and layer.py, but currently the model in demo_quan.py is loaded directly from the quantized weights. I would like to ask how to quantize the instances created from an existing pre-trained SAM model using your quantization method?

Since I don't know how you quantize your lightweight SAM model using the quantization method in layer.py, can you provide a reference example of how did you do when quantizing the model? Thank you very much!

Here is the demo I wrote, it runs successfully, but the test result after quantization is close to 0. does it need retraining? Or maybe I'm not thinking correctly? hoping your reply!

from quantization_layer.layers import InferQuantConv2d, InferQuantConvTranspose2d

model_type = 'vit_b'
checkpoint = 'checkpoints/sam_vit_b_01ec64.pth'
model = sam_model_registry[model_type](checkpoint=checkpoint)
model.to(device)
model.eval()
predictor = SamPredictor(model)

w_bit = 8
a_bit = 8
input_size = (1, 3, 1024, 1024)  
n_V = input_size[2]
n_H = input_size[3]
a_interval = torch.tensor(0.1)
a_bias = torch.tensor(0.0)
w_interval = torch.tensor(0.01)

# 量化模型中的卷积层和卷积转置层
def replace_with_quantized_layers(model):
    layers_to_replace = []
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d):
            layers_to_replace.append((name, module))
    for name, module in layers_to_replace:
        if isinstance(module, nn.Conv2d):
            quantized_module = InferQuantConv2d(
                in_channels=module.in_channels,
                out_channels=module.out_channels,
                kernel_size=module.kernel_size,
                stride=module.stride,
                padding=module.padding,
                dilation=module.dilation,
                groups=module.groups,
                bias=module.bias is not None,
                mode='quant_forward',
                w_bit=w_bit,
                a_bit=a_bit
            )
            quantized_module.get_parameter(n_V=n_V, 
                      n_H=n_H,
                      a_interval=a_interval,
                      a_bias=a_bias,
                      w_interval=w_interval)
        elif isinstance(module, nn.ConvTranspose2d):
            quantized_module = InferQuantConvTranspose2d(
                in_channels=module.in_channels,
                out_channels=module.out_channels,
                kernel_size=module.kernel_size,
                stride=module.stride,
                padding=module.padding,
                output_padding=module.output_padding,
                groups=module.groups,
                bias=module.bias is not None,
                mode='quant_forward',
                w_bit=w_bit,
                a_bit=a_bit
            )
            quantized_module.get_parameter(n_V=n_V,
                              n_H=n_H,
                              a_interval=a_interval,
                              a_bias=a_bias,
                              w_interval=w_interval)
        setattr(model, name, quantized_module)
    return model

quan_model = replace_with_quantized_layers(model)
print(quan_model)

how to evaluate on LVIS?

Hello, I find your work and it's great.

I'm a newbie and I'm having some difficulties in reproducing the LVIS validation effect, and I've looked up a lot of information but I can't figure it out. Therefore, I would like to get your help. Specifically:

  1. you don't have the instructions and results for reproducing LVIS in your readme, I wonder if it would be convenient to update it? Thanks.

  2. when I use your script and the shared lvis_instances_results_vitdet.json and lvis_v1_val.json for validation test, it always reports error that it didn't find certain images, which I found out actually belong to coco train dataset. But the LVIS val dataset I downloaded from the LVIS website is exactly the same as coco val, i.e. it does not contain the images from coco train. I would like to ask, how exactly is the LVIS validation set divided into components? Or is there a recommended download link? Thanks.

demo_hierachical_everything.py Produces No Masks

I can't the demo_hierachical_everything to work. There are no errors but the output from hierarchical_generate is an empty list.

Here is the image it makes:
test_everthing

Env

  • python 3.9
  • pip install torch==1.10.2 torchvision==0.11.3 matplotlib>=3.5.1 opencv-python==4.8.1.78 timm==0.9.12

run apk demo

Thank you for your useful work.
And I noticed that there is a apk demo in your README, how can I get it?
Snipaste_2024-01-02_21-10-58
If the apk demo is not open source, could you tell me what deployment method to use? like NCNN, MNN.
Looking forward to your reply!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.