Missed TensorFlow Dev Summit? Check out the video playlist. Watch recordings

tf_agents.drivers.py_driver.PyDriver

View source on GitHub

A driver that runs a python policy in a python environment.

Inherits From: Driver

tf_agents.drivers.py_driver.PyDriver(
    env, policy, observers, transition_observers=None, max_steps=None,
    max_episodes=None
)

Used in the notebooks

Used in the tutorials

Args:

  • env: A py_environment.Base environment.
  • policy: A py_policy.Base policy.
  • observers: A list of observers that are notified after every step in the environment. Each observer is a callable(trajectory.Trajectory).
  • transition_observers: A list of observers that are updated after every step in the environment. Each observer is a callable((TimeStep, PolicyStep, NextTimeStep)). The transition is shaped just as trajectories are for regular observers.
  • max_steps: Optional maximum number of steps for each run() call. Also see below. Default: 0.
  • max_episodes: Optional maximum number of episodes for each run() call. At least one of max_steps or max_episodes must be provided. If both are set, run() terminates when at least one of the conditions is satisfied. Default: 0.

Attributes:

  • env
  • observers
  • policy
  • transition_observers

Raises:

  • ValueError: If both max_steps and max_episodes are None.

Methods

run

View source

run(
    time_step, policy_state=()
)

Run policy in environment given initial time_step and policy_state.

Args:

  • time_step: The initial time_step.
  • policy_state: The initial policy_state.

Returns:

A tuple (final time_step, final policy_state).