TFX para móvil

Organiza tus páginas con colecciones Guarda y categoriza el contenido según tus preferencias.

Introducción

Esta guía demuestra cómo Tensorflow Extended (TFX) puede crear y evaluar modelos de aprendizaje automático que se implementarán en el dispositivo. TFX ahora brinda soporte nativo para TFLite , lo que hace posible realizar inferencias altamente eficientes en dispositivos móviles.

Esta guía lo guía a través de los cambios que se pueden realizar en cualquier tubería para generar y evaluar modelos TFLite. Proporcionamos un ejemplo completo aquí , que demuestra cómo TFX puede entrenar y evaluar modelos TFLite que se entrenan a partir del conjunto de datos MNIST . Además, mostramos cómo se puede usar la misma canalización para exportar simultáneamente tanto el modelo guardado estándar basado en Keras como el TFLite , lo que permite a los usuarios comparar la calidad de los dos.

Asumimos que está familiarizado con TFX, nuestros componentes y nuestras canalizaciones. Si no es así, consulte este tutorial .

Pasos

Solo se requieren dos pasos para crear y evaluar un modelo TFLite en TFX. El primer paso es invocar el reescritor TFLite dentro del contexto de TFX Trainer para convertir el modelo TensorFlow entrenado en uno TFLite. El segundo paso es configurar el Evaluador para evaluar modelos TFLite. Ahora discutimos cada uno a su vez.

Invocando el reescritor TFLite dentro del Entrenador.

TFX Trainer espera que se especifique un run_fn definido por el usuario en un archivo de módulo. Este run_fn define el modelo que se va a entrenar, lo entrena para el número especificado de iteraciones y exporta el modelo entrenado.

En el resto de esta sección, proporcionamos fragmentos de código que muestran los cambios necesarios para invocar la reescritura de TFLite y exportar un modelo de TFLite. Todo este código se encuentra en run_fn del módulo MNIST TFLite .

Como se muestra en el código a continuación, primero debemos crear una firma que tome un Tensor para cada función como entrada. Tenga en cuenta que esto es una desviación de la mayoría de los modelos existentes en TFX, que toman prototipos tf.Example serializados como entrada.

 signatures = {
      'serving_default':
          _get_serve_tf_examples_fn(
              model, tf_transform_output).get_concrete_function(
                  tf.TensorSpec(
                      shape=[None, 784],
                      dtype=tf.float32,
                      name='image_floats'))
  }

Luego, el modelo de Keras se guarda como un modelo guardado de la misma manera que lo hace normalmente.

  temp_saving_model_dir = os.path.join(fn_args.serving_model_dir, 'temp')
  model.save(temp_saving_model_dir, save_format='tf', signatures=signatures)

Finalmente, creamos una instancia del reescritor TFLite ( tfrw ) y lo invocamos en el modelo guardado para obtener el modelo TFLite. Almacenamos este modelo serving_model_dir en el servidor_modelo_dir proporcionado por la persona que llama a run_fn . De esta manera, el modelo TFLite se almacena en la ubicación donde todos los componentes TFX posteriores esperan encontrar el modelo.

  tfrw = rewriter_factory.create_rewriter(
      rewriter_factory.TFLITE_REWRITER, name='tflite_rewriter')
  converters.rewrite_saved_model(temp_saving_model_dir,
                                 fn_args.serving_model_dir,
                                 tfrw,
                                 rewriter.ModelType.TFLITE_MODEL)

Evaluación del modelo TFLite.

El evaluador TFX brinda la capacidad de analizar modelos entrenados para comprender su calidad en una amplia gama de métricas. Además de analizar modelos guardados, TFX Evaluator ahora también puede analizar modelos TFLite.

El siguiente fragmento de código (reproducido de la canalización MNIST ) muestra cómo configurar un evaluador que analiza un modelo TFLite.

  # Informs the evaluator that the model is a TFLite model.
  eval_config_lite.model_specs[0].model_type = 'tf_lite'

  ...

  # Uses TFMA to compute the evaluation statistics over features of a TFLite
  # model.
  model_analyzer_lite = Evaluator(
      examples=example_gen.outputs['examples'],
      model=trainer_lite.outputs['model'],
      eval_config=eval_config_lite,
  ).with_id('mnist_lite')

Como se muestra arriba, el único cambio que debemos hacer es establecer el campo model_type en tf_lite . No se requieren otros cambios de configuración para analizar el modelo TFLite. Independientemente de si se analiza un modelo TFLite o un modelo guardado, la salida del Evaluator tendrá exactamente la misma estructura.

Sin embargo, tenga en cuenta que el Evaluador asume que el modelo TFLite se guarda en un archivo llamado tflite dentro de trainer_lite.outputs['model'].