|View source on GitHub|
Wrapper around tf-slim's train function.
tf.contrib.model_pruning.train( train_op, logdir, mask_update_op, train_step_fn=train_step, train_step_kwargs=_USE_DEFAULT, log_every_n_steps=1, graph=None, master='', is_chief=True, global_step=None, number_of_steps=None, init_op=_USE_DEFAULT, init_feed_dict=None, local_init_op=_USE_DEFAULT, init_fn=None, ready_op=_USE_DEFAULT, summary_op=_USE_DEFAULT, save_summaries_secs=600, summary_writer=_USE_DEFAULT, startup_delay_steps=0, saver=None, save_interval_secs=600, sync_optimizer=None, session_config=None, trace_every_n_steps=None )
Runs a training loop using a TensorFlow supervisor. When the sync_optimizer is supplied, gradient updates are applied synchronously. Otherwise, gradient updates are applied asynchronous.
Tensorthat, when executed, will apply the gradients and return the loss value.
logdir: The directory where training logs are written to. If None, model checkpoints and summaries will not be written.
mask_update_op: Operation that upon execution updates the weight masks and thresholds.
train_step_fn: The function to call in order to execute a single gradient step. The function must have take exactly four arguments: the current session, the
Tensor, a global step
Tensorand a dictionary.
train_step_kwargs: A dictionary which is passed to the
train_step_fn. By default, two
Boolean, scalar ops called "should_stop" and "should_log" are provided.
log_every_n_steps: The frequency, in terms of global steps, that the loss and global step and logged.
graph: The graph to pass to the supervisor. If no graph is supplied the default graph is used.
master: The address of the tensorflow master.
is_chief: Specifies whether or not the training is being run by the primary replica during replica training.
Tensorrepresenting the global step. If left as
None, then slim.variables.get_or_create_global_step() is used.
number_of_steps: The max number of gradient steps to take during training, as measured by 'global_step': training will stop if global_step is greater than 'number_of_steps'. If the value is left as None, training proceeds indefinitely.
init_op: The initialization operation. If left to its default value, then the session is initialized by calling
init_feed_dict: A feed dictionary to use when executing the
local_init_op: The local initialization operation. If left to its default value, then the session is initialized by calling
init_fn: An optional callable to be executed after
init_opis called. The callable must accept one argument, the session being initialized.
ready_op: Operation to check if the model is ready to use. If left to its default value, then the session checks for readiness by calling
summary_op: The summary operation.
save_summaries_secs: How often, in seconds, to save summaries.
SummaryWriterto use. Can be
Noneto indicate that no summaries should be written. If unset, we create a SummaryWriter.
startup_delay_steps: The number of steps to wait for before beginning. Note that this must be 0 if a sync_optimizer is supplied.
saver: Saver to save checkpoints. If None, a default one will be created and used.
save_interval_secs: How often, in seconds, to save the model to
sync_optimizer: an instance of tf.compat.v1.train.SyncReplicasOptimizer, or a list of them. If the argument is supplied, gradient updates will be synchronous. If left as
None, gradient updates will be asynchronous.
session_config: An instance of
tf.compat.v1.ConfigProtothat will be used to configure the
Session. If left as
None, the default will be used.
trace_every_n_steps: produce and save a
Timelinein Chrome trace format and add it to the summaries every
trace_every_n_steps. If None, no trace information will be produced or saved.
the value of the loss function after training.
train_opis empty or if
startup_delay_stepsis non-zero when
sync_optimizeris supplied, if
number_of_stepsis negative, or if