tf_agents.bandits.policies.policy_utilities.masked_argmax

View source on GitHub

Computes the argmax where the allowed elements are given by a mask.

input_tensor Rank-2 Tensor of floats.
mask 0-1 valued Tensor of the same shape as input.
output_type Integer type of the output.

A Tensor of rank 1 and type output_type, with the masked argmax of every row of input_tensor.