Giter Club home page Giter Club logo

Comments (16)

ahmdtaha avatar ahmdtaha commented on May 29, 2024 3

Here are my two cents on this issue.
When working with BatchNorm, there are two set of variables to monitor.
The first set has both gamma (aka weight) and beta (aka bias), while the second set has both running_mean and running_var.
When this issue was raised, it was about both running_mean and running_var without regarding gamma and beta. Accordingly, the proposed solution added salt to injury and SAM no longer converges.

The initially proposed -- not working -- solution is to use bn.eval(). This freezes both the running_mean and running_var, but it also freezes beta and gamma. These are two learnable params that should update and remain aligned with the other learnable params. When beta and gamma are frozen using bn.eval, they diverge from the rest of params. This divergence is minimal with minimal impact initially -- at the first iterations. Yet, as the number of iteration increases, this divergence increases and the loss diverges to nan eventually.

Accordingly, I propose the following solution. By temporary setting momentum to zero, the running_mean and running_var are technically frozen. Yet, the learnable params beta and gamma are still learnable. The 2-step SAM optimizer would update beta and gamma along the same direction of the rest of params. SAM no longer diverges to nan.

def _disable_running_stats(self, m):
    if isinstance(m, nn.BatchNorm2d):
        m.backup_momentum = m.momentum
        m.momentum = 0
def _enable_running_stats(self, m):
    if isinstance(m, nn.BatchNorm2d):
        m.momentum = m.backup_momentum

These methods can be called using model.apply(self._disable_running_stats) and model.apply(self._enable_running_stats)

One final note for the astute, these two operations (Op1 and Op2) are probably introducing some precision issues.

from sam.

shuo-ouyang avatar shuo-ouyang commented on May 29, 2024 2

I hit the same problem when I save the running mean and var at the first pass and restore them at the second pass, the training accuracy is as normal as vanilla SGD but the validation accuracy is almost 0.

from sam.

davda54 avatar davda54 commented on May 29, 2024 1

Here are my two cents on this issue.
When working with BatchNorm, there are two set of variables to monitor.
The first set has both gamma (aka weight) and beta (aka bias), while the second set has both running_mean and running_var.
When this issue was raised, it was about both running_mean and running_var without regarding gamma and beta. Accordingly, the proposed solution added salt to injury and SAM no longer converges.

The initially proposed -- not working -- solution is to use bn.eval(). This freezes both the running_mean and running_var, but it also freezes beta and gamma. These are two learnable params that should update and remain aligned with the other learnable params. When beta and gamma are frozen using bn.eval, they diverge from the rest of params. This divergence is minimal with minimal impact initially -- at the first iterations. Yet, as the number of iteration increases, this divergence increases and the loss diverges to nan eventually.

Accordingly, I propose the following solution. By temporary setting momentum to zero, the running_mean and running_var are technically frozen. Yet, the learnable params beta and gamma are still learnable. The 2-step SAM optimizer would update beta and gamma along the same direction of the rest of params. SAM no longer diverges to nan.

def _disable_running_stats(self, m):
    if isinstance(m, nn.BatchNorm2d):
        m.backup_momentum = m.momentum
        m.momentum = 0
def _enable_running_stats(self, m):
    if isinstance(m, nn.BatchNorm2d):
        m.momentum = m.backup_momentum

These methods can be called using model.apply(self._disable_running_stats) and model.apply(self._enable_running_stats)

One final note for the astute, these two operations (Op1 and Op2) are probably introducing some precision issues.

Thank you very much for sharing these bugfixes with us! Using momentum to bypass the running statistics is very clever :) I've pushed a new commit that should correct both issues.

from sam.

ahmdtaha avatar ahmdtaha commented on May 29, 2024 1

@pengbohua I want to revise my previous comment. According to [1], "Empirically, the degree of improvement negatively correlates with the level of inductive biases built into the architecture." Indeed, when I evaluate SAM with a compact architecture, SAM bring marginal improvement if any. Yet, when I evaluate SAM with huge architecture, SAM delivers significant improvements!

[1] When Vision Transformers Outperform ResNets without Pre-training or Strong Data Augmentations

from sam.

ahmdtaha avatar ahmdtaha commented on May 29, 2024 1

Hi Ming-Hsuan-Tu @twmht ,

Sorry for the late reply. As mentioned previously batchnorm has two sets of variables
Set 1: {running_mean and running_var}
Set 2: {beta and gamma}

Set 1 is updated during every forward pass. In contrast, set 2 is updated during every backward pass. Think of {beta and gamma} as {weights and biases} of a typical layer. Basically, {beta and gamma} are used passively during forward passes.

When using bn.eval(), the batchnorm layer enters the inference mode. Accordingly, neither set 1 nor set 2 is updated!
Of course, this is an undesired behavior. This behavior has major negative implications with SAM. If you tried bn.eval(), your training will diverge, i.e., loss = nan. On a high level, this happens because almost every layer {weights and biases} is updated normally during the backward pass, while the batchnorm layers {beta and gamma} are not updated!

