Giter Club home page Giter Club logo

dada's Introduction

DADA: Differentiable Automatic Data Augmentation

Contact us with [email protected], [email protected].

Introduction

The official code for our ECCV 2020 paper DADA: Differentiable Automatic Data Augmentation, which is at least one order of magnitude faster than the state-of-the-art data augmentation (DA) policy search algorithms while achieving very comparable accuracy. The implementation of our training part is based on fast-autoaugment.

License

The project is only free for academic research purposes, but needs authorization for commerce. For commerce permission, please contact [email protected].

Citation

If you use our code/model, please consider to cite our ECCV 2020 paper DADA: Differentiable Automatic Data Augmentation [arXiv] [ECCV].

@article{li2020dada,
  author    = {Yonggang Li and
               Guosheng Hu and
               Yongtao Wang and
               Timothy M. Hospedales and
               Neil Martin Robertson and
               Yongxin Yang},
  title     = {{DADA:} Differentiable Automatic Data Augmentation},
  booktitle = {The European Conference on Computer Vision (ECCV)},
  year      = {2020}
}

Model

We provide the checkpoints in BaiduDrive, with fetching code sgap, or GoogleDrive.

CIFAR-10

Search : 0.1 GPU Hours, WResNet-40x2 on Reduced CIFAR-10

Dataset Model Baseline Cutout AA PBA Fast AA DADA
CIFAR-10 Wide-ResNet-40-2 5.3 4.1 3.7 - 3.6 3.6
CIFAR-10 Wide-ResNet-28-10 3.9 3.1 2.6 2.6 2.7 2.7
CIFAR-10 Shake-Shake(26 2x32d) 3.6 3.0 2.5 2.5 2.7 2.7
CIFAR-10 Shake-Shake(26 2x96d) 2.9 2.6 2.0 2.0 2.0 2.0
CIFAR-10 Shake-Shake(26 2x112d) 2.8 2.6 1.9 2.0 2.0 2.0
CIFAR-10 PyramidNet+ShakeDrop 2.7 2.3 1.5 1.5 1.8 1.7

CIFAR-100

Search : 0.2 GPU Hours, WResNet-40x2 on Reduced CIFAR-100

Dataset Model Baseline Cutout AA PBA Fast AA DADA
CIFAR-100 Wide-ResNet-40-2 26.0 25.2 20.7 - 20.7 20.9
CIFAR-100 Wide-ResNet-28-10 18.8 18.4 17.1 16.7 17.3 17.5
CIFAR-100 Shake-Shake(26 2x96d) 17.1 16.0 14.3 15.3 14.9 15.3
CIFAR-100 PyramidNet+ShakeDrop 14.0 12.2 10.7 10.9 11.9 11.2

SVHN

Search : 0.1 GPU Hours, WResNet-28x10 on Reduced SVHN

Dataset Model Baseline Cutout AA PBA Fast AA DADA
SVHN Wide-ResNet-28-10 1.5 1.3 1.1 1.2 1.1 1.2
SVHN Shake-Shake(26 2x96d) 1.4 1.2 1.0 1.1 - 1.1

ImageNet

Search : 1.3 GPU Hours, ResNet-50 on Reduced ImageNet

Dataset Baseline AA Fast AA OHL AA DADA
ImageNet 23.7 / 6.9 ~22.4 / 6.2 22.4 / 6.3 21.1 / 5.7 22.5 / 6.5

Installation

Environment

  1. Ubuntu 16.04 LTS
  2. CUDA 10.0
  3. PyTorch 1.2.0
  4. TorchVision 0.4.0

Install

a. Create a conda virtual environment and activate it.

conda create -n dada-env python=3.6.10
source activate dada-env # or conda activate dada-env

b. Install PyTorch and torchvision following the official instructions, e.g.,

conda install pytorch==1.2.0 torchvision==0.4.0 cudatoolkit==10.0

c. Install other python package for DADA and fast-autoaugment, e.g.,

# for training and inference
pip install -r fast-autoaugment/requirements.txt

# for searching
pip install -r requirements.txt

Getting Started

Prepare Datasets

The dataset (except ImageNet) will be automatically download if you keep the default setting. You should put the data in ./data as below: (which include the datasets of CIFAR-10, CIFAR-100, SVHN, and ImageNet)

# CIFAR-10
./data/cifar-10-python.tar.gz

