TensorBoard Scalars: Logging training metrics in Keras

View on TensorFlow.org Run in Google Colab View source on GitHub

Overview

Machine learning invariably involves understanding key metrics such as loss and how they change as training progresses. These metrics can help you understand if you're overfitting, for example, or if you're unnecessarily training for too long. You may want to compare these metrics across different training runs to help debug and improve your model.

TensorBoard's Time Series Dashboard allows you to visualize these metrics using a simple API with very little effort. This tutorial presents very basic examples to help you learn how to use these APIs with TensorBoard when developing your Keras model. You will learn how to use the Keras TensorBoard callback and TensorFlow Summary APIs to visualize default and custom scalars.

Setup

# Load the TensorBoard notebook extension.
%load_ext tensorboard
from datetime import datetime
from packaging import version

import tensorflow as tf
from tensorflow import keras
from keras import backend as K

import numpy as np

print("TensorFlow version: ", tf.__version__)
assert version.parse(tf.__version__).release[0] >= 2, \
    "This notebook requires TensorFlow 2.0 or above."
TensorFlow version:  2.8.2
# Clear any logs from previous runs
rm -rf ./logs/

Set up data for a simple regression

You're now going to use Keras to calculate a regression, i.e., find the best line of fit for a paired data set. (While using neural networks and gradient descent is overkill for this kind of problem, it does make for a very easy to understand example.)

You're going to use TensorBoard to observe how training and test loss change across epochs. Hopefully, you'll see training and test loss decrease over time and then remain steady.

First, generate 1000 data points roughly along the line y = 0.5x + 2. Split these data points into training and test sets. Your hope is that the neural net learns this relationship.

data_size = 1000
# 80% of the data is for training.
train_pct = 0.8

train_size = int(data_size * train_pct)

# Create some input data between -1 and 1 and randomize it.
x = np.linspace(-1, 1, data_size)
np.random.shuffle(x)

# Generate the output data.
# y = 0.5x + 2 + noise
y = 0.5 * x + 2 + np.random.normal(0, 0.05, (data_size, ))

# Split into test and train pairs.
x_train, y_train = x[:train_size], y[:train_size]
x_test, y_test = x[train_size:], y[train_size:]

Training the model and logging loss

You're now ready to define, train and evaluate your model.

To log the loss scalar as you train, you'll do the following:

  1. Create the Keras TensorBoard callback
  2. Specify a log directory
  3. Pass the TensorBoard callback to Keras' Model.fit().

TensorBoard reads log data from the log directory hierarchy. In this notebook, the root log directory is logs/scalars, suffixed by a timestamped subdirectory. The timestamped subdirectory enables you to easily identify and select training runs as you use TensorBoard and iterate on your model.

logdir = "logs/scalars/" + datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir)

model = keras.models.Sequential([
    keras.layers.Dense(16, input_dim=1),
    keras.layers.Dense(1),
])

model.compile(
    loss='mse', # keras.losses.mean_squared_error
    optimizer=keras.optimizers.SGD(learning_rate=0.2),
)

print("Training ... With default parameters, this takes less than 10 seconds.")
training_history = model.fit(
    x_train, # input
    y_train, # output
    batch_size=train_size,
    verbose=0, # Suppress chatty output; use Tensorboard instead
    epochs=100,
    validation_data=(x_test, y_test),
    callbacks=[tensorboard_callback],
)

print("Average test loss: ", np.average(training_history.history['loss']))
Training ... With default parameters, this takes less than 10 seconds.
Average test loss:  0.042797307365108284

Examining loss using TensorBoard

Now, start TensorBoard, specifying the root log directory you used above.

Wait a few seconds for TensorBoard's UI to spin up.

%tensorboard --logdir logs/scalars

You may see TensorBoard display the message "No dashboards are active for the current data set". That's because initial logging data hasn't been saved yet. As training progresses, the Keras model will start logging data. TensorBoard will periodically refresh and show you your scalar metrics. If you're impatient, you can tap the Refresh arrow at the top right.

As you watch the training progress, note how both training and validation loss rapidly decrease, and then remain stable. In fact, you could have stopped training after 25 epochs, because the training didn't improve much after that point.

Hover over the graph to see specific data points. You can also try zooming in with your mouse, or selecting part of them to view more detail.

Notice the "Runs" selector on the left. A "run" represents a set of logs from a round of training, in this case the result of Model.fit(). Developers typically have many, many runs, as they experiment and develop their model over time.

Use the Runs selector to choose specific runs, or choose from only training or validation. Comparing runs will help you evaluate which version of your code is solving your problem better.

Ok, TensorBoard's loss graph demonstrates that the loss consistently decreased for both training and validation and then stabilized. That means that the model's metrics are likely very good! Now see how the model actually behaves in real life.

