Giter Club home page Giter Club logo

Comments (34)

lowrollr avatar lowrollr commented on July 19, 2024 1

4e55c72

from turbozero.

Nightbringers avatar Nightbringers commented on July 19, 2024 1

And i need use mult gpus, I don't see something like pmap in code. Need a example that how to train use mult gpus.

from turbozero.

lowrollr avatar lowrollr commented on July 19, 2024 1

Supporting multiple GPUs should be straightforward to add but is not currently included. I can work on this and let you know when it is supported. Created a separate issue to track.

from turbozero.

lowrollr avatar lowrollr commented on July 19, 2024 1

If I were you, I would implement a custom class extending Trainer, and overwrite collect to place each of the symmetries into the replay buffer with self.memory_buffer.add_experience. You will need to make sure policy_mask and policy_weights are augmented consistently with your game state data.

class GoTrainer(Trainer):
   def collect(self,
        state: CollectionState,
        params: chex.ArrayTree
    ) -> CollectionState:
        step_key, new_key = jax.random.split(state.key)
        eval_output, new_env_state, new_metadata, terminated, rewards = \
            self.step_train(
                key = step_key,
                env_state = state.env_state,
                env_state_metadata = state.metadata,
                eval_state = state.eval_state,
                params = params
            )
        buffer_state = self.memory_buffer.add_experience(
            state = state.buffer_state,
            experience = BaseExperience(
                env_state=state.env_state,
                policy_mask=state.metadata.action_mask,
                policy_weights=eval_output.policy_weights,
                reward=jnp.empty_like(state.metadata.rewards)
            )
        )

       # generate symmetries here and add to replay memory just like above
        
        buffer_state = jax.lax.cond(
            terminated,
            lambda s: self.memory_buffer.assign_rewards(s, rewards),
            lambda s: s,
            buffer_state
        )

        return state.replace(
            key=new_key,
            eval_state=eval_output.eval_state,
            env_state=new_env_state,
            buffer_state=buffer_state,
            metadata=new_metadata
        )

from turbozero.

Nightbringers avatar Nightbringers commented on July 19, 2024 1

I seems find why this error happened.

from turbozero.

lowrollr avatar lowrollr commented on July 19, 2024 1

I am going to close this issue as the original problem has been resolved.

I hope to continue to expand the documentation of this project so that it is more clear as to how to approach unique use-cases.

Thank you for your questions and feedback! Please create another issue if you run in to problems and feel free to email me if you have more questions.

from turbozero.

lowrollr avatar lowrollr commented on July 19, 2024

Thank you for pointing this out!

All examples currently use the AlphaZero class, which implements update_root with the correct parameter ordering.

for MCTS parameter ordering is indeed incorrect

from turbozero.

Nightbringers avatar Nightbringers commented on July 19, 2024

Can you give a example that how to train use mult gpus?

I'm try to integrate with my custom JAX environment and have some problem. this is the key code:

def step_fn(state, action):
    new_state = env.step(state, action)
    return new_state, StepMetadata(
        rewards=new_state.rewards,
        action_mask=new_state.legal_action_mask,
        terminated=new_state.terminated,
        cur_player_id=new_state.current_player,
    )

def init_fn(key):
    state = env.init(key)
    return state, StepMetadata(
        rewards=state.rewards,
        action_mask=state.legal_action_mask,
        terminated=state.terminated,
        cur_player_id=state.current_player,
    )
az_evaluator = AlphaZero(MCTS)(
    eval_fn = eval_fn,
    num_iterations = 100,
    max_nodes = 200,
    branching_factor=82,
    action_selector = PUCTSelector()
)
def env_step_fn(state, action):
    new_state = env.step(state, action)
    return new_state, StepMetadata(
        rewards=new_state.rewards,
        action_mask=new_state.legal_action_mask,
        terminated=new_state.terminated,
        cur_player_id=new_state.current_player,
    )

