Halaman ini diterjemahkan oleh Cloud Translation API.
Switch to English

Menampilkan data gambar di TensorBoard

Lihat di TensorFlow.org Jalankan di Google Colab Lihat sumber di GitHub

Gambaran

Dengan menggunakan TensorFlow Image Summary API, Anda dapat dengan mudah membuat log tensor dan gambar arbitrer dan melihatnya di TensorBoard. Ini bisa sangat membantu untuk mengambil sampel dan memeriksa data masukan Anda, atau untuk memvisualisasikan bobot lapisan dan tensor yang dihasilkan . Anda juga dapat membuat log data diagnostik sebagai gambar yang dapat membantu selama pengembangan model Anda.

Dalam tutorial ini, Anda akan mempelajari cara menggunakan Image Summary API untuk memvisualisasikan tensor sebagai gambar. Anda juga akan mempelajari cara mengambil gambar arbitrer, mengonversinya menjadi tensor, dan memvisualisasikannya di TensorBoard. Anda akan bekerja melalui contoh sederhana namun nyata yang menggunakan Ringkasan Gambar untuk membantu Anda memahami bagaimana kinerja model Anda.

Mendirikan

try:
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.x
except Exception:
  pass

# Load the TensorBoard notebook extension.
%load_ext tensorboard
TensorFlow 2.x selected.

from datetime import datetime
import io
import itertools
from packaging import version
from six.moves import range

import tensorflow as tf
from tensorflow import keras

import matplotlib.pyplot as plt
import numpy as np
import sklearn.metrics

print("TensorFlow version: ", tf.__version__)
assert version.parse(tf.__version__).release[0] >= 2, \
    "This notebook requires TensorFlow 2.0 or above."
TensorFlow version:  2.2

Unduh set data Fashion-MNIST

Anda akan membangun jaringan saraf sederhana untuk mengklasifikasikan gambar dalam set data Fashion-MNIST . Dataset ini terdiri dari 70.000 gambar grayscale produk fashion 28x28 dari 10 kategori, dengan 7.000 gambar per kategori.

Pertama, unduh datanya:

# Download the data. The data is already divided into train and test.
# The labels are integers representing classes.
fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = \
    fashion_mnist.load_data()

# Names of the integer classes, i.e., 0 -> T-short/top, 1 -> Trouser, etc.
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 
    'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
