Bantuan melindungi Great Barrier Reef dengan TensorFlow pada Kaggle Bergabung Tantangan

Pelatihan Pada Perangkat dengan TensorFlow Lite

Lihat di TensorFlow.org Jalankan di Google Colab Lihat sumber di GitHub Unduh buku catatan

Saat menerapkan model pembelajaran mesin TensorFlow Lite ke perangkat atau aplikasi seluler, Anda mungkin ingin mengaktifkan model untuk ditingkatkan atau dipersonalisasi berdasarkan masukan dari perangkat atau pengguna akhir. Menggunakan on-perangkat teknik pelatihan memungkinkan Anda untuk memperbarui model tanpa data yang meninggalkan perangkat pengguna, meningkatkan privasi pengguna, dan tanpa mengharuskan pengguna untuk memperbarui perangkat lunak perangkat.

Misalnya, Anda mungkin memiliki model di aplikasi seluler Anda yang mengenali item mode, tetapi Anda ingin pengguna mendapatkan kinerja pengenalan yang lebih baik dari waktu ke waktu berdasarkan minat mereka. Mengaktifkan pelatihan di perangkat memungkinkan pengguna yang tertarik dengan sepatu menjadi lebih baik dalam mengenali gaya sepatu atau merek sepatu tertentu, semakin sering mereka menggunakan aplikasi Anda.

Tutorial ini menunjukkan cara membuat model TensorFlow Lite yang dapat dilatih dan ditingkatkan secara bertahap dalam aplikasi Android yang diinstal.

Mempersiapkan

Tutorial ini menggunakan Python untuk melatih dan mengonversi model TensorFlow sebelum memasukkannya ke dalam aplikasi Android. Mulailah dengan menginstal dan mengimpor paket-paket berikut.

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

print("TensorFlow version:", tf.__version__)
TensorFlow version: 2.8.0

Mengklasifikasikan gambar pakaian

Kode Contoh ini menggunakan dataset Mode MNIST untuk melatih model jaringan saraf untuk mengklasifikasikan gambar pakaian. Kumpulan data ini berisi 60.000 gambar skala abu-abu kecil (28 x 28 piksel) yang berisi 10 kategori aksesori mode yang berbeda, termasuk gaun, kemeja, dan sandal.

Gambar mode MNIST
Gambar 1: sampel Fashion-MNIST (oleh Zalando, MIT License).

Anda dapat menjelajahi dataset ini secara lebih mendalam di Keras klasifikasi tutorial .

Bangun model untuk pelatihan di perangkat

Model TensorFlow Lite biasanya hanya memiliki metode tunggal terkena fungsi (atau tanda tangan ) yang memungkinkan Anda untuk memanggil model untuk menjalankan sebuah kesimpulan. Agar model dapat dilatih dan digunakan pada perangkat, Anda harus dapat melakukan beberapa operasi terpisah, termasuk melatih, menyimpulkan, menyimpan, dan memulihkan fungsi untuk model tersebut. Anda dapat mengaktifkan fungsi ini dengan terlebih dahulu memperluas model TensorFlow agar memiliki beberapa fungsi, lalu menampilkan fungsi tersebut sebagai tanda tangan saat Anda mengonversi model ke format model TensorFlow Lite.

Contoh kode di bawah ini menunjukkan cara menambahkan fungsi berikut ke model TensorFlow:

  • train fungsi melatih model dengan data pelatihan.
  • infer fungsi memanggil kesimpulan.
  • save fungsi menghemat bobot dilatih ke dalam sistem file.
  • restore beban fungsi bobot dilatih dari sistem file.
IMG_SIZE = 28