Given the input data (60, 25, 2), the line y = 0.5x + 2 should yield (32, 14.5, 3). Does the model agree?

print(model.predict([60, 25, 2]))
# True values to compare predictions against: 
# [[32.0]
#  [14.5]
#  [ 3.0]]
[[32.148884 ]
 [14.562463 ]
 [ 3.0056725]]

Not bad!

Logging custom scalars

What if you want to log custom values, such as a dynamic learning rate? To do that, you need to use the TensorFlow Summary API.

Retrain the regression model and log a custom learning rate. Here's how:

  1. Create a file writer, using tf.summary.create_file_writer().
  2. Define a custom learning rate function. This will be passed to the Keras LearningRateScheduler callback.
  3. Inside the learning rate function, use tf.summary.scalar() to log the custom learning rate.
  4. Pass the LearningRateScheduler callback to Model.fit().

In general, to log a custom scalar, you need to use tf.summary.scalar() with a file writer. The file writer is responsible for writing data for this run to the specified directory and is implicitly used when you use the tf.summary.scalar().

logdir = "logs/scalars/" + datetime.now().strftime("%Y%m%d-%H%M%S")
file_writer = tf.summary.create_file_writer(logdir + "/metrics")
file_writer.set_as_default()

def lr_schedule(epoch):
  """
  Returns a custom learning rate that decreases as epochs progress.
  """
  learning_rate = 0.2
  if epoch > 10:
    learning_rate = 0.02
  if epoch > 20:
    learning_rate = 0.01
  if epoch > 50:
    learning_rate = 0.005

  tf.summary.scalar('learning rate', data=learning_rate, step=epoch)
  return learning_rate

lr_callback = keras.callbacks.LearningRateScheduler(lr_schedule)
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir)

model = keras.models.Sequential([
    keras.layers.Dense(16, input_dim=1),
    keras.layers.Dense(1),
])

model.compile(
    loss='mse', # keras.losses.mean_squared_error
    optimizer=keras.optimizers.SGD(),
)

training_history = model.fit(
    x_train, # input
    y_train, # output
    batch_size=train_size,
    verbose=0, # Suppress chatty output; use Tensorboard instead
    epochs=100,
    validation_data=(x_test, y_test),
    callbacks=[tensorboard_callback, lr_callback],
)

Let's look at TensorBoard again.

%tensorboard --logdir logs/scalars

Using the "Runs" selector on the left, notice that you have a <timestamp>/metrics run. Selecting this run displays a "learning rate" graph that allows you to verify the progression of the learning rate during this run.

You can also compare this run's training and validation loss curves against your earlier runs. You might also notice that the learning rate schedule returned discrete values, depending on epoch, but the learning rate plot may appear smooth. TensorBoard has a smoothing parameter that you may need to turn down to zero to see the unsmoothed values.

How does this model do?

print(model.predict([60, 25, 2]))
# True values to compare predictions against: 
# [[32.0]
#  [14.5]
#  [ 3.0]]
[[31.958094 ]
 [14.482997 ]
 [ 2.9993598]]

Batch-level logging

First let's load the MNIST dataset, normalize the data and write a function that creates a simple Keras model for classifying the images into 10 classes.

mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

def create_model():
  return tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(512, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation='softmax')
  ])
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
11501568/11490434 [==============================] - 0s 0us/step

Instantaneous batch-level logging

Logging metrics at the batch level instantaneously can show us the level of fluctuation between batches while training in each epoch, which can be useful for debugging.

Setting up a summary writer to a different log directory:

log_dir = 'logs/batch_level/' + datetime.now().strftime("%Y%m%d-%H%M%S") + '/train'
train_writer = tf.summary.create_file_writer(log_dir)

To enable batch-level logging, custom tf.summary metrics should be defined by overriding train_step() in the Model's class definition and enclosed in a summary writer context. This can simply be made combined into subclassed Model definitions or can extend to edit our previous Functional API Model, as shown below:

class MyModel(tf.keras.Model):
  def __init__(self, model):
    super().__init__()
    self.model = model

  def train_step(self, data):
    x, y = data
    with tf.GradientTape() as tape:
      y_pred = self.model(x, training=True)
      loss = self.compiled_loss(y, y_pred)
      mse = tf.keras.losses.mean_squared_error(y, K.max(y_pred, axis=-1))
    self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
    with train_writer.as_default(step=self._train_counter):
      tf.summary.scalar('batch_loss', loss)
      tf.summary.scalar('batch_mse', mse)
    return self.compute_metrics(x, y, y_pred, None)

  def call(self, x):
    x = self.model(x)
    return x

