Module: tf_agents.bandits.environments.bernoulli_action_mask_tf_environment

View source on GitHub

Environment wrapper that adds action masks to a bandit environment.

This environment wrapper takes a BanditTFEnvironment as input, and generates a new environment where the observations are joined with boolean action masks. These masks describe which actions are allowed in a given time step. If a disallowed action is chosen in a time step, the environment will raise an error. The masks are drawn independently from Bernoulli-distributed random variables with parameter action_probability.

The observations from the original environment and the mask are joined by the given join_fn function, and the result of the join function will be the observation in the new environment.


''' env = MyFavoriteBanditEnvironment(...) def join_fn(context, mask): return (context, mask) masked_env = BernoulliActionMaskTFEnvironment(env, join_fn, 0.5) '''


class BernoulliActionMaskTFEnvironment: An environment wrapper that adds action masks to observations.