ช่วยปกป้อง Great Barrier Reef กับ TensorFlow บน Kaggle เข้าร่วมท้าทาย

การฝึกอบรมบนอุปกรณ์ด้วย TensorFlow Lite

ดูบน TensorFlow.org ทำงานใน Google Colab ดูแหล่งที่มาบน GitHub ดาวน์โหลดโน๊ตบุ๊ค

เมื่อปรับใช้โมเดลแมชชีนเลิร์นนิง TensorFlow Lite กับอุปกรณ์หรือแอพมือถือ คุณอาจต้องการเปิดใช้งานโมเดลเพื่อปรับปรุงหรือปรับแต่งตามอินพุตจากอุปกรณ์หรือผู้ใช้ปลายทาง โดยใช้เทคนิคการฝึกอบรมเกี่ยวกับอุปกรณ์ช่วยให้คุณสามารถปรับปรุงรูปแบบที่ไม่มีข้อมูลออกจากอุปกรณ์ของผู้ใช้ในการปรับปรุงความเป็นส่วนตัวของผู้ใช้และโดยไม่ต้องให้ผู้ใช้สามารถอัพเดตซอฟต์แวร์อุปกรณ์

ตัวอย่างเช่น คุณอาจมีนางแบบในแอพมือถือของคุณที่รู้จักสินค้าแฟชั่น แต่คุณต้องการให้ผู้ใช้ได้รับประสิทธิภาพการจดจำที่ดีขึ้นเมื่อเวลาผ่านไปตามความสนใจของพวกเขา การเปิดใช้งานการฝึกอบรมบนอุปกรณ์ช่วยให้ผู้ใช้ที่สนใจรองเท้าสามารถจดจำสไตล์เฉพาะของรองเท้าหรือแบรนด์รองเท้าได้ดีขึ้น ยิ่งพวกเขาใช้แอปของคุณบ่อยขึ้น

บทช่วยสอนนี้จะแสดงให้คุณเห็นถึงวิธีสร้างโมเดล TensorFlow Lite ที่สามารถฝึกฝนและปรับปรุงแบบค่อยเป็นค่อยไปภายในแอพ Android ที่ติดตั้ง

ติดตั้ง

บทช่วยสอนนี้ใช้ Python ในการฝึกและแปลงโมเดล TensorFlow ก่อนที่จะรวมเข้ากับแอป Android เริ่มต้นด้วยการติดตั้งและนำเข้าแพ็คเกจต่อไปนี้

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

print("TensorFlow version:", tf.__version__)
TensorFlow version: 2.8.0

จำแนกภาพเสื้อผ้า

รหัสตัวอย่างนี้ใช้ ชุดแฟชั่น MNIST ในการฝึกอบรมรูปแบบเครือข่ายประสาทเทียมสำหรับการจำแนกภาพของเสื้อผ้า ชุดข้อมูลนี้ประกอบด้วยรูปภาพระดับสีเทาขนาดเล็ก 60,000 ภาพ (28 x 28 พิกเซล) ซึ่งประกอบด้วยเครื่องประดับแฟชั่น 10 หมวดหมู่ รวมทั้งเดรส เสื้อเชิ้ต และรองเท้าแตะ

ภาพแฟชั่น MNIST
รูปที่ 1: ตัวอย่างแฟชั่น MNIST (โดย Zalando, MIT ใบอนุญาต)

คุณสามารถสำรวจชุดนี้ในเชิงลึกมากขึ้นใน การจัดหมวดหมู่ Keras กวดวิชา

สร้างแบบจำลองสำหรับการฝึกอบรมในอุปกรณ์

รุ่น Lite TensorFlow มักจะมีเพียงวิธีการเดียวสัมผัสฟังก์ชั่น (หรือ ลายเซ็น ) ที่ช่วยให้คุณสามารถโทรรูปแบบในการทำงานอนุมาน สำหรับรุ่นที่จะฝึกและใช้งานบนอุปกรณ์ คุณต้องสามารถดำเนินการแยกกันได้หลายอย่าง รวมถึงฟังก์ชั่นการฝึก อนุมาน บันทึก และกู้คืนสำหรับรุ่นนั้น คุณสามารถเปิดใช้งานฟังก์ชันนี้โดยขยายโมเดล TensorFlow ของคุณให้มีฟังก์ชันหลายอย่างก่อน จากนั้นจึงเปิดเผยฟังก์ชันเหล่านี้เป็นลายเซ็นเมื่อคุณแปลงโมเดลของคุณเป็นรูปแบบโมเดล TensorFlow Lite

