Giter Club home page Giter Club logo

mair's Introduction

Improve and Evaluate Robustness

MIT License Latest Release Code style: black

"Make your AI Robust."

MAIR is a PyTorch-based adversarial training framework. The goal of MAIR is to (1) provide an easy implementation of adversarial training methods and (2) make it easier to evaluate the adversarial robustness of deep learning models.

Adversarial training has become the de-facto standard method for improving the robustness of models against adversarial examples. However, during the writing of our paper, we realized that there is no framework integrating adversarial training methods. Therefore, to promote reproducibility and transparency in the field of deep learning, we integrated the algorithms, tools, and pre-trained models.

Citation: If you use this package, please cite the following BibTex (GoogleScholar):

@inproceedings{
    kim2023fantastic,
    title={Fantastic Robustness Measures: The Secrets of Robust Generalization},
    author={Hoki Kim and Jinseong Park and Yujin Choi and Jaewook Lee},
    booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
    year={2023},
    url={https://openreview.net/forum?id=AGVBqJuL0T}
}

Benchmarks on several adversarially trained models are available at our notion.

Installation and usage

Installation

pip install git+https://github.com/Harry24k/MAIR.git

Usage

import mair

How to train a model?

Step1. Load model as follows:

model = ...
rmodel = mair.RobModel(model, n_classes=10).cuda()

Step2. Set trainer as follows:

from mair.defenses import AT
# Set adversarial training method: [Strandard, AT, TRADES, MART].
trainer = AT(rmodel, eps=EPS, alpha=ALPHA, steps=STEPS)
# Set recording information.
trainer.record_rob(train_loader, val_loader, eps=EPS, alpha=2/255, steps=10, std=0.1)
# Set detail training methods.
trainer.setup(optimizer="SGD(lr=0.1, momentum=0.9)",
              scheduler="Step(milestones=[100, 150], gamma=0.1)",
              scheduler_type="Epoch",
              minimizer=None, # or "AWP(rho=5e-3)",
              n_epochs=200
             )

Step3. Fit model as follows:

trainer.fit(train_loader=train_loader,
            n_epochs=200,
            save_path='./models/', 
            save_best={"Clean(Val)":"HBO", "PGD(Val)":"HB"},
            # 'save_best': model with high PGD are chosen, 
            # while in similar cases, model with high Clean are selected.
            save_type="Epoch", 
            save_overwrite=False, 
            record_type="Epoch"
           )

How to evaluate a model?

Step1. Transform model as follows:

model = ...
rmodel = mair.RobModel(model, n_classes=10).cuda()

Step2. Evaluate model as follows:

rmodel.eval_accuracy(test_loader)  # clean accuracy
rmodel.eval_rob_accuracy_gn(test_loader)  # gaussian noise accuracy
rmodel.eval_rob_accuracy_fgsm(test_loader, eps)  # FGSM accuracy
rmodel.eval_rob_accuracy_pgd(test_loader, eps, alpha, steps)  # PGD accuracy

Please refer to demo for details.

Adversarial Benchmarks & Pre-trained models

Here is our (selected) benchmark on popular techniques in adversarial training frameworks.

Note that all robustness (or robust accuracy) are measured on CIFAR-10 test dataset against PGD10. Therefore, we should aware that some models might exhibit over-estimated robustness, which should be further verified w/ stronger attacks such as AutoAttack.

Through our notion, you can check more detailed benchmarks.

ResNet

Method Architecture AWP Extra Data Best Robustness Remark
AT ResNet18 52.73
MART ResNet18 54.73
TRADES ResNet18 53.47
AT ResNet18 ✔️ 55.52
MART ResNet18 ✔️ 57.64
TRADES ResNet18 ✔️ 55.91
AT ResNet18 ✔️ ✔️ 56.52
MART ResNet18 ✔️ ✔️ 57.93 👑
TRADES ResNet18 ✔️ ✔️ 55.51

WRN28-10

Method Architecture AWP Extra Data Best Robustness Remark
AT WRN28-10 56.00
MART WRN28-10 57.69
TRADES WRN28-10 56.81
AT WRN28-10 ✔️ 58.70
MART WRN28-10 ✔️ 60.39
TRADES WRN28-10 ✔️ 59.50
AT WRN28-10 ✔️ ✔️ 62.65
MART WRN28-10 ✔️ ✔️ 63.51 👑
TRADES WRN28-10 ✔️ ✔️ 61.81

WRN34-10

Method Architecture AWP Extra Data Best Robustness Remark
AT WRN34-10 56.19
MART WRN34-10 57.56
TRADES WRN34-10 56.67
AT WRN34-10 ✔️ 59.63
MART WRN34-10 ✔️ 10.00
TRADES WRN34-10 ✔️ 59.50
AT WRN34-10 ✔️ ✔️ 63.30
MART WRN34-10 ✔️ ✔️ 64.04 👑
TRADES WRN34-10 ✔️ ✔️ 62.07

Based on our notion, we built a hub module to support the direct use of our pretrained models.

from mair.hub import load_pretrained
rmodel = load_pretrained("CIFAR10_ResNet18_AT(eps=8, alpha=2, steps=10)", flag='Best', save_dir="./")

Please refer to demo for details.

Or you can use Google-drive.

In each folder, we upload four different files:

  • log.txt: training log during training.
  • last.pth: model at the end of epoch.
  • init.pth: model at the start of epoch.
  • best.pth: best model selected by the argment save_best in trainer.fit.

To load model,

rmodel.load_dict('./models/.../best.pth')

We are excited to share modes with the community, but we've run into a storage limitation on Google Drive. Any help would be greatly appreciated!

Contribution

We welcome all contribution to MAIR in many forms 😃. Especially, we are looking for diverse adversarial training methods beyond AT, TRADES, and MART.

Future work

  • Merge measures.
  • Generalize attacks gathered from torchattacks.
  • ...

mair's People

Contributors

harry24k avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.