# CIFAR-100
./data/cifar-100-python.tar.gz

# SVHN
./data/train_32x32.mat
./data/extra_32x32.mat
./data/test_32x32.mat

# ImageNet
./data/imagenet-pytorch/
./data/imagenet-pytorch/meta.bin
./data/imagenet-pytorch/train
./data/imagenet-pytorch/val

Inference

Download the model-pth provided in , put them in ./fast-autoaugment/weights

cd fast-autoaugment
sh inference.sh

For example, you can test the provided wresnet40x2 model trained on CIFAR-10 as below:

# TITAN
GPUS=0
SEED=0
DATASET=cifar10
CONF=confs/wresnet40x2_cifar10_b512_test.yaml
GENOTYPE=CIFAR10
SAVE=weights/`basename ${CONF} .yaml`_${GENOTYPE}_${DATASET}_${SEED}/test.pth
CUDA_VISIBLE_DEVICES=${GPUS} python FastAutoAugment/train.py -c ${CONF} --dataset ${DATASET} --genotype ${GENOTYPE} --save ${SAVE} --seed ${SEED} --only-eval --batch 32

Train

The training script is provided, including most experiments of our paper.

cd fast-autoaugment
sh train.sh

For example, you can train a wresnet40x2 model on CIFAR-10 as below:

# TITAN
GPUS=0
SEED=0
DATASET=cifar10
CONF=confs/wresnet40x2_cifar10_b512_test.yaml
GENOTYPE=CIFAR10
SAVE=weights/`basename ${CONF} .yaml`_${GENOTYPE}_${DATASET}_${SEED}/test.pth
CUDA_VISIBLE_DEVICES=${GPUS} python FastAutoAugment/train.py -c ${CONF} --dataset ${DATASET} --genotype ${GENOTYPE} --save ${SAVE} --seed ${SEED}

Search

The searching script is provided, including CIFAR10, CIFAR100, SVHN, and ImageNet.

cd search_relax
sh train_paper.sh

For example, you can search a DA policy on the reduced-cifar10 dataset with wresnet40-2 model as below:

# you can change the hyper-parameters as below:
GPU=0
DATASET=reduced_cifar10
MODEL=wresnet40_2
EPOCH=20
BATCH=128
LR=0.1
WD=0.0002
AWD=0.0
ALR=0.005
CUTOUT=16
TEMPERATE=0.5
SAVE=CIFAR10
python train_search_paper.py --unrolled --report_freq 1 --num_workers 0 --epoch ${EPOCH} --batch_size ${BATCH} --learning_rate ${LR} --dataset ${DATASET} --model_name ${MODEL} --save ${SAVE} --gpu ${GPU} --arch_weight_decay ${AWD} --arch_learning_rate ${ALR} --weight_decay ${WD} --cutout --cutout_length ${CUTOUT} --temperature ${TEMPERATE}

The code for DADA with gumbel softmax is also included in this repository.

cd search_gumbel
sh train_paper.sh

Found Policy

We relase the found Data Augmentation policies in CIFAR-10, CIFAR-100, SVHN, and ImageNet by our DADA as below. The origin DA policies have been included in the fast-autoaugment/FastAutoAugment/genotype.py. You can find the genotype used by our paper as below:

vim fast-autoaugment/FastAutoAugment/genotype.py

CIFAR10

