Source code for keras_gym.core.actor_critic

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import backend as K

from ..utils import (
    check_numpy_array, check_tensor, is_vfunction, is_qfunction, is_policy)
from ..base.mixins import ActionSpaceMixin
from ..base.errors import ActionSpaceError
from ..policies.base import BasePolicy

from .base import BaseFunctionApproximator
from .value_v import V
from .value_q import QTypeI, QTypeII


__all__ = (
    'ActorCritic',
    'SoftActorCritic',
)


class BaseActorCritic(BasePolicy, BaseFunctionApproximator, ActionSpaceMixin):
    @property
    def env(self):
        return self.policy.env

    def __call__(self, s):
        """
        Draw an action from the current policy :math:`\\pi(a|s)` and get the
        expected value :math:`v(s)`.

        Parameters
        ----------
        s : state observation

            A single state observation.

        Returns
        -------
        a, v : tuple (1d array of floats, float)

            Returns a pair representing :math:`(a, v(s))`.

        """
        return self.policy(s), self.v_func(s)

    def dist_params(self, s):
        """

        Get the distribution parameters under the current policy
        :math:`\\pi(a|s)` and get the expected value :math:`v(s)`.

        Parameters
        ----------
        s : state observation

            A single state observation.

        Returns
        -------
        dist_params, v : tuple (1d array of floats, float)

            Returns a pair representing the distribution parameters of
            :math:`\\pi(a|s)` and the estimated state value :math:`v(s)`.

        """
        return self.policy.dist_params(s), self.v_func(s)

    def batch_eval(self, S, use_target_model=False):
        """
        Evaluate the actor-critic on a batch of state observations.

        Parameters
        ----------
        S : nd array, shape: [batch_size, ...]

            A batch of state observations.

        use_target_model : bool, optional

            Whether to use the :term:`target_model` internally. If False
            (default), the :term:`predict_model` is used.

        Returns
        -------
        A, V : arrays, shapes: [batch_size, ...] and [batch_size]

            A batch of sampled actions :term:`A` and state values :term:`V`.

        """
        A = self.policy.batch_eval(S, use_target_model=use_target_model)
        V = self.v_func.batch_eval(
            S, use_target_model=use_target_model)
        return A, V

    def greedy(self, s):
        """
        Draw a greedy action :math:`a=\\arg\\max_{a'}\\pi(a'|s)` and get the
        expected value :math:`v(s)`.

        Parameters
        ----------
        s : state observation

            A single state observation.

        Returns
        -------
        a, v : tuple (1d array of floats, float)

            Returns a pair representing :math:`(a, v(s))`.

        """
        return self.policy.greedy(s), self.v_func(s)

    def update(self, s, a, r, done):
        """
        Update both actor and critic.

        Parameters
        ----------
        s : state observation

            A single state observation.

        a : action

            A single action.

        r : float

            A single observed reward.

        done : bool

            Whether the episode has finished.

        """
        assert self.env.observation_space.contains(s)
        self.v_func._cache.add(s, a, r, done)

        # eager updates
        while self.v_func._cache:
            # pop with batch_size=1
            self.batch_update(*self.v_func._cache.pop())

    def batch_update(self, S, A, Rn, In, S_next, A_next=None):
        """
        Update both actor and critic on a batch of transitions.

        Parameters
        ----------
        S : nd array, shape: [batch_size, ...]

            A batch of state observations.

        A : nd Tensor, shape: [batch_size, ...]

            A batch of actions taken.

        Rn : 1d array, dtype: float, shape: [batch_size]

            A batch of partial returns. For example, in n-step bootstrapping
            this is given by:

            .. math::

                R^{(n)}_t\\ =\\ R_t + \\gamma\\,R_{t+1} + \\dots
                    \\gamma^{n-1}\\,R_{t+n-1}

            In other words, it's the non-bootstrapped part of the n-step
            return.

        In : 1d array, dtype: float, shape: [batch_size]

            A batch bootstrapping factor. For instance, in n-step bootstrapping
            this is given by :math:`I^{(n)}_t=\\gamma^n` if the episode is
            ongoing and :math:`I^{(n)}_t=0` otherwise. This allows us to write
            the bootstrapped target as
            :math:`G^{(n)}_t=R^{(n)}_t+I^{(n)}_tQ(S_{t+n}, A_{t+n})`.

        S_next : nd array, shape: [batch_size, ...]

            A batch of next-state observations.

        A_next : 2d Tensor, shape: [batch_size, ...]

            A batch of (potential) next actions :term:`A_next`. This argument
            is only used if ``update_strategy='sarsa'``.

        Returns
        -------
        losses : dict

            A dict of losses/metrics, of type ``{name <str>: value <float>}``.

        """
        use_target_model = self.v_func.bootstrap_with_target_model
        V_next = self.v_func.batch_eval(S_next, use_target_model)
        G = Rn + In * V_next

        # check shapes / dtypes
        check_numpy_array(G, ndim=1, dtype='float')
        if self.action_space_is_discrete:
            check_numpy_array(
                A, ndim=2, dtype=('float32', 'float64'),
                axis_size=self.num_actions, axis=1)
        elif self.action_space_is_box:
            check_numpy_array(
                A, ndim=2, dtype=('float32', 'float64'),
                axis_size=self.actions_ndim, axis=1)
        else:
            raise ActionSpaceError.feature_request(self.env)

        losses = self._train_on_batch([S, A, G])
        return losses

    def sync_target_model(self, tau=1.0):
        self.policy.sync_target_model(tau=tau)
        self.v_func.sync_target_model(tau=tau)


