TensorFlow Lite সহ JAX মডেল

এই পৃষ্ঠাটি এমন ব্যবহারকারীদের জন্য একটি পথ প্রদান করে যারা JAX-এ মডেল প্রশিক্ষণ দিতে চান এবং অনুমানের জন্য মোবাইলে স্থাপন করতে চান ( উদাহরণস্বরূপ colab )।

এই গাইডের পদ্ধতিগুলি একটি tflite_model তৈরি করে যা সরাসরি TFLite ইন্টারপ্রেটার কোড উদাহরণের সাথে ব্যবহার করা যেতে পারে বা একটি TFLite FlatBuffer ফাইলে সংরক্ষণ করা যেতে পারে।

পূর্বশর্ত

নতুন TensorFlow রাতের পাইথন প্যাকেজের সাথে এই বৈশিষ্ট্যটি ব্যবহার করার পরামর্শ দেওয়া হচ্ছে।

pip install tf-nightly --upgrade

আমরা JAX মডেল রপ্তানি করতে Orbax এক্সপোর্ট লাইব্রেরি ব্যবহার করব। নিশ্চিত করুন যে আপনার JAX সংস্করণটি কমপক্ষে 0.4.20 বা তার বেশি।

pip install jax --upgrade
pip install orbax-export --upgrade

JAX মডেলগুলিকে টেনসরফ্লো লাইটে রূপান্তর করুন

আমরা JAX এবং TensorFlow Lite এর মধ্যবর্তী বিন্যাস হিসাবে TensorFlow SavedModel ব্যবহার করি। একবার আপনার একটি SavedModel হয়ে গেলে বর্তমান TensorFlow Lite APIগুলি রূপান্তর প্রক্রিয়া সম্পূর্ণ করতে ব্যবহার করা যেতে পারে।

# This code snippet converts a JAX model to TFLite through TF SavedModel.
from orbax.export import ExportManager
from orbax.export import JaxModule
from orbax.export import ServingConfig
import tensorflow as tf
import jax.numpy as jnp

def model_fn(_, x):
  return jnp.sin(jnp.cos(x))

jax_module = JaxModule({}, model_fn, input_polymorphic_shape='b, ...')

# Option 1: Simply save the model via `tf.saved_model.save` if no need for pre/post
# processing.
tf.saved_model.save(
    jax_module,
    '/some/directory',
    signatures=jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
        tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
    ),
    options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),
)
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()

# Option 2: Define pre/post processing TF functions (e.g. (de)?tokenize).
serving_config = ServingConfig(
    'Serving_default',
    # Corresponds to the input signature of `tf_preprocessor`
    input_signature=[tf.TensorSpec(shape=(None,), dtype=tf.float32, name='input')],
    tf_preprocessor=lambda x: x,
    tf_postprocessor=lambda out: {'output': out}
)
export_mgr = ExportManager(jax_module, [serving_config])
export_mgr.save('/some/directory')
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()

# Option 3: Convert from TF concrete function directly
converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [
        jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
            tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
        )
    ]
)
tflite_model = converter.convert()

রূপান্তরিত TFLite মডেল পরীক্ষা করুন

মডেলটি TFLite-এ রূপান্তরিত হওয়ার পরে, আপনি মডেল আউটপুট পরীক্ষা করতে TFLite দোভাষী API চালাতে পারেন।

# Run the model with TensorFlow Lite
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors() input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], input_data)
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])