Transforms elems by applying fn to each element unstacked on axis 0. (deprecated arguments)

Used in the notebooks

Used in the guide Used in the tutorials

See also tf.scan.

map_fn unstacks elems on axis 0 to obtain a sequence of elements; calls fn to transform each element; and then stacks the transformed values back together.

Mapping functions with single-Tensor inputs and outputs

If elems is a single tensor and fn's signature is tf.Tensor->tf.Tensor, then map_fn(fn, elems) is equivalent to tf.stack([fn(elem) for elem in tf.unstack(elems)]). E.g.:

tf.map_fn(fn=lambda t: tf.range(t, t + 3), elems=tf.constant([3, 5, 2]))
<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
  array([[3, 4, 5],
         [5, 6, 7],
         [2, 3, 4]], dtype=int32)>

map_fn(fn, elems).shape = [elems.shape[0]] + fn(elems[0]).shape.

Mapping functions with multi-arity inputs and outputs

map_fn also supports functions with multi-arity inputs and outputs:

  • If