Source code for keras_gym.core.base

from abc import ABC, abstractmethod

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import backend as K

from ..base.mixins import RandomStateMixin, ActionSpaceMixin, LoggerMixin
from ..utils import check_tensor
from ..policies.base import BasePolicy


__all__ = (
    'BaseFunctionApproximator',
    'BaseUpdateablePolicy',
)


class BaseFunctionApproximator(ABC, LoggerMixin, ActionSpaceMixin, RandomStateMixin):  # noqa: E501
    @abstractmethod
    def __call__(self, *args, **kwargs):
        pass

    @abstractmethod
    def update(self, *args, **kwargs):
        pass

    @abstractmethod
    def batch_eval(self, *args, **kwargs):
        pass

    @abstractmethod
    def batch_update(self, *args, **kwargs):
        pass

    def _check_attrs(self, skip=None):
        required_attrs = [
            'env', 'gamma', 'bootstrap_n', 'train_model', 'predict_model',
            'target_model', '_cache']

        if skip is None:
            skip = []

        missing_attrs = ", ".join(
            attr for attr in required_attrs
            if attr not in skip and not hasattr(self, attr))

        if missing_attrs:
            raise AttributeError(
                "missing attributes: {}".format(missing_attrs))

    def _train_on_batch(self, inputs, outputs=None):
        """
        Run self.train_model.train_on_batch(inputs, outputs) and return the
        losses as a dict of type: {loss_name <str>: loss_value <float>}.

        """
        losses = self.train_model.train_on_batch(inputs, outputs)

        # add metric names
        if len(self.train_model.metrics_names) > 1:
            assert len(self.train_model.metrics_names) == len(losses)
            losses = dict(zip(self.train_model.metrics_names, losses))
        else:
            assert isinstance(losses, (float, np.float32, np.float64))
            assert len(self.train_model.metrics_names) == 1
            losses = {self.train_model.metrics_names[0]: losses}

        if hasattr(self.env, 'record_losses'):
            self.env.record_losses(losses)

        return losses

    @staticmethod
    def _create_target_model(model):
        target_model = keras.models.clone_model(model)
        target_model.trainable = False  # exclude from trainable weights
        return target_model

    def sync_target_model(self, tau=1.0):
        """
        Synchronize the target model with the primary model.

        Parameters
        ----------
        tau : float between 0 and 1, optional

            The amount of exponential smoothing to apply in the target update:

            .. math::

                w_\\text{target}\\ \\leftarrow\\ (1 - \\tau)\\,w_\\text{target}
                + \\tau\\,w_\\text{primary}

        """
        if tau > 1 or tau < 0:
            ValueError("tau must lie on the unit interval [0,1]")

        for m in ('model', 'param_model', 'greedy_model'):
            if hasattr(self, 'target_' + m):
                p = getattr(self, 'predict_' + m)
                t = getattr(self, 'target_' + m)
                Wp = p.get_weights()
                Wt = t.get_weights()
                Wt = [wt + tau * (wp - wt) for wt, wp in zip(Wt, Wp)]
                t.set_weights(Wt)


