Ajuda a proteger a Grande Barreira de Corais com TensorFlow em Kaggle Junte Desafio

Conversão de modelo Jax para TFLite

Visão geral

Este CodeLab demonstra como construir um modelo para reconhecimento MNIST usando Jax e como convertê-lo para TensorFlow Lite. Este codelab também demonstrará como otimizar o modelo TFLite convertido por Jax com quantização pós-treinamento.

Ver no TensorFlow.org Executar no Google Colab Ver fonte no GitHub Baixar caderno

Pré-requisitos

Recomenda-se testar esse recurso com o mais recente desenvolvimento noturno do TensorFlow pip.

pip install tf-nightly --upgrade
pip install jax --upgrade
pip install jaxlib --upgrade

Preparação de dados

Baixe os dados MNIST com conjunto de dados Keras e pré-processamento.

import numpy as np
import tensorflow as tf
import functools

import time
import itertools

import numpy.random as npr

import jax.numpy as jnp
from jax import jit, grad, random
from jax.experimental import optimizers
from jax.experimental import stax
def _one_hot(x, k, dtype=np.float32):
  """Create a one-hot encoding of x of size k."""
  return np.array(x[:, None] == np.arange(k), dtype)

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_images, test_images = train_images / 255.0, test_images / 255.0
train_images = train_images.astype(np.float32)
test_images = test_images.astype(np.float32)

train_labels = _one_hot(train_labels, 10)
test_labels = _one_hot(test_labels, 10)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
11501568/11490434 [==============================] - 0s 0us/step

Construir o modelo MNIST com Jax

def loss(params, batch):
  inputs, targets = batch
  preds = predict(params, inputs)
  return -jnp.mean(jnp.sum(preds * targets, axis=1))

def accuracy(params, batch):
  inputs, targets = batch
  target_class = jnp.argmax(targets, axis=1)
  predicted_class = jnp.argmax(predict(params, inputs), axis=1)
  return jnp.mean(predicted_class == target_class)

init_random_params, predict = stax.serial(
    stax.Flatten,
    stax.Dense(1024), stax.Relu,
    stax.Dense(1024), stax.Relu,
    stax.Dense(10), stax.LogSoftmax)

rng = random.PRNGKey(0)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Treine e avalie o modelo

step_size = 0.001
num_epochs = 10
batch_size = 128
momentum_mass = 0.9


num_train = train_images.shape[0]
num_complete_batches, leftover = divmod(num_train, batch_size)
num_batches = num_complete_batches + bool(leftover)

def data_stream():
  rng = npr.RandomState(0)
  while True:
    perm = rng.permutation(num_train)
    for i in range(num_batches):
      batch_idx = perm[i * batch_size:(i + 1) * batch_size]
      yield train_images[batch_idx], train_labels[batch_idx]
batches = data_stream()

opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)

@jit
def update(i, opt_state, batch):
  params = get_params(opt_state)
  return opt_update(i, grad(loss)(params, batch), opt_state)

_, init_params = init_random_params(rng, (-1, 28 * 28))
opt_state = opt_init(init_params)
itercount = itertools.count()