ตัวอย่างโค้ดด้านล่างแสดงวิธีเพิ่มฟังก์ชันต่อไปนี้ให้กับโมเดล TensorFlow:

  • train ฟังก์ชั่นรถไฟรุ่นที่มีข้อมูลการฝึกอบรม
  • infer ฟังก์ชั่นเรียกอนุมาน
  • save ฟังก์ชั่นบันทึกน้ำหนักสุวินัยเข้าสู่ระบบไฟล์
  • restore โหลดฟังก์ชั่นน้ำหนักสุวินัยจากระบบแฟ้ม
IMG_SIZE = 28

class Model(tf.Module):

  def __init__(self):
    self.model = tf.keras.Sequential([
        tf.keras.layers.Flatten(input_shape=(IMG_SIZE, IMG_SIZE), name='flatten'),
        tf.keras.layers.Dense(128, activation='relu', name='dense_1'),
        tf.keras.layers.Dense(10, name='dense_2')
    ])

    self.model.compile(
        optimizer='sgd',
        loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True))

  # The `train` function takes a batch of input images and labels.
  @tf.function(input_signature=[
      tf.TensorSpec([None, IMG_SIZE, IMG_SIZE], tf.float32),
      tf.TensorSpec([None, 10], tf.float32),
  ])
  def train(self, x, y):
    with tf.GradientTape() as tape:
      prediction = self.model(x)
      loss = self.model.loss(y, prediction)
    gradients = tape.gradient(loss, self.model.trainable_variables)
    self.model.optimizer.apply_gradients(
        zip(gradients, self.model.trainable_variables))
    result = {"loss": loss}
    return result

  @tf.function(input_signature=[
      tf.TensorSpec([None, IMG_SIZE, IMG_SIZE], tf.float32),
  ])
  def infer(self, x):
    logits = self.model(x)
    probabilities = tf.nn.softmax(logits, axis=-1)
    return {
        "output": probabilities,
        "logits": logits
    }

  @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
  def save(self, checkpoint_path):
    tensor_names = [weight.name for weight in self.model.weights]
    tensors_to_save = [weight.read_value() for weight in self.model.weights]
    tf.raw_ops.Save(
        filename=checkpoint_path, tensor_names=tensor_names,
        data=tensors_to_save, name='save')
    return {
        "checkpoint_path": checkpoint_path
    }

  @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
  def restore(self, checkpoint_path):
    restored_tensors = {}
    for var in self.model.weights:
      restored = tf.raw_ops.Restore(
          file_pattern=checkpoint_path, tensor_name=var.name, dt=var.dtype,
          name='restore')
      var.assign(restored)
      restored_tensors[var.name] = restored
    return restored_tensors

train ฟังก์ชั่นในการใช้โค้ดข้างต้นที่ GradientTape ชั้นเรียนเพื่อการดำเนินงานบันทึกสำหรับความแตกต่างโดยอัตโนมัติ สำหรับข้อมูลเพิ่มเติมเกี่ยวกับวิธีการใช้คลาสนี้ดู รู้เบื้องต้นเกี่ยวกับการไล่ระดับสีและความแตกต่างโดยอัตโนมัติ

คุณสามารถใช้ Model.train_step วิธีการรูปแบบ keras นี่แทนการดำเนินงานจากรอยขีดข่วน ทราบเพียงว่าการสูญเสีย (และตัวชี้วัด) ส่งกลับโดย Model.train_step เป็นวิ่งเฉลี่ยและควรได้รับการตั้งค่าอย่างสม่ำเสมอ (โดยปกติแต่ละยุค) ดู กำหนด Model.fit สำหรับรายละเอียด

เตรียมข้อมูล

รับชุดข้อมูล Fashion MNIST สำหรับฝึกโมเดลของคุณ

fashion_mnist = tf.keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

ประมวลผลชุดข้อมูลล่วงหน้า

ค่าพิกเซลในชุดข้อมูลนี้อยู่ระหว่าง 0 ถึง 255 และต้องทำให้เป็นมาตรฐานเป็นค่าระหว่าง 0 ถึง 1 สำหรับการประมวลผลโดยโมเดล หารค่าด้วย 255 เพื่อทำการปรับค่านี้

train_images = (train_images / 255.0).astype(np.float32)
test_images = (test_images / 255.0).astype(np.float32)

