tf_agents.utils.common.aggregate_losses

Aggregates and scales per example loss and regularization losses.

If global_batch_size is given it would be used for scaling, otherwise it would use the batch_dim of per_example_loss and number of replicas.

per_example_loss Per-example loss [B] or [B, T, ...].
sample_weight Optional weighting for each example, Tensor shaped [B] or [B, T, ...], or a scalar float.
global_batch_size Optional global batch size value. Defaults to (size of first dimension of losses) * (number of replicas).
regularization_loss Regularization loss.

An AggregatedLosses named tuple with scalar losses to optimize.