Watch keynotes, product sessions, workshops, and more from Google I/O See playlist


Squash the outer dimensions of input tensors; unsquash outputs.

This layer wraps a Keras layer wrapped that cannot handle more than one batch dimension. It squashes inputs' outer dimensions to a single larger batch then unsquashes the outputs of wrapped.

The outer dimensions are the leftmost rank(inputs) - inner_rank dimensions.


batch_norm = tf.keras.layers.BatchNormalization(axis=-1)
layer = SquashedOuterWrapper(wrapped=batch_norm, inner_rank=3)

inputs_0 = tf.random.normal((B, H, W, C))
# batch_norm sees tensor of shape [B, H, W, C]
# outputs_1 shape is [B, H, W, C]
outputs_0 = layer(inputs_0)

inputs_1 = tf.random.normal((B, T, H, W, C))
# batch_norm sees a tensor of shape [B * T, H, W, C]
# outputs_1 shape is [B, T, H, W, C]
outputs_1 = layer(inputs_1)

inputs_2 = tf.random.normal((B1, B2, T, H, W, C))
# batch_norm sees a tensor of shape [B1 * B2 * T, H, W, C]
# outputs_2 shape is [B1, B2, T, H, W, C]
outputs_2 = layer(inputs_2)

wrapped The keras layer to wrap.
inner_rank The inner rank of inputs that will be passed to the layer. This value allows us to infer the outer batch dimension regardless of the input shape to build or call.
**kwargs Additional arguments for keras layer construction.

ValueError If wrapped has method get_initial_state, because we do not know how to handle the case of multiple inputs and the presence of this method typically means an RNN or RNN-like layer which accepts separate state tensors.