TFX para móviles

Introducción

Esta guía demuestra cómo Tensorflow Extended (TFX) puede crear y evaluar modelos de aprendizaje automático que se implementarán en el dispositivo. TFX ahora proporciona soporte nativo para TFLite , lo que permite realizar inferencias altamente eficientes en dispositivos móviles.

Esta guía lo guía a través de los cambios que se pueden realizar en cualquier canal para generar y evaluar modelos TFLite. Aquí proporcionamos un ejemplo completo que demuestra cómo TFX puede entrenar y evaluar modelos TFLite que se entrenan a partir del conjunto de datos MNIST . Además, mostramos cómo se puede utilizar el mismo canal para exportar simultáneamente tanto el SavedModel estándar basado en Keras como el TFLite, lo que permite a los usuarios comparar la calidad de los dos.

Asumimos que está familiarizado con TFX, nuestros componentes y nuestras canalizaciones. Si no es así, consulte este tutorial .

Pasos

Solo se requieren dos pasos para crear y evaluar un modelo TFLite en TFX. El primer paso es invocar el reescritor TFLite dentro del contexto del TFX Trainer para convertir el modelo TensorFlow entrenado en uno TFLite. El segundo paso es configurar el Evaluador para evaluar los modelos TFLite. Ahora analizaremos cada uno de ellos por turno.

Invocando el reescritor TFLite dentro del Entrenador.

TFX Trainer espera que se especifique un run_fn definido por el usuario en un archivo de módulo. Este run_fn define el modelo que se va a entrenar, lo entrena para el número especificado de iteraciones y exporta el modelo entrenado.

En el resto de esta sección, proporcionamos fragmentos de código que muestran los cambios necesarios para invocar la reescritura de TFLite y exportar un modelo de TFLite. Todo este código se encuentra en run_fn del módulo MNIST TFLite .

Como se muestra en el código siguiente, primero debemos crear una firma que tome un Tensor para cada característica como entrada. Tenga en cuenta que esto es una desviación de la mayoría de los modelos existentes en TFX, que toman protos tf.Example serializados como entrada.

 signatures = {
      'serving_default':
          _get_serve_tf_examples_fn(
              model, tf_transform_output).get_concrete_function(
                  tf.TensorSpec(
                      shape=[None, 784],
                      dtype=tf.float32,
                      name='image_floats'))
  }

Luego, el modelo de Keras se guarda como SavedModel de la misma forma que lo hace normalmente.

  temp_saving_model_dir = os.path.join(fn_args.serving_model_dir, 'temp')
  model.save(temp_saving_model_dir, save_format='tf', signatures=signatures)

Finalmente, creamos una instancia de TFLite rewriter ( tfrw ) y la invocamos en SavedModel para obtener el modelo TFLite. Almacenamos este modelo TFLite en el serving_model_dir proporcionado por la persona que llama a run_fn . De esta manera, el modelo TFLite se almacena en la ubicación donde todos los componentes TFX posteriores esperarán encontrar el modelo.

  tfrw = rewriter_factory.create_rewriter(
      rewriter_factory.TFLITE_REWRITER, name='tflite_rewriter')
  converters.rewrite_saved_model(temp_saving_model_dir,
                                 fn_args.serving_model_dir,
                                 tfrw,
                                 rewriter.ModelType.TFLITE_MODEL)

Evaluación del modelo TFLite.

TFX Evaluator brinda la capacidad de analizar modelos entrenados para comprender su calidad en una amplia gama de métricas. Además de analizar SavedModels, TFX Evaluator ahora también puede analizar modelos TFLite.

El siguiente fragmento de código (reproducido de la canalización MNIST ) muestra cómo configurar un evaluador que analiza un modelo TFLite.

  # Informs the evaluator that the model is a TFLite model.
  eval_config_lite.model_specs[0].model_type = 'tf_lite'

  ...

  # Uses TFMA to compute the evaluation statistics over features of a TFLite
  # model.
  model_analyzer_lite = Evaluator(
      examples=example_gen.outputs['examples'],
      model=trainer_lite.outputs['model'],
      eval_config=eval_config_lite,
  ).with_id('mnist_lite')

Como se muestra arriba, el único cambio que debemos hacer es establecer el campo model_type en tf_lite . No se requieren otros cambios de configuración para analizar el modelo TFLite. Independientemente de si se analiza un modelo TFLite o un SavedModel, la salida del Evaluator tendrá exactamente la misma estructura.

Sin embargo, tenga en cuenta que el evaluador asume que el modelo TFLite está guardado en un archivo llamado tflite dentro de trainer_lite.outputs['modelo'].