In some environments, different sets of actions are available given different
observations. To represent this, env.observation actually contains both the
raw observation, and an action mask for this particular observation. Our
network needs to know how to split env.observation into these two parts. The
raw observation will be fed into the wrapped network, and the action mask will
be optionally passed into the wrapped network to ensure that the network only
outputs possible actions.
The network uses the splitter_fn to separate the observation from the action
mask (i.e. observation, mask = splitter_fn(inputs)). Depending on the value
of pass_mask_to_wrapped_network the mask is passed into the wrapped network
or dropped, i.e.
obs, mask = splitter_fn(inputs)
wrapped_network(obs, ...) # If pass_mask_to_wrapped_network is `False`
wrapped_network(obs, ..., mask=mask) # Otherwise, i.e. it is `True`.
In each case the observation part is fed into the wrapped_network. It is
expected that the input spec of wrapped network is compatible with the
observation part of the input of the MaskSplitterNetwork.
A function used to process observations with action
constraints (i.e. mask).
Note: The input spec of the wrapped network must be compatible with
the network-specific half of the output of the splitter_fn on the
A network.Network used to process the network-specific
part of the observation, and the mask passed as the mask parameter of
the method call of the wrapped network.
If it is set to True, the mask is fed into wrapped
network. If it is set to False, the mask portion of the input is
dropped and not fed into the wrapped network.
A tensor_spec.TensorSpec or a tuple of specs
representing the input observations including the specs of the action
A string representing name of the network.
If input_tensor_spec is not an instance of network.InputSpec.
Returns the spec of the input to the network of type InputSpec.
Get the list of all (nested) sub-layers used in this Network.
(Optional). Override or provide an input tensor spec
when creating variables.
Other arguments to network.call(), e.g. training=True.
Output specs - a nested spec calculated from the outputs (excluding any
batch dimensions). If any of the output elements is a tfp Distribution,
the associated spec entry returned is a DistributionSpec.
If no input_tensor_spec is provided, and the network did
not provide one during construction.