TFLite용 Jax 모델 변환

개요

참고: 이 API는 새로운 것이며 pip install tf-nightly를 통해서만 사용할 수 있습니다. TensorFlow 버전 2.7에서 사용할 수 있습니다. 또한 API는 아직 실험적이며 변경될 수 있습니다.

이 CodeLab은 Jax를 사용하여 MNIST 인식을 위한 모델을 구축하는 방법과 이를 TensorFlow Lite로 변환하는 방법을 보여줍니다. 이 코드랩은 또한 훈련 후 양자화를 사용하여 Jax 변환 TFLite 모델을 최적화하는 방법을 보여줍니다.

TensorFlow.org에서 보기 Google Colab에서 실행 GitHub에서 소스 보기 노트북 다운로드

전제 조건

최신 TensorFlow 야간 pip 빌드에서 이 기능을 사용하는 것이 좋습니다.

pip install tf-nightly --upgrade
pip install jax --upgrade
pip install jaxlib --upgrade

데이터 준비

Keras 데이터셋으로 MNIST 데이터를 다운로드하고 전처리합니다.

import numpy as np
import tensorflow as tf
import functools

import time
import itertools

import numpy.random as npr

import jax.numpy as jnp
from jax import jit, grad, random
from jax.example_libraries import optimizers
from jax.example_libraries import stax
2022-12-14 20:08:59.063265: E tensorflow/tsl/lib/monitoring/collection_registry.cc:81] Cannot register 2 metrics with the same name: /tensorflow/core/bfc_allocator_delay
def _one_hot(x, k, dtype=np.float32):
  """Create a one-hot encoding of x of size k."""
  return np.array(x[:, None] == np.arange(k), dtype)

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_images, test_images = train_images / 255.0, test_images / 255.0
train_images = train_images.astype(np.float32)
test_images = test_images.astype(np.float32)

train_labels = _one_hot(train_labels, 10)
test_labels = _one_hot(test_labels, 10)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11490434/11490434 [==============================] - 0s 0us/step

Jax로 MNIST 모델 빌드

def loss(params, batch):
  inputs, targets = batch
  preds = predict(params, inputs)
  return -jnp.mean(jnp.sum(preds * targets, axis=1))

def accuracy(params, batch):
  inputs, targets = batch
  target_class = jnp.argmax(targets, axis=1)
  predicted_class = jnp.argmax(predict(params, inputs), axis=1)
  return jnp.mean(predicted_class == target_class)

init_random_params, predict = stax.serial(
    stax.Flatten,
    stax.Dense(1024), stax.Relu,
    stax.Dense(1024), stax.Relu,
    stax.Dense(10), stax.LogSoftmax)

rng = random.PRNGKey(0)
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

모델 학습 및 평가

step_size = 0.001
num_epochs = 10
batch_size = 128
momentum_mass = 0.9


num_train = train_images.shape[0]
num_complete_batches, leftover = divmod(num_train, batch_size)
num_batches = num_complete_batches + bool(leftover)

def data_stream():
  rng = npr.RandomState(0)
  while True:
    perm = rng.permutation(num_train)
    for i in range(num_batches):
      batch_idx = perm[i * batch_size:(i + 1) * batch_size]
      yield train_images[batch_idx], train_labels[batch_idx]
batches = data_stream()

opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)

@jit
def update(i, opt_state, batch):
  params = get_params(opt_state)
  return opt_update(i, grad(loss)(params, batch), opt_state)

_, init_params = init_random_params(rng, (-1, 28 * 28))
opt_state = opt_init(init_params)
itercount = itertools.count()

