Missed TensorFlow Dev Summit? Check out the video playlist. Watch recordings

tf_agents.environments.trajectory_replay.TrajectoryReplay

View source on GitHub

A helper that replays a policy against given Trajectory observations.

tf_agents.environments.trajectory_replay.TrajectoryReplay(
    *args, **kwargs
)

Args:

  • policy: A tf_policy.Base policy.
  • time_major: If True, the tensors in trajectory passed to method run are assumed to have shape [time, batch, ...]. Otherwise (default) they are assumed to have shape [batch, time, ...].

Raises:

  • ValueError: If policy is not an instance of tf_policy.Base.

Methods

run

View source

run(
    trajectory, policy_state=None
)

Apply the policy to trajectory steps and store actions/info.

If self.time_major == True, the tensors in trajectory are assumed to have shape [time, batch, ...]. Otherwise they are assumed to have shape [batch, time, ...].

Args:

  • trajectory: The Trajectory to run against. If the replay class was created with time_major=True, then the tensors in trajectory must be shaped [time, batch, ...]. Otherwise they must be shaped [batch, time, ...].
  • policy_state: (optional) A nest Tensor with initial step policy state.

Returns:

  • output_actions: A nest of the actions that the policy took. If the replay class was created with time_major=True, then the tensors here will be shaped [time, batch, ...]. Otherwise they'll be shaped [batch, time, ...].
  • output_policy_info: A nest of the policy info that the policy emitted. If the replay class was created with time_major=True, then the tensors here will be shaped [time, batch, ...]. Otherwise they'll be shaped [batch, time, ...].
  • policy_state: A nest Tensor with final step policy state.

Raises:

  • TypeError: If policy_state structure doesn't match self.policy.policy_state_spec, or trajectory structure doesn't match self.policy.trajectory_spec.
  • ValueError: If policy_state doesn't match self.policy.policy_state_spec, or trajectory structure doesn't match self.policy.trajectory_spec.
  • ValueError: If trajectory lacks two outer dims.