Giter Club home page Giter Club logo

Comments (5)

ludeksvoboda avatar ludeksvoboda commented on September 21, 2024 3

Fixed the optimizer mistake, now its working. Sorry for the inconvenience.

from sparseml.

ludeksvoboda avatar ludeksvoboda commented on September 21, 2024 1

Thank you for your reply, here is my code:

train_batch_size = 64
val_batch_size = 1

shape = (224, 224)

train_loader = DataLoader(
    train_dset,
    batch_size=train_batch_size,
    shuffle=True,
    num_workers=6)

lr = 0.0005

model = smp.Unet('efficientnet-b0', activation=None, encoder_weights=None)
model = model.cuda()

optim = torch.optim.Adam(model.parameters(), lr=lr)

epochs = 50
no_batches_val = math.ceil(val_size / val_batch_size)
loss = BCEDiceLoss()
no_batches_train = math.ceil(train_size / train_batch_size)

checkpoint = torch.load('efbn0_224_sparsify_test(14).tar')
model.load_state_dict(checkpoint['model_state_dict'])

manager = ScheduledModifierManager.from_yaml('sparsify_recipe.yaml')
optimizer = manager.modify(model, optim, steps_per_epoch=len(train_loader))
train_loss_history = []
train_iou_history = []
for epoch in tqdm(range(manager.max_epochs)):

    model = model.cuda().train()

    for (imgs, labels, _) in tqdm(train_loader):
        optim.zero_grad()

        imgs, labels = imgs.cuda(), labels.cuda()

        out = model(imgs)
        batch_train_loss = loss(out, labels)

        batch_train_loss.backward()
        optim.step()
    
manager.finalize(model)

torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            },  'efbn0_224_sparse_final(' + str(epoch) + ').tar')

from sparseml.

jeanniefinks avatar jeanniefinks commented on September 21, 2024

Hi @Tramtadama
May you share your model definition and what the integration to the training code looks like? Thanks!
Jeannie / Neural Magic

from sparseml.

Satrat avatar Satrat commented on September 21, 2024

Hi @Tramtadama, I'll take a look into this and see if I can reproduce. A few questions:

  • Could you clarify what dataset you are using? I see you're passing train_dset to a DataLoader but your code doesn't include the definition
  • Could you also include the definition for BCEDiceLoss()?
  • Is this example supposed to be for image segmentation or classification? The model you're loading seems to be for segmentation, but you're calculating loss on labels

from sparseml.

ludeksvoboda avatar ludeksvoboda commented on September 21, 2024

EDIT: I now see obvious error in my code, I am not calling .step() on modified optimizer, but on "optim" the original one. I will run the corrected code and see if it works. Sorry to waste your time with my dumb mistake, in case the corrected code works.

Hi @Satrat, thank you for you efforts. The task is binary segmentation. The "label" is 224x224 image mask of 0's and 1's.
Here is the definition of the loss function:

import torch
from torch import nn

def f_score(pr, gt, beta=1, eps=1e-7, threshold=None, activation='sigmoid'):
    """
    Args:
        pr (torch.Tensor): A list of predicted elements
        gt (torch.Tensor):  A list of elements that are to be predicted
        beta (float): positive constant
        eps (float): epsilon to avoid zero division
        threshold: threshold for outputs binarization
    Returns:
        float: F score
    """

    if activation is None or activation == "none":
        activation_fn = lambda x: x
    elif activation == "sigmoid":
        activation_fn = torch.nn.Sigmoid()
    elif activation == "softmax2d":
        activation_fn = torch.nn.Softmax2d()
    else:
        raise NotImplementedError(
            "Activation implemented for sigmoid and softmax2d"
        )

    pr = activation_fn(pr)

    if threshold is not None:
        pr = (pr > threshold).float()


    tp = torch.sum(gt * pr)
    fp = torch.sum(pr) - tp
    fn = torch.sum(gt) - tp

    score = ((1 + beta ** 2) * tp + eps) \
            / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + eps)

class DiceLoss(nn.Module):
    __name__ = 'dice_loss'

    def __init__(self, eps=1e-7, activation='sigmoid'):
        super().__init__()
        self.activation = activation
        self.eps = eps

    def forward(self, y_pr, y_gt):
        return 1 - f_score(y_pr, y_gt, beta=1., eps=self.eps, threshold=None, activation=self.activation)

class BCEDiceLoss(DiceLoss):
    __name__ = 'bce_dice_loss'

    def __init__(self, eps=1e-7, activation='none'):
        super().__init__(eps, activation)
        self.bce = nn.BCEWithLogitsLoss(reduction='mean')

    def forward(self, y_pr, y_gt):
        dice = super().forward(y_pr, y_gt)
        bce = self.bce(y_pr, y_gt)
        return dice + bce

Hope this helps with the reproduction, if there is anything more that you need please tell me. The dataset is private so I do not know if providing the definition helps? It is an image of a car with license plate and corresponding binary mask where the area of license plate are 1's and everything else is 0, there are also images where there is no license plate so corresponding mask is just 0's.
Thank you!

from sparseml.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.