Descripción general
Este CodeLab demuestra cómo crear un modelo para el reconocimiento MNIST usando Jax y cómo convertirlo a TensorFlow Lite. Este laboratorio de código también demostrará cómo optimizar el modelo TFLite convertido a Jax con cuantificación posterior al entrenamiento.
![]() | ![]() | ![]() | ![]() |
Prerrequisitos
Se recomienda probar esta función con la compilación de pip nocturna de TensorFlow más reciente.
pip install tf-nightly --upgrade
pip install jax --upgrade
pip install jaxlib --upgrade
Preparación de datos
Descargue los datos de MNIST con el conjunto de datos y el preprocesamiento de Keras.
import numpy as np
import tensorflow as tf
import functools
import time
import itertools
import numpy.random as npr
import jax.numpy as jnp
from jax import jit, grad, random
from jax.experimental import optimizers
from jax.experimental import stax
def _one_hot(x, k, dtype=np.float32):
"""Create a one-hot encoding of x of size k."""
return np.array(x[:, None] == np.arange(k), dtype)
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_images, test_images = train_images / 255.0, test_images / 255.0
train_images = train_images.astype(np.float32)
test_images = test_images.astype(np.float32)
train_labels = _one_hot(train_labels, 10)
test_labels = _one_hot(test_labels, 10)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz 11493376/11490434 [==============================] - 0s 0us/step 11501568/11490434 [==============================] - 0s 0us/step
Construye el modelo MNIST con Jax
def loss(params, batch):
inputs, targets = batch
preds = predict(params, inputs)
return -jnp.mean(jnp.sum(preds * targets, axis=1))
def accuracy(params, batch):
inputs, targets = batch
target_class = jnp.argmax(targets, axis=1)
predicted_class = jnp.argmax(predict(params, inputs), axis=1)
return jnp.mean(predicted_class == target_class)
init_random_params, predict = stax.serial(
stax.Flatten,
stax.Dense(1024), stax.Relu,
stax.Dense(1024), stax.Relu,
stax.Dense(10), stax.LogSoftmax)
rng = random.PRNGKey(0)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Entrenar y evaluar el modelo
step_size = 0.001
num_epochs = 10
batch_size = 128
momentum_mass = 0.9
num_train = train_images.shape[0]
num_complete_batches, leftover = divmod(num_train, batch_size)
num_batches = num_complete_batches + bool(leftover)
def data_stream():
rng = npr.RandomState(0)
while True:
perm = rng.permutation(num_train)
for i in range(num_batches):
batch_idx = perm[i * batch_size:(i + 1) * batch_size]
yield train_images[batch_idx], train_labels[batch_idx]
batches = data_stream()
opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)
@jit
def update(i, opt_state, batch):
params = get_params(opt_state)
return opt_update(i, grad(loss)(params, batch), opt_state)
_, init_params = init_random_params(rng, (-1, 28 * 28))
opt_state = opt_init(init_params)
itercount = itertools.count()
print("\nStarting training...")
for epoch in range(num_epochs):
start_time = time.time()
for _ in range(num_batches):
opt_state = update(next(itercount), opt_state, next(batches))
epoch_time = time.time() - start_time
params = get_params(opt_state)
train_acc = accuracy(params, (train_images, train_labels))
test_acc = accuracy(params, (test_images, test_labels))
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set accuracy {}".format(train_acc))
print("Test set accuracy {}".format(test_acc))
Starting training... Epoch 0 in 4.69 sec Training set accuracy 0.8729000091552734 Test set accuracy 0.880299985408783 Epoch 1 in 3.83 sec Training set accuracy 0.8983666896820068 Test set accuracy 0.9047999978065491 Epoch 2 in 3.81 sec Training set accuracy 0.9102166891098022 Test set accuracy 0.9138000011444092 Epoch 3 in 3.85 sec Training set accuracy 0.9172500371932983 Test set accuracy 0.9218999743461609 Epoch 4 in 3.79 sec Training set accuracy 0.9224500060081482 Test set accuracy 0.9253999590873718 Epoch 5 in 3.72 sec Training set accuracy 0.9272000193595886 Test set accuracy 0.930899977684021 Epoch 6 in 3.77 sec Training set accuracy 0.9327666759490967 Test set accuracy 0.9334999918937683 Epoch 7 in 3.77 sec Training set accuracy 0.9360166788101196 Test set accuracy 0.9370999932289124 Epoch 8 in 3.77 sec Training set accuracy 0.9390000104904175 Test set accuracy 0.939300000667572 Epoch 9 in 3.73 sec Training set accuracy 0.9425666928291321 Test set accuracy 0.9429999589920044
Convierta al modelo TFLite.
Tenga en cuenta aquí, nosotros
- Inline los parametros a la Jax
predict
func confunctools.partial
. - Construir un
jnp.zeros
, este es un "marcador de posición" tensor de Jax utiliza para trazar el modelo. - Llamada
experimental_from_jax
:> * Laserving_func
se envuelve en una lista. > * La entrada se asocia con un nombre dado y se pasa como una matriz envuelta en una lista.
serving_func = functools.partial(predict, params)
x_input = jnp.zeros((1, 28, 28))
converter = tf.lite.TFLiteConverter.experimental_from_jax(
[serving_func], [[('input1', x_input)]])
tflite_model = converter.convert()
with open('jax_mnist.tflite', 'wb') as f:
f.write(tflite_model)
2021-10-30 11:51:13.208329: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:363] Ignored output_format. 2021-10-30 11:51:13.208375: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:366] Ignored drop_control_dependency. 2021-10-30 11:51:13.208383: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:372] Ignored change_concat_input_ranges.
Verifique el modelo TFLite convertido
Compare los resultados del modelo convertido con el modelo Jax.
expected = serving_func(train_images[0:1])
# Run the model with TensorFlow Lite
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], train_images[0:1, :, :])
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])
# Assert if the result of TFLite model is consistent with the JAX model.
np.testing.assert_almost_equal(expected, result, 1e-5)
Optimizar el modelo
Vamos a proporcionar un representative_dataset
hacer quantiztion posterior a la capacitación para optimizar el modelo.
def representative_dataset():
for i in range(1000):
x = train_images[i:i+1]
yield [x]
converter = tf.lite.TFLiteConverter.experimental_from_jax(
[serving_func], [[('x', x_input)]])
tflite_model = converter.convert()
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
tflite_quant_model = converter.convert()
with open('jax_mnist_quant.tflite', 'wb') as f:
f.write(tflite_quant_model)
2021-10-30 11:51:14.202412: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:363] Ignored output_format. 2021-10-30 11:51:14.202455: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:366] Ignored drop_control_dependency. 2021-10-30 11:51:14.202461: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:372] Ignored change_concat_input_ranges. 2021-10-30 11:51:14.293677: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:363] Ignored output_format. 2021-10-30 11:51:14.293768: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:366] Ignored drop_control_dependency. 2021-10-30 11:51:14.293776: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:372] Ignored change_concat_input_ranges. fully_quantize: 0, inference_type: 6, input_inference_type: 0, output_inference_type: 0
Evaluar el modelo optimizado
expected = serving_func(train_images[0:1])
# Run the model with TensorFlow Lite
interpreter = tf.lite.Interpreter(model_content=tflite_quant_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], train_images[0:1, :, :])
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])
# Assert if the result of TFLite model is consistent with the Jax model.
np.testing.assert_almost_equal(expected, result, 1e-5)
Compare el tamaño del modelo cuantificado
Deberíamos poder ver que el modelo cuantificado es cuatro veces más pequeño que el modelo original.
du -h jax_mnist.tflite
du -h jax_mnist_quant.tflite
7.2M jax_mnist.tflite 1.8M jax_mnist_quant.tflite