Model JAX dengan TensorFlow Lite

Halaman ini menyediakan jalur bagi pengguna yang ingin melatih model di JAX dan menerapkannya ke perangkat seluler untuk inferensi ( contoh colab ).

Metode dalam panduan ini menghasilkan tflite_model yang dapat digunakan langsung dengan contoh kode penerjemah TFLite atau disimpan ke file TFLite FlatBuffer.

Prasyarat

Disarankan untuk mencoba fitur ini dengan paket Python nightly TensorFlow terbaru.

pip install tf-nightly --upgrade

Kami akan menggunakan perpustakaan Orbax Ekspor untuk mengekspor model JAX. Pastikan versi JAX Anda setidaknya 0.4.20 atau lebih tinggi.

pip install jax --upgrade
pip install orbax-export --upgrade

Konversikan model JAX ke TensorFlow Lite

Kami menggunakan TensorFlow SavedModel sebagai format perantara antara JAX dan TensorFlow Lite. Setelah Anda memiliki SavedModel, API TensorFlow Lite yang ada dapat digunakan untuk menyelesaikan proses konversi.

# This code snippet converts a JAX model to TFLite through TF SavedModel.
from orbax.export import ExportManager
from orbax.export import JaxModule
from orbax.export import ServingConfig
import tensorflow as tf
import jax.numpy as jnp

def model_fn(_, x):
  return jnp.sin(jnp.cos(x))

jax_module = JaxModule({}, model_fn, input_polymorphic_shape='b, ...')

# Option 1: Simply save the model via `tf.saved_model.save` if no need for pre/post
# processing.
tf.saved_model.save(
    jax_module,
    '/some/directory',
    signatures=jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
        tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
    ),
    options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),
)
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()

# Option 2: Define pre/post processing TF functions (e.g. (de)?tokenize).
serving_config = ServingConfig(
    'Serving_default',
    # Corresponds to the input signature of `tf_preprocessor`
    input_signature=[tf.TensorSpec(shape=(None,), dtype=tf.float32, name='input')],
    tf_preprocessor=lambda x: x,
    tf_postprocessor=lambda out: {'output': out}
)
export_mgr = ExportManager(jax_module, [serving_config])
export_mgr.save('/some/directory')
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()

# Option 3: Convert from TF concrete function directly
converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [
        jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
            tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
        )
    ]
)
tflite_model = converter.convert()

Periksa model TFLite yang dikonversi

Setelah model dikonversi ke TFLite, Anda dapat menjalankan API interpreter TFLite untuk memeriksa keluaran model.

# Run the model with TensorFlow Lite
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors() input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], input_data)
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])