An optimizer that averages gradients across TPU shards.

Inherits From: Optimizer

opt An existing Optimizer to encapsulate.
reduction The reduction to apply to the shard losses.
name Optional name prefix for the operations created when applying gradients. Defaults to "CrossShardOptimizer".
group_assignment Optional 2d int32 lists with shape [num_groups, num_replicas_per_group] which describles how to apply optimizer to subgroups.

ValueError If reduction is not a valid cross-shard reduction.



View source

Apply gradients to variables.

Calls tpu_ops.cross_replica_sum() to sum gradient contributions across replicas, and then applies the real optimizer.

grads_and_vars List of (gradient, variable) pairs as returned by compute_gradients().
global_step Optional Variable to increment by one after the variables have been updated.
name Optional name for the returned operation. Default to the name passed to the Optimizer constructor.

An Operation that applies the gradients. If global_step was not None, that operation also increments global_step.

ValueError If the grads_and_vars is malformed.


View source

Compute gradients of "loss" for the variables in "var_list".

This si