Giter Club home page Giter Club logo

Comments (17)

alexriedel1 avatar alexriedel1 commented on May 30, 2024 8

Hey, I just implemnted AMP in this way and it seems to be working:

#first forward-backward pass
with torch.cuda.amp.autocast():
        preds_first = model(images)
        loss = criterion(preds_first, labels)  # use this loss for any training statistics
        
loss.mean().backward()
optimizer.first_step(zero_grad=True)
        
#second forward-backward pass
with torch.cuda.amp.autocast():
        preds_second = model(images)
        loss_second = criterion(preds_second, labels)
            
loss_second.mean().backward()
optimizer.second_step(zero_grad=True)

from sam.

ahmdtaha avatar ahmdtaha commented on May 30, 2024 4

Here are my two cents on this issue.

TLDR: use the following code and be ready to revert to the regular single-step optimization momentarily
I made the following changes to sam inside sam.py

    @torch.no_grad()
    def first_step(self, zero_grad=False, mixed_precision=False):
        with autocast() if mixed_precision else do_nothing_context_mgr():
            grad_norm = self._grad_norm()
            for group in self.param_groups:
                scale = group["rho"] / (grad_norm + 1e-12)

                for p in group["params"]:
                    if p.grad is None:
                        continue
                    self.state[p]["old_p"] = p.data.clone()
                    e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
                    p.add_(e_w)  # climb to the local maximum "w + e(w)"

            if zero_grad:
                self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False, mixed_precision=False):
        with autocast() if mixed_precision else do_nothing_context_mgr():
            for group in self.param_groups:
                for p in group["params"]:
                    if p.grad is None:
                        continue
                    p.data = self.state[p]["old_p"]  # get back to "w" from "w + e(w)"

    @torch.no_grad()
    def step(self, closure=None):
        self.base_optimizer.step(closure)

Using this pytorch tutorial, the proposed solution goes as follows

def train(
    args, model, device, train_loader, optimizer, first_step_scaler, second_step_scaler, epoch
):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()

        enable_running_stats(model)
        # First forward step
        with autocast():
            output = model(data)
            loss = F.nll_loss(output, target)
        first_step_scaler.scale(loss).backward()

        # We unscale manually for two reasons: (1) SAM's first-step adds the gradient
        # to weights directly. So gradient must be unscaled; (2) unscale_ checks if any
        # gradient is inf and updates optimizer_state["found_inf_per_device"] accordingly.
        # We use optimizer_state["found_inf_per_device"] to decide whether to apply
        # SAM's first-step or not.
        first_step_scaler.unscale_(optimizer)

        optimizer_state = first_step_scaler._per_optimizer_states[id(optimizer)]

        # Check if any gradients are inf/nan
        inf_grad_cnt = sum(v.item() for v in optimizer_state["found_inf_per_device"].values())

        if inf_grad_cnt == 0:
            # if valid graident, apply sam_first_step
            optimizer.first_step(zero_grad=True, mixed_precision=True)
            sam_first_step_applied = True
        else:
            # if invalid graident, skip sam and revert to single optimization step
            optimizer.zero_grad()
            sam_first_step_applied = False

        # Update the scaler with no impact on the model (weights or gradient). This update step
        # resets the optimizer_state["found_inf_per_device"]. So, it is applied after computing
        # inf_grad_cnt. Note that zero_grad() has no impact on the update() operation,
        # because update() leverage optimizer_state["found_inf_per_device"]
        first_step_scaler.update()

        disable_running_stats(model)
        # Second forward step
        with autocast():
            output = model(data)
            loss = F.nll_loss(output, target)
        second_step_scaler.scale(loss).backward()

        if sam_first_step_applied:
            # If sam_first_step was applied, apply the 2nd step
            optimizer.second_step(mixed_precision=True)

        second_step_scaler.step(optimizer)
        second_step_scaler.update()

where