Sub-policy Opeartion 1 Opeartion 2
sub-policy 0 (TranslateX, 0.52, 0.58) (Rotate, 0.57, 0.53)
sub-policy 1 (ShearX, 0.50, 0.46) (Sharpness, 0.50, 0.54)
sub-policy 2 (Brightness, 0.56, 0.56) (Sharpness, 0.52, 0.47)
sub-policy 3 (ShearY, 0.62, 0.48) (Brightness, 0.47, 0.46)
sub-policy 4 (ShearX, 0.44, 0.58) (TranslateY, 0.40, 0.51)
sub-policy 5 (Rotate, 0.40, 0.52) (Equalize, 0.38, 0.36)
sub-policy 6 (AutoContrast, 0.44, 0.48) (Cutout, 0.49, 0.50)
sub-policy 7 (AutoContrast, 0.56, 0.48) (Color, 0.45, 0.61)
sub-policy 8 (Rotate, 0.42, 0.64) (AutoContrast, 0.60, 0.58)
sub-policy 9 (Invert, 0.40, 0.50) (Color, 0.50, 0.44)
sub-policy 10 (Posterize, 0.56, 0.50) (Brightness, 0.53, 0.48)
sub-policy 11 (TranslateY, 0.42, 0.51) (AutoContrast, 0.38, 0.57)
sub-policy 12 (ShearX, 0.38, 0.50) (Contrast, 0.49, 0.52)
sub-policy 13 (ShearY, 0.54, 0.60) (Rotate, 0.31, 0.56)
sub-policy 14 (Posterize, 0.42, 0.50) (Color, 0.45, 0.56)
sub-policy 15 (TranslateX, 0.41, 0.45) (TranslateY, 0.36, 0.48)
sub-policy 16 (TranslateX, 0.57, 0.50) (Brightness, 0.54, 0.48)
sub-policy 17 (TranslateX, 0.53, 0.51) (Cutout, 0.69, 0.49)
sub-policy 18 (ShearX, 0.46, 0.44) (Invert, 0.42, 0.40)
sub-policy 19 (Rotate, 0.50, 0.42) (Contrast, 0.49, 0.42)
sub-policy 20 (Rotate, 0.43, 0.47) (Solarize, 0.50, 0.42)
sub-policy 21 (TranslateY, 0.74, 0.51) (Color, 0.39, 0.57)
sub-policy 22 (Equalize, 0.42, 0.53) (Sharpness, 0.40, 0.43)
sub-policy 23 (Solarize, 0.73, 0.42) (Cutout, 0.51, 0.46)
sub-policy 24 (ShearX, 0.58, 0.56) (TranslateX, 0.48, 0.49)

CIFAR-100

Sub-policy Opeartion 1 Opeartion 2
sub-policy 0 (ShearY, 0.56, 0.28) (Sharpness, 0.49, 0.22)
sub-policy 1 (Rotate, 0.36, 0.19) (Contrast, 0.56, 0.31)
sub-policy 2 (TranslateY, 0.00, 0.41) (Brightness, 0.47, 0.52)
sub-policy 3 (AutoContrast, 0.80, 0.44) (Color, 0.44, 0.37)
sub-policy 4 (Color, 0.94, 0.25) (Brightness, 0.68, 0.45)
sub-policy 5 (TranslateY, 0.63, 0.40) (Equalize, 0.82, 0.30)
sub-policy 6 (Equalize, 0.46, 0.71) (Posterize, 0.50, 0.72)
sub-policy 7 (Color, 0.52, 0.48) (Sharpness, 0.19, 0.40)
sub-policy 8 (Sharpness, 0.42, 0.38) (Cutout, 0.55, 0.24)
sub-policy 9 (ShearX, 0.74, 0.56) (TranslateX, 0.48, 0.67)
sub-policy 10 (Invert, 0.36, 0.59) (Brightness, 0.50, 0.23)
sub-policy 11 (TranslateX, 0.36, 0.36) (Posterize, 0.80, 0.32)
sub-policy 12 (TranslateX, 0.48, 0.36) (Cutout, 0.64, 0.67)
sub-policy 13 (Posterize, 0.31, 0.04) (Contrast, 1.00, 0.08)
sub-policy 14 (Contrast, 0.42, 0.26) (Cutout, 0.00, 0.44)
sub-policy 15 (Equalize, 0.16, 0.69) (Brightness, 0.73, 0.18)
sub-policy 16 (Contrast, 0.45, 0.34) (Sharpness, 0.59, 0.28)
sub-policy 17 (TranslateX, 0.13, 0.54) (Invert, 0.33, 0.48)
sub-policy 18 (Rotate, 0.50, 0.58) (Posterize, 1.00, 0.74)
sub-policy 19 (TranslateX, 0.51, 0.43) (Rotate, 0.46, 0.48)
sub-policy 20 (ShearX, 0.58, 0.46) (TranslateY, 0.33, 0.31)
sub-policy 21 (Rotate, 1.00, 0.00) (Equalize, 0.51, 0.37)
sub-policy 22 (AutoContrast, 0.26, 0.57) (Cutout, 0.34, 0.35)
sub-policy 23 (ShearX, 0.56, 0.55) (Color, 0.50, 0.50)
sub-policy 24 (ShearY, 0.46, 0.09) (Posterize, 0.55, 0.34)

SVHN

