Returns a tensor with an length 1 axis inserted at index axis.

Given a tensor input, this operation inserts a dimension of length 1 at the dimension index axis of input's shape. The dimension index follows python indexing rules: It's zero-based, a negative index it is counted backward from the end.

This operation is useful to:

  • Add an outer "batch" dimension to a single element.
  • Align axes for broadcasting.
  • To add an inner vector length axis to a tensor of scalars.

For example:

If you have a sparse tensor with shape [height, width, depth]:

sp = tf.sparse.SparseTensor(indices=[[3,4,1]], values=[7,],

You can add an outer batch axis by passing axis=0:

tf.sparse.expand_dims(sp, axis=0).shape.as_list()
[1, 10, 10, 3]

The new axis location matches Python list.insert(axis, 1):

tf.sparse.expand_dims(sp, axis=1).shape.as_list()
[10, 1, 10, 3]

Following standard python indexing rules, a negat