tf_agents.drivers.dynamic_episode_driver.DynamicEpisodeDriver

A driver that takes N episodes in an environment using a tf.while_loop.

Inherits From: Driver

Used in the notebooks

Used in the tutorials

The while loop will run num_episodes in the environment, counting transitions that result in ending an episode.

As environments run batched time_episodes, the counters for all batch elements are summed, and execution stops when the total exceeds num_episodes.

This termination condition can be overridden in subclasses by implementing the self._loop_condition_fn() method.

env A tf_environment.Base environment.
policy A tf_policy.TFPolicy policy.
observers A list of observers that are updated after every step in the environment. Each observer is a callable(Trajectory).
transition_observers A list of observers that are updated after every step in the environment. Each observer is a callable((TimeStep, PolicyStep, NextTimeStep)).
num_episodes The number of episodes to take in the environment. For batched or parallel environments, this is the total number of episodes summed across all environments.

ValueError If env is not a tf_environment.Base or policy is not an instance of tf_policy.TFPolicy.

env

info_observers

observers

policy

transition_observers

Methods

run

View source

Takes episodes in the environment using the policy and update observers.

If time_step and policy_state are not provided, run will reset the environment and request an initial state from the policy.

Args
time_step optional initial time_step. If None, it will be obtained by resetting the environment. Elements should be shape [batch_size, ...].
policy_state optional initial state for the policy. If None, it will be obtained from the policy.get_initial_state().
num_episodes Optional number of episodes to take in the environment. If None it would use initial num_episodes.
maximum_iterations Optional maximum number of iterations of the while loop to run. If provided, the cond output is AND-ed with an additional condition ensuring the number of iterations executed is no greater than maximum_iterations.

Returns
time_step TimeStep named tuple with final observation, reward, etc.
policy_state Tensor with final step policy state.