tf_agents.policies.utils.masked_argmax

Computes the argmax where the allowed elements are given by a mask.

If a row of mask contains all zeros, then this method will return -1 for the corresponding row of input_tensor.

input_tensor Rank-2 Tensor of floats.
mask 0-1 valued Tensor of the same shape as input.
output_type Integer type of the output.

A Tensor of rank 1 and type output_type, with the masked argmax of every row of input_tensor.