Giter Club home page Giter Club logo

Comments (1)

yjh0410 avatar yjh0410 commented on August 11, 2024 1

@duxuan11 Dear friend, although the official code does not provide an example of bbox-prompt, referring to the code of the SAM project, we only need to convert bbox to point format and specify labels as 2 (top-left) and 3 (bottom-right). Below, I provide an example, in which I designed two bboxes (xyxy format) to require EfficientSAM to segment two objects. You can refer to this code to implement your own needs (please do not paste and copy directly, because I slightly modified the file structure of the project)...

import cv2
from torchvision import transforms
import torch
import numpy as np
import argparse
import os

from models.build_efficient_sam import efficient_sam_model_registry

parser = argparse.ArgumentParser(description=("Runs automatic mask generation on an input image or directory of images, "
                                              "and outputs masks as either PNGs or COCO-style RLEs. Requires open-cv, "
                                              "as well as pycocotools if saving in RLE format."),
                                              )

parser.add_argument("--input", type=str, required=True,
                    help="Path to either a single input image or folder of images.",
                    )

parser.add_argument("--output", type=str, required=True,
                    help=("Path to the directory where masks will be output. Output will be either a folder "
                          "of PNGs per image or a single json with COCO-style masks."),
                    )

parser.add_argument("--model-type", type=str, required=True,
                    help="The type of model to load, in ['default', 'vit_h', 'vit_l', 'vit_b']",
                    )

parser.add_argument("--checkpoint", type=str, required=True,
                    help="The path to the SAM checkpoint to use for mask generation.",
                    )

parser.add_argument("--device", type=str, default="cuda",
                    help="The device to run generation on.")

parser.add_argument("--convert-to-rle", action="store_true",
                    help=("Save masks as COCO RLEs in a single json instead of as a folder of PNGs. "
                          "Requires pycocotools."),
                    )

parser.add_argument("--show", action="store_true",
                    help=("To show the segmentation results on the input image."),
                    )


def main(args):
    # Build the EfficientSAM model.
    model = efficient_sam_model_registry[args.model_type](checkpoint=args.checkpoint)

    # load an image
    sample_image_np = cv2.imread("data/images/ex1.jpg")
    sample_image_np = cv2.cvtColor(sample_image_np, cv2.COLOR_BGR2RGB)
    sample_image_tensor = transforms.ToTensor()(sample_image_np)

    # bboxes of the sample
    bboxes = [[ 85.7600, 196.6265, 469.7600, 543.6144],
              [236.8000,  82.8916, 325.1200, 441.4458]]
    
    # convert the bboxes into the point prompts
    num_queries = len(bboxes)
    input_points = torch.as_tensor(bboxes).unsqueeze(0)      # [bs, num_queries, 4], bs = 1
    input_points = input_points.view(-1, num_queries, 2, 2)  # [bs, num_queries, num_pts, 2]
    input_labels = torch.tensor([2, 3])  # top-left, bottom-right
    input_labels = input_labels[None, None].repeat(1, num_queries, 1) # [bs, num_queries, num_pts]

    print('Running inference using ',)
    predicted_logits, predicted_iou = model(
        sample_image_tensor[None, ...],
        input_points,
        input_labels,
    )
    sorted_ids = torch.argsort(predicted_iou, dim=-1, descending=True)
    predicted_iou = torch.take_along_dim(predicted_iou, sorted_ids, dim=2)
    # [bs, num_queries, num_candidate_masks, img_h, img_w]
    predicted_logits = torch.take_along_dim(
        predicted_logits, sorted_ids[..., None, None], dim=2
    )
    masks = torch.ge(predicted_logits, 0).cpu().detach().numpy()
    masks = masks[0, :, 0, :, :]  # [num_queries, img_h, img_w]
    
    if args.show:
        masked_image_np = cv2.cvtColor(sample_image_np, cv2.COLOR_RGB2BGR)
        for i in range(num_queries):
            mask = masks[i]
            color = [(np.random.randint(255), np.random.randint(255), np.random.randint(255))]
            # [H, W] -> [H, W, 1]         
            mask = np.repeat(mask[..., None], 3, axis=-1)
            mask_rgb = mask * color * 0.6
            inv_alph_mask = (1 - mask * 0.6)
            masked_image_np = (masked_image_np * inv_alph_mask +  mask_rgb).astype(np.uint8)
        cv2.imshow("masked image", masked_image_np)
        cv2.waitKey(0)

    # save the results
    os.makedirs("outputs/efficient_sam/", exist_ok=True)
    masked_image_np = masked_image_np.copy().astype(np.uint8)
    cv2.imwrite("outputs/efficient_sam/result.png", masked_image_np)


if __name__ == "__main__":
    args = parser.parse_args()
    np.random.seed(12)

    main(args)

from efficientsam.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.