class BaseUpdateablePolicy(BasePolicy, BaseFunctionApproximator):
    """
    Base class for modeling :term:`updateable policies <updateable policy>`.

    Parameters
    ----------
    function_approximator : FunctionApproximator

        The main :class:`FunctionApproximator <keras_gym.FunctionApproximator>`
        object.

    update_strategy : str, callable, optional

        The strategy for updating our policy. This determines the loss function
        that we use for our policy function approximator. If you wish to use a
        custom policy loss, you can override the
        :func:`policy_loss_with_metrics` method.

        Provided options are:

            'vanilla'
                Plain vanilla policy gradient. The corresponding (surrogate)
                loss function that we use is:

                .. math::

                    J(\\theta)\\ =\\ -\\mathcal{A}(s,a)\\,\\ln\\pi(a|s,\\theta)

            'ppo'
                `Proximal policy optimization
                <https://arxiv.org/abs/1707.06347>`_ uses a clipped proximal
                loss:

                .. math::

                    J(\\theta)\\ =\\ \\min\\Big(
                        r(\\theta)\\,\\mathcal{A}(s,a)\\,,\\
                        \\text{clip}\\big(
                            r(\\theta), 1-\\epsilon, 1+\\epsilon\\big)
                                \\,\\mathcal{A}(s,a)\\Big)

                where :math:`r(\\theta)` is the probability ratio:

                .. math::

                    r(\\theta)\\ =\\ \\frac
                        {\\pi(a|s,\\theta)}
                        {\\pi(a|s,\\theta_\\text{old})}

            'cross_entropy'
                Straightforward categorical cross-entropy (from logits). This
                loss function does *not* make use of the advantages
                :term:`Adv`. Instead, it minimizes the cross entropy between
                the behavior policy :math:`\\pi_b(a|s)` and the learned policy
                :math:`\\pi_\\theta(a|s)`:

                .. math::

                    J(\\theta)\\ =\\ \\hat{\\mathbb{E}}_t\\left\\{
                        -\\sum_a \\pi_b(a|S_t)\\, \\log \\pi_\\theta(a|S_t)
                    \\right\\}

    ppo_clip_eps : float, optional

        The clipping parameter :math:`\\epsilon` in the PPO clipped surrogate
        loss. This option is only applicable if ``update_strategy='ppo'``.

    entropy_beta : float, optional

        The coefficient of the entropy bonus term in the policy objective.

    random_seed : int, optional

        Sets the random state to get reproducible results.

    """
    UPDATE_STRATEGIES = ('vanilla', 'ppo', 'cross_entropy')

    def __init__(
            self, function_approximator,
            update_strategy='vanilla',
            ppo_clip_eps=0.2,
            entropy_beta=0.01,
            random_seed=None):

        self.function_approximator = function_approximator
        self.env = self.function_approximator.env
        self.update_strategy = update_strategy
        self.ppo_clip_eps = float(ppo_clip_eps)
        self.entropy_beta = float(entropy_beta)
        self.random_seed = random_seed  # sets self.random via RandomStateMixin

        self._init_models()
        self._check_attrs()

    def __call__(self, s, use_target_model=False):
        """
        Draw an action from the current policy :math:`\\pi(a|s)`.

        Parameters
        ----------
        s : state observation

            A single state observation.

        use_target_model : bool, optional

            Whether to use the :term:`target_model` internally. If False
            (default), the :term:`predict_model` is used.

        Returns
        -------
        a : action

            A single action proposed under the current policy.

        """
        S = np.expand_dims(s, axis=0)
        A = self.batch_eval(S, use_target_model)
        return A[0]

    def dist_params(self, s, use_target_model=False):
        """

        Get the parameters of the (conditional) probability distribution
        :math:`\\pi(a|s)`.

        Parameters
        ----------
        s : state observation

            A single state observation.

        use_target_model : bool, optional

            Whether to use the :term:`target_model` internally. If False
            (default), the :term:`predict_model` is used.

        Returns
        -------
        \\*params : tuple of arrays

            The raw distribution parameters.

        """
        assert self.env.observation_space.contains(s)
        S = np.expand_dims(s, axis=0)
        if use_target_model:
            params = self.target_param_model.predict(S)
        else:
            params = self.predict_param_model.predict(S)

        # extract single instance
        if isinstance(params, list):
            params = [arr[0] for arr in params]
        elif isinstance(params, np.ndarray):
            params = params[0]
        else:
            TypeError(f"params have unexpected type: {type(params)}")

        return params

    def greedy(self, s, use_target_model=False):
        """
        Draw the greedy action, i.e. :math:`\\arg\\max_a\\pi(a|s)`.

        Parameters
        ----------
        s : state observation

            A single state observation.

        use_target_model : bool, optional

            Whether to use the :term:`target_model` internally. If False
            (default), the :term:`predict_model` is used.

        Returns
        -------
        a : action

            A single action proposed under the current policy.

        """
        assert self.env.observation_space.contains(s)
        S = np.expand_dims(s, axis=0)
        if use_target_model:
            A = self.target_greedy_model.predict(S)
        else:
            A = self.predict_greedy_model.predict(S)
        return A[0]

    def update(self, s, a, advantage):
        """
        Update the policy.

        Parameters
        ----------
        s : state observation

            A single state observation.

        a : action

            A single action.

        advantage : float

            A value for the advantage :math:`\\mathcal{A}(s,a) = q(s,a) -
            v(s)`. This might be sampled and/or estimated version of the true
            advantage.

        """
        assert self.env.observation_space.contains(s)
        assert self.env.action_space.contains(a)
        S = np.expand_dims(s, axis=0)
        A = np.expand_dims(a, axis=0)
        Adv = np.expand_dims(advantage, axis=0)
        self.batch_update(S, A, Adv)

    def batch_eval(self, S, use_target_model=False):
        """
        Evaluate the policy on a batch of state observations.

        Parameters
        ----------
        S : nd array, shape: [batch_size, ...]

            A batch of state observations.

        use_target_model : bool, optional

            Whether to use the :term:`target_model` internally. If False
            (default), the :term:`predict_model` is used.

        Returns
        -------
        A : nd array, shape: [batch_size, ...]

            A batch of sampled actions.

        """
        if use_target_model:
            A = self.target_model.predict(S)
        else:
            A = self.predict_model.predict(S)
        return A

    def batch_update(self, S, A, Adv):
        """
        Update the policy on a batch of transitions.

        Parameters
        ----------
        S : nd array, shape: [batch_size, ...]

            A batch of state observations.

        A : nd array, shape: [batch_size, ...]

            A batch of actions taken by the behavior policy.

        Adv : 1d array, dtype: float, shape: [batch_size]

            A value for the :term:`advantage <Adv>` :math:`\\mathcal{A}(s,a) =
            q(s,a) - v(s)`. This might be sampled and/or estimated version of
            the true advantage.

        Returns
        -------
        losses : dict

            A dict of losses/metrics, of type ``{name <str>: value <float>}``.

        """
        losses = self._train_on_batch([S, A, Adv])
        return losses