print("\nStarting training...")
for epoch in range(num_epochs):
  start_time = time.time()
  for _ in range(num_batches):
    opt_state = update(next(itercount), opt_state, next(batches))
  epoch_time = time.time() - start_time

  params = get_params(opt_state)
  train_acc = accuracy(params, (train_images, train_labels))
  test_acc = accuracy(params, (test_images, test_labels))
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))
Starting training...
Epoch 0 in 2.79 sec
Training set accuracy 0.8728833198547363
Test set accuracy 0.880299985408783
Epoch 1 in 2.31 sec
Training set accuracy 0.8983833193778992
Test set accuracy 0.9047999978065491
Epoch 2 in 2.32 sec
Training set accuracy 0.9102333188056946
Test set accuracy 0.9138000011444092
Epoch 3 in 2.37 sec
Training set accuracy 0.9172333478927612
Test set accuracy 0.9218999743461609
Epoch 4 in 2.31 sec
Training set accuracy 0.9224833250045776
Test set accuracy 0.9253999590873718
Epoch 5 in 2.31 sec
Training set accuracy 0.9272000193595886
Test set accuracy 0.9309999942779541
Epoch 6 in 2.33 sec
Training set accuracy 0.9328166842460632
Test set accuracy 0.9334999918937683
Epoch 7 in 2.31 sec
Training set accuracy 0.9360166788101196
Test set accuracy 0.9370999932289124
Epoch 8 in 2.32 sec
Training set accuracy 0.939050018787384
Test set accuracy 0.939300000667572
Epoch 9 in 2.31 sec
Training set accuracy 0.9425666928291321
Test set accuracy 0.9429000020027161

TFLite 모델로 변환합니다.

참고로 우리는

  1. params를 functools.partial predict func에 인라인합니다.
  2. jnp.zeros 빌드합니다. 이것은 Jax가 모델을 추적하는 데 사용되는 "자리 표시자" 텐서입니다.
  3. experimental_from_jax 호출합니다.
  • serving_func 는 목록으로 래핑됩니다.
  • 입력은 지정된 이름과 연결되고 목록에 래핑된 배열로 전달됩니다.
serving_func = functools.partial(predict, params)
x_input = jnp.zeros((1, 28, 28))
converter = tf.lite.TFLiteConverter.experimental_from_jax(
    [serving_func], [[('input1', x_input)]])
tflite_model = converter.convert()
with open('jax_mnist.tflite', 'wb') as f:
  f.write(tflite_model)
2022-12-14 20:09:33.454726: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:362] Ignored output_format.
2022-12-14 20:09:33.454777: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:365] Ignored drop_control_dependency.
2022-12-14 20:09:33.454783: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:371] Ignored change_concat_input_ranges.

변환된 TFLite 모델 확인

변환된 모델의 결과를 Jax 모델과 비교하십시오.

expected = serving_func(train_images[0:1])

# Run the model with TensorFlow Lite
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], train_images[0:1, :, :])
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])

# Assert if the result of TFLite model is consistent with the JAX model.
np.testing.assert_almost_equal(expected, result, 1e-5)
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.

모델 최적화

우리는 제공 할 것입니다 representative_dataset 모델을 최적화하기 위해 훈련 후 quantiztion을 할 수 있습니다.

def representative_dataset():
  for i in range(1000):
    x = train_images[i:i+1]
    yield [x]

converter = tf.lite.TFLiteConverter.experimental_from_jax(
    [serving_func], [[('x', x_input)]])
tflite_model = converter.convert()
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
tflite_quant_model = converter.convert()
with open('jax_mnist_quant.tflite', 'wb') as f:
  f.write(tflite_quant_model)
2022-12-14 20:09:33.930065: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:362] Ignored output_format.
2022-12-14 20:09:33.930110: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:365] Ignored drop_control_dependency.
2022-12-14 20:09:33.930117: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:371] Ignored change_concat_input_ranges.
2022-12-14 20:09:34.206292: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:362] Ignored output_format.
2022-12-14 20:09:34.206341: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:365] Ignored drop_control_dependency.
2022-12-14 20:09:34.206348: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:371] Ignored change_concat_input_ranges.
fully_quantize: 0, inference_type: 6, input_inference_type: FLOAT32, output_inference_type: FLOAT32

최적화된 모델 평가

expected = serving_func(train_images[0:1])

# Run the model with TensorFlow Lite
interpreter = tf.lite.Interpreter(model_content=tflite_quant_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], train_images[0:1, :, :])
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])

# Assert if the result of TFLite model is consistent with the Jax model.
np.testing.assert_almost_equal(expected, result, 1e-5)

양자화된 모델 크기 비교

양자화된 모델이 원래 모델보다 4배 더 작은 것을 볼 수 있어야 합니다.

du -h jax_mnist.tflite
du -h jax_mnist_quant.tflite
7.2M    jax_mnist.tflite
1.8M    jax_mnist_quant.tflite