TensorFlow 2.0 Beta is available Learn more

tf.contrib.model_pruning.train

View source on GitHub

Wrapper around tf-slim's train function.

tf.contrib.model_pruning.train(
    train_op,
    logdir,
    mask_update_op,
    train_step_fn=train_step,
    train_step_kwargs=_USE_DEFAULT,
    log_every_n_steps=1,
    graph=None,
    master='',
    is_chief=True,
    global_step=None,
    number_of_steps=None,
    init_op=_USE_DEFAULT,
    init_feed_dict=None,
    local_init_op=_USE_DEFAULT,
    init_fn=None,
    ready_op=_USE_DEFAULT,
    summary_op=_USE_DEFAULT,
    save_summaries_secs=600,
    summary_writer=_USE_DEFAULT,
    startup_delay_steps=0,
    saver=None,
    save_interval_secs=600,
    sync_optimizer=None,
    session_config=None,
    trace_every_n_steps=None
)

Runs a training loop using a TensorFlow supervisor. When the sync_optimizer is supplied, gradient updates are applied synchronously. Otherwise, gradient updates are applied asynchronous.

Args:

  • train_op: A Tensor that, when executed, will apply the gradients and return the loss value.
  • logdir: The directory where training logs are written to. If None, model checkpoints and summaries will not be written.
  • mask_update_op: Operation that upon execution updates the weight masks and thresholds.
  • train_step_fn: The function to call in order to execute a single gradient step. The function must have take exactly four arguments: the current session, the train_op Tensor, a global step Tensor and a dictionary.
  • train_step_kwargs: A dictionary which is passed to the train_step_fn. By default, two Boolean, scalar ops called "should_stop" and "should_log" are provided.
  • log_every_n_steps: The frequency, in terms of global steps, that the loss and global step and logged.
  • graph: The graph to pass to the supervisor. If no graph is supplied the default graph is used.
  • master: The address of the tensorflow master.
  • is_chief: Specifies whether or not the training is being run by the primary replica during replica training.
  • global_step: The Tensor representing the global step. If left as None, then slim.variables.get_or_create_global_step() is used.
  • number_of_steps: The max number of gradient steps to take during training, as measured by 'global_step': training will stop if global_step is greater than 'number_of_steps'. If the value is left as None, training proceeds indefinitely.
  • init_op: The initialization operation. If left to its default value, then the session is initialized by calling tf.compat.v1.global_variables_initializer().
  • init_feed_dict: A feed dictionary to use when executing the init_op.
  • local_init_op: The local initialization operation. If left to its default value, then the session is initialized by calling tf.compat.v1.local_variables_initializer() and tf.compat.v1.tables_initializer().
  • init_fn: An optional callable to be executed after init_op is called. The callable must accept one argument, the session being initialized.
  • ready_op: Operation to check if the model is ready to use. If left to its default value, then the session checks for readiness by calling tf.compat.v1.report_uninitialized_variables().
  • summary_op: The summary operation.
  • save_summaries_secs: How often, in seconds, to save summaries.
  • summary_writer: SummaryWriter to use. Can be None to indicate that no summaries should be written. If unset, we create a SummaryWriter.
  • startup_delay_steps: The number of steps to wait for before beginning. Note that this must be 0 if a sync_optimizer is supplied.
  • saver: Saver to save checkpoints. If None, a default one will be created and used.
  • save_interval_secs: How often, in seconds, to save the model to logdir.
  • sync_optimizer: an instance of tf.compat.v1.train.SyncReplicasOptimizer, or a list of them. If the argument is supplied, gradient updates will be synchronous. If left as None, gradient updates will be asynchronous.
  • session_config: An instance of tf.compat.v1.ConfigProto that will be used to configure the Session. If left as None, the default will be used.
  • trace_every_n_steps: produce and save a Timeline in Chrome trace format and add it to the summaries every trace_every_n_steps. If None, no trace information will be produced or saved.

Returns:

the value of the loss function after training.

Raises:

  • ValueError: if train_op is empty or if startup_delay_steps is non-zero when sync_optimizer is supplied, if number_of_steps is negative, or if trace_every_n_steps is not None and no logdir is provided.