Missed TensorFlow Dev Summit? Check out the video playlist. Watch recordings

tf_agents.bandits.agents.utils.sum_reward_weighted_observations

View source on GitHub

Calculates an update used by some Bandit algorithms.

tf_agents.bandits.agents.utils.sum_reward_weighted_observations(
    r, x
)

Given an observation x and corresponding reward r, the weigthed observations vector (denoted b here) should be updated as b = b + r * x. This function calculates the sum of weighted rewards for batched observations x.

Args:

  • r: a Tensor of shape [batch_size]. This is the rewards of the batched observations.
  • x: a Tensor of shape [batch_size, context_dim]. This is the matrix with the (batched) observations.

Returns:

The update that needs to be added to b. Has the same shape as b. If the observation matrix x is empty, a zero vector is returned.