print("\nStarting training...")
for epoch in range(num_epochs):
  start_time = time.time()
  for _ in range(num_batches):
    opt_state = update(next(itercount), opt_state, next(batches))
  epoch_time = time.time() - start_time

  params = get_params(opt_state)
  train_acc = accuracy(params, (train_images, train_labels))
  test_acc = accuracy(params, (test_images, test_labels))
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))
Starting training...
Epoch 0 in 4.69 sec
Training set accuracy 0.8729000091552734
Test set accuracy 0.880299985408783
Epoch 1 in 3.83 sec
Training set accuracy 0.8983666896820068
Test set accuracy 0.9047999978065491
Epoch 2 in 3.81 sec
Training set accuracy 0.9102166891098022
Test set accuracy 0.9138000011444092
Epoch 3 in 3.85 sec
Training set accuracy 0.9172500371932983
Test set accuracy 0.9218999743461609
Epoch 4 in 3.79 sec
Training set accuracy 0.9224500060081482
Test set accuracy 0.9253999590873718
Epoch 5 in 3.72 sec
Training set accuracy 0.9272000193595886
Test set accuracy 0.930899977684021
Epoch 6 in 3.77 sec
Training set accuracy 0.9327666759490967
Test set accuracy 0.9334999918937683
Epoch 7 in 3.77 sec
Training set accuracy 0.9360166788101196
Test set accuracy 0.9370999932289124
Epoch 8 in 3.77 sec
Training set accuracy 0.9390000104904175
Test set accuracy 0.939300000667572
Epoch 9 in 3.73 sec
Training set accuracy 0.9425666928291321
Test set accuracy 0.9429999589920044

Converter para o modelo TFLite.

Observe aqui, nós

  1. Inline os parâmetros para o Jax predict func com functools.partial .
  2. Construir uma jnp.zeros , este é um "espaço reservado" tensor usado para Jax para traçar o modelo.
  3. Chamada experimental_from_jax :> * A serving_func é envolto em uma lista. > * A entrada é associada a um determinado nome e passada como um array empacotado em uma lista.
serving_func = functools.partial(predict, params)
x_input = jnp.zeros((1, 28, 28))
converter = tf.lite.TFLiteConverter.experimental_from_jax(
    [serving_func], [[('input1', x_input)]])
tflite_model = converter.convert()
with open('jax_mnist.tflite', 'wb') as f:
  f.write(tflite_model)
2021-10-30 11:51:13.208329: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:363] Ignored output_format.
2021-10-30 11:51:13.208375: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:366] Ignored drop_control_dependency.
2021-10-30 11:51:13.208383: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:372] Ignored change_concat_input_ranges.

Verifique o modelo TFLite convertido

Compare os resultados do modelo convertido com o modelo Jax.

expected = serving_func(train_images[0:1])

# Run the model with TensorFlow Lite
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], train_images[0:1, :, :])
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])

# Assert if the result of TFLite model is consistent with the JAX model.
np.testing.assert_almost_equal(expected, result, 1e-5)

Otimize o modelo

Vamos fornecer um representative_dataset fazer quantiztion pós-treino para otimizar o modelo.

def representative_dataset():
  for i in range(1000):
    x = train_images[i:i+1]
    yield [x]

converter = tf.lite.TFLiteConverter.experimental_from_jax(
    [serving_func], [[('x', x_input)]])
tflite_model = converter.convert()
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
tflite_quant_model = converter.convert()
with open('jax_mnist_quant.tflite', 'wb') as f:
  f.write(tflite_quant_model)
2021-10-30 11:51:14.202412: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:363] Ignored output_format.
2021-10-30 11:51:14.202455: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:366] Ignored drop_control_dependency.
2021-10-30 11:51:14.202461: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:372] Ignored change_concat_input_ranges.
2021-10-30 11:51:14.293677: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:363] Ignored output_format.
2021-10-30 11:51:14.293768: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:366] Ignored drop_control_dependency.
2021-10-30 11:51:14.293776: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:372] Ignored change_concat_input_ranges.
fully_quantize: 0, inference_type: 6, input_inference_type: 0, output_inference_type: 0

Avalie o modelo otimizado

expected = serving_func(train_images[0:1])

# Run the model with TensorFlow Lite
interpreter = tf.lite.Interpreter(model_content=tflite_quant_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], train_images[0:1, :, :])
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])

# Assert if the result of TFLite model is consistent with the Jax model.
np.testing.assert_almost_equal(expected, result, 1e-5)

Compare o tamanho do modelo quantizado

Devemos ser capazes de ver que o modelo quantizado é quatro vezes menor que o modelo original.

du -h jax_mnist.tflite
du -h jax_mnist_quant.tflite
7.2M    jax_mnist.tflite
1.8M    jax_mnist_quant.tflite
,

