|View source on GitHub|
Environment wrapper that adds action masks to a bandit environment.
This environment wrapper takes a
BanditTFEnvironment as input, and generates
a new environment where the observations are joined with boolean action
masks. These masks describe which actions are allowed in a given time step. If a
disallowed action is chosen in a time step, the environment will raise an
error. The masks are drawn independently from Bernoulli-distributed random
variables with parameter
The observations from the original environment and the mask are joined by the
join_fn function, and the result of the join function will be the
observation in the new environment.
''' env = MyFavoriteBanditEnvironment(...) def join_fn(context, mask): return (context, mask) masked_env = BernoulliActionMaskTFEnvironment(env, join_fn, 0.5) '''
class BernoulliActionMaskTFEnvironment: An environment wrapper that adds action masks to observations.