TFLite向けのJaxモデル変換

概要

注意: これは新しい API であり、pip install tf-nightly 経由でのみ使用できます。また、TensorFlow バージョン 2.7 で提供される予定です。この API は実験段階であり、変更される可能性があります。

この CodeLab では、Jax を使用して MNIST 認識のモデルを構築する方法と、それを TensorFlow Lite に変換する方法について説明します。また、トレーニング後の量子化によって、Jaxに変換されたモデルを最適化する方法についても説明します。

TensorFlow.org で表示 Google Colab で実行 GitHub でソースを表示 ノートブックをダウンロード

前提条件

最新の TensorFlow nightly pip ビルドでこの機能を試すことをお勧めします。

pip install tf-nightly --upgrade
pip install jax --upgrade
pip install jaxlib --upgrade

データの準備

MNIST データ、Keras データセット、プリプロセスをダウンロードします。

import numpy as np
import tensorflow as tf
import functools

import time
import itertools

import numpy.random as npr

import jax.numpy as jnp
from jax import jit, grad, random
from jax.example_libraries import optimizers
from jax.example_libraries import stax
def _one_hot(x, k, dtype=np.float32):
  """Create a one-hot encoding of x of size k."""
  return np.array(x[:, None] == np.arange(k), dtype)

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_images, test_images = train_images / 255.0, test_images / 255.0
train_images = train_images.astype(np.float32)
test_images = test_images.astype(np.float32)

train_labels = _one_hot(train_labels, 10)
test_labels = _one_hot(test_labels, 10)

Jax で MNIST モデルを構築する

def loss(params, batch):
  inputs, targets = batch
  preds = predict(params, inputs)
  return -jnp.mean(jnp.sum(preds * targets, axis=1))

def accuracy(params, batch):
  inputs, targets = batch
  target_class = jnp.argmax(targets, axis=1)
  predicted_class = jnp.argmax(predict(params, inputs), axis=1)
  return jnp.mean(predicted_class == target_class)

init_random_params, predict = stax.serial(
    stax.Flatten,
    stax.Dense(1024), stax.Relu,
    stax.Dense(1024), stax.Relu,
    stax.Dense(10), stax.LogSoftmax)

rng = random.PRNGKey(0)

モデルのトレーニングと評価

step_size = 0.001
num_epochs = 10
batch_size = 128
momentum_mass = 0.9


num_train = train_images.shape[0]
num_complete_batches, leftover = divmod(num_train, batch_size)
num_batches = num_complete_batches + bool(leftover)

def data_stream():
  rng = npr.RandomState(0)
  while True:
    perm = rng.permutation(num_train)
    for i in range(num_batches):
      batch_idx = perm[i * batch_size:(i + 1) * batch_size]
      yield train_images[batch_idx], train_labels[batch_idx]
batches = data_stream()

opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)

@jit
def update(i, opt_state, batch):
  params = get_params(opt_state)
  return opt_update(i, grad(loss)(params, batch), opt_state)

_, init_params = init_random_params(rng, (-1, 28 * 28))
opt_state = opt_init(init_params)
itercount = itertools.count()

print("\nStarting training...")
for epoch in range(num_epochs):
  start_time = time.time()
  for _ in range(num_batches):
    opt_state = update(next(itercount), opt_state, next(batches))
  epoch_time = time.time() - start_time

  params = get_params(opt_state)
  train_acc = accuracy(params, (train_images, train_labels))
  test_acc = accuracy(params, (test_images, test_labels))
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))

TFLite モデルに変換する

次の手順を実行します。

  1. Jax predict 関数へのパラメーターを functools.partial でインライン化します。
  2. jnp.zeros を作成します。これは Jax でモデルを追跡するために使用される「プレースホルダー」テンソルです。
  3. experimental_from_jax を呼び出します。
  • serving_func がリストでラップされます。
  • 入力は特定の名前に関連付けられ、リストでラップされた配列として渡されます。
serving_func = functools.partial(predict, params)
x_input = jnp.zeros((1, 28, 28))
converter = tf.lite.TFLiteConverter.experimental_from_jax(
    [serving_func], [[('input1', x_input)]])
tflite_model = converter.convert()
with open('jax_mnist.tflite', 'wb') as f:
  f.write(tflite_model)

変換された TFLite モデルを確認する

変換されたモデルの結果を Jax モデルと比較します。

expected = serving_func(train_images[0:1])

# Run the model with TensorFlow Lite
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], train_images[0:1, :, :])
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])

# Assert if the result of TFLite model is consistent with the JAX model.
np.testing.assert_almost_equal(expected, result, 1e-5)

モデルを最適化する

モデルを最適化するために、representative_dataset を提供してトレーニング後の量子化を実行します。

def representative_dataset():
  for i in range(1000):
    x = train_images[i:i+1]
    yield [x]

converter = tf.lite.TFLiteConverter.experimental_from_jax(
    [serving_func], [[('x', x_input)]])
tflite_model = converter.convert()
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
tflite_quant_model = converter.convert()
with open('jax_mnist_quant.tflite', 'wb') as f:
  f.write(tflite_quant_model)

最適化されたモデルを評価する

expected = serving_func(train_images[0:1])

# Run the model with TensorFlow Lite
interpreter = tf.lite.Interpreter(model_content=tflite_quant_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], train_images[0:1, :, :])
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])

# Assert if the result of TFLite model is consistent with the Jax model.
np.testing.assert_almost_equal(expected, result, 1e-5)

量子化されたモデルサイズを比較する

量子化されたモデルのサイズは元のモデルの 4 分の 1 のサイズになることがわかります。

du -h jax_mnist.tflite
du -h jax_mnist_quant.tflite