نماذج JAX مع TensorFlow Lite

توفر هذه الصفحة مسارًا للمستخدمين الذين يرغبون في تدريب النماذج في JAX ونشرها على الهاتف المحمول للاستدلال ( مثال colab ).

تنتج الطرق الموجودة في هذا الدليل نموذج tflite_model الذي يمكن استخدامه مباشرة مع مثال رمز مترجم TFLite أو حفظه في ملف TFLite FlatBuffer.

المتطلبات المسبقة

يوصى بتجربة هذه الميزة مع أحدث حزمة TensorFlow الليلية لـ Python.

pip install tf-nightly --upgrade

سوف نستخدم مكتبة Orbax Export لتصدير نماذج JAX. تأكد من أن إصدار JAX الخاص بك هو 0.4.20 على الأقل أو أعلى.

pip install jax --upgrade
pip install orbax-export --upgrade

تحويل نماذج JAX إلى TensorFlow Lite

نستخدم TensorFlow SavedModel كتنسيق وسيط بين JAX وTensorFlow Lite. بمجرد حصولك على SavedModel، يمكن استخدام واجهات برمجة تطبيقات TensorFlow Lite الموجودة لإكمال عملية التحويل.

# This code snippet converts a JAX model to TFLite through TF SavedModel.
from orbax.export import ExportManager
from orbax.export import JaxModule
from orbax.export import ServingConfig
import tensorflow as tf
import jax.numpy as jnp

def model_fn(_, x):
  return jnp.sin(jnp.cos(x))

jax_module = JaxModule({}, model_fn, input_polymorphic_shape='b, ...')

# Option 1: Simply save the model via `tf.saved_model.save` if no need for pre/post
# processing.
tf.saved_model.save(
    jax_module,
    '/some/directory',
    signatures=jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
        tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
    ),
    options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),
)
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()

# Option 2: Define pre/post processing TF functions (e.g. (de)?tokenize).
serving_config = ServingConfig(
    'Serving_default',
    # Corresponds to the input signature of `tf_preprocessor`
    input_signature=[tf.TensorSpec(shape=(None,), dtype=tf.float32, name='input')],
    tf_preprocessor=lambda x: x,
    tf_postprocessor=lambda out: {'output': out}
)
export_mgr = ExportManager(jax_module, [serving_config])
export_mgr.save('/some/directory')
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()

# Option 3: Convert from TF concrete function directly
converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [
        jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
            tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
        )
    ]
)
tflite_model = converter.convert()

تحقق من طراز TFLite المحول

بعد تحويل النموذج إلى TFLite، يمكنك تشغيل واجهات برمجة تطبيقات مترجم TFLite للتحقق من مخرجات النموذج.

# Run the model with TensorFlow Lite
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors() input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], input_data)
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])