Giter Club home page Giter Club logo

mctx's Introduction

Mctx: MCTS-in-JAX

Mctx is a library with a JAX-native implementation of Monte Carlo tree search (MCTS) algorithms such as AlphaZero, MuZero, and Gumbel MuZero. For computation speed up, the implementation fully supports JIT-compilation. Search algorithms in Mctx are defined for and operate on batches of inputs, in parallel. This allows to make the most of the accelerators and enables the algorithms to work with large learned environment models parameterized by deep neural networks.

Installation

You can install the latest released version of Mctx from PyPI via:

pip install mctx

or you can install the latest development version from GitHub:

pip install git+https://github.com/google-deepmind/mctx.git

Motivation

Learning and search have been important topics since the early days of AI research. In the words of Rich Sutton:

One thing that should be learned [...] is the great power of general purpose methods, of methods that continue to scale with increased computation even as the available computation becomes very great. The two methods that seem to scale arbitrarily in this way are search and learning.

Recently, search algorithms have been successfully combined with learned models parameterized by deep neural networks, resulting in some of the most powerful and general reinforcement learning algorithms to date (e.g. MuZero). However, using search algorithms in combination with deep neural networks requires efficient implementations, typically written in fast compiled languages; this can come at the expense of usability and hackability, especially for researchers that are not familiar with C++. In turn, this limits adoption and further research on this critical topic.

Through this library, we hope to help researchers everywhere to contribute to such an exciting area of research. We provide JAX-native implementations of core search algorithms such as MCTS, that we believe strike a good balance between performance and usability for researchers that want to investigate search-based algorithms in Python. The search methods provided by Mctx are heavily configurable to allow researchers to explore a variety of ideas in this space, and contribute to the next generation of search based agents.

Search in Reinforcement Learning

In Reinforcement Learning the agent must learn to interact with the environment in order to maximize a scalar reward signal. On each step the agent must select an action and receives in exchange an observation and a reward. We may call whatever mechanism the agent uses to select the action the agent's policy.

Classically, policies are parameterized directly by a function approximator (as in REINFORCE), or policies are inferred by inspecting a set of learned estimates of the value of each action (as in Q-learning). Alternatively, search allows to select actions by constructing on the fly, in each state, a policy or a value function local to the current state, by searching using a learned model of the environment.

Exhaustive search over all possible future courses of actions is computationally prohibitive in any non trivial environment, hence we need search algorithms that can make the best use of a finite computational budget. Typically priors are needed to guide which nodes in the search tree to expand (to reduce the breadth of the tree that we construct), and value functions are used to estimate the value of incomplete paths in the tree that don't reach an episode termination (to reduce the depth of the search tree).

Quickstart

Mctx provides a low-level generic search function and high-level concrete policies: muzero_policy and gumbel_muzero_policy.

The user needs to provide several learned components to specify the representation, dynamics and prediction used by MuZero. In the context of the Mctx library, the representation of the root state is specified by a RootFnOutput. The RootFnOutput contains the prior_logits from a policy network, the estimated value of the root state, and any embedding suitable to represent the root state for the environment model.

The dynamics environment model needs to be specified by a recurrent_fn. A recurrent_fn(params, rng_key, action, embedding) call takes an action and a state embedding. The call should return a tuple (recurrent_fn_output, new_embedding) with a RecurrentFnOutput and the embedding of the next state. The RecurrentFnOutput contains the reward and discount for the transition, and prior_logits and value for the new state.

In examples/visualization_demo.py, you can see calls to a policy:

policy_output = mctx.gumbel_muzero_policy(params, rng_key, root, recurrent_fn,
                                          num_simulations=32)

The policy_output.action contains the action proposed by the search. That action can be passed to the environment. To improve the policy, the policy_output.action_weights contain targets usable to train the policy probabilities.

We recommend to use the gumbel_muzero_policy. Gumbel MuZero guarantees a policy improvement if the action values are correctly evaluated. The policy improvement is demonstrated in examples/policy_improvement_demo.py.

Example projects

The following projects demonstrate the Mctx usage:

  • Pgx — A collection of 20+ vectorized JAX environments, including backgammon, chess, shogi, Go, and an AlphaZero example.
  • Basic Learning Demo with Mctx — AlphaZero on random mazes.
  • a0-jax — AlphaZero on Connect Four, Gomoku, and Go.
  • muax — MuZero on gym-style environments (CartPole, LunarLander).
  • Classic MCTS — A simple example on Connect Four.
  • mctx-az — Mctx with AlphaZero subtree persistence.

Tell us about your project.

Citing Mctx

This repository is part of the DeepMind JAX Ecosystem, to cite Mctx please use the citation:

