Exibindo dados de imagem no TensorBoard

Ver no TensorFlow.org Executar no Google Colab Ver fonte no GitHub Baixar caderno

Visão geral

Usando o Resumo API TensorFlow Imagem, você pode tensores e imagens arbitrárias facilmente log e visualizá-los em TensorBoard. Isto pode ser extremamente útil para provar e examinar seus dados de entrada, ou para visualizar pesos camada e tensores gerados . Você também pode registrar dados de diagnóstico como imagens que podem ser úteis no decorrer do desenvolvimento de seu modelo.

Neste tutorial, você aprenderá a usar a API de resumo de imagem para visualizar tensores como imagens. Você também aprenderá como pegar uma imagem arbitrária, convertê-la em um tensor e visualizá-la no TensorBoard. Você trabalhará com um exemplo simples, mas real, que usa resumos de imagens para ajudá-lo a entender o desempenho do seu modelo.

Configurar

try:
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.x
except Exception:
  pass

# Load the TensorBoard notebook extension.
%load_ext tensorboard
TensorFlow 2.x selected.
from datetime import datetime
import io
import itertools
from packaging import version

import tensorflow as tf
from tensorflow import keras

import matplotlib.pyplot as plt
import numpy as np
import sklearn.metrics

print("TensorFlow version: ", tf.__version__)
assert version.parse(tf.__version__).release[0] >= 2, \
    "This notebook requires TensorFlow 2.0 or above."
TensorFlow version:  2.2

Baixe o conjunto de dados Fashion-MNIST

Você está indo para construir uma rede neural simples de imagens Classificar no o Fashion-MNIST conjunto de dados. Este conjunto de dados consiste em 70.000 imagens em tons de cinza de 28x28 de produtos de moda de 10 categorias, com 7.000 imagens por categoria.

Primeiro, baixe os dados:

# Download the data. The data is already divided into train and test.
# The labels are integers representing classes.
fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = \
    fashion_mnist.load_data()

# Names of the integer classes, i.e., 0 -> T-short/top, 1 -> Trouser, etc.
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 
    'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
