tf_agents.networks.mask_splitter_network.MaskSplitterNetwork

Separates and passes the observation and mask to the wrapped network.

Inherits From: DistributionNetwork, Network

In some environments, different sets of actions are available given different observations. To represent this, env.observation actually contains both the raw observation, and an action mask for this particular observation. Our network needs to know how to split env.observation into these two parts. The raw observation will be fed into the wrapped network, and the action mask will be optionally passed into the wrapped network to ensure that the network only outputs possible actions.

The network uses the splitter_fn to separate the observation from the action mask (i.e. observation, mask = splitter_fn(inputs)). Depending on the value of pass_mask_to_wrapped_network the mask is passed into the wrapped network or dropped, i.e.

obs, mask = splitter_fn(inputs)

wrapped_network(obs, ...)  # If pass_mask_to_wrapped_network is `False`

wrapped_network(obs, ..., mask=mask)  # Otherwise, i.e. it is `True`.

In each case the observation part is fed into the wrapped_network. It is expected that the input spec of wrapped network is compatible with the observation part of the input of the MaskSplitterNetwork.

splitter_fn A function used to process observations with action constraints (i.e. mask). Note: The input spec of the wrapped network must be compatible with the network-specific half of the output of the splitter_fn on the input spec.
wrapped_network A network.Network used to process the network-specific part of the observation, and the mask passed as the mask parameter of the method call of the wrapped network.
passthrough_mask If it is set to True, the mask is fed into wrapped network. If it is set to False, the mask portion of the input is dropped and not fed into the wrapped network.
input_tensor_spec A tensor_spec.TensorSpec or a tuple of specs representing the input observations including the specs of the action constraints.
name A string representing name of the network.

ValueError If input_tensor_spec is not an instance of network.InputSpec.

input_tensor_spec Returns the spec of the input to the network of type InputSpec.
layers Get the list of all (nested) sub-layers used in this Network.
output_spec

state_spec

Methods

copy

View source

Create a shallow copy of this network.

Args
**kwargs Args to override when recreating this network. Commonly overridden args include 'name'.

Returns
A shallow copy of this network.

create_variables

View source

Force creation of the network's variables.

Return output specs.

Args
input_tensor_spec (Optional). Override or provide an input tensor spec when creating variables.
**kwargs Other arguments to network.call(), e.g. training=True.

Returns
Output specs - a nested spec calculated from the outputs (excluding any batch dimensions). If any of the output elements is a tfp Distribution, the associated spec entry returned is a DistributionSpec.

Raises
ValueError If no input_tensor_spec is provided, and the network did not provide one during construction.

get_initial_state

View source

Returns an initial state usable by the network.

Args
batch_size Tensor or constant: size of the batch dimension. Can be None in which case not dimensions gets added.

Returns
A nested object of type self.state_spec containing properly initialized Tensors.

get_layer

View source

Retrieves a layer based on either its name (unique) or index.

If name and index are both provided, index will take precedence. Indices are based on order of horizontal graph traversal (bottom-up).

Args
name String, name of layer.
index Integer, index of layer.

Returns
A layer instance.

Raises
ValueError In case of invalid layer name or index.

summary

View source

Prints a string summary of the network.

Args
line_length Total length of printed lines (e.g. set this to adapt the display to different terminal window sizes).
positions Relative or absolute positions of log elements in each line. If not provided, defaults to [.33, .55, .67, 1.].
print_fn Print function to use. Defaults to print. It will be called on each line of the summary. You can set it to a custom function in order to capture the string summary.

Raises
ValueError if summary() is called before the model is built.