Module: tf_agents.policies.utils

Utilities for policies.


class BanditPolicyType: Enumeration of bandit policy types.

class InfoFields: Strings which can be used in the policy info fields.

class PerArmPolicyInfo: PerArmPolicyInfo(log_probability, predicted_rewards_mean, predicted_rewards_optimistic, predicted_rewards_sampled, bandit_policy_type, chosen_arm_features)

class PolicyInfo: PolicyInfo(log_probability, predicted_rewards_mean, predicted_rewards_optimistic, predicted_rewards_sampled, bandit_policy_type)


bandit_policy_uniform_mask(...): Set bandit policy type tensor to BanditPolicyType.UNIFORM based on mask.


create_bandit_policy_type_tensor_spec(...): Create tensor spec for bandit policy type.

create_chosen_arm_features_info_spec(...): Creates the chosen arm features info spec from the arm observation spec.

get_model_index(...): Returns the model index for a specific arm.

get_num_actions_from_tensor_spec(...): Validates action_spec and returns number of actions.

has_bandit_policy_type(...): Check if policy info has bandit_policy_type field/tensor.

has_chosen_arm_features(...): Check if policy info has chosen_arm_features field/tensor.

masked_argmax(...): Computes the argmax where the allowed elements are given by a mask.

populate_policy_info(...): Populates policy info given all needed input.

set_bandit_policy_type(...): Sets the InfoFields.BANDIT_POLICY_TYPE on info to bandit_policy_type.