32768/29515 [=================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
26427392/26421880 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
8192/5148 [===============================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
4423680/4422102 [==============================] - 0s 0us/step

Visualizando uma única imagem

Para entender como a API de resumo de imagem funciona, agora você simplesmente registrará a primeira imagem de treinamento em seu conjunto de treinamento no TensorBoard.

Antes de fazer isso, examine a forma dos seus dados de treinamento:

print("Shape: ", train_images[0].shape)
print("Label: ", train_labels[0], "->", class_names[train_labels[0]])
Shape:  (28, 28)
Label:  9 -> Ankle boot

Observe que a forma de cada imagem no conjunto de dados é um tensor de classificação 2 de forma (28, 28), representando a altura e a largura.

No entanto, tf.summary.image() espera que um tensor de rank-4 contendo (batch_size, height, width, channels) . Portanto, os tensores precisam ser remodelados.

Você está registrando apenas uma imagem, então batch_size é 1. As imagens são em tons de cinza, para definir channels para 1.

# Reshape the image for the Summary API.
img = np.reshape(train_images[0], (-1, 28, 28, 1))

Agora você está pronto para registrar essa imagem e visualizá-la no TensorBoard.

# Clear out any prior log data.
!rm -rf logs

# Sets up a timestamped log directory.
logdir = "logs/train_data/" + datetime.now().strftime("%Y%m%d-%H%M%S")
# Creates a file writer for the log directory.
file_writer = tf.summary.create_file_writer(logdir)

# Using the file writer, log the reshaped image.
with file_writer.as_default():
  tf.summary.image("Training data", img, step=0)

Agora, use o TensorBoard para examinar a imagem. Aguarde alguns segundos para a IU girar.

%tensorboard --logdir logs/train_data

A guia "Imagens" exibe a imagem que você acabou de registrar. É uma "bota de tornozelo".

A imagem é redimensionada para um tamanho padrão para facilitar a visualização. Se você quiser ver a imagem original fora de escala, marque "Mostrar tamanho real da imagem" no canto superior esquerdo.

Brinque com os controles deslizantes de brilho e contraste para ver como eles afetam os pixels da imagem.

Visualização de várias imagens

Registrar um tensor é ótimo, mas e se você quisesse registrar vários exemplos de treinamento?

Basta especificar o número de imagens que você quer registrar quando passar dados para tf.summary.image() .

with file_writer.as_default():
  # Don't forget to reshape.
  images = np.reshape(train_images[0:25], (-1, 28, 28, 1))
  tf.summary.image("25 training data examples", images, max_outputs=25, step=0)

%tensorboard --logdir logs/train_data

Registro de dados de imagem arbitrários

E se você quiser visualizar uma imagem que não é um tensor, como uma imagem gerada por matplotlib ?

Você precisa de algum código clichê para converter o gráfico em um tensor, mas depois disso, você está pronto para prosseguir.

No código abaixo, você vai registrar as primeiras 25 imagens como uma grade agradável usando de matplotlib subplot() função. Você verá a grade no TensorBoard:

# Clear out prior logging data.
!rm -rf logs/plots

logdir = "logs/plots/" + datetime.now().strftime("%Y%m%d-%H%M%S")
file_writer = tf.summary.create_file_writer(logdir)

def plot_to_image(figure):
  """Converts the matplotlib plot specified by 'figure' to a PNG image and
  returns it. The supplied figure is closed and inaccessible after this call."""
  # Save the plot to a PNG in memory.
  buf = io.BytesIO()
  plt.savefig(buf, format='png')
  # Closing the figure prevents it from being displayed directly inside
  # the notebook.
  plt.close(figure)
  buf.seek(0)
  # Convert PNG buffer to TF image
  image = tf.image.decode_png(buf.getvalue(), channels=4)
  # Add the batch dimension
  image = tf.expand_dims(image, 0)
  return image

def image_grid():
  """Return a 5x5 grid of the MNIST images as a matplotlib figure."""
  # Create a figure to contain the plot.
  figure = plt.figure(figsize=(10,10))
  for i in range(25):
    # Start next subplot.
    plt.subplot(5, 5, i + 1, title=class_names[train_labels[i]])
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(train_images[i], cmap=plt.cm.binary)

  return figure

# Prepare the plot
figure = image_grid()
# Convert to image and log
with file_writer.as_default():
  tf.summary.image("Training data", plot_to_image(figure), step=0)

%tensorboard --logdir logs/plots

Construindo um classificador de imagem

Agora coloque tudo isso junto com um exemplo real. Afinal, você está aqui para fazer aprendizado de máquina e não para criar imagens bonitas!

Você usará resumos de imagens para entender o desempenho do seu modelo enquanto treina um classificador simples para o conjunto de dados Fashion-MNIST.

Primeiro, crie um modelo muito simples e compile-o, configurando o otimizador e a função de perda. A etapa de compilação também especifica que você deseja registrar a precisão do classificador ao longo do caminho.

model = keras.models.Sequential([
    keras.layers.Flatten(input_shape=(28, 28)),
    keras.layers.Dense(32, activation='relu'),
    keras.layers.Dense(10, activation='softmax')
])

model.compile(
    optimizer='adam', 
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

Ao treinar um classificador, é útil para ver a matriz de confusão . A matriz de confusão fornece conhecimento detalhado de como seu classificador está se saindo nos dados de teste.

Defina uma função que calcule a matriz de confusão. Você vai usar um conveniente scikit-learn função para fazer isso, e depois traçar-lo usando matplotlib.

def plot_confusion_matrix(cm, class_names):
  """
  Returns a matplotlib figure containing the plotted confusion matrix.

  Args:
    cm (array, shape = [n, n]): a confusion matrix of integer classes
    class_names (array, shape = [n]): String names of the integer classes
  """
  figure = plt.figure(figsize=(8, 8))
  plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
  plt.title("Confusion matrix")
  plt.colorbar()
  tick_marks = np.arange(len(class_names))
  plt.xticks(tick_marks, class_names, rotation=45)
  plt.yticks(tick_marks, class_names)

  # Compute the labels from the normalized confusion matrix.
  labels = np.around(cm.astype('float') / cm.sum(axis=1)[:, np.newaxis], decimals=2)

  # Use white text if squares are dark; otherwise black.
  threshold = cm.max() / 2.
  for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
    color = "white" if cm[i, j] > threshold else "black"
    plt.text(j, i, labels[i, j], horizontalalignment="center", color=color)

  plt.tight_layout()
  plt.ylabel('True label')
  plt.xlabel('Predicted label')
  return figure

Agora você está pronto para treinar o classificador e registrar regularmente a matriz de confusão ao longo do caminho.

Aqui está o que você fará:

  1. Criar o callback Keras TensorBoard para log métricas básicas
  2. Criar um Keras LambdaCallback para registrar a matriz de confusão, no final de cada época
  3. Treine o modelo usando Model.fit (), certificando-se de passar os dois callbacks

Conforme o treinamento avança, role para baixo para ver o TensorBoard inicializar.

# Clear out prior logging data.
!rm -rf logs/image

logdir = "logs/image/" + datetime.now().strftime("%Y%m%d-%H%M%S")
# Define the basic TensorBoard callback.
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir)
file_writer_cm = tf.summary.create_file_writer(logdir + '/cm')
def log_confusion_matrix(epoch, logs):
  # Use the model to predict the values from the validation dataset.
  test_pred_raw = model.predict(test_images)
  test_pred = np.argmax(test_pred_raw, axis=1)

  # Calculate the confusion matrix.
  cm = sklearn.metrics.confusion_matrix(test_labels, test_pred)
  # Log the confusion matrix as an image summary.
  figure = plot_confusion_matrix(cm, class_names=class_names)
  cm_image = plot_to_image(figure)

  # Log the confusion matrix as an image summary.
  with file_writer_cm.as_default():
    tf.summary.image("Confusion Matrix", cm_image, step=epoch)

# Define the per-epoch callback.
cm_callback = keras.callbacks.LambdaCallback(on_epoch_end=log_confusion_matrix)
# Start TensorBoard.
%tensorboard --logdir logs/image

# Train the classifier.
model.fit(
    train_images,
    train_labels,
    epochs=5,
    verbose=0, # Suppress chatty output
    callbacks=[tensorboard_callback, cm_callback],
    validation_data=(test_images, test_labels),
)

Observe que a precisão está aumentando nos conjuntos de trem e de validação. Isso é um bom sinal. Mas como está o desempenho do modelo em subconjuntos específicos de dados?

Selecione a guia "Imagens" para visualizar suas matrizes de confusão registradas. Marque "Mostrar tamanho real da imagem" no canto superior esquerdo para ver a matriz de confusão em tamanho real.

Por padrão, o painel mostra o resumo da imagem para a última etapa ou época registrada. Use o controle deslizante para ver as matrizes de confusão anteriores. Observe como a matriz muda significativamente conforme o treinamento avança, com quadrados mais escuros coalescendo ao longo da diagonal e o resto da matriz tendendo para 0 e branco. Isso significa que seu classificador está melhorando conforme o treinamento avança! Ótimo trabalho!

A matriz de confusão mostra que este modelo simples tem alguns problemas. Apesar do grande progresso, camisas, camisetas e pulôveres estão se confundindo. O modelo precisa de mais trabalho.

Se você estiver interessado, tentar melhorar este modelo com uma rede convolutional (CNN).