Comments (27)
I agree with @Borda that this should be an abstract function/class. The most simple, in my opinion, would be a class that the user can wrap their already existing metric with: masked_accuracy=MaskedMetric(Accuracy())
. This would add a additional argument to the call: value = masked_accuracy(pred, target, mask)
. The alternative, re-writing each metric to include this feature, is not feasible at the moment.
from torchmetrics.
I think I speak for both of us, saying that we'd for sure do that and really appreciate the PR :)
from torchmetrics.
Yes just ping us in the PR when you are ready, and we will assist you.
from torchmetrics.
It was closed due to no activity, so it is still not a part of lightning. @hadim please feel free to pick it up and send a PR :]
from torchmetrics.
I have implemented a version of this in my own project, would anyone like to collaborate on making a PR for this?
from torchmetrics.
Hello, perhaps I'm missing something but I'm not sure that there's a one-size-fits-all answer to this that can just be implemented as a wrapper.
I may be breaking down the problem incorrectly. For metrics with a simple internal sum, just replacing values in the mask with 0 will suffice. For metrics that have an internal mean, the generic solution would be to sum over dim=1
and then replacen_obs
with mask.sum(axis=1)
, and divide only where the denominator (or number of elements in the row of the mask is not 0). However, I'm not quite sure how to cover all metrics I feel like I could be missing scenarios. I'm not sure if there are metrics that could be dividing pred/target, and also if we want to support custom metrics.
What are your thoughts?
from torchmetrics.
Bumping this. I believe ignore_index
could be fragile in certain circumstances, and a mask would be more reliable. For example, while it is a sensible default to set ignore_index
to be the padding token index of a tokenizer, some tokenizers such as GPT-2's do not have a padding token. Masking is also how the AllenNLP metrics (https://github.com/allenai/allennlp/tree/main/allennlp/training/metrics) dealt with this. I believe individually implementing a mask for each metric is probably a good idea. I definitely understand it's a non-trivial amount of work, but I believe it's worth it in the long run, and hopefully the AllenNLP implementations could be a useful reference for some common metrics.
from torchmetrics.
Looks nice! @YuxianMeng want to send it as a PR?
Cc: @justusschock @SkafteNicki
from torchmetrics.
Looks nice! @YuxianMeng want to send it as a PR?
Cc: @justusschock @SkafteNicki
My pleasure:) A little question is should this PR contain only masked precision metrics or also contain other metrics?
from torchmetrics.
I would say all, in fact it would be nice to have an abstract function/class that do this masking and the new metrics would be created just its application, so for example:
- for functional make a wrapper which does the masking and all the new masked-like function will call existing functions with this wrapper
- for class make abstract class and the masked-like metrics will be created as inheriting from the mask and metric class
Does it make sense? @justusschock @SkafteNicki
from torchmetrics.
@YuxianMeng But with your implementation, you calculate it also for the values you set to -1 I think.
What you instead need to do is accuracy(pred[mask], target[mask])
which is why I wouldn't add extras for them to be honest. We can't include every special case here and masking tensors is not much overhead, which is why I'd prefer not to include this into the metrics package. Thoughts @SkafteNicki ?
from torchmetrics.
@YuxianMeng But with your implementation, you calculate it also for the values you set to -1 I think.
What you instead need to do is
accuracy(pred[mask], target[mask])
which is why I wouldn't add extras for them to be honest. We can't include every special case here and masking tensors is not much overhead, which is why I'd prefer not to include this into the metrics package. Thoughts @SkafteNicki ?
@justusschock As for accuracy, actually only the non-negative classes are calculated. I thought about using accuracy(pred[mask], target[mask])
, but it may cause speed trouble when training on TPU
from torchmetrics.
@YuxianMeng mind send a PR and I guess @SkafteNicki or @justusschock could help/guide you throw 🐰
from torchmetrics.
Working on it, I will let you when I'm ready :)
from torchmetrics.
This issue has been closed. Does the mask metrics features has landed? Or nobody has worked on it yet?
from torchmetrics.
@davzaman please yes, would be a great addition :]
from torchmetrics.
I didn't implement it as a class wrapper, but I have a few ideas on how to do it. it might take me a little while as i have deadlines for other things but i will be working on this!
from torchmetrics.
@davzaman I definitely see the problem. My original idea for this feature would be a simple wrapper that just internally does metric(pred[mask], target[mask])
when the user calls metric(pred, target, mask)
(or something similar). However, that would not work for all metrics I guess.
from torchmetrics.
yeah I think each metric would need to have its own, there's not an insane amount of metrics but the overhead of including tests for all of them might be much. should we just let users figure out masking on their own? Is there something we can at least include to make the process easier?
from torchmetrics.
@davzaman @SkafteNicki how is it going here?
from torchmetrics.
Hi @Borda We ran into issues in trying to follow a one-size-fits-all approach to including masks for metrics, since internally the computations might be very different (which would change the logic required to properly compute a "masked" version of the metric). I wasn't sure how to proceed from here. From what I could tell, it would be best to have a masked version of each metric separately, even though it's more work. There's a chance there's a solution that I didn't see.
from torchmetrics.
I think this issue is related to #362.
from torchmetrics.
@davzaman @yassersouri could you pls open a draft PR so we have a more concrete discussion...
and eventually, we can help to find a solution? 🐰
cc: @justusschock @SkafteNicki
from torchmetrics.
@Borda Sorry, but I am quite busy right now. I don't think I will have time to allocate to this or #362.
from torchmetrics.
@Borda I don't think I have the time to allocate to this at the moment but I am happy to help move things along.
from torchmetrics.
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.
from torchmetrics.
There was also some sort of attempt to support masked Tensors but I think it has died out https://github.com/pytorch/maskedtensor
from torchmetrics.
Related Issues (20)
- Potential Unsupported version of DICE. HOT 4
- `MetricCollection` does not share internal `state` HOT 3
- Can MultitaskWrapper.clone() have a postfix or prefix arg? HOT 1
- SSIM has values larger than 1 HOT 7
- `BinaryPrecisionRecallCurve` computes wrong value if used with logits, even though the docstring says this is supported HOT 1
- Broken source links in documentation of `1.3.0.post0` HOT 2
- RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu) HOT 2
- High memory usage of Perplexity metric HOT 2
- `MultitaskWrapper` still cannot be logged HOT 1
- usage consistency on GPU between `MetricCollection` and `FeatureShare` HOT 1
- MeanAveragePrecision doesn't work as expected when using `max_detection_thresholds != [1, 10, 100]` HOT 3
- Custom callable is ignored in retrieval metric aggregation HOT 1
- Explicitly initializing tensors in `float32` in MeanMetric goes against `torch.set_default_dtype`, leading to numerical errors HOT 6
- Add an option to switch distributions order in the KLDivergence. HOT 2
- MetricWrapper for Target Binarization HOT 3
- Total sum of squares formula HOT 3
- Add Support for SQ and RQ in Panoptic Quality HOT 5
- MPS uninitialized memory(?) causing errors in `StatScores` (which cascade to other locations) HOT 1
- Importing torchmetrics causes segmentation fault with other dependencies HOT 2
- `MetricCollection` did not copy inner state of metric in `ClasswiseWrapper` when computing groups metrics HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from torchmetrics.