แปลงป้ายข้อมูลเป็นค่าหมวดหมู่โดยทำการเข้ารหัสแบบใช้ครั้งเดียว

train_labels = tf.keras.utils.to_categorical(train_labels)
test_labels = tf.keras.utils.to_categorical(test_labels)

ฝึกโมเดล

ก่อนที่จะแปลงและการตั้งค่ารูปแบบ TensorFlow Lite ของคุณเสร็จสิ้นการฝึกอบรมครั้งแรกของรูปแบบของคุณโดยใช้ชุดข้อมูลที่ประมวลผลล่วงหน้าและ train วิธีลายเซ็น โค้ดต่อไปนี้รันการฝึกโมเดลสำหรับ 100 epochs ประมวลผลเป็นชุดละ 100 ภาพในแต่ละครั้ง และแสดงค่าที่สูญเสียหลังจากทุกๆ 10 epoch เนื่องจากการฝึกวิ่งนี้กำลังประมวลผลข้อมูลค่อนข้างน้อย จึงอาจใช้เวลาสองสามนาทีจึงจะเสร็จ

NUM_EPOCHS = 100
BATCH_SIZE = 100
epochs = np.arange(1, NUM_EPOCHS + 1, 1)
losses = np.zeros([NUM_EPOCHS])
m = Model()

train_ds = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
train_ds = train_ds.batch(BATCH_SIZE)

for i in range(NUM_EPOCHS):
  for x,y in train_ds:
    result = m.train(x, y)

  losses[i] = result['loss']
  if (i + 1) % 10 == 0:
    print(f"Finished {i+1} epochs")
    print(f"  loss: {losses[i]:.3f}")

# Save the trained weights to a checkpoint.
m.save('/tmp/model.ckpt')
Finished 10 epochs
  loss: 0.428
Finished 20 epochs
  loss: 0.378
Finished 30 epochs
  loss: 0.344
Finished 40 epochs
  loss: 0.317
Finished 50 epochs
  loss: 0.299
Finished 60 epochs
  loss: 0.283
Finished 70 epochs
  loss: 0.266
Finished 80 epochs
  loss: 0.252
Finished 90 epochs
  loss: 0.240
Finished 100 epochs
  loss: 0.230
{'checkpoint_path': <tf.Tensor: shape=(), dtype=string, numpy=b'/tmp/model.ckpt'>}
plt.plot(epochs, losses, label='Pre-training')
plt.ylim([0, max(plt.ylim())])
plt.xlabel('Epoch')
plt.ylabel('Loss [Cross Entropy]')
plt.legend();

png

แปลงโมเดลเป็นรูปแบบ TensorFlow Lite

หลังจากที่คุณได้ขยายโมเดล TensorFlow ของคุณเพื่อเปิดใช้งานฟังก์ชันเพิ่มเติมสำหรับการฝึกบนอุปกรณ์และทำการฝึกโมเดลเบื้องต้นเสร็จสิ้นแล้ว คุณสามารถแปลงเป็นรูปแบบ TensorFlow Lite ได้ ต่อไปนี้แปลงรหัสและบันทึกรูปแบบของคุณเป็นรูปแบบที่รวมทั้งชุดของลายเซ็นที่คุณใช้กับรูปแบบ TensorFlow Lite บนอุปกรณ์: train, infer, save, restore

SAVED_MODEL_DIR = "saved_model"

tf.saved_model.save(
    m,
    SAVED_MODEL_DIR,
    signatures={
        'train':
            m.train.get_concrete_function(),
        'infer':
            m.infer.get_concrete_function(),
        'save':
            m.save.get_concrete_function(),
        'restore':
            m.restore.get_concrete_function(),
    })

# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_DIR)
converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS,  # enable TensorFlow Lite ops.
    tf.lite.OpsSet.SELECT_TF_OPS  # enable TensorFlow ops.
]
converter.experimental_enable_resource_variables = True
tflite_model = converter.convert()

ตั้งค่าลายเซ็น TensorFlow Lite

โมเดล TensorFlow Lite ที่คุณบันทึกไว้ในขั้นตอนก่อนหน้านี้มีลายเซ็นของฟังก์ชันหลายอย่าง คุณสามารถเข้าถึงพวกเขาผ่าน tf.lite.Interpreter ระดับและเรียกแต่ละ restore , train , save และ infer ลายเซ็นแยกต่างหาก

interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()

