Comments (6)
@viktor-ktorvi I had a bit of time to look at it and it should be fixed in PR #2366
from torchmetrics.
Hi @viktor-ktorvi, thanks for getting back. I was too quick on the trigger and did not correctly verify that everything was in order. Sorry about that. I have now created #2386 that should be correct. It uses a hard equal ==
in the comparison instead of torch.allclose
as you wrote. Tensors should now be kept in whatever dtype the metric is initialized with.
from torchmetrics.
Hi! thanks for your contribution!, great first issue!
from torchmetrics.
So honoring the default if it is set would help, right? mind sending a fix PR... 🐰
from torchmetrics.
Something like that.
I could try, would be my first time but I'll give it a go.
from torchmetrics.
Hi,
I just found the time to check this. I've upgraded and the issue still isn't fixed.
What's wrong
The example I stated still doesn't function as expected.
Why the test passed
The implemented test passes because of the use of torch.allclose
, however, the values need not be close, but exactly equal. There's no reason for the two to not be exactly equal if all the calculations are performed in the same dtype.
Let me demonstrate why it's wrong:
import torch
from torchmetrics.aggregation import MeanMetric
torch.set_default_dtype(torch.float64)
metric = MeanMetric()
values = torch.randn(10000)
metric.update(values)
result = metric.compute()
actual_mean = values.mean()
print(f"{result} = Result\n{actual_mean} = Actual mean")
print(f"\nAll close = {torch.allclose(result, actual_mean, atol=1e-12)}")
print(f"Exactly equal = {result == actual_mean}")
-0.0041637580871582034 = Result
-0.004163758971599815 = Actual mean
All close = True
Exactly equal = False
Motivation
It might feel like I'm nitpicking, but these sorts of errors add up in complex problem formulations. For context, I'm working on approximating optimization problems in with ML, and in my particular case, when casting from float64
to float32
and recalculating the values, the equality and inequality constraints are no longer fulfilled and the objective function is off.
How to fix
I've narrowed it down to the _cast_and_nan_check_input
function, which gets called in update
line 564. At the end of _cast_and_nan_check_input
(line 104), x.float()
get's called, explicitly casting to float32
, so that'd need changing.
Additionally, lots of these
if not isinstance(x, Tensor):
x = torch.as_tensor(x, dtype=torch.float32, device=self.device)
if weight is not None and not isinstance(weight, Tensor):
weight = torch.as_tensor(weight, dtype=torch.float32, device=self.device)
statements, where the dtype is explicitly called exist, e.g., line 79 or line 559. So, each of those would need to be replaced with dtype=torch.get_default_dtype()
.
Finally, the test needs to check for equality i.e., result == compare_function(values)
instead of using torch.allclose
.
Thanks for your time! @SkafteNicki @Borda
from torchmetrics.
Related Issues (20)
- Error in ERGAS metric HOT 5
- top-k multiclass macro accuracy less than top-1 multiclass macro accuracy HOT 2
- top-k multiclass macro accuracy is not calculated correctly
- Error during argument validation: predictions can not contain `ignore_index` HOT 1
- Add a CW-SSIM support for torchmetrics HOT 8
- Create gallery of realistic examples
- Delay imports of optional dependencies such as torchaudio, torchvision HOT 1
- Incorrect result in computing `MulticlassRecall` macro average when `ignore_index` is specified HOT 1
- RetrievalNormalizedDCG doesn't change with different top_k values HOT 2
- BootStrapper.update/forward don't process kwargs HOT 1
- List Metric synchronization fails in corner case HOT 1
- Contribution: Add new audio/speech metrics for generative audio HOT 4
- ClasswiseWrapper and JaccardIndex confmat attribute error HOT 2
- MulticlassAveragePrecision crashes on .compute() if empty HOT 2
- Metric not moved to device and invalids the cpu-gpu offloading when combining with DeepSpeed HOT 1
- [Bug] No backend type associated with device type cpu HOT 2
- Metrics not being logged properly on remote GPU HOT 4
- Retrieval Metrics GPU Memory Leak HOT 4
- Specificity if TN + FN = 0 HOT 1
- Building the docs HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from torchmetrics.