tf.contrib.distribute.TowerContext

Class TowerContext

Defined in tensorflow/python/training/distribute.py.

DistributionStrategy API inside a call_for_each_tower() call.

__init__

__init__(
    distribution_strategy,
    tower_id
)

Properties

device

The device this tower is to be executed on, as a string.

distribution_strategy

The current DistributionStrategy object.

is_single_tower

Returns whether there is a single tower or multiple.

num_towers

Returns number of towers, for purposes of averaging across towers.

tower_id

Which tower is being defined, a number from 0 to num_towers - 1.

Methods

__enter__

__enter__()

__exit__

__exit__(
    exception_type,
    exception_value,
    traceback
)

merge_call

merge_call(
    merge_fn,
    *args,
    **kwargs
)

Merge args across towers and run merge_fn in a cross-tower context.

This allows communication and coordination when there are multiple calls to a model function triggered by a call to distribution.call_for_each_tower(model_fn, ...).

See MirroredDistribution.call_for_each_tower() for an explanation.

Otherwise, this is equivalent to:

distribution = get_distribution_strategy()
with cross-tower-context(distribution):
  return merge_fn(distribution, *args, **kwargs)

Args:

  • merge_fn: function that joins arguments from threads that are given as PerDevice. It accepts DistributionStrategy object as the first argument.
  • *args: positional per-thread arguments for merge_fn
  • **kwargs: keyword per-thread arguments for merge_fn.

Returns:

The return value of merge_fn, except for PerDevice values which are unpacked.