Squash the outer dimensions of input tensors; unsquash outputs.
tf_agents.keras_layers.SquashedOuterWrapper(
wrapped: tf.keras.layers.Layer,
inner_rank: int,
**kwargs
)
This layer wraps a Keras layer wrapped
that cannot handle more than one
batch dimension. It squashes inputs' outer dimensions to a single larger
batch then unsquashes the outputs of wrapped
.
The outer dimensions are the leftmost rank(inputs) - inner_rank
dimensions.
Examples:
batch_norm = tf.keras.layers.BatchNormalization(axis=-1)
layer = SquashedOuterWrapper(wrapped=batch_norm, inner_rank=3)
inputs_0 = tf.random.normal((B, H, W, C))
# batch_norm sees tensor of shape [B, H, W, C]
# outputs_1 shape is [B, H, W, C]
outputs_0 = layer(inputs_0)
inputs_1 = tf.random.normal((B, T, H, W, C))
# batch_norm sees a tensor of shape [B * T, H, W, C]
# outputs_1 shape is [B, T, H, W, C]
outputs_1 = layer(inputs_1)
inputs_2 = tf.random.normal((B1, B2, T, H, W, C))
# batch_norm sees a tensor of shape [B1 * B2 * T, H, W, C]
# outputs_2 shape is [B1, B2, T, H, W, C]
outputs_2 = layer(inputs_2)
Args |
wrapped
|
The keras layer to wrap.
|
inner_rank
|
The inner rank of inputs that will be passed to the layer.
This value allows us to infer the outer batch dimension regardless of
the input shape to build or call .
|
**kwargs
|
Additional arguments for keras layer construction.
|
Raises |
ValueError
|
If wrapped has method get_initial_state , because
we do not know how to handle the case of multiple inputs and
the presence of this method typically means an RNN or RNN-like
layer which accepts separate state tensors.
|
Attributes |
inner_rank
|
|
wrapped
|
|