View source on GitHub |
Minimize a loss function using a provided optimizer.
tfp.substrates.numpy.math.minimize(
loss_fn,
num_steps,
optimizer,
convergence_criterion=None,
batch_convergence_reduce_fn=tf.reduce_all,
trainable_variables=None,
trace_fn=_trace_loss,
return_full_length_trace=True,
jit_compile=False,
seed=None,
name='minimize'
)
Args | |
---|---|
loss_fn
|
Python callable with signature loss = loss_fn() , where loss
is a Tensor loss to be minimized. This may optionally take a seed
keyword argument, used to specify a per-iteration seed for stochastic
loss functions (a stateless Tensor seed will be passed; see
tfp.random.sanitize_seed ).
|
num_steps
|
Python int maximum number of steps to run the optimizer.
|
optimizer
|
Optimizer instance to use. This may be a TF1-style
tf.train.Optimizer , TF2-style tf.optimizers.Optimizer , or any Python
object that implements optimizer.apply_gradients(grads_and_vars) .
|
convergence_criterion
|
Optional instance of
tfp.optimizer.convergence_criteria.ConvergenceCriterion
representing a criterion for detecting convergence. If None ,
the optimization will run for num_steps steps, otherwise, it will run
for at most num_steps steps, as determined by the provided criterion.
Default value: None .
|
batch_convergence_reduce_fn
|
Python callable of signature
has_converged = batch_convergence_reduce_fn(batch_has_converged)
whose input is a Tensor of boolean values of the same shape as the
loss returned by loss_fn , and output is a scalar
boolean Tensor . This determines the behavior of batched
optimization loops when loss_fn 's return value is non-scalar.
For example, tf.reduce_all will stop the optimization
once all members of the batch have converged, tf.reduce_any once any
member has converged,
lambda x: tf.reduce_mean(tf.cast(x, tf.float32)) > 0.5 once more than
half have converged, etc.
Default value: tf.reduce_all .
|
trainable_variables
|
list of tf.Variable instances to optimize with
respect to. If None , defaults to the set of all variables accessed
during the execution of loss_fn() .
Default value: None .
|
trace_fn
|
Python callable with signature traced_values = trace_fn(
traceable_quantities) , where the argument is an instance of
tfp.math.MinimizeTraceableQuantities and the returned traced_values
may be a Tensor or nested structure of Tensor s. The traced values are
stacked across steps and returned.
The default trace_fn simply returns the loss. In general, trace
functions may also examine the gradients, values of parameters,
the state propagated by the specified convergence_criterion , if any (if
no convergence criterion is specified, this will be None ),
as well as any other quantities captured in the closure of trace_fn ,
for example, statistics of a variational distribution.
Default value: lambda traceable_quantities: traceable_quantities.loss .
|
return_full_length_trace
|
Python bool indicating whether to return a trace
of the full length num_steps , even if a convergence criterion stopped
the optimization early, by tiling the value(s) traced at the final
optimization step. This enables use in contexts such as XLA that require
shapes to be known statically.
Default value: True .
|
jit_compile
|
If True, compiles the minimization loop using
XLA. XLA performs compiler optimizations, such as fusion, and attempts to
emit more efficient code. This may drastically improve the performance.
See the docs for tf.function . (In JAX, this will apply jax.jit ).
Default value: False .
|
seed
|
PRNG seed for stochastic losses; see tfp.random.sanitize_seed.
Default value: None .
|
name
|
Python str name prefixed to ops created by this function.
Default value: 'minimize'.
|
Examples
To minimize the scalar function (x - 5)**2
:
x = tf.Variable(0.)
loss_fn = lambda: (x - 5.)**2
losses = tfp.math.minimize(loss_fn,
num_steps=100,
optimizer=tf.optimizers.Adam(learning_rate=0.1))
# In TF2/eager mode, the optimization runs immediately.
print("optimized value is {} with loss {}".format(x, losses[-1]))
In graph mode (e.g., inside of tf.function
wrapping), retrieving any Tensor
that depends on the minimization op will trigger the optimization:
with tf.control_dependencies([losses]):
optimized_x = tf.identity(x) # Use a dummy op to attach the dependency.
We can attempt to automatically detect convergence and stop the optimization
by passing an instance of
tfp.optimize.convergence_criteria.ConvergenceCriterion
. For example, to
stop the optimization once a moving average of the per-step decrease in loss
drops below 0.01
:
losses = tfp.math.minimize(
loss_fn, num_steps=1000, optimizer=tf.optimizers.Adam(learning_rate=0.1),
convergence_criterion=(
tfp.optimizers.convergence_criteria.LossNotDecreasing(atol=0.01)))
Here num_steps=1000
defines an upper bound: the optimization will be
stopped after 1000 steps even if no convergence is detected.
In some cases, we may want to track additional context inside the
optimization. We can do this by defining a custom trace_fn
. Note that
the trace_fn
is passed the loss and gradients, as well as any auxiliary
state maintained by the convergence criterion (if any), for example, moving
averages of the loss or gradients, but it may also report the
values of trainable parameters or other derived quantities by capturing them
in its closure. For example, we can capture x
and track its value over the
optimization:
# `x` is the tf.Variable instance defined above.
trace_fn = lambda traceable_quantities: {
'loss': traceable_quantities.loss, 'x': x}
trace = tfp.math.minimize(loss_fn, num_steps=100,
optimizer=tf.optimizers.Adam(0.1),
trace_fn=trace_fn)
print(trace['loss'].shape, # => [100]
trace['x'].shape) # => [100]
When optimizing a batch of losses, some batch members will converge before
others. The optimization will continue until the condition defined by the
batch_convergence_reduce_fn
becomes True
. During these additional steps,
converged elements will continue to be updated and may become unconverged.
The convergence status of batch members can be diagnosed by tracing
has_converged
:
batch_size = 10
x = tf.Variable([0.] * batch_size)
trace_fn = lambda traceable_quantities: {
'loss': traceable_quantities.loss,
'has_converged': traceable_quantities.has_converged}
trace = tfp.math.minimize(loss_fn, num_steps=100,
optimizer=tf.optimizers.Adam(0.1),,
trace_fn=trace_fn,
convergence_criterion=(
tfp.optimizers.convergence_criteria.LossNotDecreasing(atol=0.01)))
for i in range(batch_size):
print('Batch element {} final state is {}converged.'
' It first converged at step {}.'.format(
i, '' if has_converged[-1, i] else 'not ',
np.argmax(trace.has_converged[:, i])))