Training utilities

tf.train.global_step(sess, global_step_tensor)

Small helper to get the global step.

# Creates a variable to hold the global_step.
global_step_tensor = tf.Variable(10, trainable=False, name='global_step')
# Creates a session.
sess = tf.Session()
# Initializes the variable.
print('global_step: %s' % tf.train.global_step(sess, global_step_tensor))

global_step: 10
  • sess: A TensorFlow Session object.
  • global_step_tensor: Tensor or the name of the operation that contains the global step.

The global step value.

tf.train.write_graph(graph_def, logdir, name, as_text=True)

Writes a graph proto to a file.

The graph is written as a binary proto unless as_text is True.

v = tf.Variable(0, name='my_variable')
sess = tf.Session()
tf.train.write_graph(sess.graph_def, '/tmp/my-model', 'train.pbtxt')
  • graph_def: A GraphDef protocol buffer.
  • logdir: Directory where to write the graph. This can refer to remote filesystems, such as Google Cloud Storage (GCS).
  • name: Filename for the graph.
  • as_text: If True, writes the graph as an ASCII proto.