Giter Club home page Giter Club logo

Comments (6)

TC01 avatar TC01 commented on June 18, 2024

It seems that this can be fixed by explicitly converting poi to a float before checking if it's in the hypotest cache here. I printed out the types and at some point in the test run poi switched from being a numpy float64 to being a jaxlib type; I can trace further to see exactly why that happens and then maybe submit a PR.

from pyhf.

matthewfeickert avatar matthewfeickert commented on June 18, 2024

First thing we'll need to do is understand why there aren't test failures

pyhf/tests/test_infer.py

Lines 26 to 57 in 64ab264

def test_toms748_scan(tmp_path, hypotest_args):
"""
Test the upper limit toms748 scan returns the correct structure and values
"""
_, data, model = hypotest_args
results = pyhf.infer.intervals.upper_limits.toms748_scan(
data, model, 0, 5, rtol=1e-8
)
assert len(results) == 2
observed_limit, expected_limits = results
observed_cls = pyhf.infer.hypotest(
observed_limit,
data,
model,
model.config.suggested_init(),
model.config.suggested_bounds(),
)
expected_cls = np.array(
[
pyhf.infer.hypotest(
expected_limits[i],
data,
model,
model.config.suggested_init(),
model.config.suggested_bounds(),
return_expected_set=True,
)[1][i]
for i in range(5)
]
)
assert observed_cls == pytest.approx(0.05)
assert expected_cls == pytest.approx(0.05)

which will probably mean revisiting PR #1274. So writing a failing test would be a good start, so that a PR can make it pass.

from pyhf.

TC01 avatar TC01 commented on June 18, 2024

Ah, I didn't realize there was a test for this! Does that get run with all the backends? When I get a chance I can try running that locally too.

from pyhf.

kratsg avatar kratsg commented on June 18, 2024

Ah, I didn't realize there was a test for this! Does that get run with all the backends? When I get a chance I can try running that locally too.

Nope, which is likely which explains why it wasn't caught. (Adding the backend fixture in the test will have it run on all the backends).

from pyhf.

kratsg avatar kratsg commented on June 18, 2024

@matthewfeickert i saw the PR, and I think we need to swap the way we're approaching this. Here's my suggestion instead of type-casting - we need to add in shims across each lib and move some functions into our tensorlib instead to make them backend-dependent (or use a shim to swap them out as needed, like we do for scipy.optimize)

See this example:

from functools import lru_cache
import time
import timeit

import jax.numpy as jnp
import jax
import tensorflow as tf



def slow(n):
    time.sleep(1)
    return n**2

fast = lru_cache(maxsize=None)(slow)

fast_jax = jax.jit(slow)
fast_tflow = tf.function(jit_compile=True)(slow)

value = 5
print('slow')
print(timeit.timeit(lambda: [slow(value), slow(value), slow(value), slow(value), slow(value)], number=1))
print('fast')
print(timeit.timeit(lambda: [fast(value), fast(value), fast(value), fast(value), fast(value)], number=1))


value = jnp.array(5)
print('slow, jax')
print(timeit.timeit(lambda: [slow(value), slow(value), slow(value), slow(value), slow(value)], number=1))
print('fast, jax')
print(timeit.timeit(lambda: [fast_jax(value), fast_jax(value), fast_jax(value), fast_jax(value), fast_jax(value)], number=1))

value = tf.convert_to_tensor(5)
print('slow, tensorflow')
print(timeit.timeit(lambda: [slow(value), slow(value), slow(value), slow(value), slow(value)], number=1))
print('fast, tensorflow')
print(timeit.timeit(lambda: [fast_tflow(value), fast_tflow(value), fast_tflow(value), fast_tflow(value), fast_tflow(value)], number=1))

which outputs

$ python cache.py
slow
5.012567336
fast
1.0029977690000003
slow, jax
5.043927394000001
fast, jax
1.0195144690000006
slow, tensorflow
5.017408181999997
fast, tensorflow
1.0631543910000012

so we can definitely cache those values by JIT-ing for the toms748 scan here and that's probably what we want to do. My suggestion might be that we support pyhf.tensor.jit with something similiar to the signature of jax.jit across all backends (yes even numpy, but numpy would be an lru_cache).

from pyhf.

matthewfeickert avatar matthewfeickert commented on June 18, 2024

we need to add in shims across each lib and move some functions into our tensorlib instead to make them backend-dependent

Okay, sounds good. Let's start up a seperate series of PRs to do this.

from pyhf.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.