Comments (16)
Here are my two cents on this issue.
When working with BatchNorm, there are two set of variables to monitor.
The first set has both gamma (aka weight) and beta (aka bias), while the second set has both running_mean and running_var.
When this issue was raised, it was about both running_mean and running_var without regarding gamma and beta. Accordingly, the proposed solution added salt to injury and SAM no longer converges.
The initially proposed -- not working -- solution is to use bn.eval(). This freezes both the running_mean and running_var, but it also freezes beta and gamma. These are two learnable params that should update and remain aligned with the other learnable params. When beta and gamma are frozen using bn.eval, they diverge from the rest of params. This divergence is minimal with minimal impact initially -- at the first iterations. Yet, as the number of iteration increases, this divergence increases and the loss diverges to nan eventually.
Accordingly, I propose the following solution. By temporary setting momentum
to zero, the running_mean and running_var are technically frozen. Yet, the learnable params beta and gamma are still learnable. The 2-step SAM optimizer would update beta and gamma along the same direction of the rest of params. SAM no longer diverges to nan.
def _disable_running_stats(self, m):
if isinstance(m, nn.BatchNorm2d):
m.backup_momentum = m.momentum
m.momentum = 0
def _enable_running_stats(self, m):
if isinstance(m, nn.BatchNorm2d):
m.momentum = m.backup_momentum
These methods can be called using model.apply(self._disable_running_stats)
and model.apply(self._enable_running_stats)
One final note for the astute, these two operations (Op1 and Op2) are probably introducing some precision issues.
from sam.
I hit the same problem when I save the running mean and var at the first pass and restore them at the second pass, the training accuracy is as normal as vanilla SGD but the validation accuracy is almost 0.
from sam.
Here are my two cents on this issue.
When working with BatchNorm, there are two set of variables to monitor.
The first set has both gamma (aka weight) and beta (aka bias), while the second set has both running_mean and running_var.
When this issue was raised, it was about both running_mean and running_var without regarding gamma and beta. Accordingly, the proposed solution added salt to injury and SAM no longer converges.The initially proposed -- not working -- solution is to use bn.eval(). This freezes both the running_mean and running_var, but it also freezes beta and gamma. These are two learnable params that should update and remain aligned with the other learnable params. When beta and gamma are frozen using bn.eval, they diverge from the rest of params. This divergence is minimal with minimal impact initially -- at the first iterations. Yet, as the number of iteration increases, this divergence increases and the loss diverges to nan eventually.
Accordingly, I propose the following solution. By temporary setting
momentum
to zero, the running_mean and running_var are technically frozen. Yet, the learnable params beta and gamma are still learnable. The 2-step SAM optimizer would update beta and gamma along the same direction of the rest of params. SAM no longer diverges to nan.def _disable_running_stats(self, m): if isinstance(m, nn.BatchNorm2d): m.backup_momentum = m.momentum m.momentum = 0
def _enable_running_stats(self, m): if isinstance(m, nn.BatchNorm2d): m.momentum = m.backup_momentum
These methods can be called using
model.apply(self._disable_running_stats)
andmodel.apply(self._enable_running_stats)
One final note for the astute, these two operations (Op1 and Op2) are probably introducing some precision issues.
Thank you very much for sharing these bugfixes with us! Using momentum to bypass the running statistics is very clever :) I've pushed a new commit that should correct both issues.
from sam.
@pengbohua I want to revise my previous comment. According to [1], "Empirically, the degree of improvement negatively correlates with the level of inductive biases built into the architecture." Indeed, when I evaluate SAM with a compact architecture, SAM bring marginal improvement if any. Yet, when I evaluate SAM with huge architecture, SAM delivers significant improvements!
[1] When Vision Transformers Outperform ResNets without Pre-training or Strong Data Augmentations
from sam.
Hi Ming-Hsuan-Tu @twmht ,
Sorry for the late reply. As mentioned previously batchnorm has two sets of variables
Set 1: {running_mean and running_var}
Set 2: {beta and gamma}
Set 1 is updated during every forward pass. In contrast, set 2 is updated during every backward pass. Think of {beta and gamma} as {weights and biases} of a typical layer. Basically, {beta and gamma} are used passively during forward passes.
When using bn.eval(), the batchnorm layer enters the inference
mode. Accordingly, neither set 1 nor set 2 is updated!
Of course, this is an undesired behavior. This behavior has major negative implications with SAM. If you tried bn.eval(), your training will diverge, i.e., loss = nan. On a high level, this happens because almost every layer {weights and biases} is updated normally during the backward pass, while the batchnorm layers {beta and gamma} are not updated!
With a typical optimizer (e.g., SGD or Adam), this scenario won't lead to loss=nan. But with SAM, things are more entangled :)
I hope this clarifies things a bit.
from sam.
When using bn.eval(), the batchnorm layer enters the inference mode. Accordingly, neither set 1 nor set 2 is updated!
to the best of my knowledge, set 1 won't be updated, but set2 would be updated if you do backward.
For example, most of object detections would call batchnorm.eval() when training (https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/backbones/resnet.py#L657), the main reason is due to small batch size, but the weight and bias of batchnorm would be still updated when calling backward.
If you want to freeze weight and bias you have to set requires_grad to False explictly(https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/backbones/resnet.py#L623).
from sam.
i also had problems with this logic.
i suggest removing it.
without it everything works fine, and indeed SAM (and especially adaptive-SAM) works better than Adam.
from sam.
Hi, that's a weird behavior IMO. It disables updates of the running statistics during the second pass, which probably shouldn't change the convergence drastically. Anyway, from my experience, the improvement of this "fix" is only minor, so it's okay to not use it; especially if it doesn't work for your task :)
from sam.
Yeah. That makes sense. Thank you for your reply.
from sam.
I think the problem is that using the running mean and var for the second pass during training is incorrect. To solve the BN issue, we should save the running mean and var at the first pass and restore them at the second pass.
from sam.
@ahmdtaha Sounds good. Did you see an improvement in test accuracy with your two modifications?
from sam.
@pengbohua I didn't observe any improvement by handling BatchNorm layers. Yet, I didn't play much with the hyper-parameter tuning. The increased training time -- introduce by SAM -- is a turn-off in my experiments' setting.
from sam.
The initially proposed -- not working -- solution is to use bn.eval(). This freezes both the running_mean and running_var, but it also freezes beta and gamma
Why it also freezes beta and gamma? beta and gamma would be updated in the second forward pass too.
the training flag in BN only indicate whether to use global mean and variance (https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py#L145)
Please correct me if I have mistake.
from sam.
@twmht
You are right, but I am missing something here. I remember bn.eval() led to grad_norm explosion and accordingly loss=nan. It has been a while since I debugged this thing and I can't remember the details. If you think my solution introduces a problem, please let me know.
from sam.
Here are my two cents on this issue.
When working with BatchNorm, there are two set of variables to monitor.
The first set has both gamma (aka weight) and beta (aka bias), while the second set has both running_mean and running_var.
When this issue was raised, it was about both running_mean and running_var without regarding gamma and beta. Accordingly, the proposed solution added salt to injury and SAM no longer converges.
The initially proposed -- not working -- solution is to use bn.eval(). This freezes both the running_mean and running_var, but it also freezes beta and gamma. These are two learnable params that should update and remain aligned with the other learnable params. When beta and gamma are frozen using bn.eval, they diverge from the rest of params. This divergence is minimal with minimal impact initially -- at the first iterations. Yet, as the number of iteration increases, this divergence increases and the loss diverges to nan eventually.
Accordingly, I propose the following solution. By temporary settingmomentum
to zero, the running_mean and running_var are technically frozen. Yet, the learnable params beta and gamma are still learnable. The 2-step SAM optimizer would update beta and gamma along the same direction of the rest of params. SAM no longer diverges to nan.def _disable_running_stats(self, m): if isinstance(m, nn.BatchNorm2d): m.backup_momentum = m.momentum m.momentum = 0
def _enable_running_stats(self, m): if isinstance(m, nn.BatchNorm2d): m.momentum = m.backup_momentum
These methods can be called using
model.apply(self._disable_running_stats)
andmodel.apply(self._enable_running_stats)
One final note for the astute, these two operations (Op1 and Op2) are probably introducing some precision issues.Thank you very much for sharing these bugfixes with us! Using momentum to bypass the running statistics is very clever :) I've pushed a new commit that should correct both issues.
Hi, I tried this modification but my model still didn't converge. It shows below:
could you have a look at it? Thanks!
@davda54 @ahmdtaha
from sam.
Does this code converge without SAM? @msra-jqxu
from sam.
Related Issues (20)
- Model's performance HOT 1
- Any plans to implement the paper "Sharpness-Aware Training for Free"? HOT 3
- "TypeError: __init__() missing 1 required positional argument: 'base_optimizer'" with 'ddp_sharded'' HOT 1
- Any chance for the implementation of the recent Fisher SAM? HOT 3
- Is saving the state by calling .state_dict() sufficient? HOT 4
- sam install HOT 1
- RuntimeError: stack expects a non-empty TensorList?? HOT 1
- RuntimeError: stack expects a non-empty TensorList HOT 2
- i found it hard to implement this optimizer on yolov5.looking forward to s.b. could do me a FAVOR. THX HOT 5
- Training Tips for multiple GPUs may be invalid! HOT 3
- Using SAM with torch.cuda.amp.GradScaler HOT 1
- Setting Rho == 0 is NOT equivalent to running the base optimizer HOT 1
- Wrong Adaptive mode? HOT 1
- SAM yolov5 HOT 1
- Has anyone reproduce the ViT on ImageNet results using this torch implementation? HOT 2
- bayesian-sam HOT 1
- Readme.MD Usage typo issue HOT 1
- SAM doesn't seem to be doing well HOT 2
- `model.no_sync()` should include the forward pass HOT 1
- bypass_bn is missing HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from sam.