Giter Club home page Giter Club logo

manifoldmixup's Introduction

ManifoldMixup

Unofficial implementation of ManifoldMixup (Proceedings of ICML 19) in PyTorch with support for Interpolated Adversarial training.

Results of adversarial training on MNIST

Dataset Adversarial training Normal training
MNIST 0.732 0.0069

Adversarial training done using Interpolated adversarial training framework with FGSM attack (eps=0.3) and MixupManifold

Update

Following discussion with author the following features are added:

  1. Module level control for user which allows deciding specifically which layers to consider for mixup
  2. Warning is raised if a module is called more than once in forward pass, and mixup is done at first instance only
  3. If for some reason, you cannot pre-decide which modules to use for mixup, pass mixup_all=True while creating ManifoldMixupModel instance

Usage

ManifoldMixup training

from manifold_mixup import ManifoldMixupDataset, ManifoldMixupModel, ManifoldMixupLoss, MixupModule
"""
(optional)
Wrap modules which you want to use for mixup using MixupModule
Example:
class Model(nn.Module):
    def __init__(self, in_dims, hid_dims, out_dims):
        super(Model, self).__init__()
        self.m = nn.Sequential(Flatten(),
                              LinearLayer(in_dims, hid_dims, use_bn=True),
                              MixupModule(LinearLayer(hid_dims, hid_dims, use_bn=True)),
                              MixupModule(LinearLayer(hid_dims, hid_dims, use_bn=True)),
                              nn.Linear(hid_dims, out_dims))
    def forward(self, x):
        return self.m(x)
"""
"""
Wrap your dataset, model and loss in ManifoldMixup classes that's it!
"""
mixup_ds = ManifoldMixupDataset(trn_ds)
mixup_model = ManifoldMixupModel(model, alpha=0.2)
mixup_criterion = ManifoldMixupLoss(criterion)
"""
Now train as usual using mixup dataset, model and loss
"""

Interpolated Adversarial training with ManifoldMixup

from manifold_mixup import ManifoldMixupDataset, ManifoldMixupModel
from adversarial_attacks import FGSM
from interpolated_adversarial import InterpolatedAdversarialLoss

mixup_ds = ManifoldMixupDataset(trn_ds)
mixup_model = ManifoldMixupModel(model, alpha=0.2, interpolated_adv=True)

"""
To define loss for interpolated adversarial training, you need to pass attack and
your original loss function
"""
adv_loss = nn.CrossEntropyLoss()
model_loss = nn.CrossEntropyLoss()
fgsm = FGSM(adv_loss)
adv_criterion = InterpolatedAdversarialLoss(model_loss, fgsm)
"""
Now train as usual with new model, dataset and loss
"""

Creating custom adversarial attack

from adversarial_attacks import BlackBoxAdversarialAttack
"""
BlackBoxAdversarialAttack automatically handles your model's parameter's requirement for grad
and creates a callable compatible with InterpolatedAdversarialLoss
"""
class MyNewAttack(BlackBoxAdversarialAttack):
  def run(self, x, y, model):
    """
    Logic for attack
    """

Fastai implementation by @nestordemeure

https://github.com/nestordemeure/ManifoldMixup

manifoldmixup's People

Contributors

nestordemeure avatar shivamsaboo17 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

manifoldmixup's Issues

I am not able to import torched module ..

Hello, can you please help me find the torched module in your code .. I am not able to Import it.

from torched.customs.layers import LinearLayer, Flatten
from torched.trainer_utils import Train

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.