@software{deepmind2020jax,
  title = {The {D}eep{M}ind {JAX} {E}cosystem},
  author = {DeepMind and Babuschkin, Igor and Baumli, Kate and Bell, Alison and Bhupatiraju, Surya and Bruce, Jake and Buchlovsky, Peter and Budden, David and Cai, Trevor and Clark, Aidan and Danihelka, Ivo and Dedieu, Antoine and Fantacci, Claudio and Godwin, Jonathan and Jones, Chris and Hemsley, Ross and Hennigan, Tom and Hessel, Matteo and Hou, Shaobo and Kapturowski, Steven and Keck, Thomas and Kemaev, Iurii and King, Michael and Kunesch, Markus and Martens, Lena and Merzic, Hamza and Mikulik, Vladimir and Norman, Tamara and Papamakarios, George and Quan, John and Ring, Roman and Ruiz, Francisco and Sanchez, Alvaro and Sartran, Laurent and Schneider, Rosalia and Sezener, Eren and Spencer, Stephen and Srinivasan, Srivatsan and Stanojevi\'{c}, Milo\v{s} and Stokowiec, Wojciech and Wang, Luyu and Zhou, Guangyao and Viola, Fabio},
  url = {http://github.com/deepmind},
  year = {2020},
}

mctx's People

Contributors

carlosgmartin avatar fidlej avatar hawkinsp avatar hbq1 avatar mbrukman avatar mtthss avatar rchen152 avatar suryabhupa avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

mctx's Issues

Basic MCTS example

Hi, I really appreciate your library.
I will be using it for my thesis project and need to understand how it works.

For this purpose, I used it to implement classic MCTS with random rollouts in a jupyter notebook.
https://github.com/Carbon225/mctx-classic
I want it to be as informative as possible and to explain why every line is the way it is so my teammates can understand it as well. Feel free to add/ignore this example in your readme.

Below I will describe the last aspect I don't feel like I understand.

Consider this definition of the recurrent_fn:

def recurrent_fn(params, rng_key, action, embedding):
    env = embedding
    env, reward, done = env_step(env, action)
    recurrent_fn_output = mctx.RecurrentFnOutput(
        reward=reward,
        discount=jnp.where(done, 0, -1).astype(jnp.float32),
        prior_logits=my_policy_function(env),
        value=jnp.where(done, 0, my_value_function(env, rng_key)).astype(jnp.float32),
    )
    return recurrent_fn_output, env

I have read that the terminal node is considered absorbing. From my understanding this means that below this node all rewards and values should be 0. This should be guaranteed by setting discount to 0 in the RecurrentFnOutput of the terminal node. In other examples of mctx I have seen people setting the value to 0 at the terminal node as well.
Which is correct? When should the reward/value/discount be set to 0?

I also believe the reward field is never actually used by the search. It's only used for training outside the mctx library, correct?

git protocol error for pip, following readme

Recieve an error when using pip install git+git://github.com/deepmind/mctx.git from the readme that git protocol is no longer supported and to visit Please see https://github.blog/2021-09-01-improving-git-protocol-security-github/ for more information

pip install git+git://github.com/deepmind/mctx.git

Collecting git+git://github.com/deepmind/mctx.git
  Cloning git://github.com/deepmind/mctx.git to /private/var/folders/s1/njpm22952xg9f89hvj4rd60c0000gn/T/pip-req-build-6009_u2_
  Running command git clone -q git://github.com/deepmind/mctx.git /private/var/folders/s1/njpm22952xg9f89hvj4rd60c0000gn/T/pip-req-build-6009_u2_
  fatal: remote error:

  **The unauthenticated git protocol on port 9418 is no longer supported.**
  Please see https://github.blog/2021-09-01-improving-git-protocol-security-github/ for more information.

WARNING: Discarding git+git://github.com/deepmind/mctx.git. Command errored out with exit status 128: git clone -q git://github.com/deepmind/mctx.git /private/var/folders/s1/njpm22952xg9f89hvj4rd60c0000gn/T/pip-req-build-6009_u2_ Check the logs for full command output.
ERROR: Command errored out with exit status 128: git clone -q git://github.com/deepmind/mctx.git /private/var/folders/s1/njpm22952xg9f89hvj4rd60c0000gn/T/pip-req-build-6009_u2_ Check the logs for full command output.

An end-to-end training example project on gym environment

Thanks for open sourcing the great library!
I believe there are people interested in MuZero and its capacity on Atari games, and want to try it on gym environments. Also, instead of using the env.step() as the dynamic inside recurrent_fn, some people may be interested in using neural network to learn the dynamic.
I am one of them, and have written some code to support using mctx library. I also shared an example of end-to-end training on gym CartPole env. Please check my project muax and the cartpole example

Question about `simulate` function in `search.py`

Thanks for building this wonderful tool! I am currently going through the source code of search.py to try and understand how this is implemented in practice. I found myself wondering about the usage of the simulate() function.

As far as I understood, MCTS has four phases: selection, expansion, simulation and backpropagation. In the main search function, there is no call to a select() type function and the first function that is actually called is simulate(). When going into the code and comments for simulate(), however, I get the feeling that this is actually implementing the select phase as it is traversing the search tree until it encounters a node that it did not yet visit after which it is expanded by the expand() function.

My question is then: am I correct in understanding that the simulate() function is more akin to the selection phase? Moreover, given that this would imply that there is no real simulation phase, am I also correct in understanding that there is no simulation but rather a network is queried to obtain the value of a node without doing rollouts? I think this is what is happening in the code below.

step, embedding = recurrent_fn(params, rng_key, action, embedding)
).

Any help to clarify this issue would be much appreciated!

Can we use neural network define by tensorflow as recurrent_fn in muzero?

We are using muzero , and define network and loss with tensorflow and mcts with naive python sentences;
But the mcts is very slow here, so we want to use mctx as a black box to replace our mcts, but use neural network in tensorflow.
We are not familiar with jax, the issue is whether we can just use a tensorflow module as parameter recurrent_fn in muzero_policy.

Question about muzero network in go game

In paper, a bigger 32 layer network in the large-scale 19x19 Go experiments. The networks used 256 hidden planes, 128 bottleneck planes and a broadcasting block in every 8th layer.

Muzero have representation, dynamics, and prediction. Representation, dynamics, and prediction all use this 32 layer network? Should representation, dynamics, and prediction use same size network or different size network?

Paper saids Muzero learn fast than alphazero in go. But if they has same size network, then Muzero is 2 times slower than alphazero. Because every move Muzero need two network, and alphazero only need one. (if representation, dynamics, prediction, alphazero has same size network) . So alphazero should learn fast than muzero?

If alphazero use a bigger network, every inference time equal to Muzero's two network inference time, in ths case, can muzero outperform alphazero?

Muzero training go use K = 5 steps, I found that the larger the value of K, the more GPU memory it consumes. Is there a big difference between training with k=3 and training with k=5?

Understanding RootFnOutput

Hello!

I'm trying to use mctx library to train MuZero agent and have some troubles with understanding RootFnOutput.

I have two methods .root and .recurrent:

def root(
    self, env_state: Float[Array, "b w1 h1 c1"], train: bool = True
) -> mctx.RootFnOutput:
    initial_state = self.representation(env_state, train)
    policy_logits, value = self.prediction(initial_state, train)
    return mctx.RootFnOutput(
        prior_logits=policy_logits,
        value=value,
        embedding=initial_state,
    )

def recurrent(
    self,
    rng_key,
    action: Int[Array, "b"],
    state: Float[Array, "b w h c"],
    train: bool = True,
) -> Tuple[mctx.RecurrentFnOutput, Float[Array, "b w h c"]]:
    policy_logits, value = self.prediction(state, train)
    new_state, reward = self.dynamics(action, state, train)

    return (
        mctx.RecurrentFnOutput(
            reward=reward,
            discount=jnp.full(reward.shape, self.config.discount),
            prior_logits=policy_logits,
            value=value,
        ),
        new_state,
    )

which I want to use to make a call to .muzero_policy:

initial_state = jax.random.normal(provider(), (1, 84, 84, 6))
root = model.apply(params, initial_state, train=False, method=model.root)

res = mctx.muzero_policy(
    params,
    provider(),
    root=root,
    recurrent_fn=partial(model.apply, method=model.recurrent, train=False),
    num_simulations=4,
)

print(res.search_tree.raw_values)
print(res.search_tree.node_values)

"""
[[0.80960083 0.80960083 2.0920973  2.2743711  2.7579222 ]]
[[3.5309532 3.0515175 3.0072536 2.7471337 2.7579222]]
"""

As a result for raw values I obtain an array with the first two elements being equal.
Could you please help me to understand whether this is an intended behaviour or I'm missing something? It feels weird that mctx.RootFnOutput should contain value and prior_logits fields, while the documentation says that it contains the output of a representation network:

@chex.dataclass(frozen=True)
class RootFnOutput:
  """The output of a representation network.

  prior_logits: `[B, num_actions]` the logits produced by a policy network.
  value: `[B]` an approximate value of the current state.
  embedding: `[B, ...]` the inputs to the next `recurrent_fn` call.
  """
  prior_logits: chex.Array
  value: chex.Array
  embedding: RecurrentState

I would like to ask a question and hope you can help me answer it.

In muzero paper, it has some transform about value and reward. this is paper describe:
For value and reward prediction in Atari we follow [30] in scaling targets using an invertible transform h(x) = sign(x)(p|x| + 1 − 1 + εx), where ε = 0.001 in all our experiments. We then apply a transformation φ to the scalar reward and value targets in order to obtain equivalent categorical representations. We use a discrete support set of size 601 with one support for every integer between −300 and 300. Under this transformation, each scalar is represented as the linear combination of its two adjacent supports, such that the original value can be recovered by x = xlow ∗ plow + xhigh ∗ phigh. As an example, a target of 3.7 would be represented as a weight of 0.3 on the support for 3 and a weight of 0.7 on the support for 4. The value and reward outputs of the network are also modeled using a softmax output of size 601. During inference the actual value and rewards are obtained by first computing their expected value under their respective softmax distribution and subsequently by inverting the scaling transformation. Scaling and transformation of the value and reward happens transparently on the network side and is not visible to the rest of the algorithm.

questions1: That's saids in Atari game. I want know in go game, is there still this thing? Or it just like alphazero, value is a single number between -1 and 1?

here is another describe in paper:
Note that, in board games without intermediate rewards, we omit the reward prediction loss. For board games, we bootstrap directly to the end of the game, equivalent to predicting the final outcome

questions2: Does this means reward are useless in go game? we could just set to zero all the time?

PMAP w/ mctx policy

How would one go about pmap'ing the mctx policy fn of choice when the root (argument to the policy) contains an embedding (arg to the root constructor) that is a data class. I would like to map across the 0-axis in all attributes of the data class, but when specifying 0 in the in-axes argument of pmap (e.g: mctx.gumbel_muzero_policy, in_axes(0, ... )(root, ......) it throws

"ValueError: pmap in_axes specification must be a tree prefix of the corresponding value"

My guess is that specifying an axis to map over in a dataclass where each attr contains arrays (that are the ultimate target to be mapped over) is not supported. Any suggested work arounds?

root = mctx.RootFnOutput(
    prior_logits=policy_logits,
    value=value_scalar,
    embedding=env_state) # env_state is a dataclass 

recurrent_fn = get_recurrent_fn(env, env_state, env_params, net_func)

key, subkey = jx.random.split(key)

policy_output_maml = pmap(mctx.gumbel_muzero_policy, in_axes=(None, None, None, 0, None, None, None))(
    params=net_params,
    max_depth=config.max_search_depth,
    rng_key=subkey,
    root=root,
    recurrent_fn=recurrent_fn,
    num_simulations=config.num_simulations,
    qtransform=partial(
        mctx.qtransform_completed_by_mix_value,
        use_mixed_value=config.use_mixed_value,
        value_scale=config.value_scale,
        rescale_values=True))

Thanks in advance.

Request to add link to Pgx repository in README

Hello,

I hope this message finds you well. I am reaching out to kindly request the addition of a link to my project, Pgx, in the README file of the Mctx repository.

Pgx is a collection of vectorized board game environments written in JAX, featuring over 20 games including backgammon, chess, shogi, and Go. The repository also includes an AlphaZero example by Mctx, showcasing reasonable learning in Go 9x9 and other medium-sized games.

I believe this link would be a valuable addition for Mctx users, providing them with easy access to a collection of JAX-native game environments as testbed.

Thank you for considering my request. I look forward to your positive response.

Best,

Question about the improved policy in Gumbel MuZero

Hi, thanks for open sourcing the great library!

I'm using it to experiment with MCTS on a project, and I have a question regarding the function \sigma used in constructing the improved policy: \pi'=softmax(logits+\sigma completedQ).

I noticed that the scale of the function determines the weight between logits and completedQ, which in turn affects the child selection in MCTS. In my experiments with tictactoe, I discovered that a larger weight of completedQ may lead to a higher reward in MCTS.

The paper only states that \sigma is a monotonically increasing function, but I was wondering if there are any other limitations or discussions on the format of \sigma?

Thank you!

A possible memory leak when use not_jitted mode in policy_imporvement_demo.py

I am using gumbel_muzero_policy without jit, and find a memory leak, so I try to use tracemalloc to trace the memory in policy_improvement_demo.py.
When I remove jax.jit in the main function and print the memory allocation, it keep increasing.
image
when I make it jit:
image

I don't know whether this is a feature of JAX, or how can I clean memory by myself in not-jitted mode ?
Here is the change of code:

import tracemalloc

def main(_):
    rng_key = jax.random.PRNGKey(FLAGS.seed)
    # the different is just next two line, jit or not;
    run_demo = jax.jit(_run_demo)
    # run_demo = _run_demo
    tracemalloc.start(10)
    snapshot1 = tracemalloc.take_snapshot()
    for _ in range(FLAGS.num_runs):
        rng_key, output = run_demo(rng_key)
        snapshot2 = tracemalloc.take_snapshot()
        top_stats = snapshot2.compare_to(snapshot1, 'lineno')
        total = sum(stat.size for stat in top_stats)
        print("Total allocated size: %.1f MiB" % (total / 1024 / 1024))

if __name__ == "__main__":
    app.run(main)

Question regarding `qtransform_by_parent_and_siblings` in `muzero_policy`

Hello,

I have a question about the qtransform_by_parent_and_siblings function used in muzero_policy as default QTranform. It appears to implement normalization differently from the method described in the original MuZero paper and its pseudocode. While the original MuZero normalizes using the min/max values of the entire tree, qtransform_by_parent_and_siblings seems to use the min/max values of sibling nodes for normalization. Could you please clarify the origin of this function? Is there any specific reference or source for this approach?

Description in the original MuZero paper image
Pseudo-code in the original MuZero paper
class MinMaxStats(object):
  """A class that holds the min-max values of the tree."""

  def __init__(self, known_bounds: Optional[KnownBounds]):
    self.maximum = known_bounds.max if known_bounds else -MAXIMUM_FLOAT_VALUE
    self.minimum = known_bounds.min if known_bounds else MAXIMUM_FLOAT_VALUE

  def update(self, value: float):
    self.maximum = max(self.maximum, value)
    self.minimum = min(self.minimum, value)

  def normalize(self, value: float) -> float:
    if self.maximum > self.minimum:
      # We normalize only when we have set the maximum and minimum values.
      return (value - self.minimum) / (self.maximum - self.minimum)
    return value

As an additional note, in my preliminary experiments with a AlphaZero-style training in 9x9 Go, I observed that qtransform_by_parent_and_siblings seems to perform better than the original tree-wide normalization (and qtransform_by_min_max).

Thank you for your time and assistance.

Two player games support?

Will there be two player games support from mctx?

If I understand it correctly, technically it could possible to use the current mctx for two players. We may pass a flag indicating the current play in the root inference and flip the values based on the flag in the recurrent inference. Am I right?

Stochastic MuZero issues invalid actions and outcomes

Example:

import jax
import mctx
from jax import numpy as jnp, random


def main():
    n_actions = 7
    n_outcomes = 3
    batch_size = 1

    root = mctx.RootFnOutput(  # type: ignore
        prior_logits=jnp.zeros([batch_size, n_actions]),
        value=jnp.zeros(batch_size),
        embedding=jnp.zeros(batch_size),
    )

    def decision_recurrent_fn(params, key, action, state):
        jax.debug.print("action: {}", action)
        afterstate = state
        output = mctx.DecisionRecurrentFnOutput(  # type: ignore
            chance_logits=jnp.zeros([batch_size, n_outcomes]),
            afterstate_value=jnp.zeros(batch_size),
        )
        return output, afterstate

    def chance_recurrent_fn(params, key, outcome, afterstate):
        jax.debug.print("outcome: {}", outcome)
        state = afterstate
        output = mctx.ChanceRecurrentFnOutput(  # type: ignore
            action_logits=jnp.zeros([batch_size, n_actions]),
            value=jnp.zeros(batch_size),
            reward=jnp.zeros(batch_size),
            discount=jnp.ones(batch_size),
        )
        return output, state

    mctx.stochastic_muzero_policy(
        params={},
        rng_key=random.PRNGKey(0),
        root=root,
        decision_recurrent_fn=decision_recurrent_fn,
        chance_recurrent_fn=chance_recurrent_fn,
        num_simulations=20,
    )


if __name__ == "__main__":
    main()

Output:

action: [0]
action: [1]
outcome: [-6]
action: [7]
outcome: [0]
action: [4]
outcome: [-3]
action: [8]
outcome: [1]
action: [5]
outcome: [-2]
action: [6]
outcome: [-1]
action: [0]
outcome: [-7]
action: [3]
outcome: [-4]
action: [2]
outcome: [-5]
action: [9]
outcome: [2]
action: [5]
outcome: [-2]
action: [7]
outcome: [0]
action: [5]
outcome: [-2]
action: [7]
outcome: [0]
action: [7]
outcome: [0]
action: [7]
outcome: [0]
action: [7]
outcome: [0]
action: [7]
outcome: [0]
action: [2]
outcome: [-5]
action: [2]
outcome: [-5]

The actions range from 0 to 9 (7+3=10 in total), even though there are only 7 actions.
The outcomes range from -7 (a negative integer!) to 2 (7+3=10 in total), even though there are only 3 outcomes.

This may have something to do with the math inside stochastic_recurrent_fn.

Version information:

$ python3 --version
Python 3.11.6
$ python3 -c "import mctx; print(mctx.__version__)"
0.0.5
$ python3 -c "import jax; print(jax.__version__)"
0.4.23
$ python3 -c "import jaxlib; print(jaxlib.__version__)"
0.4.23

Automatically determine num_actions and num_chance_outcomes in stochastic_muzero_policy

It is possible to automatically determine the num_actions and num_chance_outcomes parameters to stochastic_muzero_policy from the root, decision_recurrent_fn, and chance_recurrent_fn parameters, via ChanceRecurrentFnOutput.action_logits.shape and DecisionRecurrentFnOutput.chance_logits.shape. That would reduce the number of parameters that users need to pass in by two.

I suggest allowing num_actions and num_chance_outcomes to be None, making them None by default, and determining their values automatically if they're None. This preserves backward compatibility. Thoughts?

Few questions about training and num_simulations

In alzero paper, the elo of go is exceed 5000, but in Gumbel paper, the elo of go is below 3000. why?

if a agent training use num_simulations==800, then i continue train use num_simulations==400, what will happen? Will it keep getting stronger or will it get lower?

evaluation use different num_simulations is it equivalent? in my training, The speed of the evaluation affect the speed of training. So I want to know,if agent1 > agent2 in 100 num_simulations. Does it means agent1 > agent2 in 200 num_simulation? and in 300 num_simulation? and in 400 num_simulation?

How many num_simulations do you recommend for training go? Is there a big difference between 400 searches and 800 searches and 1600 searches or is it will be almost same strong eventually. Or 1600 searches > 800 searches > 400 searches?

Root replacement with MCTX library

In the AlphaZero paper, for example, the authors state that, after executing an action decided upon through MCTS, they make the next state's node the root of the search tree and continue their search on this subtree.

As I see, calling the search method inside mctx/search.py always creates a fresh tree, i.e., by calling instantiate_tree_from_root.

I'm pretty surprised by this because, to me, it seems like such a root replacement strategy should have a detrimental impact on the sample efficiency of MCTS. Please correct me if I'm wrong.

Issues with Stochastic MuZero

I'm having issues with mctx.stochastic_muzero_policy. Here's an example:

import jax
import mctx
from jax import numpy as jnp

num_actions = 4
num_chance_outcomes = 2


def decision_recurrent_fn(params, key, action, state):
    return (
        mctx.DecisionRecurrentFnOutput(
            chance_logits=jnp.full(num_chance_outcomes, 0.0),
            afterstate_value=jnp.array(0.0),
        ),
        state,
    )


def chance_recurrent_fn(params, key, action, afterstate):
    return (
        mctx.ChanceRecurrentFnOutput(
            action_logits=jnp.full(num_actions, 0.0),
            value=jnp.array(0.0),
            # reward=jnp.array(1.),
            reward=1 + (action == 0) * 100,
            discount=jnp.array(0.0),
        ),
        afterstate,
    )


def root_fn(state):
    return mctx.RootFnOutput(
        prior_logits=jnp.full(num_actions, 0.0),
        value=jnp.array(0.0),
        embedding=state,
    )


def main():
    root = root_fn(jnp.full(4, 0.0))
    root = jax.tree_map(lambda x: x[None], root)

    key = jax.random.PRNGKey(0)

    output = mctx.stochastic_muzero_policy(
        params=jnp.full(20, 0.0),
        rng_key=key,
        root=root,
        decision_recurrent_fn=jax.vmap(decision_recurrent_fn, [None, None, 0, 0]),
        chance_recurrent_fn=jax.vmap(chance_recurrent_fn, [None, None, 0, 0]),
        num_simulations=1000,
        num_actions=num_actions,
        num_chance_outcomes=num_chance_outcomes,
    )
    assert (output.search_tree.children_rewards == 0).all()
    print(output.action_weights)  # [[0.007 0.451 0.063 0.479]]


if __name__ == "__main__":
    main()

The first issue is that the children_rewards are all 0, despite the fact that chance_recurrent_fn always yields a positive reward.

The second issue is that the final weight of the zeroth action (which receives an additional reward of 100) is not higher than the rest, despite a large number of simulations.

Any idea what might be causing these issues?

Computing both decision and chance branches of recurrent function in Stochastic MuZero is slow

Hello all! Thank you for creating mctx for the community. I added a performance enhancement to my copy of mctx and am wondering if you're interested in adding it to the official repo.

Computing both decision and chance branches in stochastic muzero every expand is useful if the embeddings have different shapes/dtypes, but is otherwise slow and adds a good bit of overhead (especially for high number of simulations or large networks). I modified the recurrent fn to only have to compute one branch each expand if the decision and chance embeddings are the same struct/shape/dtype.

Overall these are my changes:

In base.py another stochastic recurrent state was added that only holds one embedding:

@chex.dataclass(frozen=True)
class StochasticRecurrentStateEfficient:
    embedding: chex.ArrayTree  # [B, ...]
    is_decision_node: chex.Array  # [B]

The modified version of _make_stochastic_recurrent_fn unrolls the batch using scan so lax.cond can be used to only compute one branch:

def _make_stochastic_recurrent_fn_efficient(
    decision_node_fn: base.DecisionRecurrentFn,
    chance_node_fn: base.ChanceRecurrentFn,
    num_actions: int,
    num_chance_outcomes: int,
) -> base.RecurrentFn:
    """Make Stochastic Recurrent Fn."""

    def stochastic_recurrent_fn(
        params: base.Params,
        rng: chex.PRNGKey,
        action_or_chance: base.Action,  # [B]
        state: base.StochasticRecurrentStateEfficient,
    ) -> Tuple[base.RecurrentFnOutput, base.StochasticRecurrentStateEfficient]:
        def decision_node_branch(action_or_chance, state):
            # Internally we assume that there are `A' = A + C` "actions";
            # action_or_chance can take on values in `{0, 1, ..., A' - 1}`,.
            # To interpret it as an action we can leave it as is:
            action = action_or_chance - 0

            # temporary batch dimension
            action = jnp.expand_dims(action, axis=0)
            embedding = jnp.expand_dims(state.embedding, axis=0)

            decision_output, afterstate_embedding = decision_node_fn(
                params, rng, action, embedding
            )

            decision_output = jax.tree_map(
                lambda x: jnp.squeeze(x, axis=0), decision_output
            )
            afterstate_embedding = jax.tree_map(
                lambda x: jnp.squeeze(x, axis=0), afterstate_embedding
            )

            new_state = base.StochasticRecurrentStateEfficient(
                embedding=afterstate_embedding,
                is_decision_node=jnp.logical_not(state.is_decision_node),
            )
            # Outputs from DecisionRecurrentFunction produce chance logits with
            # dim `C`, to respect our internal convention that there are `A' = A + C`
            # "actions" we pad with `A` dummy logits which are ultimately ignored:
            # see `_mask_tree`.
            return (
                base.RecurrentFnOutput(
                    prior_logits=jnp.concatenate(
                        [
                            jnp.full([num_actions], fill_value=-jnp.inf),
                            decision_output.chance_logits,
                        ],
                        axis=-1,
                    ),
                    value=decision_output.afterstate_value,
                    reward=jnp.zeros_like(decision_output.afterstate_value),
                    discount=jnp.ones_like(decision_output.afterstate_value),
                ),
                new_state,
            )

        def chance_node_branch(action_or_chance, state):
            # To interpret it as a chance outcome we subtract num_actions:
            chance_outcome = action_or_chance - num_actions

            # temporary batch dimension
            chance_outcome = jnp.expand_dims(chance_outcome, axis=0)
            embedding = jnp.expand_dims(state.embedding, axis=0)

            chance_output, state_embedding = chance_node_fn(
                params, rng, chance_outcome, embedding
            )

            chance_output = jax.tree_map(
                lambda x: jnp.squeeze(x, axis=0), chance_output
            )
            state_embedding = jax.tree_map(
                lambda x: jnp.squeeze(x, axis=0), state_embedding
            )

            new_state = base.StochasticRecurrentStateEfficient(
                embedding=state_embedding,
                is_decision_node=jnp.logical_not(state.is_decision_node),
            )
            # Outputs from ChanceRecurrentFunction produce action logits with dim `A`,
            # to respect our internal convention that there are `A' = A + C` "actions"
            # we pad with `C` dummy logits which are ultimately ignored: see
            # `_mask_tree`.
            return (
                base.RecurrentFnOutput(
                    prior_logits=jnp.concatenate(
                        [
                            chance_output.action_logits,
                            jnp.full([num_chance_outcomes], fill_value=-jnp.inf),
                        ],
                        axis=-1,
                    ),
                    value=chance_output.value,
                    reward=chance_output.reward,
                    discount=chance_output.discount,
                ),
                new_state,
            )

        def scan_body(_, xs):
            action_or_chance, state = xs
            output, state = jax.lax.cond(
                state.is_decision_node,
                decision_node_branch,
                chance_node_branch,
                action_or_chance,
                state,
            )
            return None, (output, state)

        _, (output, new_state) = jax.lax.scan(
            scan_body, None, (action_or_chance, state)
        )

        return output, new_state

    return stochastic_recurrent_fn

Finally in the policy function there's a chex check to see if we can use the efficient version:

    try:
        chex.assert_trees_all_equal_structs(root.embedding, dummy_afterstate_embedding)
        chex.assert_trees_all_equal_shapes_and_dtypes(
            root.embedding, dummy_afterstate_embedding
        )
    except AssertionError:
        embeddings_same_shape_dtype = False
    else:
        embeddings_same_shape_dtype = True

    if embeddings_same_shape_dtype:
        # 74.05 sec for old version vs 48.37 sec for efficient version
        embedding = base.StochasticRecurrentStateEfficient(
            embedding=root.embedding,
            is_decision_node=jnp.ones([batch_size], dtype=bool),
        )
        make_stochastic_recurrent_fn = _make_stochastic_recurrent_fn_efficient
    else:
        embedding = base.StochasticRecurrentState(
            state_embedding=root.embedding,
            afterstate_embedding=dummy_afterstate_embedding,
            is_decision_node=jnp.ones([batch_size], dtype=bool),
        )
        make_stochastic_recurrent_fn = _make_stochastic_recurrent_fn

    root = root.replace(
        # pad action logits with num_chance_outcomes so dim is A + C
        prior_logits=jnp.concatenate(
            [
                root.prior_logits,
                jnp.full([batch_size, num_chance_outcomes], fill_value=-jnp.inf),
            ],
            axis=-1,
        ),
        # replace embedding with wrapper.
        embedding=embedding,
    )

    # Stochastic MuZero Change: We need to be able to tell if different nodes are
    # decision or chance. This is accomplished by imposing a special structure
    # on the embeddings stored in each node. Each embedding is an instance of
    # StochasticRecurrentState which maintains this information.
    recurrent_fn = make_stochastic_recurrent_fn(
        decision_node_fn=decision_recurrent_fn,
        chance_node_fn=chance_recurrent_fn,
        num_actions=num_actions,
        num_chance_outcomes=num_chance_outcomes,
    )

When the efficient version is used it saves a bit of memory from not having to hold both embeddings, and is faster with a wider performance gap the higher the simulation count or larger the networks. Maybe there is a better option than scan for the unroll, I just know vmap can't be used because cond is converted back to select.

Let me know if you're interested in adding this to mctx and I can make a testing colab to compare performance and would be happy to make a pull request :)

Question: what value target should be used in MuZero with Gumbel policy?

Hello and thank you for this nice library!

I have a question regarding using Gumbel policy in Muzero algorithm. What target for the value prediction should I use in this case? If I use the value backpropped to the root as in the original MuZero paper, this value will not correspond to the improved policy. The paper wasn't clear on that.

Thanks in advance!

sampling action from tree-search policy

Thanks for releasing this library!!

In the Gumbel MuZero paper (appendix F), I read:
image

Does 'benefits from the same exploration' mean 'benefits from exploration with same sampling proportionally to counts (also with temperature I assume?)' or 'also benefits from exploration, as implemented with Gumbel noise on root node' ?

In other words, with policy_output = gumbel_muzero_policy(...), policy_output.action already contains a selected action, 'sampled' here with the same gumbel that was used for building the tree. Is this selected action supposed to be used for building training trajectories or should I rather sample action proportionally to policy_output.search_tree.summary().visit_counts as Appendix F and Figure 8 seem to suggest ?

best practices for speed

Hey, thanks for this amazing library.

Question: are there any best practices for ensuring that MCTX is fast? I find it quite slow and was curious if there's common gotchas I might be missing.

Irregular action and chance outcome outputs within search with Stochastic MuZero

Hello! Thank you for recently adding an implementation of stochastic muzero. I was testing it, but it seems chance outcome and action outputs within the search are irregular. I made this test and print the action and chance outcome from within the dynamics functions:

import numpy as np

import jax
import jax.numpy as jnp
import mctx
from mctx import DecisionRecurrentFnOutput, ChanceRecurrentFnOutput, RecurrentFnOutput


num_actions = 3
num_chance_outcomes = 5


def afterstate_pred(afterstate_embedding):
    chance_logits = jnp.zeros([1, num_chance_outcomes]).at[0, 0].set(1.0)
    afterstate_value = jnp.zeros([1])
    return chance_logits, afterstate_value


def pred(embedding):
    policy_logits = jnp.zeros([1, num_actions]).at[0, 0].set(1.0)
    value = jnp.zeros([1])
    return policy_logits, value


def afterstate_dynamics(action, embedding):
    print(f"action: {action}")
    return embedding


def dynamics(chance_outcome, afterstate_embedding):
    print(f"chance_outcome: {chance_outcome}")
    return afterstate_embedding, jnp.zeros([1])


def decision_recurrent_fn(params, rng_key, action, embedding):
    afterstate_embedding = afterstate_dynamics(action, embedding)
    chance_logits, afterstate_value = afterstate_pred(afterstate_embedding)
    decision_recurrent_fn_output = DecisionRecurrentFnOutput(
        chance_logits=chance_logits,
        afterstate_value=afterstate_value,
    )
    return decision_recurrent_fn_output, afterstate_embedding


def chance_recurrent_fn(params, rng_key, chance_outcome, embedding):
    embedding, reward = dynamics(chance_outcome, embedding)
    policy_logits, value = pred(embedding)
    recurrent_fn_output = ChanceRecurrentFnOutput(
        action_logits=policy_logits,
        value=value,
        reward=reward,
        discount=jnp.full_like(reward, 0.99),
    )
    return recurrent_fn_output, embedding


def recurrent_fn(params, rng_key, action, embedding):
    embedding, reward = dynamics(action, embedding)
    policy_logits, value = pred(embedding)
    recurrent_fn_output = RecurrentFnOutput(
        reward=reward,
        discount=jnp.full_like(reward, 0.99),
        prior_logits=policy_logits,
        value=value,
    )
    return recurrent_fn_output, embedding


def stochastic_muzero_policy():
    """Tests that SMZ is equivalent to MZ with a dummy chance function."""
    root = mctx.RootFnOutput(
        prior_logits=jnp.array(
            [
                [-1.0, 0.0, 2.0],
            ]
        ),
        value=jnp.array([0.0]),
        embedding=jnp.zeros([1, 4]),
    )

    num_simulations = 10

    """policy_output = mctx.muzero_policy(
        params=(),
        rng_key=jax.random.PRNGKey(0),
        root=root,
        recurrent_fn=recurrent_fn,
        num_simulations=num_simulations,
        dirichlet_fraction=0.0,
    )"""

    stochastic_policy_output = mctx.stochastic_muzero_policy(
        params=(),
        rng_key=jax.random.PRNGKey(0),
        root=root,
        decision_recurrent_fn=decision_recurrent_fn,
        chance_recurrent_fn=chance_recurrent_fn,
        num_simulations=2 * num_simulations,
        num_actions=num_actions,
        num_chance_outcomes=num_chance_outcomes,
        dirichlet_fraction=0.0,
    )

    """np.testing.assert_array_equal(stochastic_policy_output.action, policy_output.action)

    np.testing.assert_allclose(
        stochastic_policy_output.action_weights, policy_output.action_weights
    )"""


if __name__ == "__main__":
    with jax.disable_jit():
        stochastic_muzero_policy()

but get outputs like this:

action: [0]
action: [2]
chance_outcome: [-1]
action: [3]
chance_outcome: [0]
action: [0]
chance_outcome: [-3]
action: [4]
chance_outcome: [1]
action: [5]
chance_outcome: [2]
action: [6]
chance_outcome: [3]
action: [7]
chance_outcome: [4]
action: [1]
chance_outcome: [-2]
action: [3]
chance_outcome: [0]
action: [1]
chance_outcome: [-2]
action: [2]
chance_outcome: [-1]
action: [0]
chance_outcome: [-3]
action: [0]
chance_outcome: [-3]
action: [0]
chance_outcome: [-3]
action: [0]
chance_outcome: [-3]
action: [3]
chance_outcome: [0]
action: [0]
chance_outcome: [-3]
action: [4]
chance_outcome: [1]
action: [5]
chance_outcome: [2]
action: [3]
chance_outcome: [0]

I haven't delved into the mctx policy yet to see where the problem might arise, but wanted to start with opening an issue. Do you know what might be causing this behavior?

No pip install?

There currently doesn't seem to be any pip available.

pip install mctx

ERROR: Could not find a version that satisfies the requirement mctx (from versions: none)
ERROR: No matching distribution found for mctx

Using Ubuntu 18, Python 3.6

how to use image shape (96, 96, 1) with muax like using atari PongNoFrameskip-v4

i try MCTX Muax its a good art work , but i have issue when i set conv observation like a shape (96, 96, 1) , please guide me to solve this problem
and the correct way to set train_env and eval_env as input of muax.fit() function

i use this code :

support_size = 20
embedding_size = 10
full_support_size = int(support_size * 2 + 1)
num_actions = 2

repr_fn = init_representation_func(Representation, embedding_size)
pred_fn = init_prediction_func(Prediction, num_actions, full_support_size)
dy_fn = init_dynamic_func(Dynamic, embedding_size, num_actions, full_support_size)

tracer = muax.PNStep(50, 0.999, 0.5)
buffer = muax.TrajectoryReplayBuffer(500)

gradient_transform = muax.model.optimizer(init_value=0.002, peak_value=0.002, end_value=0.0005, warmup_steps=20000, transition_steps=20000)

model = muax.MuZero(repr_fn, pred_fn, dy_fn, policy='muzero', discount=0.999,
                    optimizer=gradient_transform, support_size=support_size)

model_path = muax.fit(model, 'CartPole-v1', 
                max_episodes=1000,
                max_training_steps=50000,
                tracer=tracer,
                buffer=buffer,
                k_steps=10,
                sample_per_trajectory=1,
                buffer_warm_up=128,
                num_trajectory=128,
                tensorboard_dir='/content/data/tensorboard/',
                save_name='model_params',
                random_seed=i,
                log_all_metrics=True)

What is "millions of frames" in the plots in the paper?

Hello, thank you again for the paper and the associated code :)

