Giter Club home page Giter Club logo

lsvrd-crude.pytorch's Introduction

Visual Relationship Encoder - PyTorch Implementation

Environment

python==3.6 pytorch==0.4.0 anaconda recommended

install additional requirements with

pip install -r requirements.txt

clone roi_align (anywhere you like)

git clone https://github.com/longcw/RoIAlign.pytorch.git
cd RoIAlign.pytorch

modify the -arch argument in install.sh to suite your GPU. See https://arnon.dk/matching-sm-architectures-arch-and-gencode-for-various-nvidia-cards/ for more information

sh install.sh
sh test.sh # to see if it's correctly installed

Data

download glove 6B word embeddings from http://nlp.stanford.edu/data/glove.6B.zip unzip in a folder make a soft link under data/

download the GQA dataset from https://cs.stanford.edu/people/dorarad/gqa/download.html make a soft link under data/

make sure you have:

cd $VRE_ROOT
tree data/gqa -L 1
data/gqa
├── images - all the image files extracted in this folder
├── objects - object features (.h5 files and the .json)
├── questions12 - v1.2 question json files
└── scene_graphs - scene graph json files
tree data/glove -L 1
data/glove
└── glove.6B.300d.txt

preprocess data:

cd $LSVRD_ROOT
sh scripts/prepare_data.sh

Structure

$LSVRD_ROOT
├── cache
├── configs # configuration files
├── data
│   ├── glove -> /path/to/glove/
│   └── gqa -> /path/to/gqa/
├── lib
│   ├── data
│   │   ├── dataset.py
│   │   ├── h5io.py
│   │   └── sym_dict.py
│   ├── evaluate.py # evaluation function
│   ├── loss
│   │   ├── consistency_loss.py
│   │   ├── triplet_loss.py
│   │   └── triplet_softmax_loss.py
│   ├── model
│   │   ├── language_model.py
│   │   ├── loss_model.py
│   │   └── vision_model.py
│   ├── module
│   │   ├── backbone.py
│   │   ├── entity_net.py
│   │   ├── relation_net.py
│   │   └── similarity_model.py
│   ├── train.py
│   └── utils.py
├── main
│   └── train.py
└── scripts

Training

You can optionally use scripts/pre_extract_features.py to extract ResNet-101 feature maps as preprocessing.

example:

python main/train.py \
--config configs/resnet101-512-14-7-7-GRU-300d-1layer-5-0-256-128-0.2-0.2-1.0-1001-gt-311-100000-1e-4-0.8.json \
--n_epochs 5 \

run python main/train.py --help for other options.

Inference

To extract features on GQA, use main/infer.py example:

python main/train.py \
--config configs/
--lckpt path/to/language/model.pth
--vckpt path/to/vision/model.pth
--dataset gqa

lsvrd-crude.pytorch's People

Contributors

yzhq97 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.