Giter Club home page Giter Club logo

clipprompt's Introduction

ClipPrompt

A PyTorch implementation of ClipPrompt based on CVPR 2023 paper CLIP for All Things Zero-Shot Sketch-Based Image Retrieval, Fine-Grained or Not.

Network Architecture

Requirements

conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
pip install torchmetrics
pip install opencv-python

Dataset

Sketchy Extended and TU-Berlin Extended datasets are used in this repo, you could download these datasets from official websites, or download them from Google Drive. The data directory structure is shown as follows:

├──sketchy
  ├── train
      ├── sketch
          ├── airplane
              ├── n02691156_58-1.jpg
              └── ...
          ...
      ├── photo
          same structure as sketch
  ├── val
     same structure as train
     ...
├──tuberlin
  same structure as sketchy
  ...

Usage

To train a model on Sketchy Extended dataset, run:

python main.py --mode train --data_name sketchy

To test a model on Sketchy Extended dataset, run:

python main.py --mode test --data_name sketchy --query_name <query image path>

common arguments:

--data_root                   Datasets root path [default value is '/home/data']
--data_name                   Dataset name [default value is 'sketchy'](choices=['sketchy', 'tuberlin'])
--prompt_num                  Number of prompt embedding [default value is 3]
--save_root                   Result saved root path [default value is 'result']
--mode                        Mode of the script [default value is 'train'](choices=['train', 'test'])

train arguments:

--batch_size                  Number of images in each mini-batch [default value is 64]
--epochs                      Number of epochs over the model to train [default value is 60]
--triplet_margin              Margin of triplet loss [default value is 0.3]
--encoder_lr                  Learning rate of encoder [default value is 1e-4]
--prompt_lr                   Learning rate of prompt embedding [default value is 1e-3]
--cls_weight                  Weight of classification loss [default value is 0.5]
--seed                        Random seed (-1 for no manual seed) [default value is -1]

test arguments:

--query_name                  Query image path [default value is '/home/data/sketchy/val/sketch/cow/n01887787_591-14.jpg']
--retrieval_num               Number of retrieved images [default value is 8]

Benchmarks

The models are trained on one NVIDIA GeForce RTX 3090 (24G) GPU. seed is 42, prompt_lr is 1e-3 and distance function is 1.0 - F.cosine_similarity(x, y), the other hyperparameters are the default values.

Dataset Prompt Num mAP@200 mAP@all P@100 P@200 Download
Sketchy Extended 3 71.9 64.3 70.8 68.1 MEGA
TU-Berlin Extended 3 75.3 66.0 73.9 69.7 MEGA

Results

vis

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.