Missed TensorFlow Dev Summit? Check out the video playlist. Watch recordings

tf_agents.agents.categorical_dqn.categorical_dqn_agent.project_distribution

View source on GitHub

Projects a batch of (support, weights) onto target_support.

tf_agents.agents.categorical_dqn.categorical_dqn_agent.project_distribution(
    supports, weights, target_support, validate_args=False
)

Based on equation (7) in (Bellemare et al., 2017): https://arxiv.org/abs/1707.06887 In the rest of the comments we will refer to this equation simply as Eq7.

This code is not easy to digest, so we will use a running example to clarify what is going on, with the following sample inputs:

  • supports = [[0, 2, 4, 6, 8], [1, 3, 4, 5, 6]]
  • weights = [[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.2, 0.5, 0.1, 0.1]]
  • target_support = [4, 5, 6, 7, 8]

In the code below, comments preceded with 'Ex:' will be referencing the above values.

Args:

  • supports: Tensor of shape (batch_size, num_dims) defining supports for the distribution.
  • weights: Tensor of shape (batch_size, num_dims) defining weights on the original support points. Although for the CategoricalDQN agent these weights are probabilities, it is not required that they are.
  • target_support: Tensor of shape (num_dims) defining support of the projected distribution. The values must be monotonically increasing. Vmin and Vmax will be inferred from the first and last elements of this tensor, respectively. The values in this tensor must be equally spaced.
  • validate_args: Whether we will verify the contents of the target_support parameter.

Returns:

A Tensor of shape (batch_size, num_dims) with the projection of a batch of (support, weights) onto target_support.

Raises:

  • ValueError: If target_support has no dimensions, or if shapes of supports, weights, and target_support are incompatible.