Monte Carlo Tree Search¶
keras_gym.planning.MCTSNode |
Implementation of Monte Carlo tree search used in AlphaZero. |
-
class
keras_gym.planning.
MCTSNode
(actor_critic, state_id=None, tau=1.0, v_resign=0.999, c_puct=1.0, random_seed=None)[source]¶ Implementation of Monte Carlo tree search used in AlphaZero.
Parameters: - state_id : str
The state id of the env, which allows us to set the env to the correct state.
- actor_critic : ActorCritic object
The actor-critic that is used to evaluate leaf nodes.
- tau : float, optional
The temperature parameter used in the ‘behavior’ policy:
\[\pi(a|s)\ =\ \frac{N(s,a)^{1/\tau}}{\sum_{a'}N(s,a')^{1/\tau}}\]- v_resign : float, optional
The value we use to determine whether a player should resign before a game ends. Namely, the player will resign if the predicted value drops below \(v(s) < v_\text{resign}\).
- c_puct : float, optional
A hyperparameter that determines how to balance exploration and exploitation. It appears in the selection criterion during the search phase:
\[a\ =\ \arg\max_{a'}\left( Q(s,a) + U(s,a) \right)\]where
\[\begin{split}Q(s,a)\ &=\ \frac1{N(s)} \sum_{s'\in\text{desc}(s,a)} v(s') \\ U(s,a)\ &=\ \color{red}{c_\text{puct}}\,P(s, a)\, \frac{\sqrt{N(s)}}{1+N(s,a)}\end{split}\]Here \(\text{desc}(s,a)\) denotes the set of all the previously evaluated descendant states of the state \(s\) that can be reached by taking action \(a\). The value and prior probabilities \(v(s)\) and \(P(s,a)\) are generated by the actor-critic. Also, we use the short-hand notation for the combined state-action visit counts:
\[N(s)\ =\ \sum_{a'} N(s,a')\]Note that this is not exactly the state visit count, which would be \(N(s) + 1\) due to the initial selection and expansion of the root node itself.
- random_seed : int, optional
Sets the random state to get reproducible results.
Attributes: - env : gym-style environment
The main environment of the game.
- state : state observation
The current state of the environment.
- num_actions : int
The number of actions of the environment, i.e. regardless of whether these actions are actually available in the current state.
- is_root : bool
Whether the current node is a root node, i.e. whether it has a parent node.
- is_leaf : bool
Whether the current node is a leaf node. A leaf node is typically a node that was previous unexplored, but it may also be a terminal state node.
- is_terminal : bool
Whether the current state is a terminal state.
- parent_node : MCTSNode object
The parent node. This is used to traverse back up the tree.
- parent_action : int
Which action led to the current state from the parent state. This is used to inform the parent which child is responsible for the update in the backup phase of the search procedure.
- children : dict
A dictionary that contains all the child states accessible from the current state, format:
{action <int>: child <MCTSNode>}
.- N : 1d array, dtype: int, shape: [num_actions]
The state-action visit count \(N(s,a)\).
- P : 1d array, dtype: int, shape: [num_actions]
The prior probabilities over the space of actions \(P(s,a)\), which are generated by the actor-critic function approximator.
- U : 1d array, dtype: float, shape: [num_actions]
The UCT exploration term, which is a vector over the space of actions:
\[U(s,a)\ =\ c_\text{puct}\,P(s,a)\, \frac{\sqrt{N(s)}}{1+N(s,a)}\]- Q : 1d array, dtype: float, shape: [num_actions]
The UCT exploitation term, which is a vector over the space of actions:
\[Q(s,a)\ =\ \frac{W(s,a)}{N(s, a)}\]- W : 1d array, dtype: float, shape: [num_actions]
This is the accumulator for the numerator of the UCT exploitation term \(Q(s,a)\). It is a sum of all of the values generated by starting from \(s\), taking action \(a\):
\[W(s,a)\ =\ v(s) + \sum_{s'\in\text{desc}(s,a)} v(s')\]Here \(\text{desc}(s,a)\) denotes the set of all the previously evaluated descendant states of the state \(s\) that can be reached by taking action \(a\). The prior values \(v(s)\) and \(v(s')\) is generated by the actor-critic function approximator.
- D : 1d array, dtype: bool, shape: [num_actions]
This contains the
done
flags for each child state, i.e. whether each child state is a terminal state.
-
backup
(v)[source]¶ Back-up the newly found leaf node value up the tree.
Parameters: - v : float
The value of the newly expanded leaf node.
-
expand
()[source]¶ Expand tree, i.e. promote leaf node to a non-leaf node.
Returns: - v : float
The value of the leaf node as predicted by the actor-critic.
-
play
(tau=None)[source]¶ Play one move/action.
Parameters: - tau : float, optional
The temperature parameter used in the ‘behavior’ policy:
\[\pi(a|s)\ =\ \frac{N(s,a)^{1/\tau}}{\sum_{a'}N(s,a')^{1/\tau}}\]If left unspecified,
tau
defaults to the instance setting.
Returns: - s, a, pi, r, done : tuple
The return values represent the following quanities:
- s : state observation
The state \(s\) from which the action was taken.
- a : action
The specific action \(a\) taken from that state \(s\).
- pi : 1d array, dtype: float, shape: [num_actions]
The action probabilities \(\pi(.|s)\) that were used.
- r : float
The reward received in the transition \((s,a)\to s_\text{next}\)
- done : bool
A flag that indicates that either the game has finished or the actor-critic predicted a value that is below the cutoff value \(v(s) < v_\text{resign}\).
-
search
(n=512)[source]¶ Perform \(n\) searches.
Each search consists of three consecutive steps:
select()
,expand()
andbackup()
.Parameters: - n : int, optional
The number of searches to perform.