I noticed that in the paper, most plots use "millions of frames" on the x-axis:

image

What does this metric mean? Is it number of NN queries, number of games, training examples, or something else entirely?

qtransform_by_min_max example

Is there an example of using the qtransform_by_min_max? I'm unsure how to use it without entirely re-implementing the mctx._src.policies.muzero_policy.

end to end gym environment example

Thanks for open sourcing!

This is great and really cool to see the "curtain pulled back" a bit.

Any chance we can have an example using a real environment? Perhaps from openai-gym? Maybe a community member can contribute this

or perhaps if this is too complex this can just be omitted.

Questions on library features

Hi, thank you for your work.

I am considering using it but I had a few questions before diving deeper:

  • Does it support "multi-objective" that is having an array of rewards instead of a scalar (same for the discount factor) ?
  • Does it support asynchronous (or delayed batch inference) mcts (using for instance virtual losses) ?
  • How well does it behave when running on a non-vectorized cpu environment ?
  • How well do you think it could be integrated into a ray parallel framework ? For instance if I have a remote inference server that gets called by several processes (each taking care of their own mcts tree) and batches the requests before redispatching them to the processes
  • Does it perform well in single instance inference ?

Many Thanks

Incompatible types error with increased global precision

Hello,
Thank you for this amazing library. It's been a while I am using this package for our project. I have started facing the following issue:
'''
/home/dubey/anaconda3/envs/jaxenv_qiskit/lib/python3.11/site-packages/jax/_src/ops/scatter.py:93: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=int64 to dtype=int32. In future JAX releases this will result in an error.
warnings.warn("scatter inputs have incompatible types: cannot safely cast "
'''
I want increased precision for my experiments and so I set the following global variable as:
'''
jax.config.update("jax_enable_x64", True)
'''
Unfortunately, I have tracked down the bug to
gumbel_mu_zero_policy function
and within these functions the

  1. action_selection package - Line 140 explicitly sets the dtype to "jnp.int32" which I believe creates the above issue.
  2. search package - Line 172, 288, 364, 367 and 375 also explicitly set the dtype to "int32".

