Missed TensorFlow Dev Summit? Check out the video playlist. Watch recordings

tf_agents.utils.common.shift_values

View source on GitHub

Shifts batch-major values in time by some amount.

tf_agents.utils.common.shift_values(
    values, gamma, num_steps, final_values=None
)

Args:

  • values: A Tensor of shape [batch_size, total_steps] and dtype float32.
  • gamma: A float discount value.
  • num_steps: A nonnegative integer amount to shift values by.
  • final_values: A float32 Tensor of shape [batch_size] corresponding to the values at step num_steps + 1. Defaults to None (all zeros).

Returns:

A Tensor of shape [batch_size, total_steps], where each entry (i, j) is gamma^num_steps * values[i, j + num_steps] if j + num_steps < total_steps; gamma^(total_steps - j) * final_values[i] otherwise.

Raises:

  • ValueError: If values is not of rank 2.