¡El Día de la Comunidad de ML es el 9 de noviembre! Únase a nosotros para recibir actualizaciones de TensorFlow, JAX, y más Más información

Entrenamiento en el dispositivo en TensorFlow Lite

Ver en TensorFlow.org Ejecutar en Google Colab Ver fuente en GitHub Descargar cuaderno

Para usar TensorFlow Lite, normalmente debes preparar un modelo de TensorFlow, usar el convertidor para convertirlo al formato de modelo de TensorFlow Lite y ejecutar el modelo con TensorFlow Lite en el dispositivo. Puede seguir el mismo flujo de formación y modelos de inferencia .

El código siguiente ilustra el flujo de alto nivel de preparación de un modelo de formación TensorFlow, convirtiéndola en modelo TensorFlow Lite y funcionando en TensorFlow Lite para tipos de ropa clasificar usando el conjunto de datos moda MNIST .

La aplicación se basa en el ejemplo de clasificación Keras en la página de la guía oficial TensorFlow.

Con este Colab, puede explorar nuevas formas de entrenar modelos de aprendizaje automático listos para usar con TensorFlow Lite.

Configuración

Para que Colab funcione, deberá descargar e instalar los siguientes paquetes.

pip uninstall -y tensorflow keras
pip install tf-nightly
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

Clasifica imágenes de ropa

Este Colab entrena un modelo de red neuronal en el conjunto de datos moda MNIST de imágenes Clasificar de prendas de vestir, como zapatillas de deporte y camisetas.

Usaremos 60,000 imágenes para entrenar la red y 10,000 imágenes para probar la precisión del modelo: ¿qué tan bien aprendió el modelo a clasificar esas imágenes correctamente? Puede acceder a Fashion MNIST directamente desde TensorFlow. Importación y cargar el conjunto de datos moda MNIST directamente de TensorFlow y volverá cuatro matrices numpy:

  • Los train_images y train_labels matrices son el conjunto de entrenamiento -Los datos de los usos modelo de aprender.
  • El modelo se prueba contra el equipo de prueba, los test_images , y test_labels arrays.
fashion_mnist = tf.keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

Las imágenes son 28x28 matrices NumPy, con valores de píxeles que van de 0 a 255. Las etiquetas son una matriz de enteros que van de 0 a 9. Estos corresponden a la clase de prendas de vestir representa la imagen:

train_images.shape
(60000, 28, 28)
Etiqueta Clase
0 Camiseta / top
1 Pantalón
2 Pull-over
3 Vestido
4 Abrigo
5 Sandalia
6 Camisa
7 Zapatilla
8 Bolso
9 Botín

Cada imagen se asigna a una sola etiqueta. Dado que los nombres de las clases no están incluidas en el conjunto de datos, almacenarlos para su uso posterior al trazar las imágenes:

class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

Familiaricémonos con el aspecto de las imágenes de nuestro conjunto de formación trazando las imágenes y las etiquetas de formación.

Definir una función plot() para imprimir 25 imágenes y sus etiquetas.

def plot(images, labels):
  plt.figure(figsize=(10,10))
  for i in range(25):
      plt.subplot(5,5,i+1)
      plt.xticks([])
      plt.yticks([])
      plt.grid(False)
      plt.imshow(images[i], cmap=plt.cm.binary)
      plt.xlabel(class_names[labels[i]])
  plt.show()

Imágenes de plot training

plot(train_images, train_labels)

png

Modelo de TensorFlow para entrenamiento

En lugar de convertir un solo modelo de TensorFlow o tf.function en un modelo de TensorFlow Lite con un solo punto de entrada, podemos convertir múltiples tf.function (s) en un modelo de TensorFlow Lite. Para poder hacer eso, estamos ampliando el convertidor y el tiempo de ejecución de TensorFlow Lite para manejar múltiples firmas.

Preparando un modelo de TensorFlow. El código construye un módulo tf con 4 funciones tf:

  • La función de tren entrena el modelo con datos de entrenamiento.
  • inferir la función invoca la inferencia.
  • La función de guardar guarda los pesos entrenables en el sistema de archivos.
  • La función de restauración carga los pesos entrenables del sistema de archivos.