I am not sure how to resolve this issue but kindly let me know if this does not make sense.

Update: I tried keeping "int32" everywhere but as expected that does not make sense since when you specify the global config to x64 it does change precisions of everything.

Question about proof that Gumbel guarantees a policy improvement

Hey all! Thanks for writing the paper and this library :)

I feel like I'm missing something in the Appendix B proof for "Policy Improvement Proof for Planning with Gumbel". Specifically, the paper mentions that we can replace $argmax_a(g(a) + logits(a))$ with $argmax_{a \in A_{topn}}(g(a) + logits(a))$. While the action with the highest $g(a) + logits(a)$ is guaranteed to be in $A_{topn}$, how can we guarantee that $\mathbb{E} [q(argmax_{a \in A_{topn}}(g(a) + logits(a))] \ge \mathbb{E} [q(argmax_a(g(a) + logits(a)))]$? We don't take into account the $q$ function while sampling, so how can we be sure the inequality holds when taking $q$ into account?

There's a short proof that the "select n most probable actions" heuristic does not guarantee a policy improvement, formulated as $q = (0, 0, 1)$ and $\pi = (.5, .3, .2)$. AIUI, in order for Gumbel to guarantee a policy improvement on this example, the selection process would have to pick $\pi[2]$ to include in its final set. How is this guaranteed?