With a typical optimizer (e.g., SGD or Adam), this scenario won't lead to loss=nan. But with SAM, things are more entangled :)
I hope this clarifies things a bit.

from sam.

twmht avatar twmht commented on May 29, 2024 1

@ahmdtaha

When using bn.eval(), the batchnorm layer enters the inference mode. Accordingly, neither set 1 nor set 2 is updated!

to the best of my knowledge, set 1 won't be updated, but set2 would be updated if you do backward.

For example, most of object detections would call batchnorm.eval() when training (https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/backbones/resnet.py#L657), the main reason is due to small batch size, but the weight and bias of batchnorm would be still updated when calling backward.

If you want to freeze weight and bias you have to set requires_grad to False explictly(https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/backbones/resnet.py#L623).

from sam.

mrT23 avatar mrT23 commented on May 29, 2024

i also had problems with this logic.
i suggest removing it.

without it everything works fine, and indeed SAM (and especially adaptive-SAM) works better than Adam.

from sam.

davda54 avatar davda54 commented on May 29, 2024

Hi, that's a weird behavior IMO. It disables updates of the running statistics during the second pass, which probably shouldn't change the convergence drastically. Anyway, from my experience, the improvement of this "fix" is only minor, so it's okay to not use it; especially if it doesn't work for your task :)

from sam.

pengbohua avatar pengbohua commented on May 29, 2024

Yeah. That makes sense. Thank you for your reply.

from sam.

liujingcs avatar liujingcs commented on May 29, 2024

I think the problem is that using the running mean and var for the second pass during training is incorrect. To solve the BN issue, we should save the running mean and var at the first pass and restore them at the second pass.

from sam.

pengbohua avatar pengbohua commented on May 29, 2024

@ahmdtaha Sounds good. Did you see an improvement in test accuracy with your two modifications?

from sam.

ahmdtaha avatar ahmdtaha commented on May 29, 2024

@pengbohua I didn't observe any improvement by handling BatchNorm layers. Yet, I didn't play much with the hyper-parameter tuning. The increased training time -- introduce by SAM -- is a turn-off in my experiments' setting.

from sam.

twmht avatar twmht commented on May 29, 2024

@ahmdtaha

The initially proposed -- not working -- solution is to use bn.eval(). This freezes both the running_mean and running_var, but it also freezes beta and gamma

Why it also freezes beta and gamma? beta and gamma would be updated in the second forward pass too.

the training flag in BN only indicate whether to use global mean and variance (https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py#L145)

Please correct me if I have mistake.

from sam.

ahmdtaha avatar ahmdtaha commented on May 29, 2024

@twmht
You are right, but I am missing something here. I remember bn.eval() led to grad_norm explosion and accordingly loss=nan. It has been a while since I debugged this thing and I can't remember the details. If you think my solution introduces a problem, please let me know.

from sam.

msra-jqxu avatar msra-jqxu commented on May 29, 2024

Here are my two cents on this issue.
When working with BatchNorm, there are two set of variables to monitor.
The first set has both gamma (aka weight) and beta (aka bias), while the second set has both running_mean and running_var.
When this issue was raised, it was about both running_mean and running_var without regarding gamma and beta. Accordingly, the proposed solution added salt to injury and SAM no longer converges.
The initially proposed -- not working -- solution is to use bn.eval(). This freezes both the running_mean and running_var, but it also freezes beta and gamma. These are two learnable params that should update and remain aligned with the other learnable params. When beta and gamma are frozen using bn.eval, they diverge from the rest of params. This divergence is minimal with minimal impact initially -- at the first iterations. Yet, as the number of iteration increases, this divergence increases and the loss diverges to nan eventually.
Accordingly, I propose the following solution. By temporary setting momentum to zero, the running_mean and running_var are technically frozen. Yet, the learnable params beta and gamma are still learnable. The 2-step SAM optimizer would update beta and gamma along the same direction of the rest of params. SAM no longer diverges to nan.

def _disable_running_stats(self, m):
    if isinstance(m, nn.BatchNorm2d):
        m.backup_momentum = m.momentum
        m.momentum = 0
def _enable_running_stats(self, m):
    if isinstance(m, nn.BatchNorm2d):
        m.momentum = m.backup_momentum

These methods can be called using model.apply(self._disable_running_stats) and model.apply(self._enable_running_stats)
One final note for the astute, these two operations (Op1 and Op2) are probably introducing some precision issues.

Thank you very much for sharing these bugfixes with us! Using momentum to bypass the running statistics is very clever :) I've pushed a new commit that should correct both issues.

Hi, I tried this modification but my model still didn't converge. It shows below:
image

could you have a look at it? Thanks!
@davda54 @ahmdtaha

from sam.

ahmdtaha avatar ahmdtaha commented on May 29, 2024

Does this code converge without SAM? @msra-jqxu

from sam.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.