דגמי JAX עם TensorFlow Lite

דף זה מספק נתיב למשתמשים שרוצים לאמן מודלים ב-JAX ולפרוס לנייד לצורך הסקת מסקנות ( למשל colab ).

השיטות במדריך זה מייצרות tflite_model שניתן להשתמש בו ישירות עם דוגמה של קוד המתורגמן TFLite או לשמור בקובץ TFLite FlatBuffer.

תְנַאִי מוּקדָם

מומלץ לנסות תכונה זו עם החבילה החדשה ביותר של TensorFlow לילי Python.

pip install tf-nightly --upgrade

נשתמש בספריית Orbax Export כדי לייצא דגמי JAX. ודא שגרסת ה-JAX שלך היא לפחות 0.4.20 ומעלה.

pip install jax --upgrade
pip install orbax-export --upgrade

המרת דגמי JAX ל- TensorFlow Lite

אנו משתמשים ב- TensorFlow SavedModel כפורמט הביניים בין JAX ל- TensorFlow Lite. ברגע שיש לך SavedModel, ניתן להשתמש בממשקי API קיימים של TensorFlow Lite להשלמת תהליך ההמרה.

# This code snippet converts a JAX model to TFLite through TF SavedModel.
from orbax.export import ExportManager
from orbax.export import JaxModule
from orbax.export import ServingConfig
import tensorflow as tf
import jax.numpy as jnp

def model_fn(_, x):
  return jnp.sin(jnp.cos(x))

jax_module = JaxModule({}, model_fn, input_polymorphic_shape='b, ...')

# Option 1: Simply save the model via `tf.saved_model.save` if no need for pre/post
# processing.
tf.saved_model.save(
    jax_module,
    '/some/directory',
    signatures=jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
        tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
    ),
    options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),
)
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()

# Option 2: Define pre/post processing TF functions (e.g. (de)?tokenize).
serving_config = ServingConfig(
    'Serving_default',
    # Corresponds to the input signature of `tf_preprocessor`
    input_signature=[tf.TensorSpec(shape=(None,), dtype=tf.float32, name='input')],
    tf_preprocessor=lambda x: x,
    tf_postprocessor=lambda out: {'output': out}
)
export_mgr = ExportManager(jax_module, [serving_config])
export_mgr.save('/some/directory')
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()

# Option 3: Convert from TF concrete function directly
converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [
        jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
            tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
        )
    ]
)
tflite_model = converter.convert()

בדוק את דגם ה-TFLite שהומר

לאחר המרת המודל ל-TFLite, אתה יכול להפעיל ממשקי API של מתורגמנים של TFLite כדי לבדוק פלטי מודל.

# Run the model with TensorFlow Lite
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors() input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], input_data)
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])