tf.compat.v1.tpu.replicate

Builds a graph operator that runs a replicated TPU computation.

Example for the basic usage that inputs has static shape:


def computation(x):
  x = x + 1
  return tf.math.reduce_mean(x)

x = tf.convert_to_tensor([1., 2., 3.])
y = tf.convert_to_tensor([4., 5., 6.])
tf.compat.v1.tpu.replicate(computation, inputs=[[x], [y]])

If the inputs has dynamic shapes and you would like to automatically bucketize the inputs to avoid XLA recompilation. See the advanced example below:


def computation(x):
  x = x + 1
  return tf.math.reduce_mean(x)

# Assume input tensors in two replicas `x` and `y` both have dynamic shape
# ([None, 2]).
tf.compat.v1.tpu.replicate(
  computation,
  inputs=[x, y],
  maximum_shapes=[tf.TensorShape([None, None])],
  padding_spec=tf.compat.v1.tpu.PaddingSpec.POWER_OF_TWO)

computation A Python function that builds the computation to replicate.
inputs A list of lists of input tensors or None (equivalent to [[]]), indexed by [replica_num][input_num]. All replicas must have the same number of inputs. Each input can be a nested structure containing v