class Model(tf.Module):

  def __init__(self):
    self.model = tf.keras.Sequential([
        tf.keras.layers.Flatten(input_shape=(IMG_SIZE, IMG_SIZE), name='flatten'),
        tf.keras.layers.Dense(128, activation='relu', name='dense_1'),
        tf.keras.layers.Dense(10, name='dense_2')
    ])

    self.model.compile(
        optimizer='sgd',
        loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True))

  # The `train` function takes a batch of input images and labels.
  @tf.function(input_signature=[
      tf.TensorSpec([None, IMG_SIZE, IMG_SIZE], tf.float32),
      tf.TensorSpec([None, 10], tf.float32),
  ])
  def train(self, x, y):
    with tf.GradientTape() as tape:
      prediction = self.model(x)
      loss = self.model.loss(y, prediction)
    gradients = tape.gradient(loss, self.model.trainable_variables)
    self.model.optimizer.apply_gradients(
        zip(gradients, self.model.trainable_variables))
    result = {"loss": loss}
    return result

  @tf.function(input_signature=[
      tf.TensorSpec([None, IMG_SIZE, IMG_SIZE], tf.float32),
  ])
  def infer(self, x):
    logits = self.model(x)
    probabilities = tf.nn.softmax(logits, axis=-1)
    return {
        "output": probabilities,
        "logits": logits
    }

  @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
  def save(self, checkpoint_path):
    tensor_names = [weight.name for weight in self.model.weights]
    tensors_to_save = [weight.read_value() for weight in self.model.weights]
    tf.raw_ops.Save(
        filename=checkpoint_path, tensor_names=tensor_names,
        data=tensors_to_save, name='save')
    return {
        "checkpoint_path": checkpoint_path
    }

  @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
  def restore(self, checkpoint_path):
    restored_tensors = {}
    for var in self.model.weights:
      restored = tf.raw_ops.Restore(
          file_pattern=checkpoint_path, tensor_name=var.name, dt=var.dtype,
          name='restore')
      var.assign(restored)
      restored_tensors[var.name] = restored
    return restored_tensors

The train fungsi dalam kode di atas menggunakan yang GradientTape kelas untuk merekam operasi untuk diferensiasi otomatis. Untuk informasi lebih lanjut tentang cara menggunakan kelas ini, lihat Pengantar gradien dan diferensiasi otomatis .

Anda bisa menggunakan Model.train_step metode model keras di sini bukan implementasi dari awal. Hanya Perhatikan bahwa kehilangan (dan metrik) dikembalikan oleh Model.train_step adalah berjalan rata-rata, dan harus diatur ulang secara berkala (biasanya setiap zaman). Lihat Customize Model.fit untuk rincian.

Siapkan datanya

Dapatkan set data Fashion MNIST untuk melatih model Anda.

fashion_mnist = tf.keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

Praproses kumpulan data

Nilai piksel dalam kumpulan data ini adalah antara 0 dan 255, dan harus dinormalisasi ke nilai antara 0 dan 1 untuk diproses oleh model. Bagi nilai dengan 255 untuk membuat penyesuaian ini.

train_images = (train_images / 255.0).astype(np.float32)
test_images = (test_images / 255.0).astype(np.float32)

Konversikan label data ke nilai kategoris dengan melakukan enkode satu-panas.

train_labels = tf.keras.utils.to_categorical(train_labels)
test_labels = tf.keras.utils.to_categorical(test_labels)

Latih modelnya

Sebelum mengkonversi dan menyiapkan Model TensorFlow Lite Anda, menyelesaikan pelatihan awal model Anda menggunakan dataset preprocessed dan train metode tanda tangan. Kode berikut menjalankan pelatihan model untuk 100 epoch, memproses batch 100 gambar sekaligus, dan menampilkan nilai kerugian setelah setiap 10 epoch. Karena pelatihan ini memproses cukup banyak data, mungkin perlu beberapa menit untuk menyelesaikannya.

NUM_EPOCHS = 100
BATCH_SIZE = 100
epochs = np.arange(1, NUM_EPOCHS + 1, 1)
losses = np.zeros([NUM_EPOCHS])
m = Model()

train_ds = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
train_ds = train_ds.batch(BATCH_SIZE)

for i in range(NUM_EPOCHS):
  for x,y in train_ds:
    result = m.train(x, y)

  losses[i] = result['loss']
  if (i + 1) % 10 == 0:
    print(f"Finished {i+1} epochs")
    print(f"  loss: {losses[i]:.3f}")

# Save the trained weights to a checkpoint.
m.save('/tmp/model.ckpt')
Finished 10 epochs
  loss: 0.428
Finished 20 epochs
  loss: 0.378
Finished 30 epochs
  loss: 0.344
Finished 40 epochs
  loss: 0.317
Finished 50 epochs
  loss: 0.299
Finished 60 epochs
  loss: 0.283
Finished 70 epochs
  loss: 0.266
Finished 80 epochs
  loss: 0.252
Finished 90 epochs
  loss: 0.240
