tf.contrib.framework.init_from_checkpoint

View source on GitHub

Using assignment map initializes current variables with loaded tensors.

Assignment map supports following syntax:

  • 'checkpoint_scope_name/': 'scope_name/' - will load all variables in current scope_name from checkpoint_scope_name with matching variable names.
  • 'checkpoint_scope_name/some_other_variable': 'scope_name/variable_name' - will initialize scope_name/variable_name variable from checkpoint_scope_name/some_other_variable.
  • 'scope_variable_name': variable - will initialize given tf.Variable object with variable from the checkpoint.
  • 'scope_variable_name': list(variable) - will initialize list of partitioned variables with variable from the checkpoint.
  • '/': 'scope_name/' - will load all variables in current scope_name from checkpoint's root (e.g. no scope).

Supports loading into partitioned variables, which are represented as '<variable>/part_<part #>'.

Example:

  # Create variables.
  with tf.compat.v1.variable_scope('test'):
    m = tf.compat.v1.get_variable('my_var')
  with tf.compat.v1.variable_scope('test2'):
    var2 = tf.compat.v1.get_variable('my_var')
  var3 = tf.compat.v1.get_variable(name="my1", shape=[100, 100],
                         partitioner=lambda shape, dtype: [5, 1])
  ...
  # Specify which variables to initialize from checkpoint.
  init_from_checkpoint(checkpoint_dir, {
    'some_var': 'test/my_var',
    'some_scope/': 'test2/'})
  ...
  # Or use `Variable` objects to identify what to initialize.
  init_from_checkpoint(checkpoint_dir, {
    'some_scope/var2': var2,
  })
  # Initialize partitioned variables
  init_from_checkpoint(checkpoint_dir, {
    'some_var_from_ckpt': 'part_var',
  })
  # Or specifying the list of `Variable` objects.
  init_from_checkpoint(checkpoint_dir, {
    'some_var_from_ckpt': var3._get_variable_list(),
  })
  ...
  # Initialize variables as usual.
  session.run(tf.get_all_variables())

checkpoint_dir Directory with checkpoints file or path to checkpoint.
assignment_map Dict, where keys are names of the variables in the checkpoint and values are current variables or names of current variables (in default graph).

tf.errors.OpError If missing checkpoints or tensors in checkpoints.
ValueError If missing variables in current graph.