FunctionApproximator class

keras_gym.FunctionApproximator A generic function approximator.
class keras_gym.FunctionApproximator(env, optimizer=None, **optimizer_kwargs)[source]

A generic function approximator.

This is the central object object that provides an interface between a gym-type environment and function approximators like value functions and updateable policies.

In order to create a valid function approximator, you need to implement the body method. For example, to implement a simple multi-layer perceptron function approximator you would do something like:

import gym
import keras_gym as km
from tensorflow.keras.layers import Flatten, Dense

class MLP(km.FunctionApproximator):
    """ multi-layer perceptron with one hidden layer """
    def body(self, S):
        X = Flatten()(S)
        X = Dense(units=4)(X)
        return X

# environment
env = gym.make(...)

# generic function approximator
mlp = MLP(env, lr=0.001)

# policy and value function
pi, v = km.SoftmaxPolicy(mlp), km.V(mlp)

The default heads are simple (multi) linear regression layers, which can be overridden by your own implementation.

Parameters:
env : environment

A gym-style environment.

optimizer : keras.optimizers.Optimizer, optional

If left unspecified (optimizer=None), the function approximator’s DEFAULT_OPTIMIZER is used. See keras documentation for more details.

**optimizer_kwargs : keyword arguments

Keyword arguments for the optimizer. This is useful when you want to use the default optimizer with a different setting, e.g. changing the learning rate.

DEFAULT_OPTIMIZER

alias of tensorflow.python.keras.optimizer_v2.adam.Adam

body(self, S)[source]

This is the part of the computation graph that may be shared between e.g. policy (actor) and value function (critic). It is typically the part of a neural net that does most of the heavy lifting. One may think of the body() as an elaborate automatic feature extractor.

Parameters:
S : nd Tensor: shape: [batch_size, …]

The input state observation.

Returns:
X : nd Tensor, shape: [batch_size, …]

The intermediate keras tensor.

body_q1(self, S, A)[source]

This is similar to body(), except that it takes a state-action pair as input instead of only state observations.

Parameters:
S : nd Tensor: shape: [batch_size, …]

The input state observation.

A : nd Tensor: shape: [batch_size, …]

The input actions.

Returns:
X : nd Tensor, shape: [batch_size, …]

The intermediate keras tensor.

head_pi(self, X)[source]

This is the policy head. It returns logits, i.e. not probabilities. Use a softmax to turn the output into probabilities.

Parameters:
X : nd Tensor, shape: [batch_size, …]

X is an intermediate tensor in the full forward-pass of the computation graph; it’s the output of the last layer of the body() method.

Returns:
*params : Tensor or tuple of Tensors, shape: [batch_size, …]

These constitute the raw policy distribution parameters.

head_q1(self, X)[source]

This is the type-I Q-value head. It returns a scalar Q-value \(q(s,a)\in\mathbb{R}\).

Parameters:
X : nd Tensor, shape: [batch_size, …]

X is an intermediate tensor in the full forward-pass of the computation graph; it’s the output of the last layer of the body() method.

Returns:
Q_sa : 2d Tensor, shape: [batch_size, 1]

The output type-I Q-values \(q(s,a)\in\mathbb{R}\).

head_q2(self, X)[source]

This is the type-II Q-value head. It returns a vector of Q-values \(q(s,.)\in\mathbb{R}^n\).

Parameters:
X : nd Tensor, shape: [batch_size, …]

X is an intermediate tensor in the full forward-pass of the computation graph; it’s the output of the last layer of the body() method.

Returns:
Q_s : 2d Tensor, shape: [batch_size, num_actions]

The output type-II Q-values \(q(s,.)\in\mathbb{R}^n\).

head_v(self, X)[source]

This is the state value head. It returns a scalar V-value \(v(s)\in\mathbb{R}\).

Parameters:
X : nd Tensor, shape: [batch_size, …]

X is an intermediate tensor in the full forward-pass of the computation graph; it’s the output of the last layer of the body() method.

Returns:
V : 2d Tensor, shape: [batch_size, 1]

The output state values \(v(s)\in\mathbb{R}\).