genel bakış
Bu CodeLab, Jax kullanılarak MNIST tanıma için bir modelin nasıl oluşturulacağını ve bunun TensorFlow Lite'a nasıl dönüştürüleceğini gösterir. Bu kod laboratuvarı ayrıca, eğitim sonrası niceleme ile Jax'e dönüştürülmüş TFLite modelinin nasıl optimize edileceğini gösterecek.
![]() | ![]() | ![]() | ![]() |
Önkoşullar
Bu özelliği en yeni TensorFlow gecelik pip yapısıyla denemeniz önerilir.
pip install tf-nightly --upgrade
pip install jax --upgrade
pip install jaxlib --upgrade
Veri Hazırlama
MNIST verilerini Keras veri seti ve ön işleme ile indirin.
import numpy as np
import tensorflow as tf
import functools
import time
import itertools
import numpy.random as npr
import jax.numpy as jnp
from jax import jit, grad, random
from jax.experimental import optimizers
from jax.experimental import stax
def _one_hot(x, k, dtype=np.float32):
"""Create a one-hot encoding of x of size k."""
return np.array(x[:, None] == np.arange(k), dtype)
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_images, test_images = train_images / 255.0, test_images / 255.0
train_images = train_images.astype(np.float32)
test_images = test_images.astype(np.float32)
train_labels = _one_hot(train_labels, 10)
test_labels = _one_hot(test_labels, 10)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz 11493376/11490434 [==============================] - 0s 0us/step 11501568/11490434 [==============================] - 0s 0us/step
Jax ile MNIST modelini oluşturun
def loss(params, batch):
inputs, targets = batch
preds = predict(params, inputs)
return -jnp.mean(jnp.sum(preds * targets, axis=1))
def accuracy(params, batch):
inputs, targets = batch
target_class = jnp.argmax(targets, axis=1)
predicted_class = jnp.argmax(predict(params, inputs), axis=1)
return jnp.mean(predicted_class == target_class)
init_random_params, predict = stax.serial(
stax.Flatten,
stax.Dense(1024), stax.Relu,
stax.Dense(1024), stax.Relu,
stax.Dense(10), stax.LogSoftmax)
rng = random.PRNGKey(0)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Modeli Eğitin ve Değerlendirin
step_size = 0.001
num_epochs = 10
batch_size = 128
momentum_mass = 0.9
num_train = train_images.shape[0]
num_complete_batches, leftover = divmod(num_train, batch_size)
num_batches = num_complete_batches + bool(leftover)
def data_stream():
rng = npr.RandomState(0)
while True:
perm = rng.permutation(num_train)
for i in range(num_batches):
batch_idx = perm[i * batch_size:(i + 1) * batch_size]
yield train_images[batch_idx], train_labels[batch_idx]
batches = data_stream()
opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)
@jit
def update(i, opt_state, batch):
params = get_params(opt_state)
return opt_update(i, grad(loss)(params, batch), opt_state)
_, init_params = init_random_params(rng, (-1, 28 * 28))
opt_state = opt_init(init_params)
itercount = itertools.count()
print("\nStarting training...")
for epoch in range(num_epochs):
start_time = time.time()
for _ in range(num_batches):
opt_state = update(next(itercount), opt_state, next(batches))
epoch_time = time.time() - start_time
params = get_params(opt_state)
train_acc = accuracy(params, (train_images, train_labels))
test_acc = accuracy(params, (test_images, test_labels))
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set accuracy {}".format(train_acc))
print("Test set accuracy {}".format(test_acc))
Starting training... Epoch 0 in 4.69 sec Training set accuracy 0.8729000091552734 Test set accuracy 0.880299985408783 Epoch 1 in 3.83 sec Training set accuracy 0.8983666896820068 Test set accuracy 0.9047999978065491 Epoch 2 in 3.81 sec Training set accuracy 0.9102166891098022 Test set accuracy 0.9138000011444092 Epoch 3 in 3.85 sec Training set accuracy 0.9172500371932983 Test set accuracy 0.9218999743461609 Epoch 4 in 3.79 sec Training set accuracy 0.9224500060081482 Test set accuracy 0.9253999590873718 Epoch 5 in 3.72 sec Training set accuracy 0.9272000193595886 Test set accuracy 0.930899977684021 Epoch 6 in 3.77 sec Training set accuracy 0.9327666759490967 Test set accuracy 0.9334999918937683 Epoch 7 in 3.77 sec Training set accuracy 0.9360166788101196 Test set accuracy 0.9370999932289124 Epoch 8 in 3.77 sec Training set accuracy 0.9390000104904175 Test set accuracy 0.939300000667572 Epoch 9 in 3.73 sec Training set accuracy 0.9425666928291321 Test set accuracy 0.9429999589920044
TFLite modeline dönüştürün.
Buraya not edin, biz
- Satır içi Jax paramsi
predict
ile funcfunctools.partial
. - Bir İnşa
jnp.zeros
, bu modeli izlemek için Jax için kullanılan tensör bir "tutucu" dir. - Çağrı
experimental_from_jax
:> *serving_func
bir listede sarılır. > * Girdi, belirli bir adla ilişkilendirilir ve bir listeye sarılmış bir dizi olarak iletilir.
serving_func = functools.partial(predict, params)
x_input = jnp.zeros((1, 28, 28))
converter = tf.lite.TFLiteConverter.experimental_from_jax(
[serving_func], [[('input1', x_input)]])
tflite_model = converter.convert()
with open('jax_mnist.tflite', 'wb') as f:
f.write(tflite_model)
2021-10-30 11:51:13.208329: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:363] Ignored output_format. 2021-10-30 11:51:13.208375: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:366] Ignored drop_control_dependency. 2021-10-30 11:51:13.208383: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:372] Ignored change_concat_input_ranges.
Dönüştürülen TFLite Modelini Kontrol Edin
Dönüştürülen modelin sonuçlarını Jax modeliyle karşılaştırın.
expected = serving_func(train_images[0:1])
# Run the model with TensorFlow Lite
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], train_images[0:1, :, :])
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])
# Assert if the result of TFLite model is consistent with the JAX model.
np.testing.assert_almost_equal(expected, result, 1e-5)
Modeli Optimize Edin
Biz sağlayacaktır representative_dataset
modelini optimize etmek sonrası eğitim quantiztion yapmak.
def representative_dataset():
for i in range(1000):
x = train_images[i:i+1]
yield [x]
converter = tf.lite.TFLiteConverter.experimental_from_jax(
[serving_func], [[('x', x_input)]])
tflite_model = converter.convert()
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
tflite_quant_model = converter.convert()
with open('jax_mnist_quant.tflite', 'wb') as f:
f.write(tflite_quant_model)
2021-10-30 11:51:14.202412: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:363] Ignored output_format. 2021-10-30 11:51:14.202455: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:366] Ignored drop_control_dependency. 2021-10-30 11:51:14.202461: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:372] Ignored change_concat_input_ranges. 2021-10-30 11:51:14.293677: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:363] Ignored output_format. 2021-10-30 11:51:14.293768: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:366] Ignored drop_control_dependency. 2021-10-30 11:51:14.293776: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:372] Ignored change_concat_input_ranges. fully_quantize: 0, inference_type: 6, input_inference_type: 0, output_inference_type: 0
Optimize Edilmiş Modeli Değerlendirin
expected = serving_func(train_images[0:1])
# Run the model with TensorFlow Lite
interpreter = tf.lite.Interpreter(model_content=tflite_quant_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], train_images[0:1, :, :])
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])
# Assert if the result of TFLite model is consistent with the Jax model.
np.testing.assert_almost_equal(expected, result, 1e-5)
Nicelleştirilmiş Model boyutunu karşılaştırın
Kuantize modelin orijinal modelden dört kat daha küçük olduğunu görebilmemiz gerekir.
du -h jax_mnist.tflite
du -h jax_mnist_quant.tflite
7.2M jax_mnist.tflite 1.8M jax_mnist_quant.tflite,
genel bakış
Bu CodeLab, Jax kullanılarak MNIST tanıma için bir modelin nasıl oluşturulacağını ve bunun TensorFlow Lite'a nasıl dönüştürüleceğini gösterir. Bu kod laboratuvarı ayrıca, eğitim sonrası niceleme ile Jax'e dönüştürülmüş TFLite modelinin nasıl optimize edileceğini gösterecek.
![]() | ![]() | ![]() | ![]() |
Önkoşullar
Bu özelliği en yeni TensorFlow gecelik pip yapısıyla denemeniz önerilir.
pip install tf-nightly --upgrade
pip install jax --upgrade
pip install jaxlib --upgrade
Veri Hazırlama
MNIST verilerini Keras veri seti ve ön işleme ile indirin.
import numpy as np
import tensorflow as tf
import functools
import time
import itertools
import numpy.random as npr
import jax.numpy as jnp
from jax import jit, grad, random
from jax.experimental import optimizers
from jax.experimental import stax
def _one_hot(x, k, dtype=np.float32):
"""Create a one-hot encoding of x of size k."""
return np.array(x[:, None] == np.arange(k), dtype)
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_images, test_images = train_images / 255.0, test_images / 255.0
train_images = train_images.astype(np.float32)
test_images = test_images.astype(np.float32)
train_labels = _one_hot(train_labels, 10)
test_labels = _one_hot(test_labels, 10)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz 11493376/11490434 [==============================] - 0s 0us/step 11501568/11490434 [==============================] - 0s 0us/step
Jax ile MNIST modelini oluşturun
def loss(params, batch):
inputs, targets = batch
preds = predict(params, inputs)
return -jnp.mean(jnp.sum(preds * targets, axis=1))
def accuracy(params, batch):
inputs, targets = batch
target_class = jnp.argmax(targets, axis=1)
predicted_class = jnp.argmax(predict(params, inputs), axis=1)
return jnp.mean(predicted_class == target_class)
init_random_params, predict = stax.serial(
stax.Flatten,
stax.Dense(1024), stax.Relu,
stax.Dense(1024), stax.Relu,
stax.Dense(10), stax.LogSoftmax)
rng = random.PRNGKey(0)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Modeli Eğitin ve Değerlendirin
step_size = 0.001
num_epochs = 10
batch_size = 128
momentum_mass = 0.9
num_train = train_images.shape[0]
num_complete_batches, leftover = divmod(num_train, batch_size)
num_batches = num_complete_batches + bool(leftover)
def data_stream():
rng = npr.RandomState(0)
while True:
perm = rng.permutation(num_train)
for i in range(num_batches):
batch_idx = perm[i * batch_size:(i + 1) * batch_size]
yield train_images[batch_idx], train_labels[batch_idx]
batches = data_stream()
opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)
@jit
def update(i, opt_state, batch):
params = get_params(opt_state)
return opt_update(i, grad(loss)(params, batch), opt_state)
_, init_params = init_random_params(rng, (-1, 28 * 28))
opt_state = opt_init(init_params)
itercount = itertools.count()
print("\nStarting training...")
for epoch in range(num_epochs):
start_time = time.time()
for _ in range(num_batches):
opt_state = update(next(itercount), opt_state, next(batches))
epoch_time = time.time() - start_time
params = get_params(opt_state)
train_acc = accuracy(params, (train_images, train_labels))
test_acc = accuracy(params, (test_images, test_labels))
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set accuracy {}".format(train_acc))
print("Test set accuracy {}".format(test_acc))
Starting training... Epoch 0 in 4.69 sec Training set accuracy 0.8729000091552734 Test set accuracy 0.880299985408783 Epoch 1 in 3.83 sec Training set accuracy 0.8983666896820068 Test set accuracy 0.9047999978065491 Epoch 2 in 3.81 sec Training set accuracy 0.9102166891098022 Test set accuracy 0.9138000011444092 Epoch 3 in 3.85 sec Training set accuracy 0.9172500371932983 Test set accuracy 0.9218999743461609 Epoch 4 in 3.79 sec Training set accuracy 0.9224500060081482 Test set accuracy 0.9253999590873718 Epoch 5 in 3.72 sec Training set accuracy 0.9272000193595886 Test set accuracy 0.930899977684021 Epoch 6 in 3.77 sec Training set accuracy 0.9327666759490967 Test set accuracy 0.9334999918937683 Epoch 7 in 3.77 sec Training set accuracy 0.9360166788101196 Test set accuracy 0.9370999932289124 Epoch 8 in 3.77 sec Training set accuracy 0.9390000104904175 Test set accuracy 0.939300000667572 Epoch 9 in 3.73 sec Training set accuracy 0.9425666928291321 Test set accuracy 0.9429999589920044
TFLite modeline dönüştürün.
Buraya not edin, biz
- Satır içi Jax paramsi
predict
ile funcfunctools.partial
. - Bir İnşa
jnp.zeros
, bu modeli izlemek için Jax için kullanılan tensör bir "tutucu" dir. - Çağrı
experimental_from_jax
:> *serving_func
bir listede sarılır. > * Girdi, belirli bir adla ilişkilendirilir ve bir listeye sarılmış bir dizi olarak iletilir.
serving_func = functools.partial(predict, params)
x_input = jnp.zeros((1, 28, 28))
converter = tf.lite.TFLiteConverter.experimental_from_jax(
[serving_func], [[('input1', x_input)]])
tflite_model = converter.convert()
with open('jax_mnist.tflite', 'wb') as f:
f.write(tflite_model)
2021-10-30 11:51:13.208329: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:363] Ignored output_format. 2021-10-30 11:51:13.208375: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:366] Ignored drop_control_dependency. 2021-10-30 11:51:13.208383: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:372] Ignored change_concat_input_ranges.
Dönüştürülen TFLite Modelini Kontrol Edin
Dönüştürülen modelin sonuçlarını Jax modeliyle karşılaştırın.
expected = serving_func(train_images[0:1])
# Run the model with TensorFlow Lite
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], train_images[0:1, :, :])
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])
# Assert if the result of TFLite model is consistent with the JAX model.
np.testing.assert_almost_equal(expected, result, 1e-5)
Modeli Optimize Edin
Biz sağlayacaktır representative_dataset
modelini optimize etmek sonrası eğitim quantiztion yapmak.
def representative_dataset():
for i in range(1000):
x = train_images[i:i+1]
yield [x]
converter = tf.lite.TFLiteConverter.experimental_from_jax(
[serving_func], [[('x', x_input)]])
tflite_model = converter.convert()
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
tflite_quant_model = converter.convert()
with open('jax_mnist_quant.tflite', 'wb') as f:
f.write(tflite_quant_model)
2021-10-30 11:51:14.202412: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:363] Ignored output_format. 2021-10-30 11:51:14.202455: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:366] Ignored drop_control_dependency. 2021-10-30 11:51:14.202461: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:372] Ignored change_concat_input_ranges. 2021-10-30 11:51:14.293677: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:363] Ignored output_format. 2021-10-30 11:51:14.293768: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:366] Ignored drop_control_dependency. 2021-10-30 11:51:14.293776: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:372] Ignored change_concat_input_ranges. fully_quantize: 0, inference_type: 6, input_inference_type: 0, output_inference_type: 0
Optimize Edilmiş Modeli Değerlendirin
expected = serving_func(train_images[0:1])
# Run the model with TensorFlow Lite
interpreter = tf.lite.Interpreter(model_content=tflite_quant_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], train_images[0:1, :, :])
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])
# Assert if the result of TFLite model is consistent with the Jax model.
np.testing.assert_almost_equal(expected, result, 1e-5)
Nicelleştirilmiş Model boyutunu karşılaştırın
Kuantize modelin orijinal modelden dört kat daha küçük olduğunu görebilmemiz gerekir.
du -h jax_mnist.tflite
du -h jax_mnist_quant.tflite
7.2M jax_mnist.tflite 1.8M jax_mnist_quant.tflite