# Adds custom batch-level metrics to our previous Functional API model
model = MyModel(create_model())
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

Define our TensorBoard callback to log both epoch-level and batch-level metrics to our log directory and call model.fit() with our selected batch_size:

tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir)

model.fit(x=x_train, 
          y=y_train,
          epochs=5,
          batch_size=500, 
          validation_data=(x_test, y_test), 
          callbacks=[tensorboard_callback])
Epoch 1/5
120/120 [==============================] - 5s 36ms/step - loss: 0.4379 - accuracy: 0.8788 - val_loss: 0.2041 - val_accuracy: 0.9430
Epoch 2/5
120/120 [==============================] - 4s 31ms/step - loss: 0.1875 - accuracy: 0.9471 - val_loss: 0.1462 - val_accuracy: 0.9591
Epoch 3/5
120/120 [==============================] - 3s 27ms/step - loss: 0.1355 - accuracy: 0.9613 - val_loss: 0.1170 - val_accuracy: 0.9670
Epoch 4/5
120/120 [==============================] - 3s 27ms/step - loss: 0.1058 - accuracy: 0.9694 - val_loss: 0.0954 - val_accuracy: 0.9723
Epoch 5/5
120/120 [==============================] - 3s 27ms/step - loss: 0.0872 - accuracy: 0.9752 - val_loss: 0.0843 - val_accuracy: 0.9749
<keras.callbacks.History at 0x7fce165a2fd0>

Open TensorBoard with the new log directory and see both the epoch-level and batch-level metrics:

%tensorboard --logdir logs/batch_level

Cumulative batch-level logging

Batch-level logging can also be implemented cumulatively, averaging each batch's metrics with those of previous batches and resulting in a smoother training curve when logging batch-level metrics.

Setting up a summary writer to a different log directory:

log_dir = 'logs/batch_avg/' + datetime.now().strftime("%Y%m%d-%H%M%S") + '/train'
train_writer = tf.summary.create_file_writer(log_dir)

Create stateful metrics that can be logged per batch:

batch_loss = tf.keras.metrics.Mean('batch_loss', dtype=tf.float32)
batch_accuracy = tf.keras.metrics.SparseCategoricalAccuracy('batch_accuracy')

As before, add custom tf.summary metrics in the overridden train_step method. To make the batch-level logging cumulative, use the stateful metrics we defined to calculate the cumulative result given each training step's data.

class MyModel(tf.keras.Model):
  def __init__(self, model):
    super().__init__()
    self.model = model

  def train_step(self, data):
    x, y = data
    with tf.GradientTape() as tape:
      y_pred = self.model(x, training=True)
      loss = self.compiled_loss(y, y_pred)
    self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
    batch_loss(loss)
    batch_accuracy(y, y_pred)
    with train_writer.as_default(step=self._train_counter):
      tf.summary.scalar('batch_loss', batch_loss.result())
      tf.summary.scalar('batch_accuracy', batch_accuracy.result())
    return self.compute_metrics(x, y, y_pred, None)

  def call(self, x):
    x = self.model(x)
    return x

# Adds custom batch-level metrics to our previous Functional API model
model = MyModel(create_model())
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

As before, define our TensorBoard callback and call model.fit() with our selected batch_size:

tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir)

model.fit(x=x_train, 
          y=y_train,
          epochs=5,
          batch_size=500, 
          validation_data=(x_test, y_test), 
          callbacks=[tensorboard_callback])
Epoch 1/5
120/120 [==============================] - 4s 27ms/step - loss: 0.4266 - accuracy: 0.8813 - val_loss: 0.2055 - val_accuracy: 0.9415
Epoch 2/5
120/120 [==============================] - 3s 26ms/step - loss: 0.1864 - accuracy: 0.9476 - val_loss: 0.1417 - val_accuracy: 0.9613
Epoch 3/5
120/120 [==============================] - 3s 27ms/step - loss: 0.1352 - accuracy: 0.9614 - val_loss: 0.1148 - val_accuracy: 0.9665
Epoch 4/5
120/120 [==============================] - 3s 26ms/step - loss: 0.1066 - accuracy: 0.9702 - val_loss: 0.0932 - val_accuracy: 0.9716
Epoch 5/5
120/120 [==============================] - 3s 27ms/step - loss: 0.0859 - accuracy: 0.9749 - val_loss: 0.0844 - val_accuracy: 0.9754
<keras.callbacks.History at 0x7fce15c39f50>

Open TensorBoard with the new log directory and see both the epoch-level and batch-level metrics:

%tensorboard --logdir logs/batch_avg

That's it! You now know how to create custom training metrics in TensorBoard for a wide variety of use cases.