Giter Club home page Giter Club logo

ws_dan's Introduction

Weakly Supervised Data Augmentation Network

This is the official TensorFlow implementation of WS-DAN.

See Better Before Looking Closer: Weakly Supervised Data Augmentation Network for Fine-Grained Visual Classification.

Compatibility

  • The code is tested using Tensorflow r1.x under Ubuntu 16.04 with Python 2.x and Python 3.x.

  • Recommend Environment: Anaconda

Requirements

$ git clone [email protected]:tau-yihouxiang/WS_DAN.git
$ cd WS_DAN
$ python setup.py install
  • opencv, tqdm
$ conda install -c menpo opencv
$ pip install tqdm

Datasets and Pre-trained models

Datasets #Attention Pre-trained model
CUB-200-2011 32 WS-DAN
Stanford-Cars 32 WS-DAN
FGVC-Aircraft 32 WS-DAN

Inspiration

The code is based on the Tensorflow-Slim Library.

Preparing Datasets

Download and pre-process images and labels to tfrecords.

The convert_data.py will generate ./tfrecords folder blow the provided $dataset_dir

-Bird
   └── Data
         └─── tfrecords
         └─── images.txt
         └─── image_class_labels.txt
         └─── train_test_split.txt
         └─── images
                 └─── ****.jpg
$ python convert_data.py --dataset_name=Bird --dataset_dir=./Bird/Data
-Car
  └── Data
        └─── tfrecords
        └─── devkit
        |         └─── cars_train_annos.mat
        |         └─── cars_test_annos_withlabels.mat
        └─── cars_train
        |        └─── ****.jpg
        └─── cars_test
                 └─── ****.jpg
$ python convert_data.py --dataset_name=Car --dataset_dir=./Car/Data
-Aircraft
    └── Data
          └─── tfrecords
          └─── fgvc-aircraft-2013b
                       └─── ***
$ python convert_data.py --dataset_name=Aircraft --dataset_dir=./Aircraft/Data

Running training

ImageNet pre-trained model

Download imagenet pre-trained model inception_v3.ckpt and put it blow folder ./pre_trained/

DATASET="Bird"
TRAIN_DIR="./$DATASET/WS_DAN/TRAIN/ws_dan_part_32"
MODEL_PATH='./pre_trained/inception_v3.ckpt'

python train_sample.py --learning_rate=0.001 \
                            --dataset_name=$DATASET \
                            --dataset_dir="./$DATASET/Data/tfrecords" \
                            --train_dir=$TRAIN_DIR \
                            --checkpoint_path=$MODEL_PATH \
                            --max_number_of_steps=80000 \
                            --weight_decay=1e-5 \
                            --model_name='inception_v3_bap' \
                            --checkpoint_exclude_scopes="InceptionV3/bilinear_attention_pooling" \
                            --batch_size=12 \
                            --train_image_size=448 \
                            --num_clones=1 \
                            --gpus="3"\
                            --feature_maps="Mixed_6e"\
                            --attention_maps="Mixed_7a_b0"\
                            --num_parts=32

Running testing

DATASET="Bird"
TRAIN_DIR="./$DATASET/WS_DAN/TRAIN/ws_dan_part_32"
TEST_DIR="./$DATASET/WS_DAN/TEST/ws_dan_part_32"

python eval_sample.py --checkpoint_path=$TRAIN_DIR \
                         --dataset_name=$DATASET \
                         --dataset_split_name='test' \
                         --dataset_dir="./$DATASET/Data/tfrecords" \
                         --eval_dir=$TEST_DIR \
                         --model_name='inception_v3_bap' \
                         --batch_size=16 \
                         --eval_image_size=448\
                         --gpus="2"\
                         --feature_maps="Mixed_6e"\
                         --attention_maps="Mixed_7a_b0"\
                         --num_parts=32

Visualization

$ tensorboard --logdir=/path/to/model_dir --port=8081

Contact

Email: [email protected]

Other Re-implementation

WS-DAN.PyTorch

License

MIT

ws_dan's People

Contributors

tau-yihouxiang avatar toddwyl avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

ws_dan's Issues

Different claims for the paper and the code on attention regularization

Hi there,

Thanks for the contribution! After reading the code, I am kind of confused on the attention regularization part. Please correct me if there is some misunderstanding.

