Migrate metrics and optimizers

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

In TF1, tf.metrics is the API namespace for all the metric functions. Each of the metrics is a function that takes label and prediction as input parameters and returns the corresponding metrics tensor as result. In TF2, tf.keras.metrics contains all the metric functions and objects. The Metric object can be used with tf.keras.Model and tf.keras.layers.layer to calculate metric values.

Setup

Let's start with a couple of necessary TensorFlow imports,

import tensorflow as tf
import tensorflow.compat.v1 as tf1

and prepare some simple data for demonstration:

features = [[1., 1.5], [2., 2.5], [3., 3.5]]
labels = [0, 0, 1]
eval_features = [[4., 4.5], [5., 5.5], [6., 6.5]]
eval_labels = [0, 1, 1]

TF1: tf.compat.v1.metrics with Estimator

In TF1, the metrics can be added to EstimatorSpec as the eval_metric_ops, and the op is generated via all the metrics functions defined in tf.metrics. You can follow the example to see how to use tf.metrics.accuracy.

def _input_fn():
  return tf1.data.Dataset.from_tensor_slices((features, labels)).batch(1)

def _eval_input_fn():
  return tf1.data.Dataset.from_tensor_slices(
      (eval_features, eval_labels)).batch(1)

def _model_fn(features, labels, mode):
  logits = tf1.layers.Dense(2)(features)
  predictions = tf.math.argmax(input=logits, axis=1)
  loss = tf1.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits)
  optimizer = tf1.train.AdagradOptimizer(0.05)
  train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())
  accuracy = tf1.metrics.accuracy(labels=labels, predictions=predictions)
  return tf1.estimator.EstimatorSpec(mode, 
                                     predictions=predictions,
                                     loss=loss, 
                                     train_op=train_op,
                                     eval_metric_ops={'accuracy': accuracy})

estimator = tf1.estimator.Estimator(model_fn=_model_fn)
estimator.train(_input_fn)
estimator.evaluate(_eval_input_fn)

Also, metrics could be added to estimator directly via tf.estimator.add_metrics().

def mean_squared_error(labels, predictions):
  labels = tf.cast(labels, predictions.dtype)
  return {"mean_squared_error": 
          tf1.metrics.mean_squared_error(labels=labels, predictions=predictions)}

estimator = tf1.estimator.add_metrics(estimator, mean_squared_error)
estimator.evaluate(_eval_input_fn)

TF2: Keras Metrics API with tf.keras.Model

In TF2, tf.keras.metrics contains all the metrics classes and functions. They are designed in a OOP style and integrate closely with other tf.keras API. All the metrics can be found in tf.keras.metrics namespace, and there is usually a direct mapping between tf.compat.v1.metrics with tf.keras.metrics.

In the following example, the metrics are added in model.compile() method. Users only need to create the metric instance, without specifying the label and prediction tensor. The Keras model will route the model output and label to the metrics object.

dataset = tf.data.Dataset.from_tensor_slices((features, labels)).batch(1)
eval_dataset = tf.data.Dataset.from_tensor_slices(
      (eval_features, eval_labels)).batch(1)

inputs = tf.keras.Input((2,))
logits = tf.keras.layers.Dense(2)(inputs)
predictions = tf.math.argmax(input=logits, axis=1)
model = tf.keras.models.Model(inputs, predictions)
optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)

model.compile(optimizer, loss='mse', metrics=[tf.keras.metrics.Accuracy()])
model.evaluate(eval_dataset, return_dict=True)

With eager execution enabled, tf.keras.metrics.Metric instances can be directly used to evaluate numpy data or eager tensors. tf.keras.metrics.Metric objects are stateful containers. The metric value can be updated via metric.update_state(y_true, y_pred), and the result can be retrieved by metrics.result().

accuracy = tf.keras.metrics.Accuracy()

accuracy.update_state(y_true=[0, 0, 1, 1], y_pred=[0, 0, 0, 1])
accuracy.result().numpy()
accuracy.update_state(y_true=[0, 0, 1, 1], y_pred=[0, 0, 0, 0])
accuracy.update_state(y_true=[0, 0, 1, 1], y_pred=[1, 1, 0, 0])
accuracy.result().numpy()

For more details about tf.keras.metrics.Metric, please take a look for the API documentation at tf.keras.metrics.Metric, as well as the migration guide.

Migrate TF1.x optimizers to Keras optimizers

The optimizers in tf.compat.v1.train, such as the Adam optimizer and the gradient descent optimizer, have equivalents in tf.keras.optimizers.

The table below summarizes how you can convert these legacy optimizers to their Keras equivalents. You can directly replace the TF1.x version with the TF2 version unless additional steps (such as updating the default learning rate) are required.

Note that converting your optimizers may make old checkpoints incompatible.

TF1.x TF2 Additional steps
`tf.v1.train.GradientDescentOptimizer` tf.keras.optimizers.SGD None
`tf.v1.train.MomentumOptimizer` tf.keras.optimizers.SGD Include the `momentum` argument
`tf.v1.train.AdamOptimizer` tf.keras.optimizers.Adam Rename `beta1` and `beta2` arguments to `beta_1` and `beta_2`
`tf.v1.train.RMSPropOptimizer` tf.keras.optimizers.RMSprop Rename the `decay` argument to `rho`
`tf.v1.train.AdadeltaOptimizer` tf.keras.optimizers.Adadelta None
`tf.v1.train.AdagradOptimizer` tf.keras.optimizers.Adagrad None
`tf.v1.train.FtrlOptimizer` tf.keras.optimizers.Ftrl Remove the `accum_name` and `linear_name` arguments
`tf.contrib.AdamaxOptimizer` tf.keras.optimizers.Adamax Rename the `beta1`, and `beta2` arguments to `beta_1` and `beta_2`
`tf.contrib.Nadam` tf.keras.optimizers.Nadam Rename the `beta1`, and `beta2` arguments to `beta_1` and `beta_2`