View source on GitHub |
Minimize a loss expressed as a pure function of its parameters.
tfp.substrates.jax.math.minimize_stateless(
loss_fn,
init,
num_steps,
optimizer,
convergence_criterion=None,
batch_convergence_reduce_fn=tf.reduce_all,
trace_fn=_trace_loss,
return_full_length_trace=True,
jit_compile=False,
seed=None,
name='minimize_stateless'
)
Args | |
---|---|
loss_fn
|
Python callable with signature
loss = loss_fn(*init, seed=None) . The loss function may
optionally take a seed keyword argument, used to specify a per-iteration
seed for stochastic loss functions (a stateless Tensor seed will be
passed; see tfp.random.sanitize_seed ).
|
init
|
Tuple of Tensor initial parameter values (or nested structures
of Tensor values) passed to the loss function.
|
num_steps
|
Python int maximum number of steps to run the optimizer.
|
optimizer
|
Pure functional optimizer to use. This may be an
optax.GradientTransformation instance (in JAX), or any similar object
that implements methods
optimizer_state = optimizer.init(parameters) and
updates, optimizer_state = optimizer.update(grads, optimizer_state,
parameters) .
|
convergence_criterion
|
Optional instance of
tfp.optimizer.convergence_criteria.ConvergenceCriterion
representing a criterion for detecting convergence. If None ,
the optimization will run for num_steps steps, otherwise, it will run
for at most num_steps steps, as determined by the provided criterion.
Default value: None .
|
batch_convergence_reduce_fn
|
Python callable of signature
has_converged = batch_convergence_reduce_fn(batch_has_converged)
whose input is a Tensor of boolean values of the same shape as the
loss returned by loss_fn , and output is a scalar
boolean Tensor . This determines the behavior of batched
optimization loops when loss_fn 's return value is non-scalar.
For example, tf.reduce_all will stop the optimization
once all members of the batch have converged, tf.reduce_any once any
member has converged,
lambda x: tf.reduce_mean(tf.cast(x, tf.float32)) > 0.5 once more than
half have converged, etc.
Default value: tf.reduce_all .
|
trace_fn
|
Python callable with signature traced_values = trace_fn(
traceable_quantities) , where the argument is an instance of
tfp.math.MinimizeTraceableQuantities and the returned traced_values
may be a Tensor or nested structure of Tensor s. The traced values are
stacked across steps and returned.
The default trace_fn simply returns the loss. In general, trace
functions may also examine the gradients, values of parameters,
the state propagated by the specified convergence_criterion , if any (if
no convergence criterion is specified, this will be None ),
as well as any other quantities captured in the closure of trace_fn ,
for example, statistics of a variational distribution.
Default value: lambda traceable_quantities: traceable_quantities.loss .
|
return_full_length_trace
|
Python bool indicating whether to return a trace
of the full length num_steps , even if a convergence criterion stopped
the optimization early, by tiling the value(s) traced at the final
optimization step. This enables use in contexts such as XLA that require
shapes to be known statically.
Default value: True .
|
jit_compile
|
If True, compiles the minimization loop using
XLA. XLA performs compiler optimizations, such as fusion, and attempts to
emit more efficient code. This may drastically improve the performance.
See the docs for tf.function . (In JAX, this will apply jax.jit ).
Default value: False .
|
seed
|
PRNG seed for stochastic losses; see tfp.random.sanitize_seed.
Default value: None .
|
name
|
Python str name prefixed to ops created by this function.
Default value: 'minimize_stateless'.
|
Examples
To minimize the scalar function (x - 5)**2
:
import optax # Assume JAX backend.
loss_fn = lambda x: (x - 5.)**2
final_x, losses = tfp.math.minimize_stateless(
loss_fn,
init=0.,
num_steps=100,
optimizer=optax.adam(0.1))
print("optimized value is {} with loss {}".format(final_x, losses[-1]))
We can attempt to automatically detect convergence and stop the optimization
by passing an instance of
tfp.optimize.convergence_criteria.ConvergenceCriterion
. For example, to
stop the optimization once a moving average of the per-step decrease in loss
drops below 0.01
:
_, losses = tfp.math.minimize_stateless(
loss_fn,
init=0.,
num_steps=1000,
optimizer=optax.adam(0.1),
convergence_criterion=(
tfp.optimizers.convergence_criteria.LossNotDecreasing(atol=0.01)))
Here num_steps=1000
defines an upper bound: the optimization will be
stopped after 1000 steps even if no convergence is detected.
In some cases, we may want to track additional context inside the
optimization. We can do this by defining a custom trace_fn
. This accepts
a tfp.math.MinimizeTraceableQuantities
tuple and returns a structure
values to trace; these may include the loss, gradients, parameter values,
or any auxiliary state maintained by the convergence criterion (if any).
trace_fn = lambda traceable_quantities: {
'loss': traceable_quantities.loss,
'x': traceable_quantities.parameters}
_, trace = tfp.math.minimize_stateless(loss_fn,
init=0.,
num_steps=100,
optimizer=optax.adam(0.1),
trace_fn=trace_fn)
print(trace['loss'].shape, # => [100]
trace['x'].shape) # => [100]
When optimizing a batch of losses, some batch members will converge before
others. The optimization will continue until the condition defined by the
batch_convergence_reduce_fn
becomes True
. During these additional steps,
converged elements will continue to be updated and may become unconverged.
The convergence status of batch members can be diagnosed by tracing
has_converged
:
batch_size = 10
trace_fn = lambda traceable_quantities: {
'loss': traceable_quantities.loss,
'has_converged': traceable_quantities.has_converged}
_, trace = tfp.math.minimize_stateless(
loss_fn,
init=tf.zeros([batch_size]),
num_steps=100,
optimizer=optax.adam(0.1),
trace_fn=trace_fn,
convergence_criterion=(
tfp.optimizers.convergence_criteria.LossNotDecreasing(atol=0.01)))
for i in range(batch_size):
print('Batch element {} final state is {}converged.'
' It first converged at step {}.'.format(
i, '' if has_converged[-1, i] else 'not ',
np.argmax(trace.has_converged[:, i])))