Jax-Modellkonvertierung für TFLite

Überblick

Dieses CodeLab zeigt, wie Sie mit Jax ein Modell für die MNIST-Erkennung erstellen und es in TensorFlow Lite konvertieren. Dieses Codelab zeigt auch, wie das Jax-konvertierte TFLite-Modell mit Quantisierung nach dem Training optimiert werden kann.

Auf TensorFlow.org ansehen In Google Colab ausführen Quelle auf GitHub anzeigen Notizbuch herunterladen

Voraussetzungen

Es wird empfohlen, diese Funktion mit dem neuesten nächtlichen Pip-Build von TensorFlow auszuprobieren.

pip install tf-nightly --upgrade
pip install jax --upgrade
pip install jaxlib --upgrade

Datenaufbereitung

Laden Sie die MNIST-Daten mit Keras-Datensatz herunter und verarbeiten Sie sie vor.

import numpy as np
import tensorflow as tf
import functools

import time
import itertools

import numpy.random as npr

import jax.numpy as jnp
from jax import jit, grad, random
from jax.experimental import optimizers
from jax.experimental import stax
def _one_hot(x, k, dtype=np.float32):
  """Create a one-hot encoding of x of size k."""
  return np.array(x[:, None] == np.arange(k), dtype)

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_images, test_images = train_images / 255.0, test_images / 255.0
train_images = train_images.astype(np.float32)
test_images = test_images.astype(np.float32)

train_labels = _one_hot(train_labels, 10)
test_labels = _one_hot(test_labels, 10)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
11501568/11490434 [==============================] - 0s 0us/step

Erstellen Sie das MNIST-Modell mit Jax

def loss(params, batch):
  inputs, targets = batch
  preds = predict(params, inputs)
  return -jnp.mean(jnp.sum(preds * targets, axis=1))

def accuracy(params, batch):
  inputs, targets = batch
  target_class = jnp.argmax(targets, axis=1)
  predicted_class = jnp.argmax(predict(params, inputs), axis=1)
  return jnp.mean(predicted_class == target_class)

init_random_params, predict = stax.serial(
    stax.Flatten,
    stax.Dense(1024), stax.Relu,
    stax.Dense(1024), stax.Relu,
    stax.Dense(10), stax.LogSoftmax)

rng = random.PRNGKey(0)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Trainieren und bewerten Sie das Modell

step_size = 0.001
num_epochs = 10
batch_size = 128
momentum_mass = 0.9


num_train = train_images.shape[0]
num_complete_batches, leftover = divmod(num_train, batch_size)
num_batches = num_complete_batches + bool(leftover)

def data_stream():
  rng = npr.RandomState(0)
  while True:
    perm = rng.permutation(num_train)
    for i in range(num_batches):
      batch_idx = perm[i * batch_size:(i + 1) * batch_size]
      yield train_images[batch_idx], train_labels[batch_idx]
batches = data_stream()

opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)

@jit
def update(i, opt_state, batch):
  params = get_params(opt_state)
  return opt_update(i, grad(loss)(params, batch), opt_state)

_, init_params = init_random_params(rng, (-1, 28 * 28))
opt_state = opt_init(init_params)
itercount = itertools.count()

print("\nStarting training...")
for epoch in range(num_epochs):
  start_time = time.time()
  for _ in range(num_batches):
    opt_state = update(next(itercount), opt_state, next(batches))
  epoch_time = time.time() - start_time

  params = get_params(opt_state)
  train_acc = accuracy(params, (train_images, train_labels))
  test_acc = accuracy(params, (test_images, test_labels))
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))
Starting training...
Epoch 0 in 4.26 sec
Training set accuracy 0.8729000091552734
Test set accuracy 0.880299985408783
Epoch 1 in 3.55 sec
Training set accuracy 0.8983666896820068
Test set accuracy 0.9047999978065491
Epoch 2 in 3.79 sec
Training set accuracy 0.9102166891098022
Test set accuracy 0.9138000011444092
Epoch 3 in 3.63 sec
Training set accuracy 0.9172499775886536
Test set accuracy 0.9218999743461609
Epoch 4 in 3.72 sec
Training set accuracy 0.9224500060081482
Test set accuracy 0.9254000186920166
Epoch 5 in 3.62 sec
Training set accuracy 0.9272000193595886
Test set accuracy 0.930899977684021
Epoch 6 in 3.74 sec
Training set accuracy 0.9327666759490967
Test set accuracy 0.9334999918937683
Epoch 7 in 3.55 sec
Training set accuracy 0.9360166788101196
Test set accuracy 0.9370999932289124
Epoch 8 in 3.76 sec
Training set accuracy 0.9390000104904175
Test set accuracy 0.939300000667572
Epoch 9 in 3.60 sec
Training set accuracy 0.9425666928291321
Test set accuracy 0.9430000185966492

