Source code for keras_gym.caching.experience_replay

import numpy as np

from ..base.mixins import RandomStateMixin, ActionSpaceMixin
from ..base.errors import NumpyArrayCheckError, InsufficientCacheError
from ..utils import check_numpy_array, get_env_attr

__all__ = (

[docs]class ExperienceReplayBuffer(RandomStateMixin, ActionSpaceMixin): """ A simple numpy implementation of an experience replay buffer. This is written primarily with computer game environments (Atari) in mind. It implements a generic experience replay buffer for environments in which individual observations (frames) are stacked to represent the state. Parameters ---------- env : gym environment The main gym environment. This is needed to infer the number of stacked frames ``num_frames`` as well as the number of actions ``num_actions``. capacity : positive int The capacity of the experience replay buffer. DQN typically uses ``capacity=1000000``. batch_size : positive int, optional The desired batch size of the sample. bootstrap_n : positive int The number of steps over which to delay bootstrapping, i.e. n-step bootstrapping. gamma : float between 0 and 1 Reward discount factor. random_seed : int or None To get reproducible results. """ def __init__( self, env, capacity, batch_size=32, bootstrap_n=1, gamma=0.99, random_seed=None): self.env = env self.capacity = int(capacity) self.batch_size = int(batch_size) self.num_frames = get_env_attr(env, 'num_frames', 1) self.bootstrap_n = int(bootstrap_n) self.gamma = float(gamma) self.random_seed = random_seed # internal self._initialized = False
[docs] @classmethod def from_value_function(cls, value_function, capacity, batch_size=32): """ Create a new instance by extracting some settings from a Q-function. The settings that are extracted from the value function are: ``gamma``, ``bootstrap_n`` and ``num_frames``. The latter is taken from the value function's ``env`` attribute. Parameters ---------- value_function : value-function object A state value function or a state-action value function. capacity : positive int The capacity of the experience replay buffer. DQN typically uses ``capacity=1000000``. batch_size : positive int, optional The desired batch size of the sample. Returns ------- experience_replay_buffer A new instance. """ self = cls( env=value_function.env, capacity=capacity, batch_size=batch_size, gamma=value_function.gamma, bootstrap_n=value_function.bootstrap_n) return self
[docs] def add(self, s, a, r, done, episode_id): """ Add a transition to the experience replay buffer. Parameters ---------- s : state A single state observation. a : action A single action. r : float The observed rewards associated with this transition. done : bool Whether the episode has finished. episode_id : int The episode in which the transition took place. This is needed for generating consistent samples. """ s = self._extract_last_frame(s) if not self._initialized: self._s_dtype = s.dtype self._s_shape = s.shape if self.action_space_is_discrete: self._a_shape = (self.num_actions,) # we do one-hot encoding self._a_dtype = 'float' else: self._a_shape = self.env.action_space.shape self._a_dtype = self.env.action_space.dtype self._init_cache() if self.action_space_is_discrete: a = self._one_hot_encode_discrete(a) self._s[self._i] = s self._a[self._i] = a self._r[self._i] = r self._d[self._i] = done self._e[self._i] = episode_id self._i = (self._i + 1) % (self.capacity + self.bootstrap_n) if self._num_transitions < self.capacity + self.bootstrap_n: self._num_transitions += 1
[docs] def sample(self): """ Get a batch of transitions to be used for bootstrapped updates. Returns ------- S, A, Rn, In, S_next, A_next : tuple of arrays The returned tuple represents a batch of preprocessed transitions: (:term:`S`, :term:`A`, :term:`Rn`, :term:`In`, :term:`S_next`, :term:`A_next`) These are typically used for bootstrapped updates, e.g. minimizing the bootstrapped MSE: .. math:: \\left( R^{(n)}_t + I^{(n)}_t\\,\\sum_aP(a|S_{t+n})\\,Q(S_{t+n},a) - \\sum_aP(a|S_t)\\,Q(S_t,a) \\right)^2 """ # noqa: E501 if not self._initialized or len(self) < self.batch_size: raise InsufficientCacheError( "insufficient cached data to sample from") S = [] A = [] Rn = [] In = [] S_next = [] A_next = [] for attempt in range(10 * self.batch_size): if not self._num_transitions: raise RuntimeError( "please insert more transitions before sampling") j = self.random.randint(self._num_transitions) ep = self._e[j] # current episode id js = j - np.arange(self.num_frames - 1, -1, -1) # indices for s ks = js + self.bootstrap_n # indices for s_next ls = j + np.arange(self.bootstrap_n) # indices for n-step rewards # wrap around js %= self.capacity + self.bootstrap_n ks %= self.capacity + self.bootstrap_n ls %= self.capacity + self.bootstrap_n # check if S indices are all from the same episode if np.any(self._e[js[:-1]] > ep): # Check if all js are from the current episode or from the # immediately preceding episodes. Otherwise, we would generate # spurious data because it would probably mean that 'js' spans # the overwrite-boundary. continue for i in range(1, self.num_frames): # if j is from a previous episode, replace it by its successor if self._e[js[~i]] < ep: js[~i] = js[~i + 1] if self._e[ks[~i]] < ep: ks[~i] = ks[~i + 1] # gather partial returns rn = np.zeros(1) done = False for t, l in enumerate(ls): rn[0] += pow(self.gamma, t) * self._r[l] done = self._d[l] if done: break if not done and np.any(self._e[ks] > ep): # this shouldn't happen (TODO: investigate) continue # permutation to transpose 'num_frames' axis to axis=-1 perm = np.roll(np.arange(self._s.ndim), -1) S.append(self._s[js].transpose(perm)) A.append(self._a[js[~0:]]) Rn.append(rn) S_next.append(self._s[ks].transpose(perm)) A_next.append(self._a[ks[~0:]]) if done: In.append(np.zeros(1)) else: In.append( np.power([self.gamma], self.bootstrap_n)) if len(S) == self.batch_size: break if len(S) < self.batch_size: raise RuntimeError("couldn't construct valid sample") S = np.stack(S, axis=0) A = np.concatenate(A, axis=0) Rn = np.concatenate(Rn, axis=0) In = np.concatenate(In, axis=0) S_next = np.stack(S_next, axis=0) A_next = np.concatenate(A_next, axis=0) if self.num_frames == 1: S = np.squeeze(S, axis=-1) S_next = np.squeeze(S_next, axis=-1) return S, A, Rn, In, S_next, A_next
[docs] def clear(self): """ Clear the experience replay buffer. """ self._i = 0 self._num_transitions = 0
def __len__(self): return max(0, self._num_transitions - self.bootstrap_n) def __bool__(self): return bool(len(self)) def _init_cache(self): self._i = 0 self._num_transitions = 0 # construct appropriate shapes n = (self.capacity + self.bootstrap_n,) # create cache attrs self._s = np.empty(n + self._s_shape, self._s_dtype) # frames self._a = np.zeros(n + self._a_shape, self._a_dtype) # actions self._r = np.zeros(n, 'float') # rewards self._d = np.zeros(n, 'bool') # done? self._e = np.zeros(n, 'int32') # episode id self._initialized = True def _extract_last_frame(self, s): if self.num_frames == 1: return s check_numpy_array(s, axis_size=self.num_frames, axis=-1) if s.ndim == 3: s = s[:, :, -1] elif s.ndim == 4: s = s[:, :, :, -1] else: NumpyArrayCheckError( "expected ndim equal to 3 or 4, got shape: {}".format(s.shape)) return s