ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

tf.ensure_shape

Updates the shape of a tensor and checks at runtime that the shape holds.

When executed, this operation asserts that the input tensor x's shape is compatible with the shape argument. See tf.TensorShape.is_compatible_with for details.

x = tf.constant([[1, 2, 3],
                 [4, 5, 6]])
x = tf.ensure_shape(x, [2, 3])

Use None for unknown dimensions:

x = tf.ensure_shape(x, [None, 3])
x = tf.ensure_shape(x, [2, None])

If the tensor's shape is not compatible with the shape argument, an error is raised:

x = tf.ensure_shape(x, [5])
Traceback (most recent call last):

tf.errors.InvalidArgumentError: Shape of tensor dummy_input [3] is not
  compatible with expected shape [5]. [Op:EnsureShape]

During graph construction (typically tracing a tf.function), tf.ensure_shape updates the static-shape of the result tensor by merging the two shapes. See tf.TensorShape.merge_with for details.

This is most useful when you know a shape that can't be determined statically by TensorFlow.

The following trivial tf.function prints the input tensor's static-shape before and after ensure_shape is applied.

@tf.function
def f(tensor):
  print("Static-shape before:", tensor.shape)
  tensor = tf.ensure_shape(tensor, [None, 3])
  print("Static-shape after:", tensor.shape)
  return tensor

This lets you see the effect of tf.ensure_shape when the function is traced:

>>> cf = f.get_concrete_function(tf.TensorSpec([None, None]))
Static-shape before: (None, None)
Static-shape after: (None, 3)
cf(tf.zeros([3, 3])) # Passes
cf(tf.constant([1, 2, 3])) # fails
Traceback (most recent call last):

InvalidArgumentError:  Shape of tensor x [3] is not compatible with expected shape [3,3].

The above example raises tf.errors.InvalidArgumentError, because x's shape, (3,), is not compatible with the shape argument, (None, 3)

Inside a tf.function or v1.Graph context it checks both the buildtime and runtime shapes. This is stricter than tf.Tensor.set_shape which only checks the buildtime shape.