Sub-policy Opeartion 1 Opeartion 2
sub-policy 0 (Solarize, 0.61, 0.53) (Brightness, 0.64, 0.50)
sub-policy 1 (ShearY, 0.56, 0.54) (Sharpness, 0.67, 0.50)
sub-policy 2 (AutoContrast, 0.64, 0.50) (Posterize, 0.49, 0.42)
sub-policy 3 (Invert, 0.43, 0.62) (Equalize, 0.30, 0.53)
sub-policy 4 (Contrast, 0.49, 0.55) (Color, 0.51, 0.58)
sub-policy 5 (ShearX, 0.58, 0.50) (Brightness, 0.56, 0.54)
sub-policy 6 (Rotate, 0.43, 0.50) (Contrast, 0.47, 0.42)
sub-policy 7 (Brightness, 0.51, 0.57) (Cutout, 0.48, 0.50)
sub-policy 8 (TranslateY, 0.65, 0.46) (Rotate, 0.43, 0.46)
sub-policy 9 (ShearY, 0.41, 0.43) (Contrast, 0.48, 0.49)
sub-policy 10 (ShearY, 0.52, 0.37) (Brightness, 0.43, 0.37)
sub-policy 11 (ShearY, 0.26, 0.49) (Posterize, 0.52, 0.56)
sub-policy 12 (TranslateX, 0.67, 0.38) (TranslateY, 0.45, 0.42)
sub-policy 13 (Posterize, 0.64, 0.43) (Sharpness, 0.63, 0.54)
sub-policy 14 (Rotate, 0.47, 0.50) (Sharpness, 0.40, 0.45)
sub-policy 15 (ShearX, 0.47, 0.46) (Cutout, 0.58, 0.50)
sub-policy 16 (Rotate, 0.58, 0.53) (Solarize, 0.41, 0.43)
sub-policy 17 (Color, 0.37, 0.44) (Brightness, 0.52, 0.41)
sub-policy 18 (TranslateX, 0.49, 0.47) (Posterize, 0.49, 0.52)
sub-policy 19 (TranslateY, 0.50, 0.49) (Solarize, 0.50, 0.42)
sub-policy 20 (TranslateY, 0.27, 0.50) (Invert, 0.56, 0.29)
sub-policy 21 (ShearY, 0.64, 0.57) (Rotate, 0.49, 0.57)
sub-policy 22 (Invert, 0.49, 0.55) (Contrast, 0.41, 0.50)
sub-policy 23 (ShearX, 0.57, 0.49) (Color, 0.60, 0.50)
sub-policy 24 (Rotate, 0.54, 0.53) (Equalize, 0.52, 0.50)

ImageNet

Sub-policy Opeartion 1 Opeartion 2
sub-policy 0 (TranslateY, 0.85, 0.64) (Contrast, 0.70, 0.47)
sub-policy 1 (ShearX, 0.69, 0.64) (Brightness, 0.58, 0.46)
sub-policy 2 (Solarize, 0.33, 0.53) (Contrast, 0.36, 0.40)
sub-policy 3 (ShearY, 0.54, 0.81) (Color, 0.65, 0.67)
sub-policy 4 (Rotate, 0.52, 0.28) (Invert, 0.55, 0.46)
sub-policy 5 (ShearY, 0.76, 0.55) (AutoContrast, 0.64, 0.46)
sub-policy 6 (TranslateX, 0.32, 0.67) (Sharpness, 0.45, 0.61)
sub-policy 7 (Brightness, 0.28, 0.54) (Cutout, 0.29, 0.53)
sub-policy 8 (TranslateY, 0.26, 0.39) (Brightness, 0.30, 0.57)
sub-policy 9 (ShearX, 0.46, 0.62) (Rotate, 0.51, 0.59)
sub-policy 10 (TranslateY, 0.63, 0.38) (Invert, 0.40, 0.33)
sub-policy 11 (TranslateY, 0.49, 0.32) (Equalize, 0.43, 0.26)
sub-policy 12 (TranslateX, 0.31, 0.46) (AutoContrast, 0.40, 0.00)
sub-policy 13 (ShearY, 0.57, 0.35) (Equalize, 0.45, 0.16)
sub-policy 14 (Solarize, 0.78, 0.61) (Brightness, 0.57, 0.80)
sub-policy 15 (Color, 0.75, 0.40) (Cutout, 0.54, 0.47)
sub-policy 16 (ShearX, 0.51, 0.67) (Cutout, 0.37, 0.45)
sub-policy 17 (TranslateX, 0.68, 0.39) (Rotate, 0.47, 0.16)
sub-policy 18 (Rotate, 0.64, 0.55) (Sharpness, 0.66, 0.80)
sub-policy 19 (TranslateY, 0.47, 0.75) (Sharpness, 0.64, 0.52)
sub-policy 20 (AutoContrast, 0.29, 0.54) (Posterize, 0.35, 0.70)
sub-policy 21 (Invert, 0.55, 0.49) (Equalize, 0.44, 0.76)
sub-policy 22 (TranslateX, 0.86, 0.29) (Contrast, 0.41, 0.60)
sub-policy 23 (Invert, 0.28, 0.45) (Posterize, 0.42, 0.34)
sub-policy 24 (Posterize, 0.15, 0.33) (Color, 0.50, 0.59)

