Giter Club home page Giter Club logo

gem's Introduction

GEM

Walid Bousselham, Felix Petersen, Vittorio Ferrari, Hilde Kuehne

Vision-Language foundation models have shown remarkable performance in various zero-shot settings such as image retrieval, classification, or captioning. But so far, those models seem to fall behind when it comes to zero-shot localization of referential expressions and objects in images.

GEM allows a training-free adaptation of Vision-Language models (e.i., CLIP ...) to perform zero-shot open-vocabulary segmentation. The training-free adaptation allows to fully conserve the vocabulary learned by the Vision-Language model during its pretraing, thus allowing the segmentation of uncommon classes (e.g. Elon Musk/Mark Zuckerberg /Jeff Besos).

๐Ÿ”จ Installation

gem library can be simply installed via pip:

$ pip install gem_torch

Demo

To run the gradio app locally, first install gradio and then run app.py:

$ pip install gradio
$ python app.py

Usage

To see which pretrained models is available use the following code snippet:

import gem
gem.available_models()

Single Image

To process a single image and multiple text prompts use the following code snippet:

import torch
import gem
import requests
from PIL import Image

model_name = 'ViT-B/16'  # 'ViT-B-16-quickgelu'
pretrained = 'openai'  # 'metaclip_400m'
device = "cuda" if torch.cuda.is_available() else "cpu"
# init model and image transform
preprocess = gem.get_gem_img_transform()
gem_model = gem.create_gem_model(model_name=model_name,
                                 pretrained=pretrained, 
                                 device=device)
# load image and text
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = preprocess(
    Image.open(requests.get(url, stream=True).raw)
               ).unsqueeze(0).to(device)
text = ['cat', 'remote control']

with torch.no_grad():
    logits = gem_model(image, text)  # [B, num_prompt, W, H]
    gem_model.visualize(image, text, logits)  # (optional visualization)

Batched Inference

To process a batch of images with different number of prompts per image, one must use the batched_forward() function of gem_model:

import torch
import gem
import requests
from PIL import Image

model_name = 'ViT-B/16'  # 'ViT-B-16-quickgelu'
pretrained = 'openai'  # 'metaclip_400m'
device = "cuda" if torch.cuda.is_available() else "cpu"
# init model and image transform
gem_model = gem.create_gem_model(model_name=model_name,
                                 pretrained=pretrained, 
                                 device=device)
preprocess = gem.get_gem_img_transform()

# load image and text
urls = [
    "http://images.cocodataset.org/val2017/000000039769.jpg",
    "https://cdn.vietnambiz.vn/171464876016439296/2021/7/11/headshots16170695297430-1626006880779826347793.jpg",
    "https://preview.redd.it/do-you-think-joker-should-be-unpredictable-enough-to-put-up-v0-6a2ax4ngtlaa1.jpg?auto=webp&s=f8762e6a1b40642bcae5900bac184fc597131503",
    ]
texts = [
    ['remote control', 'cat'],
    ['elon musk', 'mark zuckerberg', 'jeff bezos', 'bill gates'],
    ['batman', 'joker', 'shoe', 'belt', 'purple suit'],
    ]  # note that the number of prompt per image can be different

# download images + convert to PIL.Image
images_pil = [Image.open(requests.get(url, stream=True).raw) for url in urls]
images = torch.stack([preprocess(img) for img in images_pil]).to(device)

with torch.no_grad():
    # return list with logits of size [1, num_prompt, W, H]
    logits_list = gem_model.batched_forward(images, texts)
    
    for i, logits in enumerate(logits_list):  # (optional visualization)
        gem_model.visualize(images[i], texts[i], logits)

API

The library provides the following methods:

  • gem.create_gem_model(model_name, pretrained, device, ...):
    • Returns model_name Vision Language model with pretrained weights loaded and GEM applied. One can also specify gem_depth, ss_attn_iter and ss_attn_temp parameters to respectively control GEM's depth, self-self attention number of iteration and temperature (see paper for more details).
  • gem.get_gem_img_transform(img_size):
    • takes in a PIL.Image and returns a torch.Tensor. This can be used as input to the model.
  • gem.visualize(image, prompts, logits, alpha=0.6, save_path=None):
    • Takes in a PIL.Image or a torch.Tensor, as well as the list of text prompt and the logits outputed by gem and plot the gem's heatmaps for each prompt. Alternatively, the heatmaps cam be saved by specifying the saving path save_path. One can also change the transparence of the heatmps via the aplha=0.6 argument.

