|View source on GitHub|
Parallel map on the list of tensors unpacked from
elems on dimension 0.
tf.vectorized_map( fn, elems )
This method works similar to tf.map_fn but is optimized to run much faster,
possibly with a much larger memory footprint. The speedups are obtained by
vectorization (see https://arxiv.org/pdf/1903.04243.pdf). The idea behind
vectorization is to semantically launch all the invocations of
parallel and fuse corresponding operations across all these invocations. This
fusion is done statically at graph generation time and the generated code is
often similar in performance to a manually fused version.
tf.vectorized_map fully parallelizes the batch, this method will
generally be significantly faster than using
tf.map_fn, especially in eager
mode. However this is an experimental feature and currently has a lot of
- There should be no data dependency between the different semantic
fn, i.e. it should be safe to map the elements of the
inputs in any order.
- Stateful kernels may mostly not be supported since these often imply a
data dependency. We do support a limited set of such stateful kernels
though (like RandomFoo, Variable operations like reads, etc).
fn has limited support for control flow operations.
particular is not supported.
fn should return nested structure of Tensors or Operations. However
if an Operation is returned, it should have zero outputs.
- The shape and dtype of any intermediate or output tensors in the
fn should not depend on the input to
fn: The callable to be performed. It accepts one argument, which will have the same (possibly nested) structure as
elems, and returns a possibly nested structure of Tensors and Operations, which may be different than the structure of
elems: A tensor or (possibly nested) sequence of tensors, each of which will be unpacked along their first dimension. The nested sequence of the resulting slices will be mapped over by
A tensor or (possibly nested) sequence of tensors. Each tensor packs the results of applying fn to tensors unpacked from elems along the first dimension, from first to last.
def outer_product(a): return tf.tensordot(a, a, 0) batch_size = 100 a = tf.ones((batch_size, 32, 32)) c = tf.vectorized_map(outer_product, a) assert c.shape == (batch_size, 32, 32, 32, 32)
# Computing per-example gradients batch_size = 10 num_features = 32 layer = tf.keras.layers.Dense(1) def model_fn(arg): with tf.GradientTape() as g: inp, label = arg inp = tf.expand_dims(inp, 0) label = tf.expand_dims(label, 0) prediction = layer(inp) loss = tf.nn.l2_loss(label - prediction) return g.gradient(loss, (layer.kernel, layer.bias)) inputs = tf.random_uniform([batch_size, num_features]) labels = tf.random_uniform([batch_size, 1]) per_example_gradients = tf.vectorized_map(model_fn, (inputs, labels)) assert per_example_gradients.shape == (batch_size, num_features, 1) assert per_example_gradients.shape == (batch_size, 1)