Giter Club home page Giter Club logo

Comments (6)

SkafteNicki avatar SkafteNicki commented on June 7, 2024 1

@viktor-ktorvi I had a bit of time to look at it and it should be fixed in PR #2366

from torchmetrics.

SkafteNicki avatar SkafteNicki commented on June 7, 2024 1

Hi @viktor-ktorvi, thanks for getting back. I was too quick on the trigger and did not correctly verify that everything was in order. Sorry about that. I have now created #2386 that should be correct. It uses a hard equal == in the comparison instead of torch.allclose as you wrote. Tensors should now be kept in whatever dtype the metric is initialized with.

from torchmetrics.

github-actions avatar github-actions commented on June 7, 2024

Hi! thanks for your contribution!, great first issue!

from torchmetrics.

Borda avatar Borda commented on June 7, 2024

So honoring the default if it is set would help, right? mind sending a fix PR... 🐰

from torchmetrics.

viktor-ktorvi avatar viktor-ktorvi commented on June 7, 2024

Something like that.

I could try, would be my first time but I'll give it a go.

from torchmetrics.

viktor-ktorvi avatar viktor-ktorvi commented on June 7, 2024

Hi,

I just found the time to check this. I've upgraded and the issue still isn't fixed.

What's wrong

The example I stated still doesn't function as expected.

Why the test passed

The implemented test passes because of the use of torch.allclose, however, the values need not be close, but exactly equal. There's no reason for the two to not be exactly equal if all the calculations are performed in the same dtype.

Let me demonstrate why it's wrong:

import torch
from torchmetrics.aggregation import MeanMetric

torch.set_default_dtype(torch.float64)

metric = MeanMetric()

values = torch.randn(10000)
metric.update(values)
result = metric.compute()

actual_mean = values.mean()

print(f"{result} = Result\n{actual_mean} = Actual mean")

print(f"\nAll close = {torch.allclose(result, actual_mean, atol=1e-12)}")
print(f"Exactly equal = {result == actual_mean}")
-0.0041637580871582034 = Result
-0.004163758971599815 = Actual mean

All close = True
Exactly equal = False

Motivation

It might feel like I'm nitpicking, but these sorts of errors add up in complex problem formulations. For context, I'm working on approximating optimization problems in with ML, and in my particular case, when casting from float64 to float32 and recalculating the values, the equality and inequality constraints are no longer fulfilled and the objective function is off.

How to fix

I've narrowed it down to the _cast_and_nan_check_input function, which gets called in update line 564. At the end of _cast_and_nan_check_input (line 104), x.float() get's called, explicitly casting to float32, so that'd need changing.

Additionally, lots of these

if not isinstance(x, Tensor):
    x = torch.as_tensor(x, dtype=torch.float32, device=self.device)
if weight is not None and not isinstance(weight, Tensor):
    weight = torch.as_tensor(weight, dtype=torch.float32, device=self.device)

statements, where the dtype is explicitly called exist, e.g., line 79 or line 559. So, each of those would need to be replaced with dtype=torch.get_default_dtype().

Finally, the test needs to check for equality i.e., result == compare_function(values) instead of using torch.allclose.

Thanks for your time! @SkafteNicki @Borda

from torchmetrics.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.