View source on GitHub |
Projects a batch of (support, weights) onto target_support.
tf_agents.agents.categorical_dqn.categorical_dqn_agent.project_distribution(
supports: tf_agents.typing.types.Tensor
,
weights: tf_agents.typing.types.Tensor
,
target_support: tf_agents.typing.types.Tensor
,
validate_args: bool = False
) -> tf_agents.typing.types.Tensor
Based on equation (7) in (Bellemare et al., 2017): https://arxiv.org/abs/1707.06887 In the rest of the comments we will refer to this equation simply as Eq7.
This code is not easy to digest, so we will use a running example to clarify what is going on, with the following sample inputs:
- supports = [[0, 2, 4, 6, 8], [1, 3, 4, 5, 6]]
- weights = [[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.2, 0.5, 0.1, 0.1]]
- target_support = [4, 5, 6, 7, 8]
In the code below, comments preceded with 'Ex:' will be referencing the above values.
Returns | |
---|---|
A Tensor of shape (batch_size, num_dims) with the projection of a batch of (support, weights) onto target_support. |
Raises | |
---|---|
ValueError
|
If target_support has no dimensions, or if shapes of supports, weights, and target_support are incompatible. |