Source code for keras_gym.proba_dists.categorical

import tensorflow as tf
from tensorflow.keras import backend as K

from ..utils.tensor import log_softmax_tf, check_tensor

from .base import BaseProbaDist


[docs]class CategoricalDist(BaseProbaDist): """ Differential implementation of a categorical distribution. Parameters ---------- logits : 2d Tensor, dtype: float, shape: [batch_size, num_categories] A batch of logits :math:`z\\in \\mathbb{R}^n` with :math:`n=` ``num_categories``. boltzmann_tau : float, optional The Boltzmann temperature that is used in generating near one-hot propensities in :func:`sample`. A smaller number means closer to deterministic, one-hot encoded samples. A larger number means better numerical stability. A good value for :math:`\\tau` is one that offers a good trade-off between these two desired properties. name : str, optional Name scope of the distribution. random_seed : int, optional To get reproducible results. """ PARAM_NAMES = ('logits',) def __init__( self, logits, boltzmann_tau=0.2, name='categorical_dist', random_seed=None): check_tensor(logits, ndim=2) self.num_categories = K.int_shape(logits)[1] self.name = name self.logits = logits self.boltzmann_tau = boltzmann_tau self.random_seed = random_seed # also sets self.random (RandomState)
[docs] def sample(self): """ Sample from the probability distribution. In order to return a differentiable sample, this method uses the approach outlined in `[ArXiv:1611.01144] <https://arxiv.org/abs/1611.01144>`_. Returns ------- sample : 2d array, shape: [batch_size, num_categories] The sampled variates. The returned arrays are near one-hot encoded versions of deterministic variates. """ logp = log_softmax_tf(self.logits) u = tf.random.uniform( shape=K.shape(logp), dtype=K.dtype(logp), seed=self.random_seed) g = -K.log(-K.log(u)) # g ~ Gumbel(0,1) return K.softmax((g + logp) / self.boltzmann_tau)
[docs] def log_proba(self, x): if K.ndim(x) == 2 and K.int_shape(x)[1] == 1: x = K.squeeze(x, axis=1) if K.ndim(x) == 1: x = K.one_hot(x, self.num_categories) check_tensor(x, same_as=self.logits) logp = tf.einsum('ij,ij->i', x, log_softmax_tf(self.logits)) return self._rename(logp, 'log_proba')
[docs] def entropy(self): p = K.softmax(self.logits) logp = log_softmax_tf(self.logits) h = tf.einsum('ij,ij->i', p, -logp) return self._rename(h, 'entropy')
[docs] def cross_entropy(self, other): self._check_other(other) p_self = K.softmax(self.logits) logp_other = log_softmax_tf(other.logits) ce = tf.einsum('ij,ij->i', p_self, -logp_other) return self._rename(ce, 'cross_entropy')
[docs] def kl_divergence(self, other): self._check_other(other) p_self = K.softmax(self.logits) logp_self = log_softmax_tf(self.logits) logp_other = log_softmax_tf(other.logits) kl_div = tf.einsum('ij,ij->i', p_self, logp_self - logp_other) return self._rename(kl_div, 'kl_divergence')