tf_agents.drivers.py_driver.PyDriver

View source on GitHub

A driver that runs a python policy in a python environment.

Inherits From: Driver

Used in the notebooks

Used in the tutorials

env A py_environment.Base environment.
policy A py_policy.Base policy.
observers A list of observers that are notified after every step in the environment. Each observer is a callable(trajectory.Trajectory).
transition_observers A list of observers that are updated after every step in the environment. Each observer is a callable((TimeStep, PolicyStep, NextTimeStep)). The transition is shaped just as trajectories are for regular observers.
max_steps Optional maximum number of steps for each run() call. Also see below. Default: 0.
max_episodes Optional maximum number of episodes for each run() call. At least one of max_steps or max_episodes must be provided. If both are set, run() terminates when at least one of the conditions is satisfied. Default: 0.

ValueError If both max_steps and max_episodes are None.

env

observers

policy

transition_observers

Methods

run

View source

Run policy in environment given initial time_step and policy_state.

Args
time_step The initial time_step.
policy_state The initial policy_state.

Returns
A tuple (final time_step, final policy_state).