Comments (6)
It seems that this can be fixed by explicitly converting poi
to a float before checking if it's in the hypotest cache here. I printed out the types and at some point in the test run poi
switched from being a numpy float64 to being a jaxlib type; I can trace further to see exactly why that happens and then maybe submit a PR.
from pyhf.
First thing we'll need to do is understand why there aren't test failures
Lines 26 to 57 in 64ab264
which will probably mean revisiting PR #1274. So writing a failing test would be a good start, so that a PR can make it pass.
from pyhf.
Ah, I didn't realize there was a test for this! Does that get run with all the backends? When I get a chance I can try running that locally too.
from pyhf.
Ah, I didn't realize there was a test for this! Does that get run with all the backends? When I get a chance I can try running that locally too.
Nope, which is likely which explains why it wasn't caught. (Adding the backend
fixture in the test will have it run on all the backends).
from pyhf.
@matthewfeickert i saw the PR, and I think we need to swap the way we're approaching this. Here's my suggestion instead of type-casting - we need to add in shims across each lib and move some functions into our tensorlib
instead to make them backend-dependent (or use a shim to swap them out as needed, like we do for scipy.optimize
)
See this example:
from functools import lru_cache
import time
import timeit
import jax.numpy as jnp
import jax
import tensorflow as tf
def slow(n):
time.sleep(1)
return n**2
fast = lru_cache(maxsize=None)(slow)
fast_jax = jax.jit(slow)
fast_tflow = tf.function(jit_compile=True)(slow)
value = 5
print('slow')
print(timeit.timeit(lambda: [slow(value), slow(value), slow(value), slow(value), slow(value)], number=1))
print('fast')
print(timeit.timeit(lambda: [fast(value), fast(value), fast(value), fast(value), fast(value)], number=1))
value = jnp.array(5)
print('slow, jax')
print(timeit.timeit(lambda: [slow(value), slow(value), slow(value), slow(value), slow(value)], number=1))
print('fast, jax')
print(timeit.timeit(lambda: [fast_jax(value), fast_jax(value), fast_jax(value), fast_jax(value), fast_jax(value)], number=1))
value = tf.convert_to_tensor(5)
print('slow, tensorflow')
print(timeit.timeit(lambda: [slow(value), slow(value), slow(value), slow(value), slow(value)], number=1))
print('fast, tensorflow')
print(timeit.timeit(lambda: [fast_tflow(value), fast_tflow(value), fast_tflow(value), fast_tflow(value), fast_tflow(value)], number=1))
which outputs
$ python cache.py
slow
5.012567336
fast
1.0029977690000003
slow, jax
5.043927394000001
fast, jax
1.0195144690000006
slow, tensorflow
5.017408181999997
fast, tensorflow
1.0631543910000012
so we can definitely cache those values by JIT-ing for the toms748 scan here and that's probably what we want to do. My suggestion might be that we support pyhf.tensor.jit
with something similiar to the signature of jax.jit
across all backends (yes even numpy, but numpy would be an lru_cache
).
from pyhf.
we need to add in shims across each lib and move some functions into our
tensorlib
instead to make them backend-dependent
Okay, sounds good. Let's start up a seperate series of PRs to do this.
from pyhf.
Related Issues (20)
- Pruning with logical "and" instead of "or" HOT 2
- Catch 100% normalization uncertainty modifiers
- Remove `InvalidNameReuse` exception in favor of in `InvalidModel`
- Typing for tensor shapes
- `ValueError: The truth value of an array` ... when using `fixed_params` kwarg for `pyhf.infer.hypotest` HOT 4
- Release v0.7.6 checklist
- Using the hessian matrix in optimization HOT 2
- Typo in docs for qmutilde test-statistic
- Implementation of fixed params in scipy minimizer
- Switch to using 'tensorflow' for Apple silicon Macs for TensorFlow v2.16.1+ HOT 2
- docs linkcheck failing for SciPy conference proceedings
- In Python 3.12 xml.etree.ElementTree will raise DeprecationWarning: Testing an element's truth value will raise an exception in future versions
- In Python 3.12 tarfile raises DeprecationWarning: Python 3.14 will, by default, filter extracted tar archives and reject files or modify their metadata. HOT 1
- Forced measurement object schema validation through `ws.model()` HOT 1
- Inconsistent return types of `model.config.suggested_bounds()` when using parameter config
- Treating arbitrarily correlated parameters with pyhf HOT 3
- xml2json breaks with HistoPath including filename HOT 2
- Move documentation pages to PyData Sphinx Theme HOT 1
- Use PyData Sphinx Theme docs version switcher
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pyhf.