[docs]class ActorCritic(BaseActorCritic): """ A generic actor-critic, combining an :term:`updateable policy` with a :term:`value function <state value function>`. The added value of using an :class:`ActorCritic` to combine a policy with a value function is that it avoids having to feed in :term:`S` (potentially very large) three times at training time. Instead, it only feeds it in once. Parameters ---------- policy : Policy object An :term:`updateable policy`. v_func : value-function object A :term:`state value function` :math:`v(s)`. value_loss_weight : float, optional Relative weight to give to the value-function loss: .. code:: python loss = policy_loss + value_loss_weight * value_loss """ def __init__(self, policy, v_func, value_loss_weight=1.0): self.policy = policy self.v_func = v_func self.value_loss_weight = value_loss_weight self._check_function_types() self._init_models()
[docs] @classmethod def from_func( cls, function_approximator, gamma=0.9, bootstrap_n=1, bootstrap_with_target_model=False, entropy_beta=0.01, update_strategy='vanilla', random_seed=None): """ Create instance directly from a :class:`FunctionApproximator <keras_gym.FunctionApproximator>` object. Parameters ---------- function_approximator : FunctionApproximator object The main :term:`function approximator`. gamma : float, optional The discount factor for discounting future rewards. bootstrap_n : positive int, optional The number of steps in n-step bootstrapping. It specifies the number of steps over which we're willing to delay bootstrapping. Large :math:`n` corresponds to Monte Carlo updates and :math:`n=1` corresponds to TD(0). bootstrap_with_target_model : bool, optional Whether to use the :term:`target_model` when constructing a bootstrapped target. If False (default), the primary :term:`predict_model` is used. entropy_beta : float, optional The coefficient of the entropy bonus term in the policy objective. update_strategy : str, callable, optional The strategy for updating our policy. This determines the loss function that we use for our policy function approximator. If you wish to use a custom policy loss, you can override the :func:`policy_loss_with_metrics` method. Provided options are: 'vanilla' Plain vanilla policy gradient. The corresponding (surrogate) loss function that we use is: .. math:: J(\\theta)\\ =\\ -\\mathcal{A}(s,a)\\,\\ln\\pi(a|s,\\theta) 'ppo' `Proximal policy optimization <https://arxiv.org/abs/1707.06347>`_ uses a clipped proximal loss: .. math:: J(\\theta)\\ =\\ \\min\\Big( r(\\theta)\\,\\mathcal{A}(s,a)\\,,\\ \\text{clip}\\big( r(\\theta), 1-\\epsilon, 1+\\epsilon\\big) \\,\\mathcal{A}(s,a)\\Big) where :math:`r(\\theta)` is the probability ratio: .. math:: r(\\theta)\\ =\\ \\frac {\\pi(a|s,\\theta)} {\\pi(a|s,\\theta_\\text{old})} 'cross_entropy' Straightforward categorical cross-entropy (from logits). This loss function does *not* make use of the advantages :term:`Adv`. Instead, it minimizes the cross entropy between the behavior policy :math:`\\pi_b(a|s)` and the learned policy :math:`\\pi_\\theta(a|s)`: .. math:: J(\\theta)\\ =\\ \\hat{\\mathbb{E}}_t\\left\\{ -\\sum_a \\pi_b(a|S_t)\\, \\log \\pi_\\theta(a|S_t) \\right\\} random_seed : int, optional Sets the random state to get reproducible results. """ func = function_approximator # just an abbreviation policy_cls = func._get_policy_class() pi = policy_cls( func, entropy_beta=entropy_beta, random_seed=random_seed) v = V( func, gamma=gamma, bootstrap_n=bootstrap_n, bootstrap_with_target_model=bootstrap_with_target_model) return cls(pi, v)
def _check_function_types(self): if not is_vfunction(self.v_func): if is_qfunction(self.v_func): raise NotImplementedError( "ActorCritic hasn't been yet implemented for Q-functions, " "please let me know is you need this; for the time being, " "please use V-function instead.") if not is_policy(self.policy, check_updateable=True): raise TypeError("expected an updateable policy") if self.policy.env != self.v_func.env: raise ValueError( "the envs of policy and v_func do not match") def _init_models(self): # inputs S, A = self.policy.train_model.inputs[:2] G = keras.Input(name='G', shape=(1,), dtype='float') # get TD advantages V = self.v_func.predict_model(S) Adv = K.stop_gradient(G - V) # update loss with advantage coming directly from graph policy_loss, metrics = self.policy.policy_loss_with_metrics(Adv, A) value_loss = self.v_func.train_model([S, G]) metrics['policy/loss'] = policy_loss metrics['value/loss'] = value_loss loss = policy_loss + self.value_loss_weight * value_loss # joint model self.train_model = keras.Model([S, A, G], loss) self.train_model.add_loss(loss) for name, metric in metrics.items(): self.train_model.add_metric(metric, name=name, aggregation='mean') self.train_model.compile(optimizer=self.policy.train_model.optimizer)
[docs]class SoftActorCritic(BaseActorCritic): """ Implementation of a `soft actor-critic <https://arxiv.org/abs/1801.01290>`_ (SAC), which uses entropy regularization in the value function as well as in its policy updates. Parameters ---------- policy : a policy object An :term:`updateable policy` object :math:`\\pi(a|s)`. v_func : v-function object A state-action value function. This is used as the entropy-regularized value function (critic). q_func1 : q-function object A :term:`type-I state-action value function`. This is used as the target for both the policy (actor) and the state value function (critic). q_func2 : q-function object Same as ``q_func1``. SAC uses two q-functions to avoid overfitting due to overly optimistic value estimates. value_loss_weight : float, optional Relative weight to give to the value-function loss: .. code:: python loss = policy_loss + value_loss_weight * value_loss """ def __init__( self, policy, v_func, q_func1, q_func2, value_loss_weight=1.0): self.policy = policy self.v_func = v_func self.q_func1 = q_func1 self.q_func2 = q_func2 self.value_loss_weight = value_loss_weight self._check_function_types() self._init_models()
[docs] @classmethod def from_func( cls, function_approximator, gamma=0.9, bootstrap_n=1, q_type=None, entropy_beta=0.01, random_seed=None): """ Create instance directly from a :class:`FunctionApproximator <keras_gym.FunctionApproximator>` object. Parameters ---------- function_approximator : FunctionApproximator object The main :term:`function approximator`. gamma : float, optional The discount factor for discounting future rewards. bootstrap_n : positive int, optional The number of steps in n-step bootstrapping. It specifies the number of steps over which we're willing to delay bootstrapping. Large :math:`n` corresponds to Monte Carlo updates and :math:`n=1` corresponds to TD(0). q_type : 1 or 2, optional Whether to model the q-function as :term:`type-I <type-I state-action value function>` or :term:`type-II <type-II state-action value function>`. This defaults to type-II for discrete action spaces and type-I otherwise. entropy_beta : float, optional The coefficient of the entropy bonus term in the policy objective. random_seed : int, optional Sets the random state to get reproducible results. """ func = function_approximator # just an abbreviation if q_type is None: q_type = 2 if func.action_space_is_discrete else 1 q_func_cls = QTypeII if q_type == 2 else QTypeI policy_cls = func._get_policy_class() pi = policy_cls( func, entropy_beta=entropy_beta, random_seed=random_seed, update_strategy='sac') v = V( func, gamma=gamma, bootstrap_n=bootstrap_n, bootstrap_with_target_model=True) q1 = q_func_cls( func, gamma=gamma, bootstrap_n=bootstrap_n) q2 = q_func_cls( func, gamma=gamma, bootstrap_n=bootstrap_n) return cls(pi, v, q1, q2)
def _check_function_types(self): if not is_vfunction(self.v_func): raise TypeError("'v_func' must be a v-function: v(s)") if not is_qfunction(self.q_func1): raise TypeError("'q_func1' must be a q-function: q(s,a)") if not is_qfunction(self.q_func2): raise TypeError("'q_func2' must be a q-function: q(s,a)") if not is_policy(self.policy, check_updateable=True): raise TypeError("'policy' must be an updateable policy") funcs = (self.policy, self.v_func, self.q_func1, self.q_func2) if not all(f.env == self.env for f in funcs): raise ValueError( "the envs of policy and value function(s) do not match") @staticmethod def _get_q_value(q_func, S, A): if is_qfunction(q_func, qtype=2): Q = q_func.target_model(S) check_tensor(Q, ndim=2, axis_size=q_func.num_actions, axis=1) check_tensor(A, ndim=2, axis_size=q_func.num_actions, axis=1) Q = tf.expand_dims(tf.einsum('ij,ij->i', Q, A), axis=1) else: Q = q_func.target_model([S, A]) check_tensor(Q, ndim=2, axis_size=1, axis=1) return Q def _init_models(self): # make sure that the policy loss is set to 'sac' if self.policy.update_strategy != 'sac': self.policy.update_strategy = 'sac' self.logger.warn("policy.update_strategy has been set to 'sac'") # inputs S, A = self.policy.train_model.inputs[:2] G = keras.Input(name='G', shape=(1,), dtype='float') # constuct log(pi(a_sampled, s)) A_sampled = self.policy.dist.sample() # differentiable log_pi = self.policy.dist.log_proba(A_sampled) # use target models for q-values, because they're non-trainable Q1 = self._get_q_value(self.q_func1, S, A_sampled) Q2 = self._get_q_value(self.q_func2, S, A_sampled) Q_both = keras.layers.Concatenate()([Q1, Q2]) check_tensor(Q_both, ndim=2, axis_size=2, axis=1) # construct entropy-corrected target for state value function Q_min = keras.layers.Lambda(lambda x: K.min(x, axis=1))(Q_both) V_target = K.stop_gradient(Q_min - self.policy.entropy_beta * log_pi) check_tensor(V_target, ndim=1) # compute advantages from q-function V = self.v_func.predict_model(S) check_tensor(V, axis_size=1, axis=1) V = K.stop_gradient(K.squeeze(V, axis=1)) Q = keras.layers.Lambda(lambda x: K.mean(x, axis=1))(Q_both) Adv = Q - self.policy.entropy_beta * log_pi - V # update loss with advantage coming directly from graph policy_loss, metrics = self.policy.policy_loss_with_metrics(Adv) v_loss = self.v_func.train_model([S, V_target]) q_loss1 = self.q_func1.train_model([S, A, G]) q_loss2 = self.q_func2.train_model([S, A, G]) value_loss = (v_loss + q_loss1 + q_loss2) / 3. # add losses to metrics dict metrics.update({ 'policy/loss': policy_loss, 'v_func/loss': v_loss, 'q_func1/loss': q_loss1, 'q_func2/loss': q_loss2, 'value/loss': value_loss, }) # combined loss function loss = policy_loss + self.value_loss_weight * value_loss check_tensor(loss, ndim=0) # should be a scalar # joint model self.train_model = keras.Model([S, A, G], loss) self.train_model.add_loss(loss) for name, metric in metrics.items(): self.train_model.add_metric(metric, name=name, aggregation='mean') self.train_model.compile(optimizer=self.policy.train_model.optimizer)
[docs] def batch_update(self, S, A, Rn, In, S_next, A_next=None): super().batch_update(S, A, Rn, In, S_next, A_next) self.q_func1.sync_target_model(tau=1.0) self.q_func2.sync_target_model(tau=1.0)