Giter Club home page Giter Club logo

simple-sam's Introduction

simple-SAM

Sharpness-Aware Minimization for Efficiently Improving Generalization


This is an unofficial repository for Sharpness-Aware Minimization for Efficiently Improving Generalization.

Shortened abstract:
Optimizing only the training loss value, as is commonly done, can easily lead to suboptimal model quality. Motivated by the connection between geometry of the loss landscape and generalization, SAM is a novel, effective procedure for instead simultaneously minimizing loss value and loss sharpness. In particular, Sharpness-Aware Minimization (SAM), seeks parameters that lie in neighborhoods having uniformly low loss, an optimization problem on which gradient descent can be performed efficiently.

The implementation uses Tensorflow 2 and is heavily inspired by davda54's PyTorch implementation.

fig fig
A sharp minimum to which a ResNet trained with SGD converged A wide minimum to which the same ResNet trained with SAM converged.

Usage

Using SAM is easy in custom training loops:

...

from sam import SAM

model = YourModel()
base_optimizer = tf.keras.optimizers.SGD()  # define an optimizer for the "sharpness-aware" update
optimizer = SAM(base_optimizer)

...

@tf.function
def train_step_SAM(images, labels):
    with tf.GradientTape() as tape:
        predictions = model(images, training=True)
        loss = loss_object(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.first_step(gradients, model.trainable_variables)

    with tf.GradientTape() as tape:
        predictions = model(images, training=True)
        loss = loss_object(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.second_step(gradients, model.trainable_variables)

...

for x, y in dataset:
    train_step_SAM(x, y)
  
...

If you want to use the Keras API:

...

from sam import sam_train_step

# override the train_step function of the keras model
class YourModel(tf.keras.Model):
    def train_step(self, data):
        return sam_train_step(self, data)

inputs = Input(...)
outputs = ...
model = YourModel(inputs, outputs)

model.compile(...)
model.fit(x_train, y_train, epochs=3)

...

Documentation

SAM.__init__

Argument Description
base_optimizer (tf.keras.optimizers) underlying optimizer that does the "sharpness-aware" update
rho (float, optional) size of the neighborhood for computing the max loss (default: 0.05)


SAM.first_step

Performs the first optimization step that finds the weights with the highest loss in the local rho-neighborhood.

Argument Description
gradients gradients computed by the first backward pass
trainable_parameters model parameters to be trained

SAM.second_step

Performs the second optimization step that updates the original weights with the gradient from the (locally) highest point in the loss landscape.

Argument Description
gradients gradients computed by the second backward pass
trainable_parameters model parameters to be trained

simple-sam's People

Contributors

hululuzhu avatar jannoshh avatar simonbiggs avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

simple-sam's Issues

Replicating results

Hi,
I was just wondering, were you able to replicate the results in the original paper using this implementation?
Thanks

Too many values to unpack when including sample_weights

If sample_weights are included, I think that data has a len() of 3, so I think it should be x, y, sample_weight = data.
My question is how do I properly include the sample_weight in the training step? I tested not using the sample_weights and it gives wrong initial loss.
I followed the tensorflow guide and included the sample weight as such
loss = self.compiled_loss(y, y_pred, sample_weight=sample_weight, regularization_losses=self.losses)
for the 2 calls to self.compiled_loss in the 2 separate
with tf.GradientTape() as tape:
blocks. Is this the correct implementation for both class_weights and sample_weights?

Customize fit method

Hi!
Thanks for the code, I was waiting for a tf.keras implementation of SAM. I wonder if you have though about how to implement this by customizing the fit method (I prefer this to the "full" custom training loop so I can use callbacks easily).
I think that in order to work the SAM class would have to inherit from tf.keras.optimizers.Optimizer, but I am not sure how to make your code work in that case. Do you have any idea?

Epsilon value of 1e-12 too small for mixed precision

The epsilon value of 1e-12 used in the following lines for the first_step and sam_train_step functions is too low and can cause NaN errors with training with mixed precision:
e_w = gradients[i] * self.rho / (grad_norm + 1e-12)

I recommend modifying the value to be at least 1e-7 and to also include loss scaling for sam_train_step such that it supports loss scale optimizers. Example implementation is:

def sam_train_step(self, data, rho=0.05, epsilon=1e-7):
    # Unpack the data. Its structure depends on your model and
    # on what you pass to `fit()`.
    x, y = data

    with tf.GradientTape() as tape:
        y_pred = self(x, training=True)  # Forward pass
        # Compute the loss value
        # (the loss function is configured in `compile()`)
        loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
        scaled_loss = self.optimizer.get_scaled_loss(loss)

    # Compute gradients
    trainable_vars = self.trainable_variables
    scaled_gradients  = tape.gradient(scaled_loss, trainable_vars)
    gradients = self.optimizer.get_unscaled_gradients(scaled_gradients)

    # first step
    e_ws = []
    grad_norm = tf.linalg.global_norm(gradients)
    for i in range(len(trainable_vars)):
        e_w = gradients[i] * rho / (grad_norm + epsilon)
        trainable_vars[i].assign_add(e_w)
        e_ws.append(e_w)

    with tf.GradientTape() as tape:
        y_pred = self(x, training=True)  # Forward pass
        # Compute the loss value
        # (the loss function is configured in `compile()`)
        loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
        scaled_loss = self.optimizer.get_scaled_loss(loss)
    trainable_vars = self.trainable_variables
    scaled_gradients  = tape.gradient(scaled_loss, trainable_vars)
    gradients = self.optimizer.get_unscaled_gradients(scaled_gradients)

    for i in range(len(trainable_vars)):
        trainable_vars[i].assign_add(-e_ws[i])
    self.optimizer.apply_gradients(zip(gradients, trainable_vars))
    # Update metrics (includes the metric that tracks the loss)
    self.compiled_metrics.update_state(y, y_pred)
    # Return a dict mapping metric names to current value
    return {m.name: m.result() for m in self.metrics}

And I also recommend suggesting that user use a loss scale optimizer with a low initial scale e.g. keras.mixed_precision.LossScaleOptimizer( optimizer, initial_scale=2 ** 2 ) . I have not submitted this is a pull request as I have yet to fully experiment with various architectures and hyperparameters, but this has proven effective at preventing NaN errors for ResNet and DenseNet style architectures when using mixed precision with my limited experimentation.

Reproducing WRN-28-10 (SAM) for SVHN dataset

I am trying to reproduce the results for WRN-28-10 (SAM) trained on 10-class classification SVHN dataset (Percentage Error 0.99) - https://paperswithcode.com/sota/image-classification-on-svhn

I'm able to train WRN-28-10 using https://github.com/hysts/pytorch_wrn (Modified the script to incorporate SAM into it)

I'm achieving a test accuracy of 93%. How can I replicate the SOTA Percentage Error 0.99 for WRN-28-10 (SAM). Which hyperparameters do I use?

Any help is appreciated!!

Verification

Have you conducted experiments to verify that your implementation could reproduce similar results to the original implementation?
Thanks

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.