Giter Club home page Giter Club logo

evojax's Issues

[Question] Are the jits around pmaps intended?

self._train_rollout_fn = jax.jit(jax.pmap(

The docs seem to say that jits around pmaps are unnecessary: https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html#pmap-and-jit

While running experiments, I also often get this warning which seems to say that it might be problematic:
UserWarning: The jitted function <unnamed function> includes a pmap. Using jit-of-pmap can lead to inefficient data movement, as the outer jit does not preserve sharded data representations and instead collects input and output arrays onto a single device. Consider removing the outer jit unless you know what you're doing. See [https://github.com/google/jax/issues/2926].

If the behaviour is intended please discard.

Multi-Agent RL Environment for CrowdSim, Predator-Prey, and Army

Hello again,

I created a crowd and predator-prey environment that may be useful for ejovax. I also work now in a melee combat environment.

Here is the PyTorch implementation of the crowd and predator environments: https://github.com/kayuksel/multi-rl-crowd-sim

Here is a video where multi-agent predators are learning to surround preys to maximize hunts: https://youtu.be/Ds9O9wPyF8g

(I will also create a competitive multi-agent environment for closed-market auction where they will self-play by placing orders).

Have a nice week.

Sincerely,
Kamer

Evaluating brax environments other than brax-ant. Terminates with error.

Information

Issue is with running brax environments other brax-ant. The included humanoid, half cheetah and fetch environments are affected.

Couldn't find any references to this issue in the repo. I could have missed something.

Expected Behavior

/home/<USER>/anaconda3/envs/evojax/bin/python /home/<USER>/evojax/scripts/benchmarks/train.py -config configs/PGPE/brax_halfcheetah.yaml
brax: 2022-06-16 20:41:01,954 [INFO] EvoJAX brax
brax: 2022-06-16 20:41:01,954 [INFO] ==============================
absl: 2022-06-16 20:41:02,137 [INFO] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
absl: 2022-06-16 20:41:02,221 [INFO] Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
MLPPolicy: 2022-06-16 20:41:03,747 [INFO] MLPPolicy.num_params = 3974
brax: 2022-06-16 20:41:03,787 [INFO] use_for_loop=False
brax: 2022-06-16 20:41:03,825 [INFO] Start to train for 1 iterations.
brax: 2022-06-16 20:41:56,024 [INFO] [TEST] Iter=1, #tests=1, max=-9.7476, avg=-9.7476, min=-9.7476, std=0.0000
brax: 2022-06-16 20:41:56,087 [INFO] Training done, best_score=-9.7476
brax: 2022-06-16 20:41:56,093 [INFO] Loaded model parameters from ./log/PGPE/brax/default.
brax: 2022-06-16 20:41:56,093 [INFO] Start to test the parameters.
brax: 2022-06-16 20:42:03,478 [INFO] [TEST] #tests=1, max=-9.9009, avg=-9.9009, min=-9.9009, std=0.0000

Current Behavior

brax: 2022-06-16 20:26:04,657 [INFO] EvoJAX brax
brax: 2022-06-16 20:26:04,657 [INFO] ==============================
absl: 2022-06-16 20:26:04,833 [INFO] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
absl: 2022-06-16 20:26:04,920 [INFO] Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
MLPPolicy: 2022-06-16 20:26:06,465 [INFO] MLPPolicy.num_params = 3974
brax: 2022-06-16 20:26:06,504 [INFO] use_for_loop=False
brax: 2022-06-16 20:26:06,541 [INFO] Start to train for 10 iterations.
Traceback (most recent call last):
  File "/home/<USER>/evojax/scripts/benchmarks/train.py", line 88, in <module>
    main(config)
  File "/home/<USER>/evojax/scripts/benchmarks/train.py", line 64, in main
    trainer.run(demo_mode=False)
  File "/home/<USER>/evojax/evojax/trainer.py", line 152, in run
    scores, bds = self.sim_mgr.eval_params(
  File "/home/<USER>/evojax/evojax/sim_mgr.py", line 258, in eval_params
    return self._scan_loop_eval(params, test)
  File "/home/<USER>/evojax/evojax/sim_mgr.py", line 355, in _scan_loop_eval
    scores, all_obs, masks, final_states = rollout_func(
  File "/home/<USER>/evojax/evojax/sim_mgr.py", line 202, in rollout
    (obs_set, obs_mask)) = jax.lax.scan(
  File "/home/<USER>/anaconda3/envs/evojax/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/<USER>/anaconda3/envs/evojax/lib/python3.9/site-packages/jax/_src/lax/control_flow.py", line 1630, in scan
    _check_tree_and_avals("scan carry output and input",
  File "/home/<USER>/anaconda3/envs/evojax/lib/python3.9/site-packages/jax/_src/lax/control_flow.py", line 2316, in _check_tree_and_avals
    raise TypeError(f"{what} must have identical types, got\n{diff}.")
jax._src.traceback_util.UnfilteredStackTrace: TypeError: scan carry output and input must have identical types, got
(State(state=State(qp=QP(pos='ShapedArray(float32[16384,8,3])', rot='ShapedArray(float32[16384,8,4])', vel='ShapedArray(float32[16384,8,3])', ang='ShapedArray(float32[16384,8,3])'), obs='ShapedArray(float32[16384,18])', reward='ShapedArray(float32[16384])', done='ShapedArray(float32[16384])', metrics={'reward_ctrl_cost': 'ShapedArray(float32[16384])', 'reward_forward': 'ShapedArray(float32[16384])'}, info={'first_obs': 'ShapedArray(float32[16384,18])', 'first_qp': QP(pos='ShapedArray(float32[16384,8,3])', rot='ShapedArray(float32[16384,8,4])', vel='ShapedArray(float32[16384,8,3])', ang='ShapedArray(float32[16384,8,3])'), 'steps': 'ShapedArray(float32[16384])', 'truncation': 'ShapedArray(float32[16384])'}), obs='ShapedArray(float32[16384,18])', feet_contact='DIFFERENT ShapedArray(int32[16384,3]) vs. ShapedArray(int32[16384,4])'), PolicyState(keys='ShapedArray(uint32[16384,2])'), 'ShapedArray(float32[16384,3974])', 'ShapedArray(float32[37])', 'ShapedArray(float32[16384])', 'ShapedArray(float32[16384])').

Exact Error:

feet_contact='DIFFERENT ShapedArray(int32[16384,3]) vs. ShapedArray(int32[16384,4])')

Failure Information

Context

Based on commit history, this appears to be due to the changes introduced in #33 .
Manually altering variable feet_contact variable from method reset_fn in file evojax/evojax/task/brax_task.py allows for the other environments to be run.

Setup details related to the hardware are irrelevant since error occurs on the hosted colab notebook as well.

brax                         0.0.13
evojax                       0.2.11               
flax                         0.4.0
jax                          0.3.1
jaxlib                       0.3.0+cuda11.cudnn82

Steps to Reproduce

Please provide detailed steps for reproducing the issue.

  1. Run evojax/scripts/benchmarks/train.py using a modified evojax/scripts/benchmarks/configs/<ES> file using non-ant brax environment.
  2. Modify feet_contact array size and test.

AssertionError for OpenES

When I try to instantiate OpenES from open_es.py, I get the following error message:
Schermata 2022-12-15 alle 20 23 59
I traced back the problem to line 110 in open_es.py, where both centered_rank and z_score arguments are set to True:
Schermata 2022-12-15 alle 20 26 01
But line 26 of FitnessShaper class from evosax/utils/reshape_fitness.py says that
Schermata 2022-12-15 alle 20 26 49
How to get around this issue?

high dimensional parametric search

I'm trying to use evojax to evolve my model parameters. I found that the algorithm only accepts the parameter num_dims as the dimension, whether it can only be int type here? If I want to evolve multidimensional parameters, such as [1000x1000] data, how can I do it? Thanks!

OpenAI Gym Integration

Hi!

Is there any example of how to use Evojax with any gym-like environment that is not implemented in Jax? Is that even possible?

Thank you!

Save top n models per checkpoint

As I understand, currently only the best model from the population is being saved in the end of the iteration. This may lead to inconsistent train/test results (due to overfitting) in some setups. Blending the top n models could potentially reduce this effect.

Would you be interested in this feature for evojax? I can work on a PR. Seems like not all solvers can have this feature.

A GNN-based Meta-Learning Method for Sparse Portfolio Optimization

Hello,

Let me start by saying that I am a fan of your work here. I have recently open-sourced by GNN-based meta-learning method for optimization. I have applied it to the sparse index-tracking problem from real-world (after an initial benchmarking on Schwefel function), and it seems to outperform Fast CMA-ES significantly both in terms of producing robust solutions on the blind test set and also in terms of time (total duration and iterations) and space complexity. I include the link to my repository here, in case you would consider adding the method or the benchmarking problem to your repository. Note: GNN, which learns how to generate populations of solutions at each iteration, is trained using gradients retrieved from the loss function, as opposed to black-box ones.

Sincerely, K

Minor issue with GIF at the end of the Abstract Paintings notebook 1

The notebook (https://github.com/google/evojax/blob/main/examples/notebooks/AbstractPainting01.ipynb) does the following to turn the saved frames into a GIF:

import glob
import IPython

frames = []
imgs = glob.glob("AbstractPainting01_canvas_record.*.png")
for file in imgs:
  new_frame = Image.open(file)
  frames.append(new_frame)
frames[0].save('AbstractPainting01_final.gif', save_all=True, append_images=frames, optimize=True, duration=200, loop=0)

The resulting GIF doesn't show the frames in order, because glob returns results in an arbitrary order (probably whatever order files appear in the filesystem).

Fix:
imgs = sorted(glob.glob("AbstractPainting01_canvas_record.*.png"))

I'd submit a pull request but I don't have 8 A100s lying around so it'd take a while to re-run the example :)

How to design custom Seq2seq model by evojax?

Dear developer:
i am developing a large seq2seq model through evojax. However, i found it is inflexible to develop my custom model with custom vocabulary based on the example seq2seq, which means i need to revise the source code, not import functions or revise parameters. I would appreciate it if you could give me some guidance.
By the way, thanks for your awesome work.

[Discussion] Sequencing side-effects in JAX

Sequencing side-effects in JAX is a known issue in JAX. I have tried to make a custom operational env/task where the step update needs to be sequential.

JAX docs say one needs to tokenize the functions to force the sequence. However, I could not find any task in EvoJax with the tokens. Did anybody ever try that? Or I should think more and redesign my problem to make it more JAX compatible since I use a bunch of jax.lax.cond (constraints are quite complicated for the problem I have).

Can't execute Brax notebook

Hi all, I run the notebook BraxTasks.ipynb as is, and the second cell crashes with the following error. I think it may be an issue with the Brax version.
Schermata 2023-09-24 alle 18 11 08

Reinitialization

Hello,

i have a task with unknown global optima and since optimizers can stuck in local optima i want to make sure the achieved optima is reached from various random starting points. Therefore i would like to incorporate some kind of reinitialization of whole search (basically starting trainer.run with multiple different seeds).
Is it even necessary? Does SimManager -> eval_params -> _for_loop_eval -> policy_reset_func perform reliable reinitialization of policy state?

Thanks in advance for your advice.

Reproducing benchmark scores

Hello everyone.

I am currently currently trying to reproduce scores from the benchmarks, specifically for ARS, as I am implementing my own version native in jax, and wanted to compare with the wrapper already implemented.

For example, I cannot achieve the score posted in the benchmark table (902.107) for ARS on cartpole_easy.

running python train.py -config configs/ARS/cartpole_easy.yaml yields the following training logs

cartpole_easy: 2022-09-25 22:45:55,777 [INFO] EvoJAX cartpole_easy
cartpole_easy: 2022-09-25 22:45:55,777 [INFO] ==============================
absl: 2022-09-25 22:45:55,791 [INFO] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
absl: 2022-09-25 22:45:57,247 [INFO] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: Interpreter CUDA Host
absl: 2022-09-25 22:45:57,247 [INFO] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
MLPPolicy: 2022-09-25 22:45:59,165 [INFO] MLPPolicy.num_params = 4609
cartpole_easy: 2022-09-25 22:45:59,429 [INFO] use_for_loop=False
cartpole_easy: 2022-09-25 22:45:59,496 [INFO] Start to train for 1000 iterations.
cartpole_easy: 2022-09-25 22:46:10,527 [INFO] Iter=50, size=100, max=399.5886, avg=207.9111, min=0.5843, std=99.0207
cartpole_easy: 2022-09-25 22:46:19,916 [INFO] Iter=100, size=100, max=543.8907, avg=364.9780, min=28.8478, std=141.8982
cartpole_easy: 2022-09-25 22:46:21,143 [INFO] [TEST] Iter=100, #tests=100, max=553.4018 avg=510.5583, min=462.4243, std=15.6930
cartpole_easy: 2022-09-25 22:46:30,627 [INFO] Iter=150, size=100, max=558.2020, avg=314.9279, min=89.8001, std=153.6488
cartpole_easy: 2022-09-25 22:46:40,068 [INFO] Iter=200, size=100, max=562.4118, avg=354.9529, min=47.0048, std=154.1567
cartpole_easy: 2022-09-25 22:46:40,114 [INFO] [TEST] Iter=200, #tests=100, max=570.1135 avg=547.5375, min=508.5795, std=10.0840
cartpole_easy: 2022-09-25 22:46:49,579 [INFO] Iter=250, size=100, max=562.1505, avg=325.3990, min=73.3733, std=161.9460
cartpole_easy: 2022-09-25 22:46:59,073 [INFO] Iter=300, size=100, max=569.5461, avg=370.2641, min=83.7473, std=166.8020
cartpole_easy: 2022-09-25 22:46:59,129 [INFO] [TEST] Iter=300, #tests=100, max=573.5941 avg=545.0388, min=505.8637, std=11.3853
cartpole_easy: 2022-09-25 22:47:08,623 [INFO] Iter=350, size=100, max=579.3894, avg=425.6462, min=82.4907, std=126.6614
cartpole_easy: 2022-09-25 22:47:18,109 [INFO] Iter=400, size=100, max=627.6509, avg=530.2781, min=156.4797, std=76.0956
cartpole_easy: 2022-09-25 22:47:18,160 [INFO] [TEST] Iter=400, #tests=100, max=639.7323 avg=600.9105, min=573.7767, std=10.7564
cartpole_easy: 2022-09-25 22:47:27,653 [INFO] Iter=450, size=100, max=668.2064, avg=546.0261, min=418.5385, std=60.5854
cartpole_easy: 2022-09-25 22:47:37,149 [INFO] Iter=500, size=100, max=684.4142, avg=574.4891, min=446.3126, std=62.5338
cartpole_easy: 2022-09-25 22:47:37,202 [INFO] [TEST] Iter=500, #tests=100, max=693.1522 avg=682.7945, min=638.0387, std=12.1575
cartpole_easy: 2022-09-25 22:47:46,708 [INFO] Iter=550, size=100, max=708.9561, avg=591.0547, min=295.5651, std=73.6026
cartpole_easy: 2022-09-25 22:47:56,212 [INFO] Iter=600, size=100, max=706.8138, avg=599.4783, min=348.7581, std=55.6310
cartpole_easy: 2022-09-25 22:47:56,263 [INFO] [TEST] Iter=600, #tests=100, max=691.0123 avg=680.4677, min=630.2983, std=6.1448
cartpole_easy: 2022-09-25 22:48:05,770 [INFO] Iter=650, size=100, max=707.0887, avg=581.3851, min=418.2251, std=75.9066
cartpole_easy: 2022-09-25 22:48:15,275 [INFO] Iter=700, size=100, max=712.7586, avg=586.4597, min=362.7628, std=71.5669
cartpole_easy: 2022-09-25 22:48:15,326 [INFO] [TEST] Iter=700, #tests=100, max=725.2336 avg=714.1309, min=635.7863, std=9.3471
cartpole_easy: 2022-09-25 22:48:24,849 [INFO] Iter=750, size=100, max=716.1056, avg=602.7747, min=458.0401, std=63.1697
cartpole_easy: 2022-09-25 22:48:34,365 [INFO] Iter=800, size=100, max=709.3475, avg=587.9896, min=393.0367, std=69.2385
cartpole_easy: 2022-09-25 22:48:34,418 [INFO] [TEST] Iter=800, #tests=100, max=732.5553 avg=720.5952, min=648.5032, std=8.3936
cartpole_easy: 2022-09-25 22:48:43,945 [INFO] Iter=850, size=100, max=706.8488, avg=598.3582, min=321.8640, std=75.2542
cartpole_easy: 2022-09-25 22:48:53,482 [INFO] Iter=900, size=100, max=720.0320, avg=596.1929, min=370.6555, std=77.2801
cartpole_easy: 2022-09-25 22:48:53,536 [INFO] [TEST] Iter=900, #tests=100, max=703.5345 avg=692.9500, min=677.6909, std=5.9381
cartpole_easy: 2022-09-25 22:49:03,068 [INFO] Iter=950, size=100, max=716.2341, avg=598.3802, min=422.7760, std=71.7756
cartpole_easy: 2022-09-25 22:49:12,455 [INFO] [TEST] Iter=1000, #tests=100, max=726.0114, avg=719.0803, min=698.4325, std=4.7247
cartpole_easy: 2022-09-25 22:49:12,457 [INFO] Training done, best_score=720.5952
cartpole_easy: 2022-09-25 22:49:12,458 [INFO] Loaded model parameters from ./log/ARS/cartpole_easy/default.
cartpole_easy: 2022-09-25 22:49:12,459 [INFO] Start to test the parameters.
cartpole_easy: 2022-09-25 22:49:12,509 [INFO] [TEST] #tests=100, max=728.9848, avg=720.6152, min=698.9832, std=5.0566

I am not entirely sure if the result on the benchmark table is intended to be 720.5952 from
cartpole_easy: 2022-09-25 22:49:12,457 [INFO] Training done, best_score=720.5952

or the max score from the final test. Regardless, neither of these match the one posted on the benchmark table.

Am I doing something wrong to reproduce these scores?
This makes me unable to compare my own implementation of the algorithm.

Thank you

Evolving topology of NN

Hey guys,

amazing work! Quick question, is it possible to also evolve topology of networks using your framework? Like NEAT does?

Issue with BatchNorm layer

while using policy network with BatchNorm layer, getting following error:

ModifyScopeVariableError: Cannot update variable "mean" in "/bn_init" because collection "batch_stats" is immutable. 
 (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ModifyScopeVariableError)

Bug of center_lr_decay_steps when use adam with PGPE

Bug

When use adam with PGPE this code

self._opt_state = self._opt_update(
            self._t // self._lr_decay_steps, -grad_center, self._opt_state
        )

means adam t will increase after every self._lr_decay_steps.
And it means mhat and vhat will not work as moving average because (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) will be very small always. (bellow is adam update code)

def update(i, g, state):
    x, m, v = state
    m = (1 - b1) * g + b1 * m  # First  moment estimate.
    v = (1 - b2) * jnp.square(g) + b2 * v  # Second moment estimate.
    mhat = m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1))  # Bias correction.
    vhat = v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1))
    x = x - step_size(i) * mhat / (jnp.sqrt(vhat) + eps)
    return x, m, v

Suggestion

I think it is better to change this code to

step_size=lambda x: self._center_lr * jnp.power(decay_coef, x // self._lr_decay_steps),

and to remove self._lr_decay_steps at

self._opt_state = self._opt_update(
            self._t, -grad_center, self._opt_state
        )

AbstractPainting02.ipynb. doesn't work on colab

Hello, this is a really great code.

I was able to run "Abstract Painting 01" very well at Google coab.
However, when I ran "AbstractPainting02", an error occurred.

Exception                                 Traceback (most recent call last)
[<ipython-input-20-b16203d22159>](https://localhost:8080/#) in <module>()
      2 devices = jax.local_devices()
      3 
----> 4 image_fn, text_fn, jax_params, jax_preprocess = clip_jax.load('ViT-B/32', "cpu")
      5 
      6 target_text_ids = jnp.array(clip_jax.tokenize([prompt])) # already with batch dim

3 frames
[/content/CLIP_JAX/clip_jax/clip.py](https://localhost:8080/#) in process_node(value, name)
    117             new_tensor = jnp.array(pytorch_tensor)
    118         else:
--> 119             raise Exception("not implemented")
    120 
    121         assert new_tensor.shape == value.shape

Exception: not implemented

Which version of clip_jax when you made?

Best

opencv-python dependency?

Hi everyone, and thanks for working on such project and open sourcing it!

I notice that the project has a dependency on opencv-python listed in the setup.py file. However, I did not find any import cv2 in the project, and the examples seem to work fine even in an environment with no opencv-python/cv2. Is the dependency on opencv-python just a leftover/accidental dependency, or there is something that I am missing?

Thanks in advance.

can one specify parts of the model that are non differentiable?

I have a model in Jax (convolutional neural network with some modifications) where most is fully differentiable, but parts are not - can I mark somehow which parts are not differentiable so to have correct gradient backpropagation, or it is done automatically?

Thanks!

Advanced loggers

It would be nice to use Tensorboard or any other advanced logging tool with evojax.
Looks like it should be straightforward to allow the user to implement a log_reward function that could be passed to the trainer.

I'm willing to implement this in a PR.

Some proposals about the `Trainer` logic

Currently I see two ways of using the Trainer.test_task:

  1. The test_task of the trainer is used for validation. The actual test set is being holdout and not seen during training or validation. In this case, how do I run the actual test? I can't pass just the test_task to the trainer, because the train_task is non-optional. Looks like there should be a way to do this with evojax.
  2. The test_task of the trainer is used for the actual test, no validation is used at all. In this case, why does the trainer.run return the best model score and not the last model score?

I propose the following (high level) logic:

best_val_reward = trainer.fit(train_task: VectorizedTask, val_task: Optional[VectorizedTask] = None)  # maybe the user doesn't want validation (e.g. train on latest data without early stopping)
test_reward = trainer.test(test_task: VectorizedTask, checkpoint="best|last|path")  # specify which checkpoint to use for testing

Probably early stopping would be pretty necessary for the trainer.fit method. Currently there is no way to determine when to do it and even which model iteration has the best result.

I'm willing to implement this logic in a PR.

SlimeVolley initial ball velocity is incorrect

Small issue with game state initialization:

def get_random_ball_v(key: jnp.ndarray):
result = random.uniform(key, shape=(2,)) * 2 - 1
ball_vx = result[1]*20
ball_vy = result[2]*7.5+17.5
return ball_vx, ball_vy

result[2] is out of bounds, which means it accesses the same value as result[1] -- and therefore ball_vx and ball_vy depend on the same value. This means the ball has higher y-vel when x-vel is positive, so the ball will be thrown higher to one player compared to the other.

How to use GPU for computing

Dear developer,
thanks for your awesome work. I have some questions.
When i run the example of seq2seq, i got this warning:

Seq2seq: 2023-09-26 20:11:32,443 [INFO] ==============================
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1695730292.463880 3694542 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
2023-09-26 20:11:32.491777: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:276] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
2023-09-26 20:11:32.491820: I external/xla/xla/stream_executor/cuda/cuda_diagnostics.cc:168] retrieving CUDA diagnostic information for host: user-MD72-HB3-00
2023-09-26 20:11:32.491829: I external/xla/xla/stream_executor/cuda/cuda_diagnostics.cc:175] hostname: user-MD72-HB3-00
2023-09-26 20:11:32.491882: I external/xla/xla/stream_executor/cuda/cuda_diagnostics.cc:199] libcuda reported version is: 535.86.5
2023-09-26 20:11:32.491912: I external/xla/xla/stream_executor/cuda/cuda_diagnostics.cc:203] kernel reported version is: 535.86.5
2023-09-26 20:11:32.491920: I external/xla/xla/stream_executor/cuda/cuda_diagnostics.cc:309] kernel version seems to match DSO: 535.86.5
jax._src.xla_bridge: 2023-09-26 20:11:32,492 [INFO] Unable to initialize backend 'cuda': FAILED_PRECONDITION: No visible GPU devices.
jax._src.xla_bridge: 2023-09-26 20:11:32,492 [INFO] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
jax._src.xla_bridge: 2023-09-26 20:11:32,494 [INFO] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
jax._src.xla_bridge: 2023-09-26 20:11:32,494 [WARNING] No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

I already successfully installed GPU supported JAX, and my equipment information is LINUX with 4090 GPU and CUDA 12.2. How do i fix this problem?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.