Watch keynotes, product sessions, workshops, and more from Google I/O See playlist

Migration examples: TF1 vs TF2

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

TensorFlow team has prepared code examples that demonstrate the equivalence between TF1 and TF2, with a focus on the high-level training API elements. We hope this lets you identify the similarities between your existing TF1 workflow and the available examples, and find a concrete path to move to TF2.

This is work in progress and more examples are being added.

Setup

Let's start with a couple of necessary TensorFlow imports,

import tensorflow as tf
import tensorflow.compat.v1 as tf1

and prepare some simple data for demonstration:

features = [[1., 1.5], [2., 2.5], [3., 3.5]]
labels = [[0.3], [0.5], [0.7]]
eval_features = [[4., 4.5], [5., 5.5], [6., 6.5]]
eval_labels = [[0.8], [0.9], [1.]]

Example 1: Training and evaluation with a trivial dense layer.

TF1: Estimator.train/evaluate

def _input_fn():
  return tf1.data.Dataset.from_tensor_slices((features, labels)).batch(1)

def _eval_input_fn():
  return tf1.data.Dataset.from_tensor_slices(
      (eval_features, eval_labels)).batch(1)

def _model_fn(features, labels, mode):
  logits = tf1.layers.Dense(1)(features)
  loss = tf1.losses.mean_squared_error(labels=labels, predictions=logits)
  optimizer = tf1.train.AdagradOptimizer(0.05)
  train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())
  return tf1.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)

estimator = tf1.estimator.Estimator(model_fn=_model_fn)
estimator.train(_input_fn)
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpb_bhpbsh
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpb_bhpbsh', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/adagrad.py:77: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpb_bhpbsh/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 5.1498923, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 3...
INFO:tensorflow:Saving checkpoints for 3 into /tmp/tmpb_bhpbsh/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 3...
INFO:tensorflow:Loss for final step: 28.89061.
<tensorflow_estimator.python.estimator.estimator.Estimator at 0x7f08bc6d8790>
estimator.evaluate(_eval_input_fn)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-06-09T01:22:45
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpb_bhpbsh/model.ckpt-3
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Inference Time : 0.09314s
INFO:tensorflow:Finished evaluation at 2021-06-09-01:22:45
INFO:tensorflow:Saving dict for global step 3: global_step = 3, loss = 67.420784
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 3: /tmp/tmpb_bhpbsh/model.ckpt-3
{'loss': 67.420784, 'global_step': 3}

TF2: Keras training API

dataset = tf.data.Dataset.from_tensor_slices((features, labels)).batch(1)
eval_dataset = tf.data.Dataset.from_tensor_slices(
      (eval_features, eval_labels)).batch(1)

model = tf.keras.models.Sequential([tf.keras.layers.Dense(1)])
optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)

model.compile(optimizer, "mse")
model.fit(dataset)
3/3 [==============================] - 0s 1ms/step - loss: 8.6691
<tensorflow.python.keras.callbacks.History at 0x7f0800087390>
model.evaluate(eval_dataset, return_dict=True)
3/3 [==============================] - 0s 1ms/step - loss: 26.6582
{'loss': 26.658172607421875}

TF2: Keras training API with Custom Training Step

Keras allows you to provide customized training step function for your model's forward and backward passes, and at the same time takes advantage of the built-in training support such as callbacks, distribution with tf.distribute, etc.

class CustomModel(tf.keras.Sequential):
  """A custom sequential model that has train_step overridden."""

  def train_step(self, data):
    batch_data, labels = data

    with tf.GradientTape() as tape:
      predictions = self(batch_data, training=True)
      # Compute the loss value (loss function is configured in `compile()`)
      loss = self.compiled_loss(labels, predictions)

    # Compute gradients
    gradients = tape.gradient(loss, self.trainable_variables)
    # Update weights
    self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
    # Update metrics (includes the metric that tracks the loss)
    self.compiled_metrics.update_state(labels, predictions)
    # Return a dict mapping metric names to current value
    return {m.name: m.result() for m in self.metrics}

dataset = tf.data.Dataset.from_tensor_slices((features, labels)).batch(1)
eval_dataset = tf.data.Dataset.from_tensor_slices(
      (eval_features, eval_labels)).batch(1)

model = CustomModel([tf.keras.layers.Dense(1)])
optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)

model.compile(optimizer, "mse")
model.fit(dataset)
3/3 [==============================] - 0s 1ms/step - loss: 16.5469
<tensorflow.python.keras.callbacks.History at 0x7f080006bbd0>
model.evaluate(eval_dataset, return_dict=True)
3/3 [==============================] - 0s 1ms/step - loss: 72.2276
{'loss': 72.2276382446289}