dada's People

Contributors

latstars avatar vdigpku avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

dada's Issues

Why you put two "elif dataset == 'reduced'":

Hello, thank you for your excellent work. I noticed that there are two elif dataset == 'reduced_imagenet': in your dataset.py. It means that if you run this code, the second elif will not work. Is this an mistake.

About DifferentiableAugment

Hi,

I don't understand the DifferentiableAugment class in the implementation. What does it do? Just subtract and add magnitude with images?. Why you have adopted such as this? Is there any specific reason for it?

class DifferentiableAugment(nn.Module):
def init(self, sub_policy):
super(DifferentiableAugment, self).init()
self.sub_policy = sub_policy

  def forward(self, origin_images, probability_b, magnitude):
      images = origin_images
      adds = 0
      for i in range(len(self.sub_policy)):
          if probability_b[i].item() != 0.0:
              images = images - magnitude[i]
              adds = adds + magnitude[i]
      images = images.detach() + adds
      return images

can't get the same augmentation policy(genotype) with the searching code

Thanks for the great work and code!
I want to reproduce the same augmentation policy(called genotype in your code) with the provided searching code. I folllow the description in ReadME.md and search augmentation policy in reduced ImageNet with Res50. However, I found the policy i get is different from the policy you gave in genotype.py, so i want to know whether i do something wrong in repoducing the result. Here are some reasons i guessed may affect the searching results:
1.In the searching code, default random seed is 2 in train_search_paper.py, is this the same random seed you used to get the final result?
2.In searching, i found the augmentations are insert after colorjitter, but in training code, augmentation policy is inserted after RandomHorizontalFlip and before Colorjitter(line 95 in fast-autoaugment/FastAutoAugment/data.py), this is not consistent in training and searching.
Are these two reasons affect the seraching? Or there are some other details i did not found in searching process process?I look forward to your reply,thanks.

Explaination for Equation 10

Hi,

Thank you for your great work. I had a question regarding equation 10 in your paper, where you approximate gradient of transformed image with respect to the magnitude of the transformation operation ( d x_ij/ dm) = 1. I don't understand the reason behind this as the pixel values may not always increase (+1 gradient) if you increase the magnitude (for e.g. shearX, shearY).

Thanks.

dataloader num_workers=0

hi, it seems that in DADA, dataloader num_workers have to be 0 in order to avioding mislocation between gradient and actual augment parameters.
But if num worker ==0, the speed advantages of DADA cannot be shown if comparsion with PBA. So do you have some trick to deal with this issue?

Imagenet

Hi. Unable to reproduce the results (search_relax). I got only 65.23% (top1) in imagenet.?

Does it support Data Parallel and multiple GPU?

As stated in the problem, I tried using data parallel and commenting out set_device for multi-gpu, but seems to not work.

The script keeps showing
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
RuntimeError: CUDA error: initialization error

Thanks!

Reduced ImageNet split

Hi,

Thank you for sharing this great work! Could you share the ImageNet split you used in the experiment? Thank you!

Extract the found policy

Hi,
After checking and testing the code, I am looking a way to extract the found policy.

Also can you explain how we can reuse the found policy.

Two Questions about some details of the paper

Hello! I have some minor questions about certain details in the training of the network itself.

  1. How are the results of the paper acquired?

