توفر هذه الصفحة مسارًا للمستخدمين الذين يرغبون في تدريب النماذج في JAX ونشرها على الهاتف المحمول للاستدلال ( مثال colab ).
تنتج الطرق الموجودة في هذا الدليل نموذج tflite_model
الذي يمكن استخدامه مباشرة مع مثال رمز مترجم TFLite أو حفظه في ملف TFLite FlatBuffer.
المتطلبات المسبقة
يوصى بتجربة هذه الميزة مع أحدث حزمة TensorFlow الليلية لـ Python.
pip install tf-nightly --upgrade
سوف نستخدم مكتبة Orbax Export لتصدير نماذج JAX. تأكد من أن إصدار JAX الخاص بك هو 0.4.20 على الأقل أو أعلى.
pip install jax --upgrade
pip install orbax-export --upgrade
تحويل نماذج JAX إلى TensorFlow Lite
نستخدم TensorFlow SavedModel كتنسيق وسيط بين JAX وTensorFlow Lite. بمجرد حصولك على SavedModel، يمكن استخدام واجهات برمجة تطبيقات TensorFlow Lite الموجودة لإكمال عملية التحويل.
# This code snippet converts a JAX model to TFLite through TF SavedModel.
from orbax.export import ExportManager
from orbax.export import JaxModule
from orbax.export import ServingConfig
import tensorflow as tf
import jax.numpy as jnp
def model_fn(_, x):
return jnp.sin(jnp.cos(x))
jax_module = JaxModule({}, model_fn, input_polymorphic_shape='b, ...')
# Option 1: Simply save the model via `tf.saved_model.save` if no need for pre/post
# processing.
tf.saved_model.save(
jax_module,
'/some/directory',
signatures=jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
),
options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),
)
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()
# Option 2: Define pre/post processing TF functions (e.g. (de)?tokenize).
serving_config = ServingConfig(
'Serving_default',
# Corresponds to the input signature of `tf_preprocessor`
input_signature=[tf.TensorSpec(shape=(None,), dtype=tf.float32, name='input')],
tf_preprocessor=lambda x: x,
tf_postprocessor=lambda out: {'output': out}
)
export_mgr = ExportManager(jax_module, [serving_config])
export_mgr.save('/some/directory')
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()
# Option 3: Convert from TF concrete function directly
converter = tf.lite.TFLiteConverter.from_concrete_functions(
[
jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
)
]
)
tflite_model = converter.convert()
تحقق من طراز TFLite المحول
بعد تحويل النموذج إلى TFLite، يمكنك تشغيل واجهات برمجة تطبيقات مترجم TFLite للتحقق من مخرجات النموذج.
# Run the model with TensorFlow Lite
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors() input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], input_data)
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])