Los pesos se serializarán como un formato de archivo de punto de control de la versión uno de TensorFlow.

IMG_SIZE = 28


class Model(tf.Module):

  def __init__(self):
    self.model = tf.keras.Sequential([
        tf.keras.layers.Flatten(input_shape=(IMG_SIZE, IMG_SIZE)),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ])

    self.model.compile(
        optimizer='sgd',
        loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
        metrics=['accuracy'])
    self._LOSS_FN = tf.keras.losses.CategoricalCrossentropy()
    self._OPTIM = tf.optimizers.SGD()

  # The `train` function takes a batch of input images and labels.
  @tf.function(input_signature=[
      tf.TensorSpec([None, IMG_SIZE, IMG_SIZE], tf.float32),
      tf.TensorSpec([None, 10], tf.float32),
  ])
  def train(self, x, y):
    # Gradient tape is used for recording operations for automatic
    # differentiation. You can refer to
    # https://www.tensorflow.org/api_docs/python/tf/GradientTape for more
    # details on how to use it.
    with tf.GradientTape() as tape:
      prediction = self.model(x)
      loss = self._LOSS_FN(prediction, y)
    gradients = tape.gradient(loss, self.model.trainable_variables)
    self._OPTIM.apply_gradients(
        zip(gradients, self.model.trainable_variables))
    result = {"loss": loss}
    for grad in gradients:
      result[grad.name] = grad
    return result

  @tf.function(input_signature=[tf.TensorSpec([None, IMG_SIZE, IMG_SIZE], tf.float32)])
  def predict(self, x):
    return {
        "output": self.model(x)
    }

  @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
  def save(self, checkpoint_path):
    tensor_names = [weight.name for weight in self.model.weights]
    tensors_to_save = [weight.read_value() for weight in self.model.weights]
    tf.raw_ops.Save(
        filename=checkpoint_path, tensor_names=tensor_names,
        data=tensors_to_save, name='save')
    return {
        "checkpoint_path": checkpoint_path
    }

  @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
  def restore(self, checkpoint_path):
    restored_tensors = {}
    for var in self.model.weights:
      restored = tf.raw_ops.Restore(
          file_pattern=checkpoint_path, tensor_name=var.name, dt=var.dtype,
          name='restore')
      var.assign(restored)
      restored_tensors[var.name] = restored
    return restored_tensors

Convertir al modelo de TensorFlow Lite

Ahora tenemos el modelo de TensorFlow, pero para asegurarnos de que se ejecute en TensorFlow Lite, tenemos que convertirlo. Este código convierte nuestro modelo para el formato correcto, y, además, nuestra salida será un conjunto de 4 firmas que vamos a utilizar para ejecutar el modelo Lite TensorFlow a continuación: train, infer, save, restore .

# Export the TensorFlow model to the saved model
SAVED_MODEL_DIR = "saved_model"
m= Model()
tf.saved_model.save(
    m,
    SAVED_MODEL_DIR,
    signatures={
        'train':
            m.train.get_concrete_function(),
        'infer':
            m.predict.get_concrete_function(),
        'save':
            m.save.get_concrete_function(),
        'restore':
            m.restore.get_concrete_function(),
    })

# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_DIR)
converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS,  # enable TensorFlow Lite ops.
    tf.lite.OpsSet.SELECT_TF_OPS  # enable TensorFlow ops.
]
converter.experimental_enable_resource_variables = True
tflite_model = converter.convert()

Entrena el modelo de TensorFlow Lite

Procesar previamente el conjunto de datos

  • Escale las imágenes de 0 a 255 a un rango de 0 a 1 antes de alimentarlas al modelo de red neuronal. Para hacerlo, divida los valores por 255.
  • Convierta etiquetas en valores categóricos, es decir, realice una codificación one-hot.

Es importante que el conjunto de entrenamiento y el conjunto de las pruebas de ser preprocesados de la misma manera.

train_images = train_images / 255.0
test_images = test_images / 255.0

train_labels = tf.keras.utils.to_categorical(train_labels)
test_labels = tf.keras.utils.to_categorical(test_labels)

Configurar TensorFlow Lite

El convertidor e intérprete de TensorFlow Lite admite firma múltiple. Los desarrolladores pueden optar por invocar la restauración, el entrenamiento, el guardado y la inferencia de firmas por separado.

