Comments (11)
Thanks a lot for the report! All functions that involve SciPy wrappers currently can't be jitted or vmapped. We may be able to use host_callback on the JAXopt side so that they work transparently for the user.
from jaxopt.
@mblondel Thanks for the quick reply. I understood the issue. For now I will just try to avoid use of the SciPy wrappers.
from jaxopt.
Modifying the last line to use map instead of vmap works
import jax
from jax.config import config; config.update("jax_enable_x64", True)
import jax.numpy as np
from jax import random,vmap
from jaxopt import linear_solve
from jaxopt import ScipyRootFinding
def func(x, params):
a,b = params
return (x - a) * (x - b)
kwargs = {'implicit_diff_solve':linear_solve.solve_normal_cg, 'method':'hybr', 'tol':1e-10}
rootfinder = ScipyRootFinding(optimality_fun=func, **kwargs)
init = np.zeros(1)
def solve(params):
root, info = rootfinder.run(init, params)
return root
key = random.PRNGKey(1235711)
param_list = random.uniform(key, (100, 2), minval = -10, maxval = 10)
root = solve(param_list[0]) # this is ok
root_list = list(map(solve,(param_list))) # this works
from jaxopt.
map
just computes each operation sequentially while vmap
will compile to vectorized operations (e.g. vector-vector products becoming matrix-vector products), so map
should be slower in general. The issue is that SciPy solvers are blackbox functions from the point of view of JAX, so it's not possible to vmap them.
from jaxopt.
You're right map is slower. That is why it is a tenmporary fix.
from jaxopt.
Some how jax.scipy.optimize.minimize (not jaxopts) is vmappable but it only supports the algorithm BFGS.
from jaxopt.
Well, it's written in JAX, unlike SciPy's solvers...
from jaxopt.
In princple we could define a custom primitive with a batching rule for the SciPy solvers.
from jaxopt.
I'm running into a similar problem with Bisection
. Is it also coming from SciPy, or is it implemented directly in jax, and it's a different problem? Tweaking the example above to:
import jax
import jax.numpy as np
from jax import random,vmap
from jaxopt import Bisection
def func(x, a):
return x - a
rootfinder = Bisection(optimality_fun=func, lower = 0., upper = 1.)
def solve(a):
return rootfinder.run(a=a[0]).params
key = random.PRNGKey(1235711)
param_list = random.uniform(key, (4,1), minval = 0, maxval = 1)
root = solve(param_list[0])
root_list = vmap(solve)(param_list) # this doesn't work
I get
jax._src.traceback_util.UnfilteredStackTrace: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<BatchTrace(level=1/1)> with
val = DeviceArray([False, False, False, False], dtype=bool, weak_type=True)
batch_dim = 0
The problem arose with the `bool` function.
from jaxopt.
It works if you set check_bracket=False
like this:
rootfinder = Bisection(optimality_fun=func, lower = 0., upper = 1., check_bracket=False)
We check the bracketing interval by default because it's easy to make a mistake. Once your code is correct, you can disable it if you want to jit or vmap.
Ideally, we need better documentation and a clearer error message if possible.
from jaxopt.
Ah, that makes perfect sense. Thanks!
This is my first foray into jax, so forgive me if this question is naive, but is it possible to hide this from the user using the checkify
module? Or do I misunderstand how that one works?
from jaxopt.
Related Issues (20)
- Unnecessary recompilation of _while_loop_lax HOT 8
- Add type annotations
- Consider switching to pyproject.toml
- OSQP crashing on unexpected params HOT 3
- `verbose=False` is not working as expected for `NonlinearCG` HOT 1
- Stochastic L-BFGS algorithm implementation
- Stopping condition 'madsen-nielsen' incorrect
- unit test failures on aarch64 linux with scipy 1.12
- `LevenbergMarquardt` implementation does not accept PyTree parameters
- diag(JTJ) can be more efficient
- JAXOPT Projected Gradient
- "invalid escape sequence" warning in `BoxOSQP` docstring
- Error when taking gradient wrt parameters in BoxOSQP
- Disable warnings in vmap or print to stderr HOT 2
- Constraint violation causes L-BGFS-B to fail HOT 1
- LinearSolveTest.test_solve_sparse fails with jax 0.4.26 HOT 1
- Tests crash/failures on `aarch64-linux` with latest jax(lib) (0.4.28)
- Projected gradient for multidimensional array
- Annoying warning in 3.12
- pytrees bounds for `jaxopt.ScipyBoundedMinimize` HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from jaxopt.