32768/29515 [=================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
26427392/26421880 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
8192/5148 [===============================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
4423680/4422102 [==============================] - 0s 0us/step

Memvisualisasikan satu gambar

Untuk memahami cara kerja API Ringkasan Gambar, Anda sekarang cukup membuat log gambar pelatihan pertama dalam set pelatihan Anda di TensorBoard.

Sebelum Anda melakukannya, periksa bentuk data pelatihan Anda:

print("Shape: ", train_images[0].shape)
print("Label: ", train_labels[0], "->", class_names[train_labels[0]])
Shape:  (28, 28)
Label:  9 -> Ankle boot

Perhatikan bahwa bentuk dari setiap gambar dalam kumpulan data adalah tensor rank-2 dari bentuk (28, 28), yang mewakili tinggi dan lebar.

Namun, tf.summary.image() mengharapkan tensor peringkat-4 yang berisi (batch_size, height, width, channels) . Oleh karena itu, tensor perlu dibentuk kembali.

Anda hanya mencatat satu gambar, jadi batch_size adalah 1. Gambar tersebut batch_size abu-abu, jadi setel channels ke 1.

# Reshape the image for the Summary API.
img = np.reshape(train_images[0], (-1, 28, 28, 1))

Anda sekarang siap untuk mencatat gambar ini dan melihatnya di TensorBoard.

# Clear out any prior log data.
!rm -rf logs

# Sets up a timestamped log directory.
logdir = "logs/train_data/" + datetime.now().strftime("%Y%m%d-%H%M%S")
# Creates a file writer for the log directory.
file_writer = tf.summary.create_file_writer(logdir)

# Using the file writer, log the reshaped image.
with file_writer.as_default():
  tf.summary.image("Training data", img, step=0)

Sekarang, gunakan TensorBoard untuk memeriksa gambar. Tunggu beberapa detik hingga UI berputar.

%tensorboard --logdir logs/train_data

Tab "Gambar" menampilkan gambar yang baru saja Anda login. Ini adalah "sepatu bot".

Gambar diskalakan ke ukuran default agar lebih mudah dilihat. Jika Anda ingin melihat gambar asli tanpa skala, centang "Tampilkan ukuran gambar aktual" di kiri atas.

Mainkan penggeser kecerahan dan kontras untuk melihat bagaimana pengaruhnya terhadap piksel gambar.

Memvisualisasikan banyak gambar

Mencatat satu tensor memang bagus, tetapi bagaimana jika Anda ingin mencatat beberapa contoh pelatihan?

Cukup tentukan jumlah gambar yang ingin Anda catat saat meneruskan data ke tf.summary.image() .

with file_writer.as_default():
  # Don't forget to reshape.
  images = np.reshape(train_images[0:25], (-1, 28, 28, 1))
  tf.summary.image("25 training data examples", images, max_outputs=25, step=0)

%tensorboard --logdir logs/train_data
.dll

Mencatat data gambar arbitrer

Bagaimana jika Anda ingin memvisualisasikan gambar yang bukan tensor, seperti gambar yang dihasilkan oleh matplotlib ?

Anda memerlukan beberapa kode boilerplate untuk mengubah plot menjadi tensor, tetapi setelah itu, Anda siap melakukannya.

Dalam kode di bawah ini, Anda akan mencatat 25 gambar pertama sebagai kisi yang bagus menggunakan fungsi subplot() matplotlib. Anda kemudian akan melihat kisi di TensorBoard:

# Clear out prior logging data.
!rm -rf logs/plots

logdir = "logs/plots/" + datetime.now().strftime("%Y%m%d-%H%M%S")
file_writer = tf.summary.create_file_writer(logdir)

def plot_to_image(figure):
  """Converts the matplotlib plot specified by 'figure' to a PNG image and
  returns it. The supplied figure is closed and inaccessible after this call."""
  # Save the plot to a PNG in memory.
  buf = io.BytesIO()
  plt.savefig(buf, format='png')
  # Closing the figure prevents it from being displayed directly inside
  # the notebook.
  plt.close(figure)
  buf.seek(0)
  # Convert PNG buffer to TF image
  image = tf.image.decode_png(buf.getvalue(), channels=4)
  # Add the batch dimension
  image = tf.expand_dims(image, 0)
  return image

def image_grid():
  """Return a 5x5 grid of the MNIST images as a matplotlib figure."""
  # Create a figure to contain the plot.
  figure = plt.figure(figsize=(10,10))
  for i in range(25):
    # Start next subplot.
    plt.subplot(5, 5, i + 1, title=class_names[train_labels[i]])
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(train_images[i], cmap=plt.cm.binary)
  
  return figure

# Prepare the plot
figure = image_grid()
# Convert to image and log
with file_writer.as_default():
  tf.summary.image("Training data", plot_to_image(figure), step=0)

%tensorboard --logdir logs/plots

Membangun pengklasifikasi gambar

Sekarang gabungkan semuanya dengan contoh nyata. Lagi pula, Anda di sini untuk melakukan pembelajaran mesin dan tidak membuat plot gambar yang indah!

Anda akan menggunakan ringkasan gambar untuk memahami seberapa baik performa model Anda saat melatih pengklasifikasi sederhana untuk set data Fashion-MNIST.

Pertama, buat model yang sangat sederhana dan kompilasi, atur fungsi optimizer dan loss. Langkah kompilasi juga menentukan bahwa Anda ingin mencatat keakuratan pengklasifikasi di sepanjang jalan.

model = keras.models.Sequential([
    keras.layers.Flatten(input_shape=(28, 28)),
    keras.layers.Dense(32, activation='relu'),
    keras.layers.Dense(10, activation='softmax')
])

model.compile(
    optimizer='adam', 
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

Saat melatih pengklasifikasi, sebaiknya lihat matriks kebingungan . Matriks konfusi memberi Anda pengetahuan mendetail tentang performa pengklasifikasi Anda pada data pengujian.

Tentukan fungsi yang menghitung matriks kebingungan. Anda akan menggunakan fungsi Scikit-learn untuk melakukan ini, lalu memplotnya menggunakan matplotlib.

def plot_confusion_matrix(cm, class_names):
  """
  Returns a matplotlib figure containing the plotted confusion matrix.

  Args:
    cm (array, shape = [n, n]): a confusion matrix of integer classes
    class_names (array, shape = [n]): String names of the integer classes
  """
  figure = plt.figure(figsize=(8, 8))
  plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
  plt.title("Confusion matrix")
  plt.colorbar()
  tick_marks = np.arange(len(class_names))
  plt.xticks(tick_marks, class_names, rotation=45)
  plt.yticks(tick_marks, class_names)

  # Compute the labels from the normalized confusion matrix.
  labels = np.around(cm.astype('float') / cm.sum(axis=1)[:, np.newaxis], decimals=2)

  # Use white text if squares are dark; otherwise black.
  threshold = cm.max() / 2.
  for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
    color = "white" if cm[i, j] > threshold else "black"
    plt.text(j, i, labels[i, j], horizontalalignment="center", color=color)

  plt.tight_layout()
  plt.ylabel('True label')
  plt.xlabel('Predicted label')
  return figure

Anda sekarang siap untuk melatih pengklasifikasi dan secara teratur mencatat matriks kebingungan di sepanjang jalan.

Inilah yang akan Anda lakukan:

  1. Buat callback Keras TensorBoard untuk mencatat metrik dasar
  2. Buat Keras LambdaCallback untuk mencatat matriks kebingungan di akhir setiap epoch
  3. Latih model menggunakan Model.fit (), pastikan untuk meneruskan kedua callback

Saat pelatihan berlangsung, gulir ke bawah untuk melihat TensorBoard dimulai.

# Clear out prior logging data.
!rm -rf logs/image

logdir = "logs/image/" + datetime.now().strftime("%Y%m%d-%H%M%S")
# Define the basic TensorBoard callback.
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir)
file_writer_cm = tf.summary.create_file_writer(logdir + '/cm')
def log_confusion_matrix(epoch, logs):
  # Use the model to predict the values from the validation dataset.
  test_pred_raw = model.predict(test_images)
  test_pred = np.argmax(test_pred_raw, axis=1)

  # Calculate the confusion matrix.
  cm = sklearn.metrics.confusion_matrix(test_labels, test_pred)
  # Log the confusion matrix as an image summary.
  figure = plot_confusion_matrix(cm, class_names=class_names)
  cm_image = plot_to_image(figure)

  # Log the confusion matrix as an image summary.
  with file_writer_cm.as_default():
    tf.summary.image("Confusion Matrix", cm_image, step=epoch)

# Define the per-epoch callback.
cm_callback = keras.callbacks.LambdaCallback(on_epoch_end=log_confusion_matrix)
# Start TensorBoard.
%tensorboard --logdir logs/image

# Train the classifier.
model.fit(
    train_images,
    train_labels,
    epochs=5,
    verbose=0, # Suppress chatty output
    callbacks=[tensorboard_callback, cm_callback],
    validation_data=(test_images, test_labels),
)

Perhatikan bahwa akurasi meningkat pada set kereta dan validasi. Itu pertanda bagus. Namun, bagaimana performa model pada subkumpulan data tertentu?

Pilih tab "Gambar" untuk memvisualisasikan matriks kebingungan yang Anda catat. Centang "Tampilkan ukuran gambar aktual" di kiri atas untuk melihat matriks kebingungan dalam ukuran penuh.

Secara default, dasbor menampilkan ringkasan gambar untuk langkah atau masa yang dicatat terakhir. Gunakan penggeser untuk melihat matriks kebingungan sebelumnya. Perhatikan bagaimana matriks berubah secara signifikan saat pelatihan berlangsung, dengan kotak yang lebih gelap bergabung di sepanjang diagonal, dan matriks lainnya cenderung ke arah 0 dan putih. Ini berarti pengklasifikasi Anda meningkat seiring dengan kemajuan pelatihan! Kerja bagus!

Matriks konfusi menunjukkan bahwa model sederhana ini memiliki beberapa masalah. Terlepas dari kemajuan besar, Kemeja, Kaos, dan Pullover semakin bingung satu sama lain. Model ini membutuhkan lebih banyak pekerjaan.

Jika Anda tertarik, coba perbaiki model ini dengan jaringan konvolusional (CNN).