Giter Club home page Giter Club logo

cvpr2023's Introduction

Experiments for CVPR 2023 Talk

Scaling PyTorch Model Training With Minimal Code Changes

In this short tutorial, Sebastian I'll show you how to accelerate the training of LLMs and Vision Transformers with minimal code changes using open-source libraries.

cvpr2023's People

Contributors

rasbt avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

cvpr2023's Issues

When I perform a breakpoint to continue training (resume) or model distillation to load an existing weight file, I find that torch.load() reads an error.

The fabric parameters I use are as follows:
fabric = L.Fabric(accelerator='cuda', precision='bf16-mixed', devices=torch.cuda.device_count(), strategy='FSDP')
The model training and saving code is as follows:

for epoch in range(start_epoch, opt.OPTIM.NUM_EPOCHS + 1):
        epoch_start_time = time.time()
        epoch_loss = 0
        train_id = 1
        model_restoration.train()
        for i, data in enumerate(tqdm(train_loader), 0):

            iter_loss = 0
            # zero_grad
            for param in model_restoration.parameters():
                param.grad = None

            if opt.TRAINING.DISTILL:
                if opt.TRAINING.ACCELERATE:
                    input_ = data[1]
                    target_ori = data[0]
                else:
                    input_ = data[1].cuda()
                    target_ori = data[0].cuda()
                with torch.no_grad():
                    target = teacher_model(input_)
            else:
                if opt.TRAINING.ACCELERATE:
                    input_ = data[1]
                    target = data[0]
                else:
                    input_ = data[1].cuda()
                    target = data[0].cuda()
            out = model_restoration(input_)
            loss_psnr = (1 + criterion_psnr(out, target) / 100)
            loss = loss_psnr
            if opt.TRAINING.ACCELERATE:
                fabric.backward(loss)
            else:
                loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            iter_loss += loss_psnr.item()
            if i > 0 and i % print_freq == 0: print(f'\titer: {i}  loss: {epoch_loss:.4f}  iter_loss: {iter_loss:.4f}  iter_best psnr: 
               {iter_best_psnr:.4f}  loss_psnr:{loss_psnr:.4f}  learning rate:{scheduler.get_lr()[0]:.6f}')
            if i > 0 and i % step_iter == 0: scheduler.step()
            if iter_loss <= iter_best_psnr: # and -iter_loss < 100:
                iter_best_psnr = iter_loss
                torch.save({'epoch': epoch, 
                            'state_dict': model_restoration.state_dict(),
                            'optimizer' : optimizer.state_dict()
                            }, os.path.join(model_dir, "model_iter_best.pth"))

When I continued training (resume) from a breakpoint, I found that the weight of the model obtained by using lightning.Fabric training was 1/8 of the weight saved by pytorch training, and the weight saved by lightning.Fabric training could not be loaded.The error message is as follows:

Traceback (most recent call last):
  File "/root/anaconda3/envs/wsk-py310/lib/python3.10/site-packages/torch/serialization.py", line 531, in _check_seekable
    f.seek(f.tell())
AttributeError: 'NoneType' object has no attribute 'seek'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/project/Code/Skin/train_distill_accelerate.py", line 351, in <module>
    main()
  File "/project/Code/Skin/train_distill_accelerate.py", line 212, in main
    utils.load_checkpoint_multigpu(model_restoration, path_chk_rest)
  File "/project/Code/Skin/utils/model_utils.py", line 38, in load_checkpoint_multigpu
    checkpoint = torch.load(weights)
  File "/root/anaconda3/envs/wsk-py310/lib/python3.10/site-packages/torch/serialization.py", line 986, in load
    with _open_file_like(f, 'rb') as opened_file:
  File "/root/anaconda3/envs/wsk-py310/lib/python3.10/site-packages/torch/serialization.py", line 440, in _open_file_like
    return _open_buffer_reader(name_or_buffer)
  File "/root/anaconda3/envs/wsk-py310/lib/python3.10/site-packages/torch/serialization.py", line 425, in __init__
    _check_seekable(buffer)
  File "/root/anaconda3/envs/wsk-py310/lib/python3.10/site-packages/torch/serialization.py", line 534, in _check_seekable
    raise_err_msg(["seek", "tell"], e)
  File "/root/anaconda3/envs/wsk-py310/lib/python3.10/site-packages/torch/serialization.py", line 527, in raise_err_msg
    raise type(e)(msg)
AttributeError: 'NoneType' object has no attribute 'seek'. You can only torch.load from a file that is seekable. Please pre-load the data into a buffer like io.BytesIO and try to load from it instead.
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision

And when I use the teacher model trained by lightning.Fabric to train the student model, the weights saved by lightning.Fabric are saved as follows when loading:

Traceback (most recent call last):
  File "/project/Code/Skin/train_distill_accelerate.py", line 351, in <module>
    main()
  File "/project/Code/Skin/train_distill_accelerate.py", line 129, in main
    utils.load_checkpoint_multigpu(teacher_model, weights)
  File "/project/Code/Skin/utils/model_utils.py", line 45, in load_checkpoint_multigpu
    model.load_state_dict(new_state_dict)
  File "/root/anaconda3/envs/wsk-py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for NAFNet:
	size mismatch for intro.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([32, 3, 3, 3]).
	size mismatch for intro.bias: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for ending.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([3, 32, 3, 3])......

The load_checkpoint_multigpu loading weight file function is as follows:

def load_checkpoint_multigpu(model, weights):
    checkpoint = torch.load(weights)
    state_dict = checkpoint["state_dict"]
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k.replace('module.', '') if 'module' in k else k
        new_state_dict[name] = v
    model.load_state_dict(new_state_dict)

I would like to ask you how to save the model weights when using lightning.Fabric() to accelerate the model training code? How to load the model weights trained by lightning.Fabric() when resume or distillation? What is the difference between the weight file saved by lightning.Fabric() and the weight file saved by pytorch code? Why is the size of the former 1/8 of the latter? It seems that only the model structure is saved but not the training weight parameters of the model. Looking forward to your reply, best wishes!

AssertionError: input must be floating point

Thanks for your great work!The project runs without error at the beginning,but I got an error a few minutes later:
Traceback (most recent call last):
File "/data/ccm/Projects/LOCATE/train.py", line 139, in
masks, logits, loss_proto, loss_con = model(exo, ego, aff_label, (epoch, args.warm_epoch))
File "/data/wsc/anaconda3/envs/LOCATE/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
Traceback (most recent call last):
File "/data/ccm/Projects/LOCATE/train.py", line 139, in
masks, logits, loss_proto, loss_con = model(exo, ego, aff_label, (epoch, args.warm_epoch))
File "/data/wsc/anaconda3/envs/LOCATE/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
return self._call_impl(*args, **kwargs)
File "/data/wsc/anaconda3/envs/LOCATE/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
File "/data/wsc/anaconda3/envs/LOCATE/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
return forward_call(*args, **kwargs)
File "/data/wsc/anaconda3/envs/LOCATE/lib/python3.8/site-packages/lightning/fabric/wrappers.py", line 121, in forward
File "/data/wsc/anaconda3/envs/LOCATE/lib/python3.8/site-packages/lightning/fabric/wrappers.py", line 121, in forward
output = self._forward_module(*args, **kwargs)
File "/data/wsc/anaconda3/envs/LOCATE/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
output = self._forward_module(*args, **kwargs)
File "/data/wsc/anaconda3/envs/LOCATE/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/wsc/anaconda3/envs/LOCATE/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return self._call_impl(*args, **kwargs)
File "/data/wsc/anaconda3/envs/LOCATE/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/data/wsc/anaconda3/envs/LOCATE/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1519, in forward
return forward_call(*args, **kwargs)
File "/data/wsc/anaconda3/envs/LOCATE/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1519, in forward
else self._run_ddp_forward(*inputs, **kwargs)
File "/data/wsc/anaconda3/envs/LOCATE/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1355, in _run_ddp_forward
else self._run_ddp_forward(*inputs, **kwargs)
File "/data/wsc/anaconda3/envs/LOCATE/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1355, in _run_ddp_forward
return self.module(*inputs, **kwargs) # type: ignore[index]
File "/data/wsc/anaconda3/envs/LOCATE/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self.module(*inputs, **kwargs) # type: ignore[index]
File "/data/wsc/anaconda3/envs/LOCATE/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/wsc/anaconda3/envs/LOCATE/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return self._call_impl(*args, **kwargs)
File "/data/wsc/anaconda3/envs/LOCATE/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/data/ccm/Projects/LOCATE/models/locate.py", line 82, in forward
_, exo_key, exo_attn = self.vit_model.get_last_key(exo)
File "/data/ccm/Projects/LOCATE/models/dino/vision_transformer.py", line 284, in get_last_key
return forward_call(*args, **kwargs)
File "/data/ccm/Projects/LOCATE/models/locate.py", line 133, in forward
kmeans.fit_predict(exo_aff_desc.contiguous())
File "/data/wsc/anaconda3/envs/LOCATE/lib/python3.8/site-packages/fast_pytorch_kmeans/kmeans.py", line 163, in fit_predict
x = blk(x)
File "/data/wsc/anaconda3/envs/LOCATE/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
assert X.dtype in [torch.half, torch.float, torch.double], "input must be floating point"
AssertionError: input must be floating point
return self._call_impl(*args, **kwargs)
File "/data/wsc/anaconda3/envs/LOCATE/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/data/ccm/Projects/LOCATE/models/dino/vision_transformer.py", line 122, in forward
y, attn = self.attn(self.norm1(x))
File "/data/wsc/anaconda3/envs/LOCATE/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/wsc/anaconda3/envs/LOCATE/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/data/ccm/Projects/LOCATE/models/dino/vision_transformer.py", line 93, in forward
attn = (q @ k.transpose(-2, -1)) * self.scale
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 22.00 MiB. GPU 2 has a total capacty of 23.70 GiB of which 5.62 MiB is free. Process 2009145 has 21.84 GiB memory in use. Including non-PyTorch memory, this process has 1.84 GiB memory in use. Of the allocated memory 414.50 MiB is allocated by PyTorch, and 27.50 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
I don't know what's wrong with this project.

I tried the modification method you mentioned, but the optimizer still reported an error:

          > Thanks for the note and sorry about the hassle. I think that's because I had the current dev version installed, which has not been released yet. The new version allows you to combine the two lines
    model = fabric.setup_module(model)
    optimizer = fabric.setup_optimizers(optimizer)

into a single line

    model, optimizer = fabric.setup(model, optimizer)

But for the latest stable release, the 2 separate lines are necessary. I updated the code accordingly.

I tried the modification method you mentioned, but the optimizer still reported an error:

File "/project/Code/Skin/aa.py", line 21, in <module>
    main()
  File "/project/Code/Skin/aa.py", line 15, in main
    optimizer = fabric.setup_optimizers(optimizer)
  File "/opt/conda/envs/wsk-py310/lib/python3.10/site-packages/lightning/fabric/fabric.py", line 289, in setup_optimizers
    optimizers = [self._strategy.setup_optimizer(optimizer) for optimizer in optimizers]
  File "/opt/conda/envs/wsk-py310/lib/python3.10/site-packages/lightning/fabric/fabric.py", line 289, in <listcomp>
    optimizers = [self._strategy.setup_optimizer(optimizer) for optimizer in optimizers]
  File "/opt/conda/envs/wsk-py310/lib/python3.10/site-packages/lightning/fabric/strategies/fsdp.py", line 213, in setup_optimizer
    raise ValueError(
ValueError: The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the optimizer after setting up the model.

My environment and related library versions are as follows:

python                   3.10
numpy                    1.26.0
scipy                    1.11.3
pandas                   2.1.1
watermark                2.4.3
lightning                2.0.9.post0
torch                    2.0.1
torchaudio               2.0.2
torchmetrics             1.2.0
torchvision              0.15.2
transformers             4.34.0
deepspeed                0.11.0

Originally posted by @Williamwsk in #3 (comment)

@Williamwsk I added an example and readme for loading the checkpoints with Fabric here: https://github.com/rasbt/cvpr2023/tree/main/08_saving-and-loading

          @Williamwsk  I added an example and readme for loading the checkpoints with Fabric here: https://github.com/rasbt/cvpr2023/tree/main/08_saving-and-loading

I hope this helps!

Originally posted by @rasbt in #5 (comment)

Thank you very much for your reply。After making changes according to the code you gave me, I found that the code in the training function was executed multiple times (The number of repetitions is consistent with the number of GPUs), such as the model training accuracy output, the number of times the weights were saved, etc. were also executed multiple times. How should I change this situation? It seems that the latest model will not be automatically overwritten when fabric.save().

In addition, I would like to ask you how to transfer the model weight file saved with fabric.save() to the model weight file of pytorch.save()? Because some subsequent operations need to use the .pth weight file to do other work. Looking forward to your reply, thank you very much!

Regarding the issue of setup(model, optimizer) in fabric FSDP

When I follow your code guidance, the environment configuration is consistent with yours, but fabric.setup(model, optimizer) reports an error. I see that your code is also written like this. Have you ever encountered this problem? Looking forward to your reply!

The source code is as follows:
`import torch
import torchvision as tv
import lightning as L

def main():
fabric = L.Fabric(accelerator='cuda', precision='bf16-mixed', devices=torch.cuda.device_count(), strategy='fsdp')
fabric.launch()

model = tv.models.resnet18()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
model, optimizer = fabric.setup(model, optimizer)
# model = fabric.setup_module(model)
# optimizer = fabric.setup_optimizers(optimizer)
dataset = tv.datasets.CIFAR10("data", download=True, transform=tv.transforms.ToTensor())
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8)
dataloader = fabric.setup_dataloaders(dataloader)

if name == 'main':
main()`

The error is reported as follows:
Traceback (most recent call last): File "/project/Code/Skin/aa.py", line 23, in <module> model, optimizer = fabric.setup(model, optimizer) File "/opt/conda/envs/wsk-py310/lib/python3.10/site-packages/lightning/fabric/fabric.py", line 198, in setup self._validate_setup(module, optimizers) File "/opt/conda/envs/wsk-py310/lib/python3.10/site-packages/lightning/fabric/fabric.py", line 856, in _validate_setup main() File "/project/Code/Skin/aa.py", line 14, in main model, optimizer = fabric.setup(model, optimizer) File "/opt/conda/envs/wsk-py310/lib/python3.10/site-packages/lightning/fabric/fabric.py", line 198, in setup raise RuntimeError( RuntimeError: The Fabricrequires the model and optimizer(s) to be set up separately. Create and set up the model first throughmodel = self.setup_model(model). Then create the optimizer and set it up: optimizer = self.setup_optimizer(optimizer). self._validate_setup(module, optimizers) File "/opt/conda/envs/wsk-py310/lib/python3.10/site-packages/lightning/fabric/fabric.py", line 856, in _validate_setup raise RuntimeError( RuntimeError: The Fabricrequires the model and optimizer(s) to be set up separately. Create and set up the model first throughmodel = self.setup_model(model). Then create the optimizer and set it up: optimizer = self.setup_optimizer(optimizer).

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.