Source code for keras_gym.losses.value_based

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import backend as K

from ..utils import check_tensor
from .base import BaseLoss

__all__ = (
    'ProjectedSemiGradientLoss',
    'RootMeanSquaredError',
    'LoglossSign',
)


[docs]class RootMeanSquaredError(BaseLoss): """ Root-mean-squared error (RMSE) loss. Parameters ---------- name : str, optional Optional name for the op. """ name = 'rmse' def __init__(self, delta=1.0, name='root_mean_squared_error'): self._func = tf.keras.losses.MeanSquaredError(name=name)
[docs] def __call__(self, y_true, y_pred, sample_weight=None): """ Compute the RMSE loss. Parameters ---------- y_true : Tensor, shape: [batch_size, ...] Ground truth values. y_pred : Tensor, shape: [batch_size, ...] The predicted values. sample_weight : Tensor, dtype: float, optional Tensor whose rank is either 0, or the same rank as ``y_true``, or is broadcastable to ``y_true``. ``sample_weight`` acts as a coefficient for the loss. If a scalar is provided, then the loss is simply scaled by the given value. If ``sample_weight`` is a tensor of size ``[batch_size]``, then the total loss for each sample of the batch is rescaled by the corresponding element in the ``sample_weight`` vector. If the shape of sample_weight matches the shape of ``y_pred``, then the loss of each measurable element of ``y_pred`` is scaled by the corresponding value of ``sample_weight``. Returns ------- loss : 0d Tensor (scalar) The batch loss. """ return K.sqrt(self._func(y_true, y_pred, sample_weight=sample_weight))
[docs]class ProjectedSemiGradientLoss(BaseLoss): """ Loss function for type-II Q-function. This loss function projects the predictions :math:`q(s, .)` onto the actions for which we actually received a feedback signal. Parameters ---------- G : 1d Tensor, dtype: float, shape: [batch_size] The returns that we wish to fit our value function on. base_loss : keras loss function, optional Keras loss function. Default: :func:`huber_loss <tensorflow.losses.huber_loss>`. """ def __init__(self, G, base_loss=keras.losses.Huber()): check_tensor(G) if K.ndim(G) == 2: check_tensor(G, axis_size=1, axis=1) G = K.squeeze(G, axis=1) check_tensor(G, ndim=1) self.G = K.stop_gradient(G) self.base_loss = base_loss
[docs] def __call__(self, A, Q_pred, sample_weight=None): """ Compute the projected MSE. Parameters ---------- A : 2d Tensor, dtype: int, shape: [batch_size, num_actions] A batch of (one-hot encoded) discrete actions :term:`A`. Q_pred : 2d Tensor, shape: [batch_size, num_actions] The predicted values :math:`q(s,.)`, a.k.a. ``y_pred``. sample_weight : Tensor, dtype: float, optional Tensor whose rank is either 0 or is broadcastable to ``y_true``. ``sample_weight`` acts as a coefficient for the loss. If a scalar is provided, then the loss is simply scaled by the given value. If ``sample_weight`` is a tensor of size ``[batch_size]``, then the total loss for each sample of the batch is rescaled by the corresponding element in the ``sample_weight`` vector. Returns ------- loss : 0d Tensor (scalar) The batch loss. """ # check/fix shapes and dtypes batch_size = K.int_shape(self.G)[0] check_tensor(Q_pred, ndim=2, axis_size=batch_size, axis=0) check_tensor(A, ndim=2, axis_size=batch_size, axis=0) A.set_shape(K.int_shape(Q_pred)) # project onto actions taken: q(s,.) --> q(s,a) Q_pred_projected = tf.einsum('ij,ij->i', Q_pred, A) # the actual loss return self.base_loss( self.G, Q_pred_projected, sample_weight=sample_weight)
[docs]class LoglossSign(BaseLoss): """ Logloss implemented for predicted logits :math:`z\\in\\mathbb{R}` and ground truth :math:`y\\pm1`. .. math:: L\\ =\\ \\log\\left( 1 + \\exp(-y\\,z) \\right) """ def __init__(self): pass
[docs] def __call__(self, y_true, z_pred, sample_weight=None): """ Parameters ---------- y_true : Tensor, shape: [batch_size, ...] Ground truth values :math:`y\\pm1`. z_pred : Tensor, shape: [batch_size, ...] The predicted logits :math:`z\\in\\mathbb{R}`. sample_weight : Tensor, dtype: float, optional Not yet implemented. #TODO: implement this """ if K.dtype(z_pred) == 'float32': z_pred = K.clip(z_pred, -15, 15) elif K.dtype(z_pred) == 'float64': z_pred = K.clip(z_pred, -30, 30) else: raise TypeError('Expected dtype for z_pred: float32 or float64') return K.log(1 + K.exp(-y_true * z_pred))