eval_key, rng_key = jax.random.split(rng_key)
eval_keys = jax.random.split(eval_key, batch_size)
env, env_state_metadata = jax.vmap(init_fn)(eval_keys)`
evaluator_init = partial(az_evaluator.init, template_embedding=env)
eval_state = jax.vmap(evaluator_init)(eval_keys)
output = az_evaluator.evaluate(
            eval_state=eval_state,
            env_state=env,
            root_metadata=env_state_metadata,
            params=param,
            env_step_fn=env_step_fn
        )

this is the error:

eval_state = self.update_root(eval_state, env_state, root_metadata, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/turbozero/core/evaluators/alphazero.py", line 31, in update_root
key, tree = get_rng(tree)
^^^^^^^^^^^^^
File "/turbozero/core/trees/tree.py", line 143, in get_rng
rng, new_rng = jax.random.split(tree.key, 2)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/lib/python3.11/site-packages/jax/_src/random.py", line 303, in split
return _return_prng_keys(wrapped, _split(typed_key, num))
^^^^^^^^^^^^^^^^^^^^^^
File "/python3.11/site-packages/jax/_src/random.py", line 286, in _split
raise TypeError("split accepts a single key, but was given a key array of"
TypeError: split accepts a single key, but was given a key array ofshape (100,) != (). Use jax.vmap for batching.

from turbozero.

lowrollr avatar lowrollr commented on July 19, 2024

evaluate should be vmapped, try something like this:

eval_key, env_key, rng_key = jax.random.split(rng_key, 3)
eval_keys = jax.random.split(eval_key, batch_size)
env_keys = jax.random.split(env_key, batch_size)

env_state, env_state_metadata = jax.vmap(init_fn)(env_keys)

# template embedding should not have a batch dimension
template_env_state, _ = init_fn(jax.random.PRNGKey(0))
evaluator_init = partial(az_evaluator.init, template_embedding=template_env_state)
eval_state = jax.vmap(evaluator_init)(eval_keys)

evaluate = partial(az_evaluator.evaluate, 
                   env_step_fn=env_step_fn,
                   params=param)

output = jax.vmap(evaluate)(
        eval_state=eval_state,
        env_state=env_state,
        root_metadata=env_state_metadata)

I recommend using the Trainer class as described here, https://github.com/lowrollr/turbozero/blob/main/notebooks/hello_world.ipynb

I haven't fully documented a lot of the underlying classes yet which do have their peculiarities -- Trainer should be more straightforward to work with.

from turbozero.

Nightbringers avatar Nightbringers commented on July 19, 2024

thanks,now have a new problem.

File "/turbozero/core/evaluators/mcts/mcts.py", line 81, in evaluate
eval_state = self.update_root(eval_state, env_state, root_metadata, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/turbozero/core/evaluators/alphazero.py", line 56, in update_root
return set_root(tree, root_node)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/turbozero/core/trees/tree.py", line 121, in set_root
data=jax.tree_util.tree_map(
^^^^^^^^^^^^^^^^^^^^^^^
File "/turbozero/core/trees/tree.py", line 122, in
lambda x, y: x.at[tree.ROOT_INDEX].set(y),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py", line 497, in set
return scatter._scatter_update(self.array, self.index, values, lax.scatter,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/lib/python3.11/site-packages/jax/_src/ops/scatter.py", line 80, in _scatter_update
return _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/lib/python3.11/site-packages/jax/_src/ops/scatter.py", line 115, in _scatter_impl
y = jnp.broadcast_to(y, tuple(indexer.slice_shape))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 1227, in broadcast_to
return util._broadcast_to(array, shape)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "//lib/python3.11/site-packages/jax/_src/numpy/util.py", line 425, in _broadcast_to
for arr_d, shape_d in safe_zip(arr_shape, shape_tail))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: safe_zip() argument 2 is shorter than argument 1

from turbozero.

lowrollr avatar lowrollr commented on July 19, 2024

Could you share your code?

from turbozero.

Nightbringers avatar Nightbringers commented on July 19, 2024
def one_step(prev):
        """Execute one self-play move using MCTS.
        """
        env_state, rng_key, step, eval_state = prev
        rng_key, rng_key_next = jax.random.split(rng_key, 2)
        env_state_metadata = StepMetadata(
            rewards=env_state.rewards,
            action_mask=env_state.legal_action_mask,
            terminated=env_state.terminated,
            cur_player_id=env_state.current_player,
        )
        terminated = env_state.terminated

        output = jax.vmap(evaluate)(
            eval_state=eval_state,
            env_state=env_state,
            root_metadata=env_state_metadata)

        env_state = step_fn_move(env_state, output.action)

        eval_state = output.eval_state
        eval_state = az_evaluator.step(eval_state, output.action)

        return (env_state, rng_key_next, step + 1, env_new3, eval_state)

from turbozero.

Nightbringers avatar Nightbringers commented on July 19, 2024

it use jax.lax.scan:

output = jax.lax.scan(one_step, (env_state, rng_key, step,eval_state), None, length=400, unroll=1)

maybe this place is incorrect?

eval_state = output.eval_state
eval_state = az_evaluator.step(eval_state, output.action)

from turbozero.

lowrollr avatar lowrollr commented on July 19, 2024

The error suggests to me that some of the data might have been passed to evaluate without a batch dimension, just from debugging similar errors before.

Trainer already implements the necessary functions to collect episodes and progress the env and evaluator to the next state, I'm curious why you are implementing these yourself? Is there a feature you need that's missing or something that's confusing or unclear? I'm hopeful that most users won't need to provide anything besides environment dynamics functions.

from turbozero.

Nightbringers avatar Nightbringers commented on July 19, 2024

yes,there a feature i need that's missing, like go self-play data process, it need symmetries.

from turbozero.

lowrollr avatar lowrollr commented on July 19, 2024

Could you describe how the feature should work? I can help you find the right spot to integrate it.

from turbozero.

Nightbringers avatar Nightbringers commented on July 19, 2024

Could you describe how the feature should work? I can help you find the right spot to integrate it.

in go game, self-play data usually augmentation by something like np.rot90 and np.fliplr.

from turbozero.

Nightbringers avatar Nightbringers commented on July 19, 2024

about that error, could be this place need vmap?

eval_state = output.eval_state
eval_state = az_evaluator.step(eval_state, output.action)

from turbozero.

lowrollr avatar lowrollr commented on July 19, 2024

about that error, could be this place need vmap?

eval_state = output.eval_state
eval_state = az_evaluator.step(eval_state, output.action)

Yes, every function that operates on an evaluator state assumes a singular input rather than a batch and should be vmapped.

Could you describe how the feature should work? I can help you find the right spot to integrate it.

in go game, self-play data usually augmentation by something like np.rot90 and np.fliplr.

If I were you, I would implement a custom class extending Trainer, and overwrite collect to place each of the symmetries into the replay buffer with self.memory_buffer.add_experience. You will need to make sure policy_mask and policy_weights are augmented consistently with your game state data.

here: https://github.com/lowrollr/turbozero/blob/main/core/training/train.py#L103-L140

You should only need to extend the behavior of collect.

It might be a good idea for me to allow for users to specify any number of transforms to apply to augment experiences prior to storing in replay memory -- this is a fairly common use-case so it ideally should not require a custom class.

from turbozero.

Nightbringers avatar Nightbringers commented on July 19, 2024

yes, I'm going to try this, but if can't use mult gpus, training will be slow.

I change to this:

eval_state = output.eval_state
eval_state = jax.vmap(az_evaluator.step)(eval_state, output.action)

But error occurred before the code reached that.
that error is in this:

output = jax.vmap(evaluate)(
            eval_state=eval_state,
            env_state=env_state,
            root_metadata=env_state_metadata)

and i print(state.observation.shape) in eval_fn, it shows (19, 19, 17) , without batch dimension, is this normal?

from turbozero.

Nightbringers avatar Nightbringers commented on July 19, 2024

The speed that I test was slow. Then I test https://github.com/lowrollr/turbozero/blob/main/notebooks/hello_world.ipynb this example, seems also slow from gpu utilization.

what is max_nodes means? I found that max_nodes has a significant impact on speed; the larger max_nodes is, the slower the speed.
mctx-az is also in this situation. and mctx-az is faster than turbozero that in my test.

from turbozero.

lowrollr avatar lowrollr commented on July 19, 2024

max_nodes reflects the maximum capacity of the tree. Trees cannot be sized dynamically so must have a maximum number of nodes set prior to collection. It makes sense that performance gets worse as max_nodes increases, this ultimately makes it so operations are on larger matrices.

I am aware that the backend is currently less performant than mctx. It is a priority for me to fix this.

from turbozero.

Nightbringers avatar Nightbringers commented on July 19, 2024

I'm still confuse about max_nodes.
if max_nodes=100, is it means every node have 100 child node? or is it means that will search 100 node at most in this evaluate?
why you suggest it larger than num_simulations?that will be very very slow.

yes, I look forward to it running faster.

from turbozero.

Nightbringers avatar Nightbringers commented on July 19, 2024

I think the most factor affecting speed lies in the computational aspect of the CPU. GPU not fully utilized.

from turbozero.

lowrollr avatar lowrollr commented on July 19, 2024

max_nodes refers to the maximum capacity of the tree -- so yes it does mean a tree with max_nodes=100 will at most evaluate 100 distinct game states.

I advise setting it higher than num_iterations because this implementation re-uses subtrees from a previous search -- so most of the time the tree will already be partially populated when a new search is started. Setting max_nodes higher than num_iterations means that there will be room for num_iterations nodes in the tree more often. I have yet to document out-of-bounds behavior but it works similarly to https://github.com/lowrollr/mctx-az.

Increasing max_nodes linearly increases the memory footprint of the search tree data structure. It's definitely a trade-off of speed vs. accuracy to set it higher/lower, and its value relative to 'num_iterations' should be problem-dependent (branching factor and # of iterations both matter). Setting max_nodes = num_iterations is fine if you're worried about speed.

from turbozero.

lowrollr avatar lowrollr commented on July 19, 2024

I think the most factor affecting speed lies in the computational aspect of the CPU. GPU not fully utilized.

Do you have evidence of this? I'm not aware of any CPU-bound portion of the training loop and in my experiments I've had no issues with GPU utilization.

from turbozero.

Nightbringers avatar Nightbringers commented on July 19, 2024

In my experiments, keep num_simulations unchange, when max_nodes = 32, it cost 460 seconds,
when max_nodes = 600, it cost 2200 seconds,gpu Pwr:Usage/Cap is more lower than when max_nodes = 32. Test with mctx-az. Turbozero is same situation.
In this situation, because num_simulations unchange, so the computational workload of the GPU remains unchanged, so problem is in cpu?

from turbozero.

lowrollr avatar lowrollr commented on July 19, 2024

Some more details could be useful here.

What are you running, one call to search?
What is num_simulations set to?
What environment are you using?

I'm not sure this is entirely unexpected behavior.

Computational workload is higher when max_nodes is increased. Search operates on tensors of size [num_batches, max_nodes, ... ]

from turbozero.

Nightbringers avatar Nightbringers commented on July 19, 2024

environment: go 19*19
model_size: just like alphazero paper, 40 block, 256 channl
num_simulations: 128
batch size: 50 per device
step: 410

There is another strange phenomenon: as the number of steps increases, the gpu utilization rate becomes lower and lower

from turbozero.

lowrollr avatar lowrollr commented on July 19, 2024

Thank you for letting me know, I will look into what you are describing to see if I can replicate it and diagnose.

Are you running more than one call to search (MCTS.evaluate)? Some of the weird behavior you are describing could be down to JIT-compilation overhead on the first call.

from turbozero.

Nightbringers avatar Nightbringers commented on July 19, 2024

It's use jax.lax.scan,like this: jax.lax.scan(one_step, (env_state, rng_key, step,eval_state), None, length=410, unroll=1)

JIT-compilation should only affect one times.
environment: go 19*19
model_size: just like alphazero paper, 40 block, 256 channl
num_simulations: 128
batch size: 50 per device
step: 410
when max_nodes = 32, it cost 460 seconds,
when max_nodes = 600, it cost 2200 seconds,
gpu Pwr:Usage/Cap is more lower than when max_nodes = 32.
Test with mctx-az. Turbozero is same situation.

from turbozero.

lowrollr avatar lowrollr commented on July 19, 2024

Thanks for pointing this out, I will see if I can replicate.

from turbozero.

Nightbringers avatar Nightbringers commented on July 19, 2024

Are you familiar with the MuZero algorithm? I have some questions and hope you can help me.

from turbozero.

lowrollr avatar lowrollr commented on July 19, 2024

I haven't worked with MuZero specifically as much but have read the paper. Feel free to send me an email with your questions and I'll see if I can answer.

I still plan on looking into the issues you mention but have been very busy and have not had a chance yet.

from turbozero.

Related Issues (9)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.