Giter Club home page Giter Club logo

zsgnet-pytorch's Introduction

zsgnet-pytorch

This is the official repository for ICCV19 oral paper Zero-Shot Grounding of Objects from Natural Language Queries. It contains the code and the datasets to reproduce the numbers for our model ZSGNet in the paper.

The code has been refactored from the original implementation and now supports Distributed learning (see pytorch docs) for significantly faster training (around 4x speedup from pytorch Dataparallel)

The code is fairly easy to use and extendable for future work. Feel free to open an issue in case of queries.

Install

Requirements:

  • python>=3.6
  • pytorch>=1.1

To use the same environment you can use conda and the environment file conda_env_zsg.yml file provided. Please refer to Miniconda for details on installing conda.

MINICONDA_ROOT=[to your Miniconda/Anaconda root directory]
conda env create -f conda_env_zsg.yml --prefix $MINICONDA_ROOT/envs/zsg_pyt
conda activate zsg_pyt

Data Preparation

Look at DATA_README.md for quick start and DATA_PREP_README.md for obtaining annotations from the parent datasets.

Training

Basic usage is python code/main_dist.py "experiment_name" --arg1=val1 --arg2=val2 and the arg1, arg2 can be found in configs/cfg.yaml. This trains using the DataParallel mode.

For distributed learning use python -m torch.distributed.launch --nproc_per_node=$ngpus code/main_dist.py instead. This trains using the DistributedDataParallel mode. (Also see caveat in using distributed training below)

An example to train on ReferIt dataset (note you must have prepared referit dataset) would be:

python code/main_dist.py "referit_try" --ds_to_use='refclef' --bs=16 --nw=4

Similarly for distributed learning (need to set npgus as the number of gpus)

python -m torch.distributed.launch --nproc_per_node=$npgus code/main_dist.py "referit_try" --ds_to_use='refclef' --bs=16 --nw=4

Logging

Logs are stored inside tmp/ directory. When you run the code with $exp_name the following are stored:

  • txt_logs/$exp_name.txt: the config used and the training, validation losses after ever epoch.
  • models/$exp_name.pth: the model, optimizer, scheduler, accuracy, number of epochs and iterations completed are stored. Only the best model upto the current epoch is stored.
  • ext_logs/$exp_name.txt: this uses the logging module of python to store the logger.debug outputs printed. Mainly used for debugging.
  • tb_logs/$exp_name: this is still wip, right now just creates a directory and nothing more, ideally want to support the tensorboard logs.
  • predictions: the validation outputs of current best model.

Evaluation

There are two ways to evaluate.

  1. For validation, it is already computed in the training loop. If you just want to evaluate on validation or testing on a model trained previously ($exp_name) you can do:
python code/main_dist.py $exp_name --ds_to_use='refclef' --resume=True --only_valid=True --only_test=True

or you can use a different experiment name as well and pass --resume_path argument like:

python code/main_dist.py $exp_name --ds_to_use='refclef' --resume=True --resume_path='./tmp/models/referit_try.pth' 

After this, the logs would be available inside tmp/txt_logs/$exp_name.txt

  1. If you have some other model, you can output the predictions in the following structure into a pickle file say predictions.pkl:
[
    {'id': annotation_id,
 	'pred_boxes': [x1,y1,x2,y2]},
    .
    .
    .
]

Then you can evaluate using code/eval_script.py using:

python code/eval_script.py predictions_file gt_file

For referit it would be

python code/eval_script.py ./tmp/predictions/$exp_name/val_preds_$exp_name.pkl ./data/referit/csv_dir/val.csv

Caveats in DistributedDataParallel

When training using DDP, there is no easy way to get all the validation outputs into one process (that works only for tensors). As a result one has to save the predictions of each separate process and then read again to combine them in the main process. Current implementation doesn't do this for simplicity, as a result the validation results obtained during training are slight different from the actual results.

To get the correct results, one can follow the steps in Evaluation as is (the point to note is NOT use torch.distributed.launch for evaluation). Thus, you would get correct results when using simply dataparallel.

Pre-trained Models

The pre-trained models are available in Google Drive

ToDo

  • Add colab demo.
  • Add installation guide.
  • Pretrained models
  • Add hubconfig
  • Add tensorboard

Acknowledgements

We thank:

  1. @yhenon for their repository on retina-net (https://github.com/yhenon/pytorch-retinanet).
  2. @amdegroot for their repsository on ssd using vgg (https://github.com/amdegroot/ssd.pytorch)
  3. fastai repository for helpful logging, anchor box generation and convolution functions.
  4. maskrcnn-benchmark repository for many of the distributed utils and implementation of non-maxima suppression.

Citation

If you find the code or dataset useful, please cite us:

@InProceedings{Sadhu_2019_ICCV,
author = {Sadhu, Arka and Chen, Kan and Nevatia, Ram},
title = {Zero-Shot Grounding of Objects From Natural Language Queries},
booktitle = {The IEEE International Conference on Computer Vision (ICCV)},
month = {October},
year = {2019}
}

zsgnet-pytorch's People

Contributors

theshadow29 avatar kanchen-usc avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.