Source code for keras_gym.utils.misc

import os
import time
import logging

from PIL import Image


__all__ = (
    'enable_logging',
    'generate_gif',
    'get_env_attr',
    'get_transition',
    'has_env_attr',
    'is_policy',
    'is_qfunction',
    'is_vfunction',
    'render_episode',
    'set_tf_loglevel',
)


[docs]def enable_logging(level=logging.INFO, level_tf=logging.ERROR): """ Enable logging output. This executes the following two lines of code: .. code:: python import logging logging.basicConfig(level=logging.INFO) set_tf_loglevel(logging.ERROR) Note that :func:`set_tf_loglevel` is another keras-gym utility function. Parameters ---------- level : int, optional Log level for native python logging. For instance, if you'd like to see more verbose logging messages you might set ``level=logging.DEBUG``. level_tf : int, optional Log level for tensorflow-specific logging (logs coming from the C++ layer). """ logging.basicConfig(level=logging.INFO) set_tf_loglevel(logging.ERROR)
[docs]def get_transition(env): """ Generate a transition from the environment. This basically does a single step on the environment and then closes it. Parameters ---------- env : gym environment A gym environment. Returns ------- s, a, r, s_next, a_next, done, info : tuple A single transition. Note that the order and the number of items returned is different from what ``env.reset()`` return. """ try: s = env.reset() a = env.action_space.sample() a_next = env.action_space.sample() s_next, r, done, info = env.step(a) return s, a, r, s_next, a_next, done, info finally: env.close()
[docs]def render_episode(env, policy, step_delay_ms=0): """ Run a single episode with env.render() calls with each time step. Parameters ---------- env : gym environment A gym environment. policy : callable A policy objects that is used to pick actions: ``a = policy(s)``. step_delay_ms : non-negative float The number of milliseconds to wait between consecutive timesteps. This can be used to slow down the rendering. """ s = env.reset() env.render() for t in range(int(1e9)): a = policy(s) s_next, r, done, info = env.step(a) env.render() time.sleep(step_delay_ms / 1e3) if done: break s = s_next time.sleep(5 * step_delay_ms / 1e3) env.close()
[docs]def has_env_attr(env, attr, max_depth=100): """ Check if a potentially wrapped environment has a given attribute. Parameters ---------- env : gym environment A potentially wrapped environment. attr : str The attribute name. max_depth : positive int, optional The maximum depth of wrappers to traverse. """ e = env for i in range(max_depth): if hasattr(e, attr): return True if not hasattr(e, 'env'): break e = e.env return False
[docs]def get_env_attr(env, attr, default='__ERROR__', max_depth=100): """ Get the given attribute from a potentially wrapped environment. Note that the wrapped envs are traversed from the outside in. Once the attribute is found, the search stops. This means that an inner wrapped env may carry the same (possibly conflicting) attribute. This situation is *not* resolved by this function. Parameters ---------- env : gym environment A potentially wrapped environment. attr : str The attribute name. max_depth : positive int, optional The maximum depth of wrappers to traverse. """ e = env for i in range(max_depth): if hasattr(e, attr): return getattr(e, attr) if not hasattr(e, 'env'): break e = e.env if default == '__ERROR__': raise AttributeError("env is missing attribute: {}".format(attr)) return default
[docs]def generate_gif(env, policy, filepath, resize_to=None, duration=50): """ Store a gif from the episode frames. Parameters ---------- env : gym environment The environment to record from. policy : keras-gym policy object The policy that is used to take actions. filepath : str Location of the output gif file. resize_to : tuple of ints, optional The size of the output frames, ``(width, height)``. Notice the ordering: first **width**, then **height**. This is the convention PIL uses. duration : float, optional Time between frames in the animated gif, in milliseconds. """ logger = logging.getLogger('generate_gif') # collect frames frames = [] s = env.reset() for t in range(env.spec.max_episode_steps or 10000): a = policy(s) s_next, r, done, info = env.step(a) # store frame frame = env.render(mode='rgb_array') frame = Image.fromarray(frame) frame = frame.convert('P', palette=Image.ADAPTIVE) if resize_to is not None: if not (isinstance(resize_to, tuple) and len(resize_to) == 2): raise TypeError("expected a tuple of size 2, resize_to=(w, h)") frame = frame.resize(resize_to) frames.append(frame) if done: break s = s_next # store last frame frame = env.render(mode='rgb_array') frame = Image.fromarray(frame) frame = frame.convert('P', palette=Image.ADAPTIVE) if resize_to is not None: frame = frame.resize(resize_to) frames.append(frame) # generate gif os.makedirs(os.path.dirname(filepath), exist_ok=True) frames[0].save( fp=filepath, format='GIF', append_images=frames[1:], save_all=True, duration=duration, loop=0) logger.info("recorded episode to: {}".format(filepath))
[docs]def is_vfunction(obj): """ Check whether an object is a :term:`state value function`, or V-function. Parameters ---------- obj Object to check. Returns ------- bool Whether ``obj`` is a V-function. """ # import at runtime to avoid circular dependence from ..core.value_v import V return isinstance(obj, V)
[docs]def is_qfunction(obj, qtype=None): """ Check whether an object is a :term:`state-action value function <type-I state-action value function>`, or Q-function. Parameters ---------- obj Object to check. qtype : 1 or 2, optional Check for specific Q-function type, i.e. :term:`type-I <type-I state-action value function>` or :term:`type-II <type-II state-action value function>`. Returns ------- bool Whether ``obj`` is a (type-I/II) Q-function. """ # import at runtime to avoid circular dependence from ..core.value_q import QTypeI, QTypeII if qtype is None: return isinstance(obj, (QTypeI, QTypeII)) elif qtype in (1, 1., '1', 'i', 'I'): return isinstance(obj, QTypeI) elif qtype in (2, 2., '2', 'ii', 'II'): return isinstance(obj, QTypeII) else: raise ValueError("unexpected qtype: {}".format(qtype))
[docs]def is_policy(obj, check_updateable=False): """ Check whether an object is an :term:`(updateable) policy <updateable policy>`. Parameters ---------- obj Object to check. check_updateable : bool, optional If the obj is a policy, also check whether or not the policy is updateable. Returns ------- bool Whether ``obj`` is a (updateable) policy. """ # import at runtime to avoid circular dependence from ..policies.base import BasePolicy from ..core.base import BaseUpdateablePolicy if check_updateable: return isinstance(obj, BaseUpdateablePolicy) return isinstance(obj, BasePolicy)
[docs]def set_tf_loglevel(level): """ Set the logging level for Tensorflow logger. This also sets the logging level of the underlying C++ layer. Parameters ---------- level : int A logging level as provided by the builtin :mod:`logging` module, e.g. ``level=logging.INFO``. """ if level >= logging.FATAL: os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' if level >= logging.ERROR: os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' if level >= logging.WARNING: os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' else: os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0' logging.getLogger('tensorflow').setLevel(level)