Finished 100 epochs
  loss: 0.230
{'checkpoint_path': <tf.Tensor: shape=(), dtype=string, numpy=b'/tmp/model.ckpt'>}
plt.plot(epochs, losses, label='Pre-training')
plt.ylim([0, max(plt.ylim())])
plt.xlabel('Epoch')
plt.ylabel('Loss [Cross Entropy]')
plt.legend();

png

Konversi model ke format TensorFlow Lite

Setelah Anda memperluas model TensorFlow untuk mengaktifkan fungsi tambahan untuk pelatihan di perangkat dan menyelesaikan pelatihan awal model, Anda dapat mengonversinya ke format TensorFlow Lite. Bertobat kode berikut dan menyimpan model Anda ke format itu, termasuk set tanda tangan yang Anda gunakan dengan model TensorFlow Lite pada perangkat: train, infer, save, restore .

SAVED_MODEL_DIR = "saved_model"

tf.saved_model.save(
    m,
    SAVED_MODEL_DIR,
    signatures={
        'train':
            m.train.get_concrete_function(),
        'infer':
            m.infer.get_concrete_function(),
        'save':
            m.save.get_concrete_function(),
        'restore':
            m.restore.get_concrete_function(),
    })

# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_DIR)
converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS,  # enable TensorFlow Lite ops.
    tf.lite.OpsSet.SELECT_TF_OPS  # enable TensorFlow ops.
]
converter.experimental_enable_resource_variables = True
tflite_model = converter.convert()

Siapkan tanda tangan TensorFlow Lite

Model TensorFlow Lite yang Anda simpan di langkah sebelumnya berisi beberapa tanda tangan fungsi. Anda dapat mengaksesnya melalui tf.lite.Interpreter kelas dan meminta setiap restore , train , save , dan infer tanda tangan secara terpisah.

interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()

infer = interpreter.get_signature_runner("infer")

Bandingkan output dari model asli, dan model lite yang dikonversi:

logits_original = m.infer(x=train_images[:1])['logits'][0]
logits_lite = infer(x=train_images[:1])['logits'][0]

png

Di atas, Anda dapat melihat bahwa perilaku model tidak diubah oleh konversi ke TFLite.

Latih kembali model di perangkat

Setelah mengkonversi model Anda untuk TensorFlow Lite dan menggunakan dengan aplikasi Anda, Anda dapat melatih model pada perangkat menggunakan data baru dan train metode tanda tangan model Anda. Setiap latihan menghasilkan serangkaian bobot baru yang dapat Anda simpan untuk digunakan kembali dan peningkatan model lebih lanjut, seperti yang ditunjukkan di bagian berikutnya.

Di Android, Anda dapat melakukan pelatihan di perangkat dengan TensorFlow Lite menggunakan Java atau C++ API. Di Jawa, menggunakan Interpreter kelas untuk memuat model dan dorongan tugas model pelatihan. Contoh berikut menunjukkan bagaimana menjalankan prosedur pelatihan menggunakan runSignature metode:

try (Interpreter interpreter = new Interpreter(modelBuffer)) {
    int NUM_EPOCHS = 100;
    int BATCH_SIZE = 100;
    int IMG_HEIGHT = 28;
    int IMG_WIDTH = 28;
    int NUM_TRAININGS = 60000;
    int NUM_BATCHES = NUM_TRAININGS / BATCH_SIZE;

    List<FloatBuffer> trainImageBatches = new ArrayList<>(NUM_BATCHES);
    List<FloatBuffer> trainLabelBatches = new ArrayList<>(NUM_BATCHES);

    // Prepare training batches.
    for (int i = 0; i < NUM_BATCHES; ++i) {
        FloatBuffer trainImages = FloatBuffer.allocateDirect(BATCH_SIZE * IMG_HEIGHT * IMG_WIDTH).order(ByteOrder.nativeOrder());
        FloatBuffer trainLabels = FloatBuffer.allocateDirect(BATCH_SIZE * 10).order(ByteOrder.nativeOrder());

        // Fill the data values...
        trainImageBatches.add(trainImages.rewind());
        trainImageLabels.add(trainLabels.rewind());
    }

    // Run training for a few steps.
    float[] losses = new float[NUM_EPOCHS];
    for (int epoch = 0; epoch < NUM_EPOCHS; ++epoch) {
        for (int batchIdx = 0; batchIdx < NUM_BATCHES; ++batchIdx) {
            Map<String, Object> inputs = new HashMap<>();
            inputs.put("x", trainImageBatches.get(batchIdx));
            inputs.put("y", trainLabelBatches.get(batchIdx));

            Map<String, Object> outputs = new HashMap<>();
            FloatBuffer loss = FloatBuffer.allocate(1);
            outputs.put("loss", loss);

            interpreter.runSignature(inputs, outputs, "train");

            // Record the last loss.
            if (batchIdx == NUM_BATCHES - 1) losses[epoch] = loss.get(0);
        }

        // Print the loss output for every 10 epochs.
        if ((epoch + 1) % 10 == 0) {
            System.out.println(
              "Finished " + (epoch + 1) + " epochs, current loss: " + loss.get(0));
        }
    }

    // ...
}