From the paper, its said that:
Following [3, 10, 15], we search the DA policies on the reduced datasets and evaluate on the full datasets. Furthermore, we split half of the reduced datasets as training set, and the remaining half as validation set for the data augmentation search.
So what is the workflow for the training of a neural network with DADA? Do we search on the dataset using train_search_paper, then transfer the policies and use it for training? If yes, then where is the method used to transfer the searched policies to the training? If no, then how is the validation data used? It seems like you are only using half of the data to train the neural network (train_portion = 0.5).

  1. Is any other sub policy depth/sub policy count considered for search?
  2. Why is ColorJitter used in conjunction with the DADA Policy? the subpolicies seem to be able to include it anyways.

Final DA policy: use all or 25?

Hi,

After you find the optimal DA policies, do you use all the policies (eg, 105 sub policies for cifar10) or choose the top 25, as shown in the readme?

Thanks!

How to add (weight, probability and magnitude) to the forward calculation to calculate the gradient?

I have the following two questions. Is the code in No.1 used to calculate the gradient of magnitude and probability when backward? And the second code is used to calculate the gradient of ops_weights?

1、
def forward(self, origin_images, trans_images, probability, probability_index, magnitude):
index = sum( p_i.item()<<i for i, p_i in enumerate(probability_index))
com_image = 0
images = origin_images
adds = 0

    for selection in range(2**len(self.sub_policy)):
        trans_probability = 1
        for i in range(len(self.sub_policy)):
            if selection & (1<<i):
                trans_probability = trans_probability * probability[i]
               
                if selection == index:
                    images = images - magnitude[i]
                    adds = adds + magnitude[i]
            else:
                trans_probability = trans_probability * ( 1 - probability[i] )
        if selection == index:
            images = images.detach() + adds
            com_image = com_image + trans_probability * images
        else:
            com_image = com_image + trans_probability
    return com_image  

2、
def forward(self, origin_images, trans_images_list, probabilities, probabilities_index, magnitudes, weights, weights_index):

    for i, (p, p_i, m, w, op) in enumerate(zip(probabilities, probabilities_index, magnitudes, weights, self._ops)):
        if weights_index.item() == i:
            return sum(w * op(origin_images, trans_images, p, p_i, m))
        else:
            return w

IDADA

Could you give me the IDADA source code?

use one-step unrolled validation loss

因为使用one-step unrolled validation loss,就是设置unrolled = Ture的时候,需要new_model去获得虚拟的梯度,内存太大了,所以我直接使用unrolled=False,但是这样的话,就只会更新magnitude,不会更新weight和probabilities,这个地方是不是代码写的有问题
def _backward_step(self, input_valid, target_valid): loss = self.model._loss(input_valid, target_valid) loss.backward()

Low accuracy while searching

set -x

cifar100

GPU=1
DATASET=cifar100
MODEL=resnet50
EPOCH=20
BATCH=128
LR=0.1
WD=0.0002
AWD=0.0
ALR=0.005
CUTOUT=16
TEMPERATE=0.5

which python
python train_search_paper.py --unrolled --report_freq 1 --num_workers 0 --epoch ${EPOCH} --batch_size ${BATCH} --learning_rate ${LR} --dataset ${DATASET} --model_name ${MODEL} --gpu ${GPU} --arch_weight_decay ${AWD} --arch_learning_rate ${ALR} --weight_decay ${WD} --cutout --cutout_length ${CUTOUT} --temperature ${TEMPERATE}

Hello, I used the reset50 network to search for the augmentation policy. During searching, I noticed that the accuracy for training and validation is very low.

04/27 03:59:04 PM valid 187 2.398671e+00 38.285406 70.545213
04/27 03:59:04 PM valid 188 2.397762e+00 38.289517 70.568783
04/27 03:59:04 PM valid 189 2.397031e+00 38.297697 70.575658
04/27 03:59:04 PM valid 190 2.395328e+00 38.326243 70.590641
04/27 03:59:04 PM valid 191 2.396694e+00 38.309733 70.576986
04/27 03:59:04 PM valid 192 2.396227e+00 38.337921 70.575615
04/27 03:59:05 PM valid 193 2.396536e+00 38.321521 70.574259
04/27 03:59:05 PM valid 194 2.396318e+00 38.325321 70.584936
04/27 03:59:05 PM valid 195 2.395298e+00 38.336000 70.588000
04/27 03:59:05 PM valid_acc 38.336000

Is this Ok?

Ops weights and probabilities not getting updated

As a part of my personal research, I am working on studying various automated data augmentation techniques.
Thus, while trying to reproduce your results, I am facing issues with the updation of probabilities and op_weights. Only the value of the magnitude is getting updated over epochs and the probabilities and ops_weights are remaining constant throughout the runtime at values 0.5 and 0.0095 respectively.
I would like to request you to kindly help me rectify this issue!

Thanking you.

Does search phase training loss converging?

I tried to search reduce imagenet policy by origin imagenet search script, but found the epoch is only 20.
it seems that in only 20 epoch, training loss is not converging?
If traing loss is not converging , how to validate the performace of data augment policy?
By the way, could you share the train loss, val acc, train acc of search phase in the end?

A question about the gradients of sampling

Hello, thank you for your great works. I have a question about how you update the sampling parameters (arch parameters). You update the the sampling parameters with the validation set. But your do not augment the validation set with the sampled augmentation. That means that, the gradients of the loss respect to the validation data is None. Then how do you update the sampling parameters by the validation set?

No genotypes module

Dear Authors,

I follow the instruction to install required modules. But when I run search_relax train.py I received no module genotypes. May I ask how can I install this module?

Best

How to get the NetworkCIFAR ?

image
Hi!
i want to run the search_relax/train.py, but there is no model/NetworkCIFAR.

I don't know how to get it. So can you upload the NetworkCIFAR.py file?
Thanks!

Error in search_gumbel

Hi,

I tried running the Gumbel-Softmax model with the following parameters:

GPU=0
DATASET=cifar10
MODEL=wresnet40_2
EPOCH=200
BATCH=128
LR=0.1
WD=0.0002
AWD=0.0
ALR=0.001
CUTOUT=16
SAVE=CIFAR10

python train_search_paper.py --unrolled --report_freq 1 --num_workers 0 --epoch ${EPOCH} --batch_size ${BATCH} --learning_rate ${LR} --dataset ${DATASET} --model_name ${MODEL} --save ${SAVE} --gpu ${GPU} --arch_weight_decay ${AWD} --arch_learning_rate ${ALR} --cutout --cutout_length ${CUTOUT}

and I am getting the following error:

Traceback (most recent call last):
File "train_search_paper.py", line 284, in
main()
File "train_search_paper.py", line 175, in main
train_acc, train_obj = train(train_queue, valid_queue, model, architect, criterion, optimizer, lr)
File "train_search_paper.py", line 223, in train
loss.backward()
File "/scratch/clear/jmarrie/miniconda3/envs/env/lib/python3.8/site-packages/torch/_tensor.py", line 255, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/scratch/clear/jmarrie/miniconda3/envs/env/lib/python3.8/site-packages/torch/autograd/init.py", line 147, in backward
Variable._execution_engine.run_backward(
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [105]] is $t version 1; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

Do you have any idea where it comes from?

Thanks,

Juliette

About method DifferentiableAugment

Hi, thanks for the good work. I have one question need your help. I am confused with the method DifferentiableAugment in https://github.com/VDIGPKU/DADA/blob/master/search_relax/model_search.py#L43.
`class DifferentiableAugment(nn.Module):
def init(self, sub_policy):
super(DifferentiableAugment, self).init()
self.sub_policy = sub_policy

def forward(self, origin_images, probability_b, magnitude):
    images = origin_images
    adds = 0
    for i in range(len(self.sub_policy)):
        if probability_b[i].item() != 0.0:
            images = images - magnitude[i]
            adds = adds + magnitude[i]
    images = images.detach() + adds
    return images`

It seems that image processing oprations are not really called to preprocess training images on search stage. For each sub_policy, just minus the magnitude? If I misunderstood something, please tell me, thanks.

How to choose the final policy in the search phase?

Hi!
I have set learning rate=0.001. But when searching, the valid metric curve is not unstable.
Now I choose the top-25 sub_policy(epoch=36, 1 epoch=500 iter),because the valid metric curve looks good. Then the second train stage is converged. I don't know why?
Additionally, I choose the sub_policy in the case of curve oscillation, the second train stage is not stable and model is not converged.
During the entire search phase, the valid metric are oscillating. Is this normal?

I don't know how to choose the final policy, looks like I select sub_policy randomly. Thanks!

The search stage figure:
X axis-epoch, Y axis-valid metric.
image

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.