From the code, what I understand for the center loss part is that for every class(label), you have a center for the features and obviously those features are also used for softmax classification with multiplying a scale 100. However, what you claimed in the paper is that the center loss is used for the attention regularization which will assign each attention feature in the feature matrix a center. The equation you used in the paper for center loss is the sum of distance difference between those attention features ("with an distinguished M in the equation").

Is there any explanation of doing this?

advisory

Author, hello, may I ask you to use the tensorflow version?
thanks

consult

Hello, author.
Ask you a question: Does your data generate tfrecord, is it normalized when reading? and
I did not find the relevant code.
preprocess_for_train function parameters
Args:
image: 3-D Tensor of image. If dtype is tf.float32 then the range should be
[0, 1], otherwise it would converted to tf.float32 assuming that the range
is [0, MAX], where MAX is largest positive representable number for
int(8/16/32) data type (see tf.image.convert_image_dtype for details).
height: integer
width: integer
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
where each coordinate is [0, 1) and the coordinates are arranged
as [ymin, xmin, ymax, xmax].
fast_mode: Optional boolean, if True avoids slower transformations (i.e.
bi-cubic resizing, random_hue or random_contrast).
scope: Optional scope for name_scope.
add_image_summaries: Enable image summaries.
Returns:
3-D float Tensor of distorted image used for training with range [-1, 1].

Some performance issues in the programs

Hello, I found that there are some little performance issue in tf.random_uniform WS_DAN//nets/nasnet/nasnet_test.py. If the function is called a lot, the efficiency of program execution will be reduced. I think that tf.random_uniform should be created before the loop. There are also several similar places, such as inputs, inputs, inputs and here.

引用

你好,这篇文章最后发表在哪了?我想引用,但是找不到除了arxiv的bib。

当batchsize设为4或8时报错

train_sample.py中的np.random.choice(np.arange(0, num_parts), 1, p=part_weights) 处报错 “probabilities contain NaN”.
使用batchsize=1, 12, 16, 32均不报错。请问是为什么呢?该算法对batchsize有特殊限制?

consult

Hello, can you ask some questions?
In the test, when the batch-size is less than 64, the memory usage is 8707M. What is the operation? I don't understand.
excuse me.

FileWriter Warning

Finished trainig!Saving model to disk.保存模型的时候 attempting to use a closed FileWriter .The operation will be a noop unless the FileWriter is explicitly reopened.

something wrong when converting Aircraft

$ python convert_data.py --dataset_name=Aircraft --dataset_dir=./Aircraft/Data
Traceback (most recent call last):
File "convert_data.py", line 76, in
tf.app.run()
File "/mnt/disk0/home/mahailong/anaconda3/envs/CenterNet/lib/python3.6/site-packages/tensorflow/python/platform/app.py", line 126, in run
_sys.exit(main(argv))
File "convert_data.py", line 65, in main
convert_aircraft.run(FLAGS.dataset_dir)
File "/mnt/disk0/home/mahailong/WS_DAN/datasets/convert_aircraft.py", line 196, in run
train_dataset, test_dataset = generate_datasets(dataset_dir)
File "/mnt/disk0/home/mahailong/WS_DAN/datasets/convert_aircraft.py", line 155, in generate_datasets
train_info = np.loadtxt(os.path.join(data_root, 'fgvc-aircraft-2013b/data', 'images_variant_trainval.txt'), str)
File "/mnt/disk0/home/mahailong/anaconda3/envs/CenterNet/lib/python3.6/site-packages/numpy/lib/npyio.py", line 1141, in loadtxt
for x in read_data(_loadtxt_chunksize):
File "/mnt/disk0/home/mahailong/anaconda3/envs/CenterNet/lib/python3.6/site-packages/numpy/lib/npyio.py", line 1065, in read_data
% line_num)
ValueError: Wrong number of columns at line 1235

advisory

Hello, author.
Is the version of tensorflow you used version 1.12 after 1.9?
Thank you!

Embedding dimension

The depth of the feature_maps, aka the depth of Mixed_6e from Inception_v3, is 768 and by default 32 attention_maps are generated, then after the BAP module, the width and height of tensor are reduced, leaving a tensor of shape (N, 32, 768), right?
Then it is normalized and reshape to (N, 32*768) as the embeddings. It confuses me that wouldn't it a bit too large for an embedding? I read other papers about metric learning and most of them would not generate an embedding of size large than 512.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.