Visão geral

Este CodeLab demonstra como construir um modelo para reconhecimento MNIST usando Jax e como convertê-lo para TensorFlow Lite. Este codelab também demonstrará como otimizar o modelo TFLite convertido por Jax com quantização pós-treinamento.

Ver no TensorFlow.org Executar no Google Colab Ver fonte no GitHub Baixar caderno

Pré-requisitos

Recomenda-se testar esse recurso com o mais recente desenvolvimento noturno do TensorFlow pip.

pip install tf-nightly --upgrade
pip install jax --upgrade
pip install jaxlib --upgrade

Preparação de dados

Baixe os dados MNIST com conjunto de dados Keras e pré-processamento.

import numpy as np
import tensorflow as tf
import functools

import time
import itertools

import numpy.random as npr

import jax.numpy as jnp
from jax import jit, grad, random
from jax.experimental import optimizers
from jax.experimental import stax
def _one_hot(x, k, dtype=np.float32):
  """Create a one-hot encoding of x of size k."""
  return np.array(x[:, None] == np.arange(k), dtype)

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_images, test_images = train_images / 255.0, test_images / 255.0
train_images = train_images.astype(np.float32)
test_images = test_images.astype(np.float32)

train_labels = _one_hot(train_labels, 10)
test_labels = _one_hot(test_labels, 10)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
11501568/11490434 [==============================] - 0s 0us/step

Construir o modelo MNIST com Jax

def loss(params, batch):
  inputs, targets = batch
  preds = predict(params, inputs)
  return -jnp.mean(jnp.sum(preds * targets, axis=1))

def accuracy(params, batch):
  inputs, targets = batch
  target_class = jnp.argmax(targets, axis=1)
  predicted_class = jnp.argmax(predict(params, inputs), axis=1)
  return jnp.mean(predicted_class == target_class)

init_random_params, predict = stax.serial(
    stax.Flatten,
    stax.Dense(1024), stax.Relu,
    stax.Dense(1024), stax.Relu,
    stax.Dense(10), stax.LogSoftmax)

rng = random.PRNGKey(0)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Treine e avalie o modelo

step_size = 0.001
num_epochs = 10
batch_size = 128
momentum_mass = 0.9


num_train = train_images.shape[0]
num_complete_batches, leftover = divmod(num_train, batch_size)
num_batches = num_complete_batches + bool(leftover)

def data_stream():
  rng = npr.RandomState(0)
  while True:
    perm = rng.permutation(num_train)
    for i in range(num_batches):
      batch_idx = perm[i * batch_size:(i + 1) * batch_size]
      yield train_images[batch_idx], train_labels[batch_idx]
batches = data_stream()

opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)

@jit
def update(i, opt_state, batch):
  params = get_params(opt_state)
  return opt_update(i, grad(loss)(params, batch), opt_state)

_, init_params = init_random_params(rng, (-1, 28 * 28))
opt_state = opt_init(init_params)
itercount = itertools.count()

print("\nStarting training...")
for epoch in range(num_epochs):
  start_time = time.time()
  for _ in range(num_batches):
    opt_state = update(next(itercount), opt_state, next(batches))
  epoch_time = time.time() - start_time

  params = get_params(opt_state)
  train_acc = accuracy(params, (train_images, train_labels))
  test_acc = accuracy(params, (test_images, test_labels))
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))
Starting training...
Epoch 0 in 4.69 sec
Training set accuracy 0.8729000091552734
Test set accuracy 0.880299985408783
Epoch 1 in 3.83 sec
Training set accuracy 0.8983666896820068
Test set accuracy 0.9047999978065491
Epoch 2 in 3.81 sec
Training set accuracy 0.9102166891098022
Test set accuracy 0.9138000011444092
Epoch 3 in 3.85 sec
Training set accuracy 0.9172500371932983
Test set accuracy 0.9218999743461609
Epoch 4 in 3.79 sec
Training set accuracy 0.9224500060081482
Test set accuracy 0.9253999590873718
Epoch 5 in 3.72 sec
Training set accuracy 0.9272000193595886
Test set accuracy 0.930899977684021
Epoch 6 in 3.77 sec
Training set accuracy 0.9327666759490967
Test set accuracy 0.9334999918937683
Epoch 7 in 3.77 sec
Training set accuracy 0.9360166788101196
Test set accuracy 0.9370999932289124
Epoch 8 in 3.77 sec
Training set accuracy 0.9390000104904175
Test set accuracy 0.939300000667572
Epoch 9 in 3.73 sec
Training set accuracy 0.9425666928291321
Test set accuracy 0.9429999589920044

