مدل های JAX با TensorFlow Lite

این صفحه مسیری را برای کاربرانی فراهم می‌کند که می‌خواهند مدل‌هایی را در JAX آموزش دهند و برای استنباط در تلفن همراه مستقر کنند ( مثال colab ).

روش‌های موجود در این راهنما یک tflite_model را تولید می‌کنند که می‌تواند مستقیماً با مثال کد مفسر TFLite استفاده شود یا در یک فایل TFLite FlatBuffer ذخیره شود.

پيش نياز

توصیه می شود این ویژگی را با جدیدترین بسته پایتون شبانه TensorFlow امتحان کنید.

pip install tf-nightly --upgrade

ما از کتابخانه Orbax Export برای صادرات مدل های JAX استفاده خواهیم کرد. مطمئن شوید که نسخه JAX شما حداقل 0.4.20 یا بالاتر باشد.

pip install jax --upgrade
pip install orbax-export --upgrade

مدل های JAX را به TensorFlow Lite تبدیل کنید

ما از TensorFlow SavedModel به عنوان قالب میانی بین JAX و TensorFlow Lite استفاده می کنیم. هنگامی که SavedModel دارید، API های موجود TensorFlow Lite می توانند برای تکمیل فرآیند تبدیل استفاده شوند.

# This code snippet converts a JAX model to TFLite through TF SavedModel.
from orbax.export import ExportManager
from orbax.export import JaxModule
from orbax.export import ServingConfig
import tensorflow as tf
import jax.numpy as jnp

def model_fn(_, x):
  return jnp.sin(jnp.cos(x))

jax_module = JaxModule({}, model_fn, input_polymorphic_shape='b, ...')

# Option 1: Simply save the model via `tf.saved_model.save` if no need for pre/post
# processing.
tf.saved_model.save(
    jax_module,
    '/some/directory',
    signatures=jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
        tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
    ),
    options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),
)
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()

# Option 2: Define pre/post processing TF functions (e.g. (de)?tokenize).
serving_config = ServingConfig(
    'Serving_default',
    # Corresponds to the input signature of `tf_preprocessor`
    input_signature=[tf.TensorSpec(shape=(None,), dtype=tf.float32, name='input')],
    tf_preprocessor=lambda x: x,
    tf_postprocessor=lambda out: {'output': out}
)
export_mgr = ExportManager(jax_module, [serving_config])
export_mgr.save('/some/directory')
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()

# Option 3: Convert from TF concrete function directly
converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [
        jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
            tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
        )
    ]
)
tflite_model = converter.convert()

مدل TFLite تبدیل شده را بررسی کنید

پس از تبدیل مدل به TFLite، می توانید API های مفسر TFLite را برای بررسی خروجی های مدل اجرا کنید.

# Run the model with TensorFlow Lite
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors() input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], input_data)
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])