infer = interpreter.get_signature_runner("infer")

เปรียบเทียบผลลัพธ์ของรุ่นดั้งเดิมและรุ่น Lite ที่แปลงแล้ว:

logits_original = m.infer(x=train_images[:1])['logits'][0]
logits_lite = infer(x=train_images[:1])['logits'][0]

png

ด้านบน คุณจะเห็นว่าพฤติกรรมของโมเดลไม่เปลี่ยนแปลงโดยการแปลงเป็น TFLite

ฝึกโมเดลใหม่บนอุปกรณ์

หลังจากแปลงรูปแบบของคุณเพื่อ TensorFlow Lite และปรับใช้กับแอปของคุณคุณสามารถฝึกให้โมเดลบนอุปกรณ์ที่ใช้ข้อมูลใหม่และ train วิธีลายเซ็นของรูปแบบของคุณ การฝึกวิ่งแต่ละครั้งจะสร้างตุ้มน้ำหนักชุดใหม่ ซึ่งคุณสามารถบันทึกเพื่อนำกลับมาใช้ใหม่และปรับปรุงโมเดลต่อไปได้ ดังแสดงในหัวข้อถัดไป

บน Android คุณสามารถทำการฝึกอบรมบนอุปกรณ์ด้วย TensorFlow Lite โดยใช้ Java หรือ C++ API ใน Java ใช้ Interpreter ระดับการโหลดรูปแบบไดรฟ์และงานฝึกอบรมรุ่น แสดงให้เห็นตัวอย่างต่อไปนี้วิธีการเรียกใช้ขั้นตอนการฝึกอบรมโดยใช้ runSignature วิธีการ:

try (Interpreter interpreter = new Interpreter(modelBuffer)) {
    int NUM_EPOCHS = 100;
    int BATCH_SIZE = 100;
    int IMG_HEIGHT = 28;
    int IMG_WIDTH = 28;
    int NUM_TRAININGS = 60000;
    int NUM_BATCHES = NUM_TRAININGS / BATCH_SIZE;

    List<FloatBuffer> trainImageBatches = new ArrayList<>(NUM_BATCHES);
    List<FloatBuffer> trainLabelBatches = new ArrayList<>(NUM_BATCHES);

    // Prepare training batches.
    for (int i = 0; i < NUM_BATCHES; ++i) {
        FloatBuffer trainImages = FloatBuffer.allocateDirect(BATCH_SIZE * IMG_HEIGHT * IMG_WIDTH).order(ByteOrder.nativeOrder());
        FloatBuffer trainLabels = FloatBuffer.allocateDirect(BATCH_SIZE * 10).order(ByteOrder.nativeOrder());

        // Fill the data values...
        trainImageBatches.add(trainImages.rewind());
        trainImageLabels.add(trainLabels.rewind());
    }

    // Run training for a few steps.
    float[] losses = new float[NUM_EPOCHS];
    for (int epoch = 0; epoch < NUM_EPOCHS; ++epoch) {
        for (int batchIdx = 0; batchIdx < NUM_BATCHES; ++batchIdx) {
            Map<String, Object> inputs = new HashMap<>();
            inputs.put("x", trainImageBatches.get(batchIdx));
            inputs.put("y", trainLabelBatches.get(batchIdx));

            Map<String, Object> outputs = new HashMap<>();
            FloatBuffer loss = FloatBuffer.allocate(1);
            outputs.put("loss", loss);

            interpreter.runSignature(inputs, outputs, "train");

            // Record the last loss.
            if (batchIdx == NUM_BATCHES - 1) losses[epoch] = loss.get(0);
        }

        // Print the loss output for every 10 epochs.
        if ((epoch + 1) % 10 == 0) {
            System.out.println(
              "Finished " + (epoch + 1) + " epochs, current loss: " + loss.get(0));
        }
    }

    // ...
}

คุณสามารถดูตัวอย่างโค้ดที่สมบูรณ์ของรูปแบบการฝึกอบรมภายในแอป Android ใน app ตัวอย่างรูปแบบส่วนบุคคล

ดำเนินการฝึกอบรมสักสองสามช่วงเพื่อปรับปรุงหรือปรับแต่งโมเดล ในทางปฏิบัติ คุณจะดำเนินการฝึกอบรมเพิ่มเติมนี้โดยใช้ข้อมูลที่รวบรวมบนอุปกรณ์ เพื่อความง่าย ตัวอย่างนี้ใช้ข้อมูลการฝึกเดียวกันกับขั้นตอนการฝึกก่อนหน้านี้