Anda dapat melihat contoh kode lengkap model pelatihan kembali dalam sebuah aplikasi Android dalam aplikasi model yang personalisasi demo .

Jalankan pelatihan selama beberapa zaman untuk meningkatkan atau mempersonalisasi model. Dalam praktiknya, Anda akan menjalankan pelatihan tambahan ini menggunakan data yang dikumpulkan di perangkat. Untuk mempermudah, contoh ini menggunakan data pelatihan yang sama dengan langkah pelatihan sebelumnya.

train = interpreter.get_signature_runner("train")

NUM_EPOCHS = 50
BATCH_SIZE = 100
more_epochs = np.arange(epochs[-1]+1, epochs[-1] + NUM_EPOCHS + 1, 1)
more_losses = np.zeros([NUM_EPOCHS])


for i in range(NUM_EPOCHS):
  for x,y in train_ds:
    result = train(x=x, y=y)
  more_losses[i] = result['loss']
  if (i + 1) % 10 == 0:
    print(f"Finished {i+1} epochs")
    print(f"  loss: {more_losses[i]:.3f}")
Finished 10 epochs
  loss: 0.223
Finished 20 epochs
  loss: 0.216
Finished 30 epochs
  loss: 0.210
Finished 40 epochs
  loss: 0.204
Finished 50 epochs
  loss: 0.198
plt.plot(epochs, losses, label='Pre-training')
plt.plot(more_epochs, more_losses, label='On device')
plt.ylim([0, max(plt.ylim())])
plt.xlabel('Epoch')
plt.ylabel('Loss [Cross Entropy]')
plt.legend();

png

Di atas Anda dapat melihat bahwa pelatihan pada perangkat mengambil tepat di tempat prapelatihan berhenti.

Simpan beban terlatih

Saat Anda menyelesaikan latihan yang dijalankan di perangkat, model memperbarui kumpulan bobot yang digunakannya di memori. Menggunakan save metode tanda tangan Anda buat dalam model TensorFlow Lite Anda, Anda dapat menyimpan bobot ini ke file pos pemeriksaan untuk kemudian digunakan kembali dan memperbaiki model.

save = interpreter.get_signature_runner("save")

save(checkpoint_path=np.array("/tmp/model.ckpt", dtype=np.string_))
{'checkpoint_path': array(b'/tmp/model.ckpt', dtype=object)}

Di aplikasi Android, Anda dapat menyimpan bobot yang dihasilkan sebagai file pos pemeriksaan di ruang penyimpanan internal yang dialokasikan untuk aplikasi Anda.

try (Interpreter interpreter = new Interpreter(modelBuffer)) {
    // Conduct the training jobs.

    // Export the trained weights as a checkpoint file.
    File outputFile = new File(getFilesDir(), "checkpoint.ckpt");
    Map<String, Object> inputs = new HashMap<>();
    inputs.put("checkpoint_path", outputFile.getAbsolutePath());
    Map<String, Object> outputs = new HashMap<>();
    interpreter.runSignature(inputs, outputs, "save");
}

Kembalikan beban terlatih

Setiap kali Anda membuat juru bahasa dari model TFLite, juru bahasa pada awalnya akan memuat bobot model asli.

Jadi setelah Anda telah melakukan beberapa pelatihan dan disimpan file pos pemeriksaan, Anda akan perlu menjalankan restore metode tanda tangan untuk memuat pos pemeriksaan.

Aturan yang baik adalah "Setiap kali Anda membuat Juru Bahasa untuk model, jika ada pos pemeriksaan, muatlah". Jika Anda perlu mengatur ulang model ke perilaku dasar, hapus saja pos pemeriksaan dan buat penerjemah baru.