few questions about setting hyperparameters

In go game, the default value for "max_num_considered_actions" is 16. I want to know if this value should increase with the increase of num_simulations?Also maxvisit_init ? What value is better for max_num_considered_actions when num_simulations is 400 and 800 and 1600? There are a few other hyperparameters that I'm not sure how to set optimally. I hope you can help me check how the settings are. Do I need to adjust qtransform?
for go 19*19, my settings is:
policy_output = mctx.gumbel_muzero_policy(
...
num_simulations=400,
max_num_considered_actions = 32,
qtransform=partial(
mctx.qtransform_completed_by_mix_value,
value_scale=0.1,
maxvisit_init=50,
rescale_values=False,
),
gumbel_scale=1.0,
)

Pass State into recurrent_fn

Hello, thank you for open sourcing this great and concise implementation. I think it'd be useful to be able to pass network state into the recurrent function. WIth something like batch norm it could be done as part of params, but it'd be nice to have it separate. For my case I'm working with stateful rnn in the recurrent fn, so it would be nice to have network state separate and passed back into recurrent_fn like embedding is. I can copy the repo just as easily, but am wondering if you all wouldn't mind adding the feature to the repo as I think it could be useful for research and then we can still simply install the repo and use as is.

Changing which actions are invalid based on state

Hello, thank you for this great library!