Konvertieren in das TFLite-Modell.

Beachten Sie hier, wir

  1. Inline die params zum Jax predict func mit functools.partial .
  2. Build a jnp.zeros , ist dies ein „Platzhalter“ Tensor für Jax verwendet , um das Modell zu verfolgen.
  3. Anruf experimental_from_jax :> * Die serving_func wird in einer Liste gewickelt. > * Die Eingabe wird mit einem gegebenen Namen verknüpft und als Array übergeben, das in eine Liste eingeschlossen ist.
serving_func = functools.partial(predict, params)
x_input = jnp.zeros((1, 28, 28))
converter = tf.lite.TFLiteConverter.experimental_from_jax(
    [serving_func], [[('input1', x_input)]])
tflite_model = converter.convert()
with open('jax_mnist.tflite', 'wb') as f:
  f.write(tflite_model)
2021-09-08 11:23:09.594301: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:363] Ignored output_format.
2021-09-08 11:23:09.594350: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:366] Ignored drop_control_dependency.
WARNING:absl:Buffer deduplication procedure will be skipped when flatbuffer library is not properly loaded
2021-09-08 11:23:09.594359: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:372] Ignored change_concat_input_ranges.

Überprüfen Sie das konvertierte TFLite-Modell

Vergleichen Sie die Ergebnisse des konvertierten Modells mit dem Jax-Modell.

expected = serving_func(train_images[0:1])

# Run the model with TensorFlow Lite
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], train_images[0:1, :, :])
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])

# Assert if the result of TFLite model is consistent with the JAX model.
np.testing.assert_almost_equal(expected, result, 1e-5)

Optimieren Sie das Modell

Wir werden einen bieten representative_dataset nach dem Training quantiztion zu tun , um das Modell zu optimieren.

def representative_dataset():
  for i in range(1000):
    x = train_images[i:i+1]
    yield [x]

converter = tf.lite.TFLiteConverter.experimental_from_jax(
    [serving_func], [[('x', x_input)]])
tflite_model = converter.convert()
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
tflite_quant_model = converter.convert()
with open('jax_mnist_quant.tflite', 'wb') as f:
  f.write(tflite_quant_model)
2021-09-08 11:23:11.502205: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:363] Ignored output_format.
2021-09-08 11:23:11.502258: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:366] Ignored drop_control_dependency.
2021-09-08 11:23:11.502266: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:372] Ignored change_concat_input_ranges.
WARNING:absl:Buffer deduplication procedure will be skipped when flatbuffer library is not properly loaded
2021-09-08 11:23:11.580031: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:363] Ignored output_format.
2021-09-08 11:23:11.580077: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:366] Ignored drop_control_dependency.
2021-09-08 11:23:11.580084: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:372] Ignored change_concat_input_ranges.
fully_quantize: 0, inference_type: 6, input_inference_type: 0, output_inference_type: 0
WARNING:absl:Buffer deduplication procedure will be skipped when flatbuffer library is not properly loaded

Bewerten Sie das optimierte Modell

expected = serving_func(train_images[0:1])

# Run the model with TensorFlow Lite
interpreter = tf.lite.Interpreter(model_content=tflite_quant_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], train_images[0:1, :, :])
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])

# Assert if the result of TFLite model is consistent with the Jax model.
np.testing.assert_almost_equal(expected, result, 1e-5)

Vergleichen Sie die quantisierte Modellgröße

Wir sollten sehen können, dass das quantisierte Modell viermal kleiner ist als das Originalmodell.

du -h jax_mnist.tflite
du -h jax_mnist_quant.tflite
7.2M    jax_mnist.tflite
1.8M    jax_mnist_quant.tflite