דף זה מספק נתיב למשתמשים שרוצים לאמן מודלים ב-JAX ולפרוס לנייד לצורך הסקת מסקנות ( למשל colab ).
השיטות במדריך זה מייצרות tflite_model
שניתן להשתמש בו ישירות עם דוגמה של קוד המתורגמן TFLite או לשמור בקובץ TFLite FlatBuffer.
תְנַאִי מוּקדָם
מומלץ לנסות תכונה זו עם החבילה החדשה ביותר של TensorFlow לילי Python.
pip install tf-nightly --upgrade
נשתמש בספריית Orbax Export כדי לייצא דגמי JAX. ודא שגרסת ה-JAX שלך היא לפחות 0.4.20 ומעלה.
pip install jax --upgrade
pip install orbax-export --upgrade
המרת דגמי JAX ל- TensorFlow Lite
אנו משתמשים ב- TensorFlow SavedModel כפורמט הביניים בין JAX ל- TensorFlow Lite. ברגע שיש לך SavedModel, ניתן להשתמש בממשקי API קיימים של TensorFlow Lite להשלמת תהליך ההמרה.
# This code snippet converts a JAX model to TFLite through TF SavedModel.
from orbax.export import ExportManager
from orbax.export import JaxModule
from orbax.export import ServingConfig
import tensorflow as tf
import jax.numpy as jnp
def model_fn(_, x):
return jnp.sin(jnp.cos(x))
jax_module = JaxModule({}, model_fn, input_polymorphic_shape='b, ...')
# Option 1: Simply save the model via `tf.saved_model.save` if no need for pre/post
# processing.
tf.saved_model.save(
jax_module,
'/some/directory',
signatures=jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
),
options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),
)
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()
# Option 2: Define pre/post processing TF functions (e.g. (de)?tokenize).
serving_config = ServingConfig(
'Serving_default',
# Corresponds to the input signature of `tf_preprocessor`
input_signature=[tf.TensorSpec(shape=(None,), dtype=tf.float32, name='input')],
tf_preprocessor=lambda x: x,
tf_postprocessor=lambda out: {'output': out}
)
export_mgr = ExportManager(jax_module, [serving_config])
export_mgr.save('/some/directory')
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()
# Option 3: Convert from TF concrete function directly
converter = tf.lite.TFLiteConverter.from_concrete_functions(
[
jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
)
]
)
tflite_model = converter.convert()
בדוק את דגם ה-TFLite שהומר
לאחר המרת המודל ל-TFLite, אתה יכול להפעיל ממשקי API של מתורגמנים של TFLite כדי לבדוק פלטי מודל.
# Run the model with TensorFlow Lite
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors() input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], input_data)
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])