Comments (5)
Fixed the optimizer mistake, now its working. Sorry for the inconvenience.
from sparseml.
Thank you for your reply, here is my code:
train_batch_size = 64
val_batch_size = 1
shape = (224, 224)
train_loader = DataLoader(
train_dset,
batch_size=train_batch_size,
shuffle=True,
num_workers=6)
lr = 0.0005
model = smp.Unet('efficientnet-b0', activation=None, encoder_weights=None)
model = model.cuda()
optim = torch.optim.Adam(model.parameters(), lr=lr)
epochs = 50
no_batches_val = math.ceil(val_size / val_batch_size)
loss = BCEDiceLoss()
no_batches_train = math.ceil(train_size / train_batch_size)
checkpoint = torch.load('efbn0_224_sparsify_test(14).tar')
model.load_state_dict(checkpoint['model_state_dict'])
manager = ScheduledModifierManager.from_yaml('sparsify_recipe.yaml')
optimizer = manager.modify(model, optim, steps_per_epoch=len(train_loader))
train_loss_history = []
train_iou_history = []
for epoch in tqdm(range(manager.max_epochs)):
model = model.cuda().train()
for (imgs, labels, _) in tqdm(train_loader):
optim.zero_grad()
imgs, labels = imgs.cuda(), labels.cuda()
out = model(imgs)
batch_train_loss = loss(out, labels)
batch_train_loss.backward()
optim.step()
manager.finalize(model)
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
}, 'efbn0_224_sparse_final(' + str(epoch) + ').tar')
from sparseml.
Hi @Tramtadama
May you share your model definition and what the integration to the training code looks like? Thanks!
Jeannie / Neural Magic
from sparseml.
Hi @Tramtadama, I'll take a look into this and see if I can reproduce. A few questions:
- Could you clarify what dataset you are using? I see you're passing
train_dset
to a DataLoader but your code doesn't include the definition - Could you also include the definition for
BCEDiceLoss()
? - Is this example supposed to be for image segmentation or classification? The model you're loading seems to be for segmentation, but you're calculating loss on labels
from sparseml.
EDIT: I now see obvious error in my code, I am not calling .step()
on modified optimizer, but on "optim" the original one. I will run the corrected code and see if it works. Sorry to waste your time with my dumb mistake, in case the corrected code works.
Hi @Satrat, thank you for you efforts. The task is binary segmentation. The "label" is 224x224 image mask of 0's and 1's.
Here is the definition of the loss function:
import torch
from torch import nn
def f_score(pr, gt, beta=1, eps=1e-7, threshold=None, activation='sigmoid'):
"""
Args:
pr (torch.Tensor): A list of predicted elements
gt (torch.Tensor): A list of elements that are to be predicted
beta (float): positive constant
eps (float): epsilon to avoid zero division
threshold: threshold for outputs binarization
Returns:
float: F score
"""
if activation is None or activation == "none":
activation_fn = lambda x: x
elif activation == "sigmoid":
activation_fn = torch.nn.Sigmoid()
elif activation == "softmax2d":
activation_fn = torch.nn.Softmax2d()
else:
raise NotImplementedError(
"Activation implemented for sigmoid and softmax2d"
)
pr = activation_fn(pr)
if threshold is not None:
pr = (pr > threshold).float()
tp = torch.sum(gt * pr)
fp = torch.sum(pr) - tp
fn = torch.sum(gt) - tp
score = ((1 + beta ** 2) * tp + eps) \
/ ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + eps)
class DiceLoss(nn.Module):
__name__ = 'dice_loss'
def __init__(self, eps=1e-7, activation='sigmoid'):
super().__init__()
self.activation = activation
self.eps = eps
def forward(self, y_pr, y_gt):
return 1 - f_score(y_pr, y_gt, beta=1., eps=self.eps, threshold=None, activation=self.activation)
class BCEDiceLoss(DiceLoss):
__name__ = 'bce_dice_loss'
def __init__(self, eps=1e-7, activation='none'):
super().__init__(eps, activation)
self.bce = nn.BCEWithLogitsLoss(reduction='mean')
def forward(self, y_pr, y_gt):
dice = super().forward(y_pr, y_gt)
bce = self.bce(y_pr, y_gt)
return dice + bce
Hope this helps with the reproduction, if there is anything more that you need please tell me. The dataset is private so I do not know if providing the definition helps? It is an image of a car with license plate and corresponding binary mask where the area of license plate are 1's and everything else is 0, there are also images where there is no license plate so corresponding mask is just 0's.
Thank you!
from sparseml.
Related Issues (20)
- Search for models in the sparsezoo using architecture_name HOT 7
- Unexpected keyword krgument 'image_size' HOT 5
- Got error on YOLOv8n `sparseml.ultralytics.train` train starting HOT 2
- Llama 2 sparsity support HOT 1
- My own model HOT 1
- PyTorch 2.1.0 and Lightning 2.1.0 Support: AssertionError on `assert self._strategy is not None` HOT 13
- Question on quantization size HOT 2
- Add ScheduledModifierManager.from_str HOT 1
- Adding a `.pre-commit-config.yaml` file for maintaining consistent style and code quality. HOT 3
- Oriented Bounding Box support HOT 1
- Sparse ML not working for Transformers HOT 3
- Models with loops in their graph can't be converted to DeepSparse after QAT HOT 4
- RecursionError when converting LlaMa model to ONNX HOT 6
- Error converting mistral to onnx HOT 13
- SparseML/YOLOv5s - ValueError: Unable to find any modifiers in given recipe. HOT 1
- Feature Request: Oriented Bounding Box Sparsification for YOLOv5/YOLOv8 on Custom Models/Datasets HOT 1
- [Roadmap] SparseML Roadmap Q1 2024 HOT 1
- Regarding the execution speed and model size after Sparsifying ResNet-50 HOT 2
- Class Index change observed when validating a yolov5 pruned sparseml model HOT 2
- yolov5 sparse fine tuning error HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from sparseml.