tf_agents.bandits.agents.utils.sum_reward_weighted_observations
Stay organized with collections
Save and categorize content based on your preferences.
Calculates an update used by some Bandit algorithms.
tf_agents.bandits.agents.utils.sum_reward_weighted_observations(
r: tf_agents.typing.types.Tensor
,
x: tf_agents.typing.types.Tensor
) -> tf_agents.typing.types.Tensor
Given an observation x
and corresponding reward r
, the weigthed
observations vector (denoted b
here) should be updated as b = b + r * x
.
This function calculates the sum of weighted rewards for batched
observations x
.
Args |
r
|
a Tensor of shape [batch_size ]. This is the rewards of the batched
observations.
|
x
|
a Tensor of shape [batch_size , context_dim ]. This is the matrix
with the (batched) observations.
|
Returns |
The update that needs to be added to b . Has the same shape as b .
If the observation matrix x is empty, a zero vector is returned.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2024-04-26 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-04-26 UTC."],[],[]]