base_optimizer = torch.optim.SGD  # define an optimizer for the "sharpness-aware" update
optimizer = SAM(model.parameters(), base_optimizer, lr=0.1, momentum=0.9)
first_step_scaler = GradScaler(2 ** 8) # A small scaler_init acts as a warmup
second_step_scaler = GradScaler(2 ** 8)  # A small scaler_init acts as a warmup

How is this tested?

(1) a lot of debugging to make sure the code is doing what is supposed to do, (2) train my model twice: full and mixed precision; then verify both loss curves are similar -- of course, not identical.

What is the main catch?

I found that the network produces NaN predictions during inference while not crashing during training (forward and backward). While the network $f_\theta$ has finite parameters $\theta$, it produces NaN for some -- not all -- inputs. When a network reaches this state (shown in the next figure), SAM's first step (gradient-ascent) always generates NaN/inf gradient which signals instability. Then, of course, SAM's second step also generates NaN/inf gradient. This instability is never observed explicitly during training because PyTorch GradScaler skips gradient-descent whenever gradient is Nan. Accordingly, the network's parameters $\theta$ remain intact despite multiple backpropagation steps.

Mixed-Precision-SAM

To get out of this unstable state, the proposed solution reverts to the regular single-step stochastic gradient-descent momentarily. A gradient-descent step is likely to have valid -- none NaN -- gradient compared to gradient-ascent. This pushes the network's parameters outside the unstable state. It is worthnoting that loss curves are high-dimensional, i.e., my 2D drawing is for illustration purpose only.

Why the network enters this state in the first place?

I don't know. Yet, the network enters this state at an early training stage which signals poor initialization.

One thing I don't like about the proposed solution is that it is verbose. I wish someone propose a concise solution.

I found the following resources helpful while investigating this issue.
[1] https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.py
[2] https://pytorch.org/docs/stable/notes/amp_examples.html

from sam.

alexriedel1 avatar alexriedel1 commented on May 30, 2024 2

@milliema Yes that's absolutely explainable as SAM needs two backward passes through the network instead of one with a simpel SGD, so it should take double the time to train

from sam.

rohitsingh02 avatar rohitsingh02 commented on May 30, 2024

Same thing is happening with me, unable to use it with AUTOMATIC MIXED PRECISION (Pytorch).

from sam.

milliema avatar milliema commented on May 30, 2024

Hey, I just implemnted AMP in this way and it seems to be working:

#first forward-backward pass
with torch.cuda.amp.autocast():
        preds_first = model(images)
        loss = criterion(preds_first, labels)  # use this loss for any training statistics
        
loss.mean().backward()
optimizer.first_step(zero_grad=True)
        
#second forward-backward pass
with torch.cuda.amp.autocast():
        preds_second = model(images)
        loss_second = criterion(preds_second, labels)
            
loss_second.mean().backward()
optimizer.second_step(zero_grad=True)

Thanks for the reply!
I'm using apex fp16 and it's a little different with torch.amp.
For my case, I only included the base optimizer in the apex initialization.

model, optimizer.base_optimizer = amp.initialize(model, optimizer.base_optimizer, opt_level="O1")

As for the backward, I keep 1st step to be as same as before and only use scaled_loss for the 2nd backward.

loss = cal_loss(xx)
loss.backward()
optimizer.first_step(zero_grad=True)
loss = cal_loss(xx)
with amp.scale_loss(loss, optimizer.base_optimizer) as scaled_loss:
     scaled_loss.backward()
optimizer.second_step(zero_grad=True)

It's able to work but I'm not sure whether it's the best solution. If I use scaled loss for the 1st backward, Nan loss always happens.

from sam.

alexriedel1 avatar alexriedel1 commented on May 30, 2024

Is it also possible to initialize the full optimizer?
Are your model forward outputs fp16 now?

from sam.

milliema avatar milliema commented on May 30, 2024