train = interpreter.get_signature_runner("train")

NUM_EPOCHS = 50
BATCH_SIZE = 100
more_epochs = np.arange(epochs[-1]+1, epochs[-1] + NUM_EPOCHS + 1, 1)
more_losses = np.zeros([NUM_EPOCHS])


for i in range(NUM_EPOCHS):
  for x,y in train_ds:
    result = train(x=x, y=y)
  more_losses[i] = result['loss']
  if (i + 1) % 10 == 0:
    print(f"Finished {i+1} epochs")
    print(f"  loss: {more_losses[i]:.3f}")
Finished 10 epochs
  loss: 0.223
Finished 20 epochs
  loss: 0.216
Finished 30 epochs
  loss: 0.210
Finished 40 epochs
  loss: 0.204
Finished 50 epochs
  loss: 0.198
plt.plot(epochs, losses, label='Pre-training')
plt.plot(more_epochs, more_losses, label='On device')
plt.ylim([0, max(plt.ylim())])
plt.xlabel('Epoch')
plt.ylabel('Loss [Cross Entropy]')
plt.legend();

png

ด้านบน คุณจะเห็นว่าการฝึกในอุปกรณ์จะเลือกจุดที่การฝึกเตรียมการหยุดลง

บันทึกน้ำหนักการฝึกอบรม

เมื่อคุณเสร็จสิ้นการฝึกวิ่งบนอุปกรณ์ โมเดลจะอัปเดตชุดน้ำหนักที่ใช้ในหน่วยความจำ ใช้ save ลายเซ็นวิธีคุณสร้างในรูปแบบ TensorFlow Lite ของคุณคุณสามารถบันทึกน้ำหนักเหล่านี้ไปยังแฟ้มจุดตรวจสอบเพื่อนำมาใช้ในภายหลังและปรับปรุงรูปแบบของคุณ

save = interpreter.get_signature_runner("save")

save(checkpoint_path=np.array("/tmp/model.ckpt", dtype=np.string_))
{'checkpoint_path': array(b'/tmp/model.ckpt', dtype=object)}

ในแอปพลิเคชัน Android คุณสามารถจัดเก็บน้ำหนักที่สร้างขึ้นเป็นไฟล์จุดตรวจสอบในพื้นที่จัดเก็บข้อมูลภายในที่จัดสรรให้กับแอปของคุณ

try (Interpreter interpreter = new Interpreter(modelBuffer)) {
    // Conduct the training jobs.

    // Export the trained weights as a checkpoint file.
    File outputFile = new File(getFilesDir(), "checkpoint.ckpt");
    Map<String, Object> inputs = new HashMap<>();
    inputs.put("checkpoint_path", outputFile.getAbsolutePath());
    Map<String, Object> outputs = new HashMap<>();
    interpreter.runSignature(inputs, outputs, "save");
}

คืนค่าน้ำหนักที่ฝึก

ทุกครั้งที่คุณสร้างล่ามจากรุ่น TFLite ล่ามจะโหลดน้ำหนักรุ่นดั้งเดิมในขั้นต้น

ดังนั้นหลังจากที่คุณได้กระทำการฝึกอบรมและบันทึกไฟล์ด่านคุณจะต้องเรียกใช้ restore วิธีลายเซ็นให้กับโหลดด่าน

กฎที่ดีคือ "ทุกครั้งที่คุณสร้างล่ามสำหรับโมเดล หากมีจุดตรวจ ให้โหลดมัน" หากคุณต้องการรีเซ็ตโมเดลเป็นการทำงานพื้นฐาน ให้ลบจุดตรวจสอบและสร้างล่ามใหม่

another_interpreter = tf.lite.Interpreter(model_content=tflite_model)
another_interpreter.allocate_tensors()

infer = another_interpreter.get_signature_runner("infer")
restore = another_interpreter.get_signature_runner("restore")
logits_before = infer(x=train_images[:1])['logits'][0]

# Restore the trained weights from /tmp/model.ckpt
restore(checkpoint_path=np.array("/tmp/model.ckpt", dtype=np.string_))

logits_after = infer(x=train_images[:1])['logits'][0]

compare_logits({'Before': logits_before, 'After': logits_after})

png

