tf.ensure_shape

TensorFlow 2 version View source on GitHub

Updates the shape of a tensor and checks at runtime that the shape holds.

For example:

x = tf.compat.v1.placeholder(tf.int32)
print(x.shape)
==> TensorShape(None)
y = x * 2
print(y.shape)
==> TensorShape(None)

y = tf.ensure_shape(y, (None, 3, 3))
print(y.shape)
==> TensorShape([Dimension(None), Dimension(3), Dimension(3)])

with tf.compat.v1.Session() as sess:
  # Raises tf.errors.InvalidArgumentError, because the shape (3,) is not
  # compatible with the shape (None, 3, 3)
  sess.run(y, feed_dict={x: [1, 2, 3]})

x A Tensor.
shape A TensorShape representing the shape of this tensor, a TensorShapeProto, a list, a tuple, or None.
name A name for this operation (optional). Defaults to "EnsureShape".

A Tensor. Has the same type and contents as x. At runtime, raises a tf.errors.InvalidArgumentError if shape is incompatible with the shape of x.