Missed TensorFlow Dev Summit? Check out the video playlist. Watch recordings

tf_agents.utils.common.get_episode_mask

View source on GitHub

Create a mask that is 0.0 for all final steps, 1.0 elsewhere.

tf_agents.utils.common.get_episode_mask(
    time_steps
)

Args:

  • time_steps: A TimeStep namedtuple representing a batch of steps.

Returns:

A float32 Tensor with 0s where step_type == LAST and 1s otherwise.