Missed TensorFlow Dev Summit? Check out the video playlist. Watch recordings

tf_agents.utils.common.aggregate_losses

View source on GitHub

Aggregates and scales per example loss and regularization losses.

tf_agents.utils.common.aggregate_losses(
    per_example_loss=None, sample_weight=None, global_batch_size=None,
    regularization_loss=None
)

If global_batch_size is given it would be used for scaling, otherwise it would use the batch_dim of per_example_loss and number of replicas.

Args:

  • per_example_loss: Per-example loss [B].
  • sample_weight: Optional weighting for each example [B].
  • global_batch_size: Optional global batch size value. Defaults to (size of first dimension of losses) * (number of replicas).
  • regularization_loss: Regularization loss.

Returns:

An AggregatedLosses named tuple with scalar losses to optimize.