Comments (34)
from turbozero.
And i need use mult gpus, I don't see something like pmap in code. Need a example that how to train use mult gpus.
from turbozero.
Supporting multiple GPUs should be straightforward to add but is not currently included. I can work on this and let you know when it is supported. Created a separate issue to track.
from turbozero.
If I were you, I would implement a custom class extending
Trainer
, and overwritecollect
to place each of the symmetries into the replay buffer withself.memory_buffer.add_experience
. You will need to make surepolicy_mask
andpolicy_weights
are augmented consistently with your game state data.
class GoTrainer(Trainer):
def collect(self,
state: CollectionState,
params: chex.ArrayTree
) -> CollectionState:
step_key, new_key = jax.random.split(state.key)
eval_output, new_env_state, new_metadata, terminated, rewards = \
self.step_train(
key = step_key,
env_state = state.env_state,
env_state_metadata = state.metadata,
eval_state = state.eval_state,
params = params
)
buffer_state = self.memory_buffer.add_experience(
state = state.buffer_state,
experience = BaseExperience(
env_state=state.env_state,
policy_mask=state.metadata.action_mask,
policy_weights=eval_output.policy_weights,
reward=jnp.empty_like(state.metadata.rewards)
)
)
# generate symmetries here and add to replay memory just like above
buffer_state = jax.lax.cond(
terminated,
lambda s: self.memory_buffer.assign_rewards(s, rewards),
lambda s: s,
buffer_state
)
return state.replace(
key=new_key,
eval_state=eval_output.eval_state,
env_state=new_env_state,
buffer_state=buffer_state,
metadata=new_metadata
)
from turbozero.
I seems find why this error happened.
from turbozero.
I am going to close this issue as the original problem has been resolved.
I hope to continue to expand the documentation of this project so that it is more clear as to how to approach unique use-cases.
Thank you for your questions and feedback! Please create another issue if you run in to problems and feel free to email me if you have more questions.
from turbozero.
Thank you for pointing this out!
All examples currently use the AlphaZero
class, which implements update_root
with the correct parameter ordering.
for MCTS
parameter ordering is indeed incorrect
from turbozero.
Can you give a example that how to train use mult gpus?
I'm try to integrate with my custom JAX environment and have some problem. this is the key code:
def step_fn(state, action):
new_state = env.step(state, action)
return new_state, StepMetadata(
rewards=new_state.rewards,
action_mask=new_state.legal_action_mask,
terminated=new_state.terminated,
cur_player_id=new_state.current_player,
)
def init_fn(key):
state = env.init(key)
return state, StepMetadata(
rewards=state.rewards,
action_mask=state.legal_action_mask,
terminated=state.terminated,
cur_player_id=state.current_player,
)
az_evaluator = AlphaZero(MCTS)(
eval_fn = eval_fn,
num_iterations = 100,
max_nodes = 200,
branching_factor=82,
action_selector = PUCTSelector()
)
def env_step_fn(state, action):
new_state = env.step(state, action)
return new_state, StepMetadata(
rewards=new_state.rewards,
action_mask=new_state.legal_action_mask,
terminated=new_state.terminated,
cur_player_id=new_state.current_player,
)
eval_key, rng_key = jax.random.split(rng_key)
eval_keys = jax.random.split(eval_key, batch_size)
env, env_state_metadata = jax.vmap(init_fn)(eval_keys)`
evaluator_init = partial(az_evaluator.init, template_embedding=env)
eval_state = jax.vmap(evaluator_init)(eval_keys)
output = az_evaluator.evaluate(
eval_state=eval_state,
env_state=env,
root_metadata=env_state_metadata,
params=param,
env_step_fn=env_step_fn
)
this is the error:
eval_state = self.update_root(eval_state, env_state, root_metadata, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/turbozero/core/evaluators/alphazero.py", line 31, in update_root
key, tree = get_rng(tree)
^^^^^^^^^^^^^
File "/turbozero/core/trees/tree.py", line 143, in get_rng
rng, new_rng = jax.random.split(tree.key, 2)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/lib/python3.11/site-packages/jax/_src/random.py", line 303, in split
return _return_prng_keys(wrapped, _split(typed_key, num))
^^^^^^^^^^^^^^^^^^^^^^
File "/python3.11/site-packages/jax/_src/random.py", line 286, in _split
raise TypeError("split accepts a single key, but was given a key array of"
TypeError: split accepts a single key, but was given a key array ofshape (100,) != (). Use jax.vmap for batching.
from turbozero.
evaluate
should be vmapped, try something like this:
eval_key, env_key, rng_key = jax.random.split(rng_key, 3)
eval_keys = jax.random.split(eval_key, batch_size)
env_keys = jax.random.split(env_key, batch_size)
env_state, env_state_metadata = jax.vmap(init_fn)(env_keys)
# template embedding should not have a batch dimension
template_env_state, _ = init_fn(jax.random.PRNGKey(0))
evaluator_init = partial(az_evaluator.init, template_embedding=template_env_state)
eval_state = jax.vmap(evaluator_init)(eval_keys)
evaluate = partial(az_evaluator.evaluate,
env_step_fn=env_step_fn,
params=param)
output = jax.vmap(evaluate)(
eval_state=eval_state,
env_state=env_state,
root_metadata=env_state_metadata)
I recommend using the Trainer
class as described here, https://github.com/lowrollr/turbozero/blob/main/notebooks/hello_world.ipynb
I haven't fully documented a lot of the underlying classes yet which do have their peculiarities -- Trainer
should be more straightforward to work with.
from turbozero.
thanks,now have a new problem.
File "/turbozero/core/evaluators/mcts/mcts.py", line 81, in evaluate
eval_state = self.update_root(eval_state, env_state, root_metadata, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/turbozero/core/evaluators/alphazero.py", line 56, in update_root
return set_root(tree, root_node)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/turbozero/core/trees/tree.py", line 121, in set_root
data=jax.tree_util.tree_map(
^^^^^^^^^^^^^^^^^^^^^^^
File "/turbozero/core/trees/tree.py", line 122, in
lambda x, y: x.at[tree.ROOT_INDEX].set(y),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py", line 497, in set
return scatter._scatter_update(self.array, self.index, values, lax.scatter,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/lib/python3.11/site-packages/jax/_src/ops/scatter.py", line 80, in _scatter_update
return _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/lib/python3.11/site-packages/jax/_src/ops/scatter.py", line 115, in _scatter_impl
y = jnp.broadcast_to(y, tuple(indexer.slice_shape))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 1227, in broadcast_to
return util._broadcast_to(array, shape)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "//lib/python3.11/site-packages/jax/_src/numpy/util.py", line 425, in _broadcast_to
for arr_d, shape_d in safe_zip(arr_shape, shape_tail))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: safe_zip() argument 2 is shorter than argument 1
from turbozero.
Could you share your code?
from turbozero.
def one_step(prev):
"""Execute one self-play move using MCTS.
"""
env_state, rng_key, step, eval_state = prev
rng_key, rng_key_next = jax.random.split(rng_key, 2)
env_state_metadata = StepMetadata(
rewards=env_state.rewards,
action_mask=env_state.legal_action_mask,
terminated=env_state.terminated,
cur_player_id=env_state.current_player,
)
terminated = env_state.terminated
output = jax.vmap(evaluate)(
eval_state=eval_state,
env_state=env_state,
root_metadata=env_state_metadata)
env_state = step_fn_move(env_state, output.action)
eval_state = output.eval_state
eval_state = az_evaluator.step(eval_state, output.action)
return (env_state, rng_key_next, step + 1, env_new3, eval_state)
from turbozero.
it use jax.lax.scan:
output = jax.lax.scan(one_step, (env_state, rng_key, step,eval_state), None, length=400, unroll=1)
maybe this place is incorrect?
eval_state = output.eval_state
eval_state = az_evaluator.step(eval_state, output.action)
from turbozero.
The error suggests to me that some of the data might have been passed to evaluate
without a batch dimension, just from debugging similar errors before.
Trainer
already implements the necessary functions to collect episodes and progress the env and evaluator to the next state, I'm curious why you are implementing these yourself? Is there a feature you need that's missing or something that's confusing or unclear? I'm hopeful that most users won't need to provide anything besides environment dynamics functions.
from turbozero.
yes,there a feature i need that's missing, like go self-play data process, it need symmetries.
from turbozero.
Could you describe how the feature should work? I can help you find the right spot to integrate it.
from turbozero.
Could you describe how the feature should work? I can help you find the right spot to integrate it.
in go game, self-play data usually augmentation by something like np.rot90 and np.fliplr.
from turbozero.
about that error, could be this place need vmap?
eval_state = output.eval_state
eval_state = az_evaluator.step(eval_state, output.action)
from turbozero.
about that error, could be this place need vmap?
eval_state = output.eval_state eval_state = az_evaluator.step(eval_state, output.action)
Yes, every function that operates on an evaluator state assumes a singular input rather than a batch and should be vmapped.
Could you describe how the feature should work? I can help you find the right spot to integrate it.
in go game, self-play data usually augmentation by something like np.rot90 and np.fliplr.
If I were you, I would implement a custom class extending Trainer
, and overwrite collect
to place each of the symmetries into the replay buffer with self.memory_buffer.add_experience
. You will need to make sure policy_mask
and policy_weights
are augmented consistently with your game state data.
here: https://github.com/lowrollr/turbozero/blob/main/core/training/train.py#L103-L140
You should only need to extend the behavior of collect
.
It might be a good idea for me to allow for users to specify any number of transforms to apply to augment experiences prior to storing in replay memory -- this is a fairly common use-case so it ideally should not require a custom class.
from turbozero.
yes, I'm going to try this, but if can't use mult gpus, training will be slow.
I change to this:
eval_state = output.eval_state
eval_state = jax.vmap(az_evaluator.step)(eval_state, output.action)
But error occurred before the code reached that.
that error is in this:
output = jax.vmap(evaluate)(
eval_state=eval_state,
env_state=env_state,
root_metadata=env_state_metadata)
and i print(state.observation.shape) in eval_fn, it shows (19, 19, 17) , without batch dimension, is this normal?
from turbozero.
The speed that I test was slow. Then I test https://github.com/lowrollr/turbozero/blob/main/notebooks/hello_world.ipynb this example, seems also slow from gpu utilization.
what is max_nodes means? I found that max_nodes has a significant impact on speed; the larger max_nodes is, the slower the speed.
mctx-az is also in this situation. and mctx-az is faster than turbozero that in my test.
from turbozero.
max_nodes
reflects the maximum capacity of the tree. Trees cannot be sized dynamically so must have a maximum number of nodes set prior to collection. It makes sense that performance gets worse as max_nodes
increases, this ultimately makes it so operations are on larger matrices.
I am aware that the backend is currently less performant than mctx. It is a priority for me to fix this.
from turbozero.
I'm still confuse about max_nodes.
if max_nodes=100, is it means every node have 100 child node? or is it means that will search 100 node at most in this evaluate?
why you suggest it larger than num_simulations?that will be very very slow.
yes, I look forward to it running faster.
from turbozero.
I think the most factor affecting speed lies in the computational aspect of the CPU. GPU not fully utilized.
from turbozero.
max_nodes
refers to the maximum capacity of the tree -- so yes it does mean a tree with max_nodes=100 will at most evaluate 100 distinct game states.
I advise setting it higher than num_iterations
because this implementation re-uses subtrees from a previous search -- so most of the time the tree will already be partially populated when a new search is started. Setting max_nodes
higher than num_iterations
means that there will be room for num_iterations
nodes in the tree more often. I have yet to document out-of-bounds behavior but it works similarly to https://github.com/lowrollr/mctx-az.
Increasing max_nodes
linearly increases the memory footprint of the search tree data structure. It's definitely a trade-off of speed vs. accuracy to set it higher/lower, and its value relative to 'num_iterations' should be problem-dependent (branching factor and # of iterations both matter). Setting max_nodes
= num_iterations
is fine if you're worried about speed.
from turbozero.
I think the most factor affecting speed lies in the computational aspect of the CPU. GPU not fully utilized.
Do you have evidence of this? I'm not aware of any CPU-bound portion of the training loop and in my experiments I've had no issues with GPU utilization.
from turbozero.
In my experiments, keep num_simulations unchange, when max_nodes = 32, it cost 460 seconds,
when max_nodes = 600, it cost 2200 seconds,gpu Pwr:Usage/Cap is more lower than when max_nodes = 32. Test with mctx-az. Turbozero is same situation.
In this situation, because num_simulations unchange, so the computational workload of the GPU remains unchanged, so problem is in cpu?
from turbozero.
Some more details could be useful here.
What are you running, one call to search?
What is num_simulations
set to?
What environment are you using?
I'm not sure this is entirely unexpected behavior.
Computational workload is higher when max_nodes
is increased. Search operates on tensors of size [num_batches, max_nodes, ... ]
from turbozero.
environment: go 19*19
model_size: just like alphazero paper, 40 block, 256 channl
num_simulations: 128
batch size: 50 per device
step: 410
There is another strange phenomenon: as the number of steps increases, the gpu utilization rate becomes lower and lower
from turbozero.
Thank you for letting me know, I will look into what you are describing to see if I can replicate it and diagnose.
Are you running more than one call to search (MCTS.evaluate
)? Some of the weird behavior you are describing could be down to JIT-compilation overhead on the first call.
from turbozero.
It's use jax.lax.scan,like this: jax.lax.scan(one_step, (env_state, rng_key, step,eval_state), None, length=410, unroll=1)
JIT-compilation should only affect one times.
environment: go 19*19
model_size: just like alphazero paper, 40 block, 256 channl
num_simulations: 128
batch size: 50 per device
step: 410
when max_nodes = 32, it cost 460 seconds,
when max_nodes = 600, it cost 2200 seconds,
gpu Pwr:Usage/Cap is more lower than when max_nodes = 32.
Test with mctx-az. Turbozero is same situation.
from turbozero.
Thanks for pointing this out, I will see if I can replicate.
from turbozero.
Are you familiar with the MuZero algorithm? I have some questions and hope you can help me.
from turbozero.
I haven't worked with MuZero specifically as much but have read the paper. Feel free to send me an email with your questions and I'll see if I can answer.
I still plan on looking into the issues you mention but have been very busy and have not had a chance yet.
from turbozero.
Related Issues (9)
- LazyZero-based training sample commands fail with "invalid multinomial distribution" HOT 7
- Dirilecht instead of dirichlet in mcts HOT 1
- AlphaZero+MCTS: Visit probabilities for invalid actions can be non-zero HOT 3
- speed issue HOT 1
- allow for running w/ multiple gpus and provide an example HOT 1
- allow for user-specified data augmentation in Trainer HOT 1
- Batch MCTS is needed !!! HOT 7
- The key differences between this work and the implementation of alphazero in PGX
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from turbozero.