another_interpreter = tf.lite.Interpreter(model_content=tflite_model)
another_interpreter.allocate_tensors()

infer = another_interpreter.get_signature_runner("infer")
restore = another_interpreter.get_signature_runner("restore")
logits_before = infer(x=train_images[:1])['logits'][0]

# Restore the trained weights from /tmp/model.ckpt
restore(checkpoint_path=np.array("/tmp/model.ckpt", dtype=np.string_))

logits_after = infer(x=train_images[:1])['logits'][0]

compare_logits({'Before': logits_before, 'After': logits_after})

png

Pos pemeriksaan dihasilkan oleh pelatihan dan penyimpanan dengan TFLite. Di atas Anda dapat melihat bahwa menerapkan pos pemeriksaan memperbarui perilaku model.

Di aplikasi Android Anda, Anda dapat memulihkan serial, bobot terlatih dari file pos pemeriksaan yang Anda simpan sebelumnya.

try (Interpreter anotherInterpreter = new Interpreter(modelBuffer)) {
    // Load the trained weights from the checkpoint file.
    File outputFile = new File(getFilesDir(), "checkpoint.ckpt");
    Map<String, Object> inputs = new HashMap<>();
    inputs.put("checkpoint_path", outputFile.getAbsolutePath());
    Map<String, Object> outputs = new HashMap<>();
    anotherInterpreter.runSignature(inputs, outputs, "restore");
}

Jalankan Inferensi menggunakan bobot terlatih

Setelah Anda telah dimuat bobot disimpan sebelumnya dari file pos pemeriksaan, menjalankan infer metode penggunaan yang berat dengan model asli Anda untuk meningkatkan prediksi. Setelah memuat bobot disimpan, Anda dapat menggunakan infer metode tanda tangan seperti yang ditunjukkan di bawah ini.

infer = another_interpreter.get_signature_runner("infer")
result = infer(x=test_images)
predictions = np.argmax(result["output"], axis=1)

true_labels = np.argmax(test_labels, axis=1)
result['output'].shape
(10000, 10)

Plot label yang diprediksi.

class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

def plot(images, predictions, true_labels):
  plt.figure(figsize=(10,10))
  for i in range(25):
      plt.subplot(5,5,i+1)
      plt.xticks([])
      plt.yticks([])
      plt.grid(False)
      plt.imshow(images[i], cmap=plt.cm.binary)
      color = 'b' if predictions[i] == true_labels[i] else 'r'
      plt.xlabel(class_names[predictions[i]], color=color)
  plt.show()

plot(test_images, predictions, true_labels)

png

predictions.shape
(10000,)

Di aplikasi Android Anda, setelah memulihkan bobot terlatih, jalankan inferensi berdasarkan data yang dimuat.

try (Interpreter anotherInterpreter = new Interpreter(modelBuffer)) {
    // Restore the weights from the checkpoint file.

    int NUM_TESTS = 10;
    FloatBuffer testImages = FloatBuffer.allocateDirect(NUM_TESTS * 28 * 28).order(ByteOrder.nativeOrder());
    FloatBuffer output = FloatBuffer.allocateDirect(NUM_TESTS * 10).order(ByteOrder.nativeOrder());

    // Fill the test data.

    // Run the inference.
    Map<String, Object> inputs = new HashMap<>();
    inputs.put("x", testImages.rewind());
    Map<String, Object> outputs = new HashMap<>();
    outputs.put("output", output);
    anotherInterpreter.runSignature(inputs, outputs, "infer");
    output.rewind();

    // Process the result to get the final category values.
    int[] testLabels = new int[NUM_TESTS];
    for (int i = 0; i < NUM_TESTS; ++i) {
        int index = 0;
        for (int j = 1; j < 10; ++j) {
            if (output.get(i * 10 + index) < output.get(i * 10 + j)) index = testLabels[j];
        }
        testLabels[i] = index;
    }
}

Selamat! Anda sekarang telah membuat model TensorFlow Lite yang mendukung pelatihan di perangkat. Untuk lebih coding detail, memeriksa pelaksanaan contoh dalam aplikasi model yang personalisasi demo .

Jika Anda tertarik untuk belajar lebih banyak tentang klasifikasi citra, periksa Keras klasifikasi tutorial di TensorFlow halaman panduan resmi. Tutorial ini didasarkan pada latihan itu dan memberikan lebih banyak kedalaman tentang subjek klasifikasi.