@alexriedel1 I've tried to initialize amp with optimizer, but it doesn't work.
Amp should affect the step function within optimizer. However, in SAM the optimizer doesn't use step but use first_step/second_step instead. So I guess it's better to init the base optimizer.
I didn't check the forward output. The training speed is almost half of regular training, e.g. 1000 img/s for regular training and 500 img/s for SAM. Is it the case same for you?

from sam.

stale avatar stale commented on May 30, 2024

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

from sam.

jeongHwarr avatar jeongHwarr commented on May 30, 2024

@alexriedel1 Hello. I found your comment while looking for a way to apply amp for the sam optimizer. In the original amp method, I know that the loss is divided by the scale again and then backward. For example, scaler.scale(loss).backward().

Why did you do backwards using the mean of the loss? Is there any problem with this? Is it okay to not use scaler?

from sam.

alexriedel1 avatar alexriedel1 commented on May 30, 2024

@alexriedel1 Hello. I found your comment while looking for a way to apply amp for the sam optimizer. In the original amp method, I know that the loss is divided by the scale again and then backward. For example, scaler.scale(loss).backward().

Why did you do backwards using the mean of the loss? Is there any problem with this? Is it okay to not use scaler?

I didn't fully implement the amp method as proposed. I think using the scaler will be no problem.

Reducing the loss to mean is just dependent on your loss function. For example, pytorchs BCE Loss is already implemented with the mean reduction by default.

from sam.

jeongHwarr avatar jeongHwarr commented on May 30, 2024

@alexriedel1 Ok, I got it. I think the problem when applying amp to sam is that I cannot use scaler.step(optimizer) or optimizer.first_step(zero_grad=True). I think that the gradient needs to be unscaled when using amp.
This is done through scaler.step(optimizer), but when I use it, I cannot use optimizer.first_step(zero_grad=True).

from sam.

maxmatical avatar maxmatical commented on May 30, 2024

yes, the original solution does not unscale the gradients, which would lead to the scaling factor interfering with the learning rate

if you take a look at scaler.step(optimizer), under the hood it is doing

  1. calling unscale_(optimizer)
  2. calling optimizer.step()

in theory you should be able to run something similar to the example here by doing the following during traning

# first pass
with torch.cuda.amp.autocast():
    out = model(input)
    loss = criterion(out, label)

scaler.scale(loss).backward()
scaler.unscale_(optimizer)
optimizer.first_step(zero_grad=True)
scaler.update()

# 2nd pass
with torch.cuda.amp.autocast():
    out_2 = model(input)
    loss_2 = criterion(out_2, labels)

scaler.scale(loss_2).backward()
scaler.unscale_(optimizer)
optimizer.second_step(zero_grad=True)
scaler.update()

however since you're not calling scaler.step(optimizer) you run the risk of inf/NaN gradients (unless the sam steps already takes care of this)

from sam.

stale avatar stale commented on May 30, 2024

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

from sam.

twmht avatar twmht commented on May 30, 2024

@maxmatical

is model.no_sync() important in the first backward pass? it seems that you don't have that.

from sam.

mengjingyouling avatar mengjingyouling commented on May 30, 2024

Hey, I just implemnted AMP in this way and it seems to be working:

#first forward-backward pass
with torch.cuda.amp.autocast():
        preds_first = model(images)
        loss = criterion(preds_first, labels)  # use this loss for any training statistics
        
loss.mean().backward()
optimizer.first_step(zero_grad=True)
        
#second forward-backward pass
with torch.cuda.amp.autocast():
        preds_second = model(images)
        loss_second = criterion(preds_second, labels)
            
loss_second.mean().backward()
optimizer.second_step(zero_grad=True)

Can it work well?

from sam.

alibalapour avatar alibalapour commented on May 30, 2024

@ahmdtaha
Do you have any idea about how to implement gradient accumulation in your code?

from sam.

rtxbae avatar rtxbae commented on May 30, 2024

@alibalapour @ahmdtaha have you found out how to implement the gradient accumulation in the code?

from sam.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.