keras-gym
stable
  • Example Notebooks
  • Function Approximators
  • Policies
  • Probability Distributions
  • Caching
  • Planning
    • Objects
      • Monte Carlo Tree Search
  • Wrappers
  • Environments
  • Loss Functions
  • Utilities
  • Glossary
  • Release Notes
keras-gym
  • Docs »
  • Planning »
  • Monte Carlo Tree Search
  • Edit on GitHub

Monte Carlo Tree Search¶

keras_gym.planning.MCTSNode Implementation of Monte Carlo tree search used in AlphaZero.
class keras_gym.planning.MCTSNode(actor_critic, state_id=None, tau=1.0, v_resign=0.999, c_puct=1.0, random_seed=None)[source]¶

Implementation of Monte Carlo tree search used in AlphaZero.

Parameters:
state_id : str

The state id of the env, which allows us to set the env to the correct state.

actor_critic : ActorCritic object

The actor-critic that is used to evaluate leaf nodes.

tau : float, optional

The temperature parameter used in the ‘behavior’ policy:

\[\pi(a|s)\ =\ \frac{N(s,a)^{1/\tau}}{\sum_{a'}N(s,a')^{1/\tau}}\]
v_resign : float, optional

The value we use to determine whether a player should resign before a game ends. Namely, the player will resign if the predicted value drops below \(v(s) < v_\text{resign}\).

c_puct : float, optional

A hyperparameter that determines how to balance exploration and exploitation. It appears in the selection criterion during the search phase:

\[a\ =\ \arg\max_{a'}\left( Q(s,a) + U(s,a) \right)\]

where

\[\begin{split}Q(s,a)\ &=\ \frac1{N(s)} \sum_{s'\in\text{desc}(s,a)} v(s') \\ U(s,a)\ &=\ \color{red}{c_\text{puct}}\,P(s, a)\, \frac{\sqrt{N(s)}}{1+N(s,a)}\end{split}\]

Here \(\text{desc}(s,a)\) denotes the set of all the previously evaluated descendant states of the state \(s\) that can be reached by taking action \(a\). The value and prior probabilities \(v(s)\) and \(P(s,a)\) are generated by the actor-critic. Also, we use the short-hand notation for the combined state-action visit counts:

\[N(s)\ =\ \sum_{a'} N(s,a')\]

Note that this is not exactly the state visit count, which would be \(N(s) + 1\) due to the initial selection and expansion of the root node itself.

random_seed : int, optional

Sets the random state to get reproducible results.

Attributes:
env : gym-style environment

The main environment of the game.

state : state observation

The current state of the environment.

num_actions : int

The number of actions of the environment, i.e. regardless of whether these actions are actually available in the current state.

is_root : bool

Whether the current node is a root node, i.e. whether it has a parent node.

is_leaf : bool

Whether the current node is a leaf node. A leaf node is typically a node that was previous unexplored, but it may also be a terminal state node.

is_terminal : bool

Whether the current state is a terminal state.

parent_node : MCTSNode object

The parent node. This is used to traverse back up the tree.

parent_action : int

Which action led to the current state from the parent state. This is used to inform the parent which child is responsible for the update in the backup phase of the search procedure.

children : dict

A dictionary that contains all the child states accessible from the current state, format: {action <int>: child <MCTSNode>}.

N : 1d array, dtype: int, shape: [num_actions]

The state-action visit count \(N(s,a)\).

P : 1d array, dtype: int, shape: [num_actions]

The prior probabilities over the space of actions \(P(s,a)\), which are generated by the actor-critic function approximator.

U : 1d array, dtype: float, shape: [num_actions]

The UCT exploration term, which is a vector over the space of actions:

\[U(s,a)\ =\ c_\text{puct}\,P(s,a)\, \frac{\sqrt{N(s)}}{1+N(s,a)}\]
Q : 1d array, dtype: float, shape: [num_actions]

The UCT exploitation term, which is a vector over the space of actions:

\[Q(s,a)\ =\ \frac{W(s,a)}{N(s, a)}\]
W : 1d array, dtype: float, shape: [num_actions]

This is the accumulator for the numerator of the UCT exploitation term \(Q(s,a)\). It is a sum of all of the values generated by starting from \(s\), taking action \(a\):

\[W(s,a)\ =\ v(s) + \sum_{s'\in\text{desc}(s,a)} v(s')\]

Here \(\text{desc}(s,a)\) denotes the set of all the previously evaluated descendant states of the state \(s\) that can be reached by taking action \(a\). The prior values \(v(s)\) and \(v(s')\) is generated by the actor-critic function approximator.

D : 1d array, dtype: bool, shape: [num_actions]

This contains the done flags for each child state, i.e. whether each child state is a terminal state.

backup(self, v)[source]¶

Back-up the newly found leaf node value up the tree.

Parameters:
v : float

The value of the newly expanded leaf node.

expand(self)[source]¶

Expand tree, i.e. promote leaf node to a non-leaf node.

Returns:
v : float

The value of the leaf node as predicted by the actor-critic.

play(self, tau=None)[source]¶

Play one move/action.

Parameters:
tau : float, optional

The temperature parameter used in the ‘behavior’ policy:

\[\pi(a|s)\ =\ \frac{N(s,a)^{1/\tau}}{\sum_{a'}N(s,a')^{1/\tau}}\]

If left unspecified, tau defaults to the instance setting.

Returns:
s, a, pi, r, done : tuple

The return values represent the following quanities:

s : state observation

The state \(s\) from which the action was taken.

a : action

The specific action \(a\) taken from that state \(s\).

pi : 1d array, dtype: float, shape: [num_actions]

The action probabilities \(\pi(.|s)\) that were used.

r : float

The reward received in the transition \((s,a)\to s_\text{next}\)

done : bool

A flag that indicates that either the game has finished or the actor-critic predicted a value that is below the cutoff value \(v(s) < v_\text{resign}\).

search(self, n=512)[source]¶

Perform \(n\) searches.

Each search consists of three consecutive steps: select(), expand() and backup().

Parameters:
n : int, optional

The number of searches to perform.

select(self)[source]¶

Traverse down the tree to find a leaf node to expand.

Returns:
leaf_node : MCTSNode object

The selected leaf node.

show(self, max_depth=None)[source]¶

Visualize the search tree. Prints to stdout.

Parameters:
max_depth : positive int, optional

The maximal depth to visualize. If left unspecified, the full search tree is shown.

Next Previous

© Copyright 2018, Kristian Holsheimer Revision 12d83ec9.

Built with Sphinx using a theme provided by Read the Docs.