tfp.substrates.jax.math.MinimizeTraceableQuantities

Namedtuple of quantities that may be traced from tfp.math.minimize.

These are (in order):

  • step: int Tensor index (starting from zero) of the current optimization step.
  • loss: float Tensor value returned from the user-provided loss_fn.
  • gradients: list of Tensor gradients of loss with respect to the parameters.
  • parameters: list of Tensor values of parameters being optimized. This corresponds to trainable_variables passed to minimize, or init passed to minimize_stateless.
  • has_converged: boolean Tensor of the same shape as loss_fn, with True values corresponding to loss entries that have converged according to the user-provided convergence criterion. If no convergence criterion was specified, this is None.
  • convergence_criterion_state: structure of Tensors containing any auxiliary state (e.g., moving averages of loss or other quantities) maintained by the user-provided convergence criterion.
  • optimizer_state: structure of Tensors containing optional state from a user-provided pure optimizer.

step A namedtuple alias for field number 0
loss A namedtuple alias for field number 1
gradients A namedtuple alias for field number 2
parameters A namedtuple alias for field number 3
has_converged A namedtuple alias for field number 4
convergence_criterion_state A namedtuple alias for field number 5
optimizer_state A namedtuple alias for field number 6
seed A namedtuple alias for field number 7