Source code for keras_gym.planning.monte_carlo_tree_search

from copy import deepcopy

import numpy as np

from ..base.mixins import ActionSpaceMixin, RandomStateMixin
from ..base.errors import LeafNodeError, NotLeafNodeError, EpisodeDoneError
from ..utils import argmax, one_hot

__all__ = (
    'MCTSNode',
)


[docs]class MCTSNode(ActionSpaceMixin, RandomStateMixin): """ Implementation of Monte Carlo tree search used in AlphaZero. Parameters ---------- state_id : str The state id of the env, which allows us to set the env to the correct state. actor_critic : ActorCritic object The actor-critic that is used to evaluate leaf nodes. tau : float, optional The temperature parameter used in the 'behavior' policy: .. math:: \\pi(a|s)\\ =\\ \\frac{N(s,a)^{1/\\tau}}{\\sum_{a'}N(s,a')^{1/\\tau}} v_resign : float, optional The value we use to determine whether a player should resign before a game ends. Namely, the player will resign if the predicted value drops below :math:`v(s) < v_\\text{resign}`. c_puct : float, optional A hyperparameter that determines how to balance exploration and exploitation. It appears in the selection criterion during the *search* phase: .. math:: a\\ =\\ \\arg\\max_{a'}\\left( Q(s,a) + U(s,a) \\right) where .. math:: Q(s,a)\\ &=\\ \\frac1{N(s)} \\sum_{s'\\in\\text{desc}(s,a)} v(s') \\\\ U(s,a)\\ &=\\ \\color{red}{c_\\text{puct}}\\,P(s, a)\\, \\frac{\\sqrt{N(s)}}{1+N(s,a)} Here :math:`\\text{desc}(s,a)` denotes the set of all the previously evaluated descendant states of the state :math:`s` that can be reached by taking action :math:`a`. The value and prior probabilities :math:`v(s)` and :math:`P(s,a)` are generated by the actor-critic. Also, we use the short-hand notation for the combined state-action visit counts: .. math:: N(s)\\ =\\ \\sum_{a'} N(s,a') Note that this is not exactly the state visit count, which would be :math:`N(s) + 1` due to the initial selection and expansion of the root node itself. random_seed : int, optional Sets the random state to get reproducible results. Attributes ---------- env : gym-style environment The main environment of the game. state : state observation The current state of the environment. num_actions : int The number of actions of the environment, i.e. regardless of whether these actions are actually available in the current state. is_root : bool Whether the current node is a root node, i.e. whether it has a parent node. is_leaf : bool Whether the current node is a leaf node. A leaf node is typically a node that was previous unexplored, but it may also be a terminal state node. is_terminal : bool Whether the current state is a terminal state. parent_node : MCTSNode object The parent node. This is used to traverse back up the tree. parent_action : int Which action led to the current state from the parent state. This is used to inform the parent which child is responsible for the update in the *backup* phase of the search procedure. children : dict A dictionary that contains all the child states accessible from the current state, format: ``{action <int>: child <MCTSNode>}``. N : 1d array, dtype: int, shape: [num_actions] The state-action visit count :math:`N(s,a)`. P : 1d array, dtype: int, shape: [num_actions] The prior probabilities over the space of actions :math:`P(s,a)`, which are generated by the actor-critic function approximator. U : 1d array, dtype: float, shape: [num_actions] The UCT exploration term, which is a vector over the space of actions: .. math:: U(s,a)\\ =\\ c_\\text{puct}\\,P(s,a)\\, \\frac{\\sqrt{N(s)}}{1+N(s,a)} Q : 1d array, dtype: float, shape: [num_actions] The UCT exploitation term, which is a vector over the space of actions: .. math:: Q(s,a)\\ =\\ \\frac{W(s,a)}{N(s, a)} W : 1d array, dtype: float, shape: [num_actions] This is the accumulator for the numerator of the UCT exploitation term :math:`Q(s,a)`. It is a sum of all of the values generated by starting from :math:`s`, taking action :math:`a`: .. math:: W(s,a)\\ =\\ v(s) + \\sum_{s'\\in\\text{desc}(s,a)} v(s') Here :math:`\\text{desc}(s,a)` denotes the set of all the previously evaluated descendant states of the state :math:`s` that can be reached by taking action :math:`a`. The prior values :math:`v(s)` and :math:`v(s')` is generated by the actor-critic function approximator. D : 1d array, dtype: bool, shape: [num_actions] This contains the ``done`` flags for each child state, i.e. whether each child state is a terminal state. """ def __init__( self, actor_critic, state_id=None, tau=1.0, v_resign=0.999, c_puct=1.0, random_seed=None): self.actor_critic = actor_critic self.tau = tau self.v_resign = v_resign self.c_puct = c_puct self.random_seed = random_seed # also sets self.random # set/reset env self.env = deepcopy(self.actor_critic.env) if state_id is None: self.env.reset() else: self.env.set_state(state_id) self.state_id = self.env.state_id self.state = self.env.state self.is_terminal = self.env.done # these are set/updated dynamically self.parent_node = None self.parent_action = None self.children = {} self.is_leaf = True self.v_abs_max = 0 self.v = None self.P = None def __repr__(self): s = "MCTSNode('{}', v={:s} done={}".format( self.state_id, self._str(self.v, length=5, suffix=','), self._str(self.is_terminal, suffix=')', length=5)) return s def reset(self): self.__init__( actor_critic=self.actor_critic, state_id=None, tau=self.tau, v_resign=self.v_resign, c_puct=self.c_puct, random_seed=self.random_seed)
[docs] def search(self, n=512): """ Perform :math:`n` searches. Each search consists of three consecutive steps: :func:`select`, :func:`expand` and :func:`backup`. Parameters ---------- n : int, optional The number of searches to perform. """ for _ in range(n): leaf_node = self.select() v = leaf_node.v if leaf_node.is_terminal else leaf_node.expand() leaf_node.backup(v)
[docs] def play(self, tau=None): """ Play one move/action. Parameters ---------- tau : float, optional The temperature parameter used in the 'behavior' policy: .. math:: \\pi(a|s)\\ =\\ \\frac{N(s,a)^{1/\\tau}}{\\sum_{a'}N(s,a')^{1/\\tau}} If left unspecified, ``tau`` defaults to the instance setting. Returns ------- s, a, pi, r, done : tuple The return values represent the following quanities: s : state observation The state :math:`s` from which the action was taken. a : action The specific action :math:`a` taken from that state :math:`s`. pi : 1d array, dtype: float, shape: [num_actions] The action probabilities :math:`\\pi(.|s)` that were used. r : float The reward received in the transition :math:`(s,a)\\to s_\\text{next}` done : bool A flag that indicates that either the game has finished or the actor-critic predicted a value that is below the cutoff value :math:`v(s) < v_\\text{resign}`. """ if self.is_leaf: raise LeafNodeError( "cannot play from a leaf node; must search first") if tau is None: tau = self.tau # construct pi(a|s) ~ N(s,a)^1/tau if tau < 0.1: # no need to compute pi if tau is very small a = argmax(self.N, random_state=self.random) pi = one_hot(a, self.num_actions) else: pi = np.power(self.N, 1 / tau) pi /= np.sum(pi) a = self.random.choice(self.num_actions, p=pi) # this will become the new root node child = self.children[a] # update env s = self.state s_next, r, done, info = self.env.step(a) assert child.state_id == info['state_id'] # switch to new root node child.parent_node = None child.parent_action = None self.__dict__.update(child.__dict__) return s, pi, r, done # or self.v_abs_max < self.v_resign
[docs] def select(self): """ Traverse down the tree to find a leaf node to expand. Returns ------- leaf_node : MCTSNode object The selected leaf node. """ if self.is_leaf: return self # pick action according to PUCT algorithm a = max(self.children.keys(), key=(lambda a: self.Q[a] + self.U[a])) child = self.children[a] # recursively traverse down the tree return child.select()
[docs] def expand(self): """ Expand tree, i.e. promote leaf node to a non-leaf node. Returns ------- v : float The value of the leaf node as predicted by the actor-critic. """ if not self.is_leaf: raise NotLeafNodeError( "node is not a leaf node; cannot expand node more than once") if self.is_terminal: raise EpisodeDoneError("cannot expand further; episode is done") self.P, v = self.actor_critic.dist_params(self.state) if self.v is None: self.v = float(v) # make TrainMonitor quiet if hasattr(self.env, 'quiet'): quiet_orig, self.env.quiet = self.env.quiet, True for a in self.env.available_actions: s_next, r, done, info = self.env.step(a) child = MCTSNode( self.actor_critic, state_id=info['state_id'], tau=self.tau, v_resign=self.v_resign, c_puct=self.c_puct, random_seed=self.random_seed) child.random = self.random child.parent_node = self child.parent_action = a if done: self.D[a] = True child.v = -r # note: flip sign for 'opponent' self.children[a] = child self.env.set_state(self.state_id) # reset state to root # reinstate original 'quiet' flag in TrainMonitor if hasattr(self.env, 'quiet'): self.env.quiet = quiet_orig # after expansion, this is no longer a leaf node self.is_leaf = False return self.v
[docs] def backup(self, v): """ Back-up the newly found leaf node value up the tree. Parameters ---------- v : float The value of the newly expanded leaf node. """ if self.is_leaf and not self.is_terminal: raise LeafNodeError( "node is a leaf node; cannot backup before expanding") self.v_abs_max = max(self.v_abs_max, np.abs(v)) # recursively traverse up the tree if not self.is_root: # notice that we flip sign for 'opponent' self.parent_node.N[self.parent_action] += 1 self.parent_node.W[self.parent_action] += -v self.parent_node.backup(-v)
@property def U(self): if self.is_leaf: U = None else: # PUCT: U(s,a) = P(s,a) sqrt(sum_b N(s,b)) / (1 + N(s,a)) U = self.c_puct * self.P * np.sqrt(np.sum(self.N)) / (1 + self.N) U[self.D | (~self.env.available_actions_mask)] = 0 return U @property def Q(self): if self.is_leaf: Q = None else: Q = self.W / (self.N + 1e-16) Q[self.D] = self.env.win_reward Q[~self.env.available_actions_mask] = self.env.loss_reward return Q @property def N(self): if not hasattr(self, '_N'): self._N = np.zeros(self.num_actions, dtype='int') return self._N @property def W(self): if not hasattr(self, '_W'): self._W = np.zeros(self.num_actions, dtype='float') self._W[~self.env.available_actions_mask] = -np.inf return self._W @property def D(self): if not hasattr(self, '_D'): self._D = np.zeros(self.num_actions, dtype='bool') return self._D @property def is_root(self): return self.parent_node is None
[docs] def show(self, max_depth=None): """ Visualize the search tree. Prints to stdout. Parameters ---------- max_depth : positive int, optional The maximal depth to visualize. If left unspecified, the full search tree is shown. """ if max_depth is None: max_depth = np.inf self._show(depth=max_depth, prefix='', suffix='')
def _show(self, depth, prefix, suffix): if depth == 0: return print(prefix + str(self) + suffix) if self.children and depth > 1: print() for a, child in self.children.items(): child._show( depth=(depth - 1), prefix=(prefix + " "), suffix=( " a={:d} Q={:s} U={:s} N={:s}" .format( a, self._str(self.Q[a]), self._str(self.U[a]), self._str(self.N[a])))) if a == 6 and depth > 1: print() @staticmethod def _str(x, suffix='', length=5): if isinstance(x, (float, np.float32, np.float64)): x = '{:g}'.format(x) s = str(x)[:length].strip() + suffix s += ' ' * max(0, length + len(suffix) - len(s)) return s