tfp.substrates.jax.math.minimize

Minimize a loss function using a provided optimizer.

loss_fn Python callable with signature loss = loss_fn(), where loss is a Tensor loss to be minimized. This may optionally take a seed keyword argument, used to specify a per-iteration seed for stochastic loss functions (a stateless Tensor seed will be passed; see tfp.random.sanitize_seed).
num_steps Python int maximum number of steps to run the optimizer.
optimizer Optimizer instance to use. This may be a TF1-style tf.train.Optimizer, TF2-style tf.optimizers.Optimizer, or any Python object that implements optimizer.apply_gradients(grads_and_vars).
convergence_criterion Optional instance of tfp.optimizer.convergence_criteria.ConvergenceCriterion representing a criterion for detecting convergence. If None, the optimization will run for num_steps steps, otherwise, it will run for at most num_steps steps, as determined by the provided criterion. Default value: None.
batch_convergence_reduce_fn Python callable of signature has_converged = batch_convergence_reduce_fn(batch_has_converged) whose input is a Tensor of boolean values of the same shape as the loss returned by loss_fn, and output is a scalar boolean Tensor. This determines the behavior of batched optimization loops when loss_fn's return value is non-scalar. For example, tf.reduce_all will stop the optimization once all members of the batch have converged, tf.reduce_any once any member has converged, lambda x: tf.reduce_mean(tf.cast(x, tf.float32)) > 0.5 once more than half have converged, etc. Default value: tf.reduce_all.
trainable_variables list of tf.Variable instances to optimize with respect to. If None, defaults to the set of all variables accessed during the execution of loss_fn(). Default value: None.
trace_fn Python callable with signature traced_values = trace_fn( traceable_quantities), where the argument is an instance of tfp.math.MinimizeTraceableQuantities and the returned traced_values may be a Tensor or nested structure of Tensors. The traced values are stacked across steps and returned. The default trace_fn simply returns the loss. In general, trace functions may also examine the gradients, values of parameters, the state propagated by the specified convergence_criterion, if any (if no convergence criterion is specified, this will be None), as well as any other quantities captured in the closure of trace_fn, for example, statistics of a variational distribution. Default value: lambda traceable_quantities: traceable_quantities.loss.
return_full_length_trace Python bool indicating whether to return a trace of the full length num_steps, even if a convergence criterion stopped the optimization early, by tiling the value(s) traced at the final optimization step. This enables use in contexts such as XLA that require shapes to be known statically. Default value: True.
jit_compile If True, compiles the minimization loop using XLA. XLA performs compiler optimizations, such as fusion, and attempts to emit more efficient code. This may drastically improve the performance. See the docs for tf.function. (In JAX, this will apply jax.jit). Default value: False.
seed PRNG seed for stochastic losses; see tfp.random.sanitize_seed. Default value: None.
name Python str name prefixed to ops created by this function. Default value: 'minimize'.

trace Tensor or nested structure of Tensors, according to the return type of trace_fn. Each Tensor has an added leading dimension stacking the trajectory of the traced values over the course of the optimization. The size of this dimension is equal to num_steps if a convergence criterion was not specified and/or return_full_length_trace=True, and otherwise it is equal equal to the number of optimization steps taken.

Examples

To minimize the scalar function (x - 5)**2:

x = tf.Variable(0.)
loss_fn = lambda: (x - 5.)**2
losses = tfp.math.minimize(loss_fn,
                           num_steps=100,
                           optimizer=tf.optimizers.Adam(learning_rate=0.1))

# In TF2/eager mode, the optimization runs immediately.
print("optimized value is {} with loss {}".format(x, losses[-1]))

In graph mode (e.g., inside of tf.function wrapping), retrieving any Tensor that depends on the minimization op will trigger the optimization:

with tf.control_dependencies([losses]):
  optimized_x = tf.identity(x)  # Use a dummy op to attach the dependency.

We can attempt to automatically detect convergence and stop the optimization by passing an instance of tfp.optimize.convergence_criteria.ConvergenceCriterion. For example, to stop the optimization once a moving average of the per-step decrease in loss drops below 0.01:

losses = tfp.math.minimize(
  loss_fn, num_steps=1000, optimizer=tf.optimizers.Adam(learning_rate=0.1),
  convergence_criterion=(
    tfp.optimizers.convergence_criteria.LossNotDecreasing(atol=0.01)))

Here num_steps=1000 defines an upper bound: the optimization will be stopped after 1000 steps even if no convergence is detected.

In some cases, we may want to track additional context inside the optimization. We can do this by defining a custom trace_fn. Note that the trace_fn is passed the loss and gradients, as well as any auxiliary state maintained by the convergence criterion (if any), for example, moving averages of the loss or gradients, but it may also report the values of trainable parameters or other derived quantities by capturing them in its closure. For example, we can capture x and track its value over the optimization:

# `x` is the tf.Variable instance defined above.
trace_fn = lambda traceable_quantities: {
  'loss': traceable_quantities.loss, 'x': x}
trace = tfp.math.minimize(loss_fn, num_steps=100,
                          optimizer=tf.optimizers.Adam(0.1),
                          trace_fn=trace_fn)
print(trace['loss'].shape,   # => [100]
      trace['x'].shape)      # => [100]

When optimizing a batch of losses, some batch members will converge before others. The optimization will continue until the condition defined by the batch_convergence_reduce_fn becomes True. During these additional steps, converged elements will continue to be updated and may become unconverged. The convergence status of batch members can be diagnosed by tracing has_converged:

batch_size = 10
x = tf.Variable([0.] * batch_size)
trace_fn = lambda traceable_quantities: {
  'loss': traceable_quantities.loss,
  'has_converged': traceable_quantities.has_converged}
trace = tfp.math.minimize(loss_fn, num_steps=100,
                          optimizer=tf.optimizers.Adam(0.1),,
                          trace_fn=trace_fn,
                          convergence_criterion=(
    tfp.optimizers.convergence_criteria.LossNotDecreasing(atol=0.01)))

for i in range(batch_size):
  print('Batch element {} final state is {}converged.'
        ' It first converged at step {}.'.format(
        i, '' if has_converged[-1, i] else 'not ',
        np.argmax(trace.has_converged[:, i])))