Module: tf.contrib.tpu

Defined in tensorflow/contrib/tpu/

Ops related to Tensor Processing Units.


profiler module: Classes for TPU trace events.


class CrossShardOptimizer: An optimizer that averages gradients across TPU shards.

class DeviceAssignment: Mapping from logical cores in a computation to the physical TPU topology.

class InfeedQueue: A helper object to build a device infeed queue.

class InputPipelineConfig: Please see the definition of these values in TPUConfig.

class RunConfig: RunConfig with TPU support.

class TPUConfig: TPU related configuration required by TPUEstimator.

class TPUEstimator: Estimator with TPU support.

class TPUEstimatorSpec: Ops and objects returned from a model_fn and passed to TPUEstimator.

class Topology: Describes a set of TPU devices.



batch_parallel(...): Shards computation along the batch dimension for parallel execution.

bfloat16_scope(...): Scope class for bfloat16 variables so that the model uses custom getter.

core(...): Returns the device name for a core in a replicated TPU computation.

cross_replica_sum(...): An Op to sum inputs across replicated TPU instances. Each

device_assignment(...): Computes a device_assignment of a computation across a TPU topology.

infeed_dequeue(...): A placeholder op for a value that will be fed into the computation.

infeed_dequeue_tuple(...): A placeholder op for values fed into the TPU simultaneously as a tuple.

initialize_system(...): Initializes a distributed TPU system for use with TensorFlow.

keras_to_tpu_model(...): Copy model along with weights to the TPU. Returns a TPU model. (experimental)

outfeed_enqueue(...): An op which emits a single Tensor value from an XLA computation.

outfeed_enqueue_tuple(...): An op which emits multiple Tensor values from an XLA computation.

repeat(...): Builds a training loop that executes a fixed number of iterations.

replicate(...): Builds a graph operator that runs a replicated TPU computation.

rewrite(...): Rewrites computation for execution on a TPU system.

shard(...): Shards computation for parallel execution.

shutdown_system(...): Shuts down a running a distributed TPU system.

while_loop(...): Builds a training loop for TPUs.