By default, the models loaded by gem.create_gem_model() returns logits outputed by GEM, but can also return the logits of the original Vision Language model (it can be useful for visualization). To do so, set return_ori=True.

More Examples

Semantic Segmentation

For the semantic segmentation task, given a list of foreground class names, one must predict a 2D map where each location is the id of the predicted class. Depending on the dataset, we may also want to predict a background class. However, the textual description "a photo of a background" is not descriptive of what the background is composed of. Hence, we propose to use the following code method:

import torch
import gem
import requests
from PIL import Image

model_name = 'ViT-B/16'  # 'ViT-B-16-quickgelu'
pretrained = 'openai'  # 'metaclip_400m'
device = "cuda" if torch.cuda.is_available() else "cpu"
# init model and image transform
gem_model = gem.create_gem_model(model_name=model_name,
                                 pretrained=pretrained, 
                                 device=device)
preprocess = gem.get_gem_img_transform()

predict_background = True  # whether the background is predicted
if predict_background:
    threshold = 0.85  # the threshold depends on the number of classes

# load image and text
image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
class_names = ['airplane', 'cat', 'dog', '...']  # foreground class names

with torch.no_grad():
    logits = gem_model(image, class_names)  # [1, num_class, W, H]

pred = logits.argmax(dim=1)
if predict_background:
    pred = pred + 1  # we assume the background's index is 0
    probs = logits.softmax(dim=1)
    max_prob = probs.max(dim=1)[0]
    pred[probs < threshold] = 0  # if the max prob is lower than the threshold the background is predicted

Note that threshold depends on the number of classes and should be determined via a hyperparameter sweep.

Dataset

gem can also be used with regular pytorch dataset.

import torch
import gem
from PIL import Image
from torchvision.datasets import VOCSegmentation

model_name = 'ViT-B/16'  # 'ViT-B-16-quickgelu'
pretrained = 'openai'  # 'metaclip_400m'
device = "cuda" if torch.cuda.is_available() else "cpu"
# init model and image transform
gem_model = gem.create_gem_model(model_name=model_name,
                                 pretrained=pretrained, 
                                 device=device)
preprocess = gem.get_gem_img_transform()

predict_background = True  # whether the background is predicted
if predict_background:
    threshold = 0.85  # the threshold depends on the number of classes

# load dataset
root = './data'  # path to save the dataset
dataset = VOCSegmentation(root=root, image_set='val', download=True, transform=preprocess)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4)
class_names = ['airplane', 'cat', 'dog', '...']  # foreground class names

with torch.no_grad():
    for (image, _) in dataloader:
        logits = gem_model(image, class_names)  # [1, num_class, W, H]
    
        pred = logits.argmax(dim=1)
        if predict_background:
            pred = pred + 1  # we assume the background's index is 0
            probs = logits.softmax(dim=1)
            max_prob = probs.max(dim=1)[0]
            pred[probs < threshold] = 0  # if the max prob is lower than the threshold the background is predicted

โญ Acknowledgement

This code is build as wrapper around OpenCLIP library from LAION, visit their repo for more vision-language models. This project takes inspiration from CLIP and CLIPSurgery, please visit their repository. This repo also uses einops as well and take inspiration from CLIP and CLIPSurgery repository.

๐Ÿ“š Citation

If you find this repository useful, please consider citing our work ๐Ÿ“ and giving a star ๐ŸŒŸ :

@article{bousselham2023gem,
  title={Grounding Everything: Emerging Localization Properties in Vision-Language Transformers},
  author={Walid Bousselham, Felix Petersen, Vittorio Ferrari, Hilde Kuehne},
  journal={arXiv preprint arXiv:2312.00878},
  year={2023}
}

gem's People

Contributors

mmattamala avatar walbouss avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.