จุดตรวจถูกสร้างขึ้นโดยการฝึกอบรมและการบันทึกด้วย TFLite ด้านบน คุณจะเห็นว่าการใช้จุดตรวจจะปรับปรุงพฤติกรรมของโมเดล

ในแอป Android คุณสามารถกู้คืนตุ้มน้ำหนักที่ผ่านการฝึกฝนและจัดลำดับจากไฟล์จุดตรวจที่คุณเก็บไว้ก่อนหน้านี้ได้

try (Interpreter anotherInterpreter = new Interpreter(modelBuffer)) {
    // Load the trained weights from the checkpoint file.
    File outputFile = new File(getFilesDir(), "checkpoint.ckpt");
    Map<String, Object> inputs = new HashMap<>();
    inputs.put("checkpoint_path", outputFile.getAbsolutePath());
    Map<String, Object> outputs = new HashMap<>();
    anotherInterpreter.runSignature(inputs, outputs, "restore");
}

เรียกใช้การอนุมานโดยใช้ตุ้มน้ำหนักที่ฝึกแล้ว

เมื่อคุณได้โหลดน้ำหนักบันทึกไว้ก่อนหน้าจากแฟ้มด่านวิ่ง infer วิธีการใช้น้ำหนักที่มีรูปแบบเดิมของคุณเพื่อปรับปรุงการคาดการณ์ หลังจากโหลดน้ำหนักที่บันทึกไว้, คุณสามารถใช้ infer วิธีลายเซ็นที่แสดงด้านล่าง

infer = another_interpreter.get_signature_runner("infer")
result = infer(x=test_images)
predictions = np.argmax(result["output"], axis=1)

true_labels = np.argmax(test_labels, axis=1)
result['output'].shape
(10000, 10)

พล็อตฉลากที่คาดการณ์ไว้

class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

def plot(images, predictions, true_labels):
  plt.figure(figsize=(10,10))
  for i in range(25):
      plt.subplot(5,5,i+1)
      plt.xticks([])
      plt.yticks([])
      plt.grid(False)
      plt.imshow(images[i], cmap=plt.cm.binary)
      color = 'b' if predictions[i] == true_labels[i] else 'r'
      plt.xlabel(class_names[predictions[i]], color=color)
  plt.show()

plot(test_images, predictions, true_labels)

png

predictions.shape
(10000,)

ในแอปพลิเคชัน Android ของคุณ หลังจากกู้คืนตุ้มน้ำหนักที่ฝึกแล้ว ให้เรียกใช้การอนุมานตามข้อมูลที่โหลด

try (Interpreter anotherInterpreter = new Interpreter(modelBuffer)) {
    // Restore the weights from the checkpoint file.

    int NUM_TESTS = 10;
    FloatBuffer testImages = FloatBuffer.allocateDirect(NUM_TESTS * 28 * 28).order(ByteOrder.nativeOrder());
    FloatBuffer output = FloatBuffer.allocateDirect(NUM_TESTS * 10).order(ByteOrder.nativeOrder());

    // Fill the test data.

    // Run the inference.
    Map<String, Object> inputs = new HashMap<>();
    inputs.put("x", testImages.rewind());
    Map<String, Object> outputs = new HashMap<>();
    outputs.put("output", output);
    anotherInterpreter.runSignature(inputs, outputs, "infer");
    output.rewind();

    // Process the result to get the final category values.
    int[] testLabels = new int[NUM_TESTS];
    for (int i = 0; i < NUM_TESTS; ++i) {
        int index = 0;
        for (int j = 1; j < 10; ++j) {
            if (output.get(i * 10 + index) < output.get(i * 10 + j)) index = testLabels[j];
        }
        testLabels[i] = index;
    }
}

ยินดีด้วย! ตอนนี้คุณได้สร้างโมเดล TensorFlow Lite ที่รองรับการฝึกอบรมในอุปกรณ์แล้ว สำหรับข้อมูลเพิ่มเติมเข้ารหัสรายละเอียดให้ตรวจสอบการดำเนินงานเช่นใน แอปพลิเคสาธิตรูปแบบส่วนบุคคล

หากคุณมีความสนใจในการเรียนรู้เพิ่มเติมเกี่ยวกับการจัดหมวดหมู่ภาพตรวจสอบ Keras จำแนกกวดวิชา ใน TensorFlow หน้าคู่มืออย่างเป็นทางการ บทช่วยสอนนี้อิงตามแบบฝึกหัดนั้นและให้ความลึกมากขึ้นในเรื่องการจัดหมวดหมู่