tf.map_fn

Transforms elems by applying fn to each element unstacked on axis 0. (deprecated arguments)

Used in the notebooks

Used in the guide Used in the tutorials

See also tf.scan.

map_fn unstacks elems on axis 0 to obtain a sequence of elements; calls fn to transform each element; and then stacks the transformed values back together.

Mapping functions with single-Tensor inputs and outputs

If elems is a single tensor and fn's signature is tf.Tensor->tf.Tensor, then map_fn(fn, elems) is equivalent to tf.stack([fn(elem) for elem in tf.unstack(elems)]). E.g.:

tf.map_fn(fn=lambda t: tf.range(t, t + 3), elems=tf.constant([3, 5, 2]))
<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
  array([[3, 4, 5],
         [5, 6, 7],
         [2, 3, 4]], dtype=int32)>

map_fn(fn, elems).shape = [elems.shape[0]] + fn(elems[0]).shape.

Mapping functions with multi-arity inputs and outputs

map_fn also supports functions with multi-arity inputs and outputs:

  • If