tf_agents.utils.common.aggregate_losses

View source on GitHub

Aggregates and scales per example loss and regularization losses.

If global_batch_size is given it would be used for scaling, otherwise it would use the batch_dim of per_example_loss and number of replicas.

per_example_loss Per-example loss [B].
sample_weight Optional weighting for each example [B].
global_batch_size Optional global batch size value. Defaults to (size of first dimension of losses) * (number of replicas).
regularization_loss Regularization loss.

An AggregatedLosses named tuple with scalar losses to optimize.