tf_agents.bandits.agents.utils.sum_reward_weighted_observations

Calculates an update used by some Bandit algorithms.

Given an observation x and corresponding reward r, the weigthed observations vector (denoted b here) should be updated as b = b + r * x. This function calculates the sum of weighted rewards for batched observations x.

r a Tensor of shape [batch_size]. This is the rewards of the batched observations.
x a Tensor of shape [batch_size, context_dim]. This is the matrix with the (batched) observations.

The update that needs to be added to b. Has the same shape as b. If the observation matrix x is empty, a zero vector is returned.