Creates a recurrent actor network.
Inherits From: Network
tf_agents.agents.ddpg.actor_rnn_network.ActorRnnNetwork(
input_tensor_spec,
output_tensor_spec,
conv_layer_params=None,
input_fc_layer_params=(200, 100),
lstm_size=(40,),
output_fc_layer_params=(200, 100),
activation_fn=tf.keras.activations.relu,
name='ActorRnnNetwork'
)
Args
input_tensor_spec
A nest of tensor_spec.TensorSpec
representing the
input observations.
output_tensor_spec
A nest of tensor_spec.BoundedTensorSpec
representing
the actions.
conv_layer_params
Optional list of convolution layers parameters, where
each item is a length-three tuple indicating (filters, kernel_size,
stride).
input_fc_layer_params
Optional list of fully_connected parameters, where
each item is the number of units in the layer. This is applied before
the LSTM cell.
lstm_size
An iterable of ints specifying the LSTM cell sizes to use.
output_fc_layer_params
Optional list of fully_connected parameters, where
each item is the number of units in the layer. This is applied after the
LSTM cell.
activation_fn
Activation function, e.g. tf.nn.relu, slim.leaky_relu, ...
name
A string representing name of the network.
Raises
ValueError
If input_tensor_spec
contains more than one observation.
Attributes
input_tensor_spec
Returns the spec of the input to the network of type InputSpec.
layers
Get the list of all (nested) sub-layers used in this Network.
state_spec
Methods
copy
View source
copy(
**kwargs
)
Create a shallow copy of this network.
Note: Network layer weights are never copied. This method recreates
the Network
instance with the same arguments it was initialized with
(excepting any new kwargs).
Args
**kwargs
Args to override when recreating this network. Commonly
overridden args include 'name'.
Returns
A shallow copy of this network.
create_variables
View source
create_variables(
input_tensor_spec=None, **kwargs
)
Force creation of the network's variables.
Return output specs.
Args
input_tensor_spec
(Optional). Override or provide an input tensor spec
when creating variables.
**kwargs
Other arguments to network.call()
, e.g. training=True
.
Returns
Output specs - a nested spec calculated from the outputs (excluding any
batch dimensions). If any of the output elements is a tfp Distribution
,
the associated spec entry returned is a DistributionSpec
.
Raises
ValueError
If no input_tensor_spec
is provided, and the network did
not provide one during construction.
get_initial_state
View source
get_initial_state(
batch_size=None
)
Returns an initial state usable by the network.
Args
batch_size
Tensor or constant: size of the batch dimension. Can be None
in which case not dimensions gets added.
Returns
A nested object of type self.state_spec
containing properly
initialized Tensors.
get_layer
View source
get_layer(
name=None, index=None
)
Retrieves a layer based on either its name (unique) or index.
If name
and index
are both provided, index
will take precedence.
Indices are based on order of horizontal graph traversal (bottom-up).
Args
name
String, name of layer.
index
Integer, index of layer.
Returns
A layer instance.
Raises
ValueError
In case of invalid layer name or index.
summary
View source
summary(
line_length=None, positions=None, print_fn=None
)
Prints a string summary of the network.
Args
line_length
Total length of printed lines (e.g. set this to adapt the
display to different terminal window sizes).
positions
Relative or absolute positions of log elements in each line.
If not provided, defaults to [.33, .55, .67, 1.]
.
print_fn
Print function to use. Defaults to print
. It will be called
on each line of the summary. You can set it to a custom function in
order to capture the string summary.
Raises
ValueError
if summary()
is called before the model is built.