tf_agents.agents.categorical_dqn.categorical_dqn_agent.project_distribution
Stay organized with collections
Save and categorize content based on your preferences.
Projects a batch of (support, weights) onto target_support.
tf_agents.agents.categorical_dqn.categorical_dqn_agent.project_distribution(
supports: tf_agents.typing.types.Tensor
,
weights: tf_agents.typing.types.Tensor
,
target_support: tf_agents.typing.types.Tensor
,
validate_args: bool = False
) -> tf_agents.typing.types.Tensor
Based on equation (7) in (Bellemare et al., 2017):
https://arxiv.org/abs/1707.06887
In the rest of the comments we will refer to this equation simply as Eq7.
This code is not easy to digest, so we will use a running example to clarify
what is going on, with the following sample inputs:
- supports = [[0, 2, 4, 6, 8],
[1, 3, 4, 5, 6]]
- weights = [[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.2, 0.5, 0.1, 0.1]]
- target_support = [4, 5, 6, 7, 8]
In the code below, comments preceded with 'Ex:' will be referencing the above
values.
Args |
supports
|
Tensor of shape (batch_size, num_dims) defining supports for the
distribution.
|
weights
|
Tensor of shape (batch_size, num_dims) defining weights on the
original support points. Although for the CategoricalDQN agent these
weights are probabilities, it is not required that they are.
|
target_support
|
Tensor of shape (num_dims) defining support of the projected
distribution. The values must be monotonically increasing. Vmin and Vmax
will be inferred from the first and last elements of this tensor,
respectively. The values in this tensor must be equally spaced.
|
validate_args
|
Whether we will verify the contents of the target_support
parameter.
|
Returns |
A Tensor of shape (batch_size, num_dims) with the projection of a batch of
(support, weights) onto target_support.
|
Raises |
ValueError
|
If target_support has no dimensions, or if shapes of supports,
weights, and target_support are incompatible.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2024-04-26 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-04-26 UTC."],[],[]]