interpreter = tf.lite.Interpreter(model_content=tflite_model)

train = interpreter.get_signature_runner("train")
infer = interpreter.get_signature_runner("infer")
save = interpreter.get_signature_runner("save")
restore = interpreter.get_signature_runner("restore")

En Android, el entrenamiento en el dispositivo de TensorFlow Lite se puede realizar con las API de Java o C ++. En este documento, describimos cómo funcionará el modelo de TensorFlow Lite anterior con la API de Java.

Entrena el modelo

Utilizando el conjunto de datos preprocesado y el train método de firma.

# Run training for a few steps, this may take a few minutes to finish. The loss
# value will be printed every 10 epochs.
# Within each epoch, we will split the training data into batches of size 100.
NUM_EPOCHS = 100
BATCH_SIZE = 100
epochs = np.arange(1, NUM_EPOCHS + 1, 1)
losses = np.zeros([NUM_EPOCHS])

for i in range(NUM_EPOCHS):
  for batch_idx in range(len(train_images) // BATCH_SIZE):
    batched_images = train_images[BATCH_SIZE*(batch_idx) : BATCH_SIZE * (batch_idx + 1)]
    batched_labels = train_labels[BATCH_SIZE*(batch_idx) : BATCH_SIZE * (batch_idx + 1)]
    result = train(
        x=tf.constant(batched_images, shape=(BATCH_SIZE, IMG_SIZE, IMG_SIZE),
                      dtype=tf.float32),
        y=tf.constant(batched_labels, shape=(BATCH_SIZE, 10), dtype=tf.float32))
  losses[i] = result['loss']
  if (i + 1) % 10 == 0:
    print('Finished {0} epochs, current loss: {1}'.format(i + 1, losses[i]))

plt.plot(epochs, losses)
plt.show()
Finished 10 epochs, current loss: 4.234197616577148
Finished 20 epochs, current loss: 4.153818130493164
Finished 30 epochs, current loss: 4.09471321105957
Finished 40 epochs, current loss: 4.049306869506836
Finished 50 epochs, current loss: 4.074450492858887
Finished 60 epochs, current loss: 4.049517631530762
Finished 70 epochs, current loss: 4.011128902435303
Finished 80 epochs, current loss: 3.8494346141815186
Finished 90 epochs, current loss: 3.86380672454834
Finished 100 epochs, current loss: 3.7599244117736816

png

En Java, va a utilizar el Interpreter de clase para cargar un modelo de unidad y tareas de entrenamiento del modelo. El siguiente ejemplo muestra cómo ejecutar el procedimiento de formación mediante el uso de la runSignature método:

try (Interpreter interpreter = new Interpreter(modelBuffer)) {
    int NUM_EPOCHS = 100;
    int NUM_TRAININGS = 60000;
    float[][][] trainImages = new float[NUM_TRAININGS][28][28];
    float[][] trainLabels = new float[NUM_TRAININGS][10];

    // Fill the data values.

    // Run training for a few steps.
    for (int i = 0; i < NUM_EPOCHS; ++i) {
        Map<String, Object> inputs = new HashMap<>();
        inputs.put("x", trainImages);
        inputs.put("y", trainLabels);
        Map<String, Object> outputs = new HashMap<>();
        FloatBuffer loss = FloatBuffer.allocate(1);
        outputs.put("loss", loss);
        interpreter.runSignature(inputs, outputs, "train");
    }

    // Do the other stuffs..
}

Ejecutar inferencia en el modelo entrenado de TensorFlow Lite

Ahora que tenemos un modelo de TensorFlow Lite entrenado, podemos ejecutar inferencias en las imágenes de prueba usando la API de ejecución de firmas:

infer = interpreter.get_signature_runner("infer")
result = infer(
    x=tf.constant(test_images, shape=(len(test_images), IMG_SIZE, IMG_SIZE), dtype=tf.float32))
result_labels = np.argmax(result["output"], axis=1)
plot(test_images, result_labels)

png

Exportar los pesos entrenados al archivo de punto de control

El archivo de controles se puede generar a través del save método de firma.

save(checkpoint_path=np.array("/tmp/model.ckpt", dtype=np.string_))
{'checkpoint_path': array(b'/tmp/model.ckpt', dtype=object)}

En Java, puede almacenar el peso entrenado como un formato de punto de control en el almacenamiento interno de la aplicación. Las tareas de capacitación generalmente se realizan en el tiempo de inactividad (por ejemplo, durante la noche) en el proceso de fondo de vez en cuando.

try (Interpreter interpreter = new Interpreter(modelBuffer)) {
    // Conduct the training jobs.

    // Export the trained weights as a checkpoint file.
    File outputFile = new File(getFilesDir(), "checkpoint.ckpt");
    Map<String, Object> inputs = new HashMap<>();
    inputs.put("checkpoint_path", outputFile.getAbsolutePath());
    Map<String, Object> outputs = new HashMap<>();
    interpreter.runSignature(inputs, outputs, "save");
}

Restaurar los pesos entrenados desde el archivo de punto de control

El archivo de controles exportado puede ser restaurado a través de la restore método de firma.

another_interpreter = tf.lite.Interpreter(model_content=tflite_model)

train = another_interpreter.get_signature_runner("train")
infer = another_interpreter.get_signature_runner("infer")
save = another_interpreter.get_signature_runner("save")
restore = another_interpreter.get_signature_runner("restore")

# Restore the trained weights from /tmp/model.ckpt
# The time spent in weight restoration is proportionate to the checkpoint size
# and the number of variables in the model.
restore(checkpoint_path=np.array("/tmp/model.ckpt", dtype=np.string_))
{'dense/bias:0': array([ 0.04012525,  0.2057373 ,  0.2243858 , -0.24513818, -0.02761128,
         0.09906419, -0.11930446,  0.04712639, -0.16470878,  0.01239975,
        -0.00356665, -0.22564562, -0.10880154, -0.01482808,  0.06578001,
        -0.04510193, -0.03627148,  0.2789909 ,  0.13839999,  0.00209918,
         0.17772792, -0.04050817,  0.00417841,  0.05942791,  0.12545821,
        -0.00292166, -0.00278226,  0.05466946,  0.22082086, -0.00598721,
        -0.01459187,  0.02165319,  0.18708995, -0.02477028, -0.08105396,
        -0.04406855,  0.06384442,  0.05739888,  0.08162292, -0.02138082,
        -0.0134026 ,  0.00778261,  0.14147109, -0.00077838,  0.28427938,
        -0.15945041,  0.06246428, -0.01937703,  0.1309807 ,  0.1728773 ,
         0.10792284,  0.06019975, -0.19701743,  0.0174903 ,  0.1654197 ,
        -0.0681069 , -0.00954542,  0.0846714 , -0.07780166,  0.09400689,
         0.05912594, -0.2068568 , -0.19693165, -0.00241943,  0.05349341,
         0.15399657,  0.07771349, -0.13121206, -0.14502516,  0.05140257,
         0.02792422, -0.01451887,  0.13365516, -0.16654228, -0.00853377,
         0.28775522,  0.04611694, -0.0695993 ,  0.03202499,  0.12142272,
        -0.07113872,  0.16598509, -0.19415075,  0.18359615,  0.2454954 ,
         0.01218397, -0.01503987, -0.08677252,  0.20013258, -0.02418128,
         0.10109005, -0.01130499, -0.06901348, -0.06240098,  0.2581603 ,
         0.05776929,  0.12293585, -0.12227609,  0.27185693,  0.00404145,
         0.08342918,  0.16701412,  0.15018363, -0.00610264,  0.19649959,
        -0.07236976,  0.08896685, -0.06355023,  0.17494006, -0.00094923,
         0.14586641,  0.12557171,  0.00379796,  0.29036382,  0.12231576,
         0.03147348,  0.05118695,  0.02248127, -0.02469863, -0.03478021,
        -0.08656331, -0.10407642, -0.00477005, -0.03330058, -0.07830597,
         0.08219373, -0.19386095,  0.13342676], dtype=float32),
 'dense/kernel:0': array([[-0.06770029, -0.01710479,  0.0432684 , ..., -0.01258496,
          0.02522793, -0.0021971 ],
        [-0.05924093, -0.05714141,  0.06765721, ..., -0.00604145,
         -0.00119626, -0.05456134],
        [ 0.0551593 ,  0.07199081, -0.05804528, ...,  0.0130715 ,
          0.01361321,  0.03524218],
        ...,
        [-0.01654845, -0.04120301, -0.02618122, ...,  0.03981541,
          0.00753146,  0.02586931],
        [ 0.03664923,  0.03375211,  0.06792767, ..., -0.01346136,
          0.07872342,  0.04055472],
        [-0.00415883, -0.0521685 ,  0.019919  , ..., -0.00344475,
         -0.02212732,  0.04830738]], dtype=float32),
 'dense_1/bias:0': array([-0.03116724, -0.1191585 , -0.08460724,  0.11861875, -0.05077216,
         0.6267651 ,  0.03524619,  0.04212384, -0.04703877, -0.49000773],
       dtype=float32),
 'dense_1/kernel:0': array([[-0.05368271, -0.17174229, -0.02926885, ..., -0.07829902,
         -0.212175  , -0.25497243],
        [-0.1964751 , -0.21575639,  0.11350437, ..., -0.551292  ,
         -0.2187808 , -0.18261924],
        [-0.0016375 ,  0.01338738, -0.15002763, ..., -0.11356883,
          0.1846354 , -0.29634455],
        ...,
        [-0.03453091, -0.20565349,  0.03822452, ..., -0.14005761,
          0.04568734, -0.30656683],
        [ 0.13269441,  0.48917156, -0.13545942, ...,  0.1056123 ,
          0.14475733,  0.23750143],
        [ 0.15468922, -0.39078405,  0.19439282, ...,  0.3861173 ,
          0.03167873, -0.28202772]], dtype=float32)}

En Java, puede restaurar los pesos entrenados serializados desde el archivo, almacenado en el almacenamiento interno. Cuando se reinicia la aplicación, los pesos entrenados generalmente deben restaurarse antes de las inferencias.

try (Interpreter another_interpreter = new Interpreter(modelBuffer)) {
    // Load the trained weights from the checkpoint file.
    File outputFile = new File(getFilesDir(), "checkpoint.ckpt");
    Map<String, Object> inputs = new HashMap<>();
    inputs.put("checkpoint_path", outputFile.getAbsolutePath());
    Map<String, Object> outputs = new HashMap<>();
    another_interpreter.runSignature(inputs, outputs, "restore");
}

Ejecute la inferencia usando los pesos entrenados

Los desarrolladores pueden utilizar el modelo entrenado para funcionar a través de la inferencia infer método de firma.

result = infer(
    x=tf.constant(test_images, shape=(len(test_images), IMG_SIZE, IMG_SIZE), dtype=tf.float32))
result_labels = np.argmax(result["output"], axis=1)

Trazar las etiquetas predichas

plot(test_images, result_labels)

png

En Java, después de restaurar los pesos entrenados, los desarrolladores pueden ejecutar las inferencias basándose en los datos cargados.

try (Interpreter another_interpreter = new Interpreter(modelBuffer)) {
    // Restore the weights from the checkpoint file.

    int NUM_TESTS = 10;
    float[][][] testImages = new float[NUM_TESTS][28][28];
    float[][] output = new float[NUM_TESTS][10];

    // Fill the test data.

    // Run the inference.
    inputs = new HashMap<>();
    inputs.put("x", testImages);
    outputs = new HashMap<>();
    outputs.put("output", output);
    another_interpreter.runSignature(inputs, outputs, "infer");

    // Process the result to get the final category values.
    int[] testLabels = new int[NUM_TESTS];
    for (int i = 0; i < NUM_TESTS; ++i) {
        int index = 0;
        for (int j = 1; j < 10; ++j) {
            if (output[i][index] < output[i][j]) index = testLabels[j];
        }
        testLabels[i] = index;
    }
}

¡Y eso es! Ahora tiene un modelo de TensorFlow Lite que puede realizar entrenamiento en el dispositivo. Esperamos que este tutorial de código le brinde una buena idea sobre cómo ejecutar el entrenamiento en el dispositivo en TensorFlow Lite, y estamos emocionados de ver a dónde lo lleva.