Converter para o modelo TFLite.

Observe aqui, nós

  1. Inline os parâmetros para o Jax predict func com functools.partial .
  2. Construir uma jnp.zeros , este é um "espaço reservado" tensor usado para Jax para traçar o modelo.
  3. Chamada experimental_from_jax :> * A serving_func é envolto em uma lista. > * A entrada é associada a um determinado nome e passada como um array empacotado em uma lista.
serving_func = functools.partial(predict, params)
x_input = jnp.zeros((1, 28, 28))
converter = tf.lite.TFLiteConverter.experimental_from_jax(
    [serving_func], [[('input1', x_input)]])
tflite_model = converter.convert()
with open('jax_mnist.tflite', 'wb') as f:
  f.write(tflite_model)
2021-10-30 11:51:13.208329: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:363] Ignored output_format.
2021-10-30 11:51:13.208375: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:366] Ignored drop_control_dependency.
2021-10-30 11:51:13.208383: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:372] Ignored change_concat_input_ranges.

Verifique o modelo TFLite convertido

Compare os resultados do modelo convertido com o modelo Jax.

expected = serving_func(train_images[0:1])

# Run the model with TensorFlow Lite
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], train_images[0:1, :, :])
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])

# Assert if the result of TFLite model is consistent with the JAX model.
np.testing.assert_almost_equal(expected, result, 1e-5)

Otimize o modelo

Vamos fornecer um representative_dataset fazer quantiztion pós-treino para otimizar o modelo.

def representative_dataset():
  for i in range(1000):
    x = train_images[i:i+1]
    yield [x]

converter = tf.lite.TFLiteConverter.experimental_from_jax(
    [serving_func], [[('x', x_input)]])
tflite_model = converter.convert()
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
tflite_quant_model = converter.convert()
with open('jax_mnist_quant.tflite', 'wb') as f:
  f.write(tflite_quant_model)
2021-10-30 11:51:14.202412: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:363] Ignored output_format.
2021-10-30 11:51:14.202455: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:366] Ignored drop_control_dependency.
2021-10-30 11:51:14.202461: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:372] Ignored change_concat_input_ranges.
2021-10-30 11:51:14.293677: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:363] Ignored output_format.
2021-10-30 11:51:14.293768: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:366] Ignored drop_control_dependency.
2021-10-30 11:51:14.293776: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:372] Ignored change_concat_input_ranges.
fully_quantize: 0, inference_type: 6, input_inference_type: 0, output_inference_type: 0

Avalie o modelo otimizado

expected = serving_func(train_images[0:1])

# Run the model with TensorFlow Lite
interpreter = tf.lite.Interpreter(model_content=tflite_quant_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], train_images[0:1, :, :])
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])

# Assert if the result of TFLite model is consistent with the Jax model.
np.testing.assert_almost_equal(expected, result, 1e-5)

Compare o tamanho do modelo quantizado

Devemos ser capazes de ver que o modelo quantizado é quatro vezes menor que o modelo original.

du -h jax_mnist.tflite
du -h jax_mnist_quant.tflite
7.2M    jax_mnist.tflite
1.8M    jax_mnist_quant.tflite