I have a setting where the actions available (actually options) change dynamically. Right now, I am thinking that I can learn when options are available.

I see that the muzero policies (e.g. here) allow for setting invalid_actions but this seems to be permanent throughout the search. If I want to change these functions so that which actions are invalid are a function of the current state, do you have a good sense of where to do this?

Right now I'm thinking that this belows in the expand step of MCTS (i.e. here). Do you think I'm on the right track?

If not, can you help me find a better solution?

Thank you!

Request for Documentation and Assistance with mctx Library for AI Integration in 2048 Game

I only have started learning Python this year and now I have used the tkinter library to create a window version of the 2048 game (here https://github.com/TWJizhanHuang/Novice_Lab_of_2048_in_Python). However, I haven't implemented any animation effects yet.

Currently, I am trying to integrate AI into my 2048 game to enable the "auto-run" option, similar to the web version example available at https://jupiter.xtrp.io/.

I searched and noticed the mctx library that you have released (including https://www.deepmind.com/open-source/monte-carlo-tree-search-in-jax) and I would like to use and apply it. However, I have been struggling to find documentation that explains how to use your mctx library.

Could you please provide me with some documentation or guidance on how to use the mctx library? Alternatively, if you have any advice or suggestions, I would greatly appreciate your input. Thank you in advance for any assistance you can provide.

New package version

Hi,

I'm installing mctx using pip install mctx and the version available is 0.0.2 from Jul 26, 2022 by the commit 8dac10f
Could you bump the version to 0.0.3 to apply last commits?

Question about Go experiments

Hi! Mctx and the original paper "Policy improvement by planning with Gumbel" are both amazing 👍
I would like to try reproducing them myself. I have a question regarding the original paper.

In the experiments using Go, the evaluation was performed on Pachi. While Pachi supports the rules of japanese|chinese|aga|new_zealand|simplified_ing, it does not seem to support the most common rule in computer Go, the Tromp-Taylor rule.

So, I have two questions:

(1) What rule was used for scoring in Go during training?
(2) When evaluating with Pachi, which rule was specified for playing?

[Question] invalid action mask in interior action selection?

Hello! First thank you guys for this repo. I learned a lot about JAX and MuZero from reading the codes. I plan to use this library on our project soon to hopefully solve some math problems.

Just a question about the invalid action mask. I noticed that you have the invalid action at root nodes but not the interior nodes. Is not introducing interior action mask part of the features of being rule agnostic (there is a sentence in MuZero paper saying MuZero "learns" the rules)? Or is one still recommended to go ahead and rewrite a few functions to add masks for the interior action selection?

I should be able to rewrite the gumbel_muzero_interior_action_selection to achieve this, but since you have such an API only for roots, I'm asking to make sure that this is still the best practice and doesn't go against the purpose of the algorithm.

If it's the latter case, do you plan to expand features in that direction? Would love to help if there is the opportunity.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.