Comments (4)
When FLAGS.manual_loop=True
we should also call a jax.jit(solver.update)
method instead of solver.update
; this implies making the solver class hashable.
from jaxopt.
Thanks @Algue-Rythme !
That almost works. Unfortunately, to make it work in flax_image_classif.py
I need also to remove the pre_update=print_accuracy argument of the solver, as otherwise it crashes with this exception:
Traceback (most recent call last):
File "examples/deep_learning/flax_image_classif.py", line 199, in <module>
app.run(main)
File "/home/pedregosa/anaconda3/lib/python3.8/site-packages/absl/app.py", line 300, in run
_run_main(main, args)
File "/home/pedregosa/anaconda3/lib/python3.8/site-packages/absl/app.py", line 251, in _run_main
sys.exit(main(argv))
File "examples/deep_learning/flax_image_classif.py", line 188, in main
params, state = jax.jit(solver.update)(params=params, state=state,
File "/home/pedregosa/dev/jaxopt/jaxopt/_src/optax_wrapper.py", line 120, in update
params, state = self.pre_update(params, state, *args, **kwargs)
File "examples/deep_learning/flax_image_classif.py", line 145, in print_accuracy
if state.iter_num % 10 == 0:
jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function.
While tracing the function update at /home/pedregosa/dev/jaxopt/jaxopt/_src/optax_wrapper.py:104 for jit, this concrete value was not available in Python because it depends on the value of the argument 'state'.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
from jaxopt.
I tried to protect the problmatic lines with a with jax.disable_jit():
statement but it still failed
from jaxopt.
Solved for the flax_resnet.py example in #119 . Leaving this issue open since there are other examples where the GPU utilization is poor
from jaxopt.
Related Issues (20)
- Expression tree API like CVXPY HOT 3
- Unnecessary recompilation of _while_loop_lax HOT 8
- Add type annotations
- Consider switching to pyproject.toml
- OSQP crashing on unexpected params HOT 3
- `verbose=False` is not working as expected for `NonlinearCG` HOT 1
- Stochastic L-BFGS algorithm implementation
- Stopping condition 'madsen-nielsen' incorrect
- unit test failures on aarch64 linux with scipy 1.12
- `LevenbergMarquardt` implementation does not accept PyTree parameters
- diag(JTJ) can be more efficient
- JAXOPT Projected Gradient
- "invalid escape sequence" warning in `BoxOSQP` docstring
- Error when taking gradient wrt parameters in BoxOSQP
- Disable warnings in vmap or print to stderr HOT 2
- Constraint violation causes L-BGFS-B to fail HOT 1
- LinearSolveTest.test_solve_sparse fails with jax 0.4.26 HOT 1
- Tests crash/failures on `aarch64-linux` with latest jax(lib) (0.4.28)
- Projected gradient for multidimensional array
- Annoying warning in 3.12
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from jaxopt.