Missed TensorFlow Dev Summit? Check out the video playlist. Watch recordings

tf_agents.bandits.policies.policy_utilities.masked_argmax

View source on GitHub

Computes the argmax where the allowed elements are given by a mask.

tf_agents.bandits.policies.policy_utilities.masked_argmax(
    input_tensor, mask, output_type=tf.int32
)

Args:

  • input_tensor: Rank-2 Tensor of floats.
  • mask: 0-1 valued Tensor of the same shape as input.
  • output_type: Integer type of the output.

Returns:

A Tensor of rank 1 and type output_type, with the masked argmax of every row of input_tensor.