[docs] def policy_loss_with_metrics(self, Adv, A=None): """ This method constructs the policy loss as a scalar-valued Tensor, together with a dictionary of metrics (also scalars). This method may be overridden to construct a custom policy loss and/or to change the accompanying metrics. Parameters ---------- Adv : 1d Tensor, shape: [batch_size] A batch of advantages. A : nd Tensor, shape: [batch_size, ...] A batch of actions taken under the behavior policy. For some choices of policy loss, e.g. ``update_strategy='sac'`` this input is ignored. Returns ------- loss, metrics : (Tensor, dict of Tensors) The policy loss along with some metrics, which is a dict of type ``{name <str>: metric <Tensor>}``. The loss and each of the metrics (dict values) are scalar Tensors, i.e. Tensors with ``ndim=0``. The ``loss`` is passed to a keras Model using ``train_model.add_loss(loss)``. Similarly, each metric in the metric dict is passed to the model using ``train_model.add_metric(metric, name=name, aggregation='mean')``. """ if K.ndim(Adv) == 2: check_tensor(Adv, axis_size=1, axis=1) Adv = K.squeeze(Adv, axis=1) check_tensor(Adv, ndim=1) if self.update_strategy == 'vanilla': assert A is not None log_pi = self.dist.log_proba(A) check_tensor(log_pi, same_as=Adv) entropy = K.mean(self.dist.entropy()) # flip sign to get loss from objective loss = -K.mean(Adv * log_pi) + self.entropy_beta * entropy # no metrics related to behavior_dist since its not used in loss metrics = {'policy/entropy': entropy} elif self.update_strategy == 'ppo': assert A is not None log_pi = self.dist.log_proba(A) log_pi_old = K.stop_gradient(self.target_dist.log_proba(A)) check_tensor(log_pi, same_as=Adv) check_tensor(log_pi_old, same_as=Adv) eps = self.ppo_clip_eps ratio = K.exp(log_pi - log_pi_old) ratio_clip = K.clip(ratio, 1 - eps, 1 + eps) check_tensor(ratio, same_as=Adv) check_tensor(ratio_clip, same_as=Adv) clip_objective = K.mean(K.minimum(Adv * ratio, Adv * ratio_clip)) entropy = K.mean(self.dist.entropy()) kl_div = K.mean(self.target_dist.kl_divergence(self.dist)) # flip sign to get loss from objective loss = -(clip_objective + self.entropy_beta * entropy) metrics = {'policy/entropy': entropy, 'policy/kl_div': kl_div} elif self.update_strategy == 'sac': self.logger.debug("using update_strategy 'sac'") loss = -K.mean(Adv) metrics = {'policy/entropy': K.mean(self.dist.entropy())} elif self.update_strategy == 'cross_entropy': raise NotImplementedError('cross_entropy') else: raise ValueError( "unknown update_strategy '{}'".format(self.update_strategy)) # rename check_tensor(loss, ndim=0) loss = tf.identity(loss, name='policy/loss') return loss, metrics
def _check_attrs(self): model = [ 'predict_model', 'target_model', 'predict_greedy_model', 'target_greedy_model', 'predict_param_model', 'target_param_model', 'train_model', ] misc = [ 'function_approximator', 'env', 'update_strategy', 'ppo_clip_eps', 'entropy_beta', 'random_seed', ] missing_attrs = [a for a in model + misc if not hasattr(self, a)] if missing_attrs: raise AttributeError( "missing attributes: {}".format(", ".join(missing_attrs)))