Giter Club home page Giter Club logo

semisup-adv's Introduction

Unlabeled Data Improves Adversarial Robustness

This repository contains code for reproducing data and models from the NeurIPS 2019 paper Unlabeled Data Improves Adversarial Robustness by Yair Carmon, Aditi Raghunathan, Ludwig Schmidt, Percy Liang and John C. Duchi.

CIFAR-10 unlabeled data and trained models

Below are links to files containing our unlabeled data and pretrained models:

Additional files:

Dependencies

To create a conda environment called semisup-adv containing all the dependencies, run

conda env create -f environment.yml  

Note: We tested this code on 2 GPUs in parallel, each with 12GB of memory. Running on CPUs or GPUs with less memory might require adjustments.

The code in this repo is based on code from the following sources:

Running robust self-training

To run robust self-training you will a pickle file containing pseudo-labeled data. You can download ti_500K_pseudo_labeled.pickle containing our 500K pseudo-labeled TinyImages, or you can generate one from scratch using the instructions below.

Adversarial training with TRADES

The following command performs adversarial training and produces a model
equivalent to RST_adv(50K+500K) described in the paper.

python robust_self_training.py --aux_data_filename ti_500K_pseudo_labeled.pickle --distance l_inf --epsilon 0.031 --model_dir rst_adv 

When the script finishes running there will a be checkpoint file called rst_adv/checkpoint-epoch200.pt. The following commands runs a PGD attack (PGD_Ours from the paper) on the model

python attack_evaluation.py --model_path rst_adv/checkpoint-epoch200.pt --attack pgd --output_suffix pgd_ours  

To run the Carlini-Wanger attack on randomly selected 1000 images from the test set, use

python attack_evaluation.py --model_path rst_adv/checkpoint-epoch200.pt --attack cw --output_suffix cw --num_eval_batches 5 --shuffle_testset  

Stability training

The following commands performs stability training and produces a model equivalent to
RST_stab(50K+500K) described in the paper.

python robust_self_training.py --aux_data_filename ti_500K_pseudo_labeled.pickle --distance l_2 --epsilon 0.25 --model_dir rst_stab --epochs 800

When the script finishes running there will a be checkpoint file called rst_stab/checkpoint-epoch800.pt. The following commands runs randomized smoothing certification on the model, as described in the paper.

python smoothing_evaluation.py --model_path rst_stab/checkpoint-epoch800.pt --sigma 0.25  

Creating the unlabeled data from scratch

Note: creating the unlabeled data from scratch takes a while; plan for three days at least.

Step zero: Downloading data

Create a data directory that has the following files:

Step one: Tiny Image preliminaries

In this step, we do the following two preliminary steps.

  1. Compute distances from all the TinyImages to CIFAR-10 test set, in order to ensure we do not add any images from the test set to the unlabeled data sourced from TinyImages.
  2. Create train/test data for selection model (See Appendix B.6)

Note that the data directory should contain the following files: tiny_images.bin, cifar10_keywords_unique_v7.json, tinyimage_subset_indices_v7.json and tinyimage_subset_data_v7.pickle.

Here is an example run.

python tinyimages_preliminaries.py --data_dir ../data/ --output_dir ../data

Step two: Train a selection model

Here we train the data selection model described in Appendix B.6 of the paper. Note that data_dir should contain the following files: tiny_images.bin, ti_vs_cifar_inds.pickle (from above).

Here is an example run.

python train_cifar10_vs_ti.py --output_dir ../cifar10-vs-ti/ --data_dir ../data/  

Step three: Selecting unlabeled data and removing CIFAR-10 test set

We apply the model trained above on TinyImages and select images based on the predictions, while making sure to remove all images that are close (in l2 distance) to the CIFAR-10 test set.

python tinyimages_prediction.py --model_path ../cifar10-vs-ti/model_state_epoch520.pth --data_dir ../data --output_dir ../data/ --output_filename ti_500K_unlabeled.pickle  

Step four: Training a vanilla model on CIFAR-10

We now train a model (Wide ResNet 28-10) on CIFAR-10 training set.

python robust_self_training.py --distance l_2 --beta 0 --unsup_fraction 0 --model_dir vanilla  

Step five: Generating pseudo-labels

As a final step, we generate pseudo-labels by applying the classifier from Step 4 on the unlabeled data sourced in Step 3.

python generate_pseudolabels.py --model_dir ../vanilla  --model_epoch 200 --data_dir ../data/ --data_filename ti_500K_unlabeled.pickle --output_dir ../data/ --output_filename ti_500K_pseudo_labeled.pickle  

Reference

@inproceedings{carmon2019unlabeled,  
author = {Yair Carmon and Aditi Raghunathan and Ludwig Schmidt and Percy Liang and John Duchi},  
title = {Unlabeled Data Improves Adversarial Robustness},  
year = 2019,  
booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},  
}  

semisup-adv's People

Contributors

jaturongkongmanee avatar yaircarmon avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.