Introdução
Este guia demonstra como o Tensorflow Extended (TFX) pode criar e avaliar modelos de machine learning que serão implantados no dispositivo. O TFX agora oferece suporte nativo para TFLite , o que possibilita realizar inferências altamente eficientes em dispositivos móveis.
Este guia orienta você pelas alterações que podem ser feitas em qualquer pipeline para gerar e avaliar modelos TFLite. Fornecemos um exemplo completo aqui , demonstrando como o TFX pode treinar e avaliar modelos TFLite que são treinados a partir do conjunto de dados MNIST . Além disso, mostramos como o mesmo pipeline pode ser usado para exportar simultaneamente o SavedModel padrão baseado em Keras e o TFLite, permitindo que os usuários comparem a qualidade dos dois.
Presumimos que você esteja familiarizado com o TFX, nossos componentes e nossos pipelines. Se não, então por favor veja este tutorial .
Degraus
Apenas duas etapas são necessárias para criar e avaliar um modelo TFLite no TFX. A primeira etapa é invocar o reescritor TFLite no contexto do TFX Trainer para converter o modelo treinado do TensorFlow em um TFLite. A segunda etapa é configurar o Evaluator para avaliar os modelos TFLite. Agora discutimos cada um por sua vez.
Invocando o reescritor TFLite no Trainer.
O TFX Trainer espera que um run_fn
definido pelo usuário seja especificado em um arquivo de módulo. Este run_fn
define o modelo a ser treinado, treina-o para o número especificado de iterações e exporta o modelo treinado.
No restante desta seção, fornecemos trechos de código que mostram as alterações necessárias para invocar o reescritor TFLite e exportar um modelo TFLite. Todo esse código está localizado no run_fn
do módulo MNIST TFLite .
Conforme mostrado no código abaixo, primeiro devemos criar uma assinatura que receba um Tensor
para cada recurso como entrada. Observe que isso é um desvio da maioria dos modelos existentes no TFX, que usam protos tf.Example serializados como entrada.
signatures = {
'serving_default':
_get_serve_tf_examples_fn(
model, tf_transform_output).get_concrete_function(
tf.TensorSpec(
shape=[None, 784],
dtype=tf.float32,
name='image_floats'))
}
Em seguida, o modelo Keras é salvo como um SavedModel da mesma forma que normalmente é.
temp_saving_model_dir = os.path.join(fn_args.serving_model_dir, 'temp')
model.save(temp_saving_model_dir, save_format='tf', signatures=signatures)
Por fim, criamos uma instância do reescritor TFLite ( tfrw
) e a invocamos no SavedModel para obter o modelo TFLite. Armazenamos esse modelo serving_model_dir
no servidor_modelo_dir fornecido pelo chamador do run_fn
. Dessa forma, o modelo TFLite é armazenado no local onde todos os componentes do TFX downstream esperam encontrar o modelo.
tfrw = rewriter_factory.create_rewriter(
rewriter_factory.TFLITE_REWRITER, name='tflite_rewriter')
converters.rewrite_saved_model(temp_saving_model_dir,
fn_args.serving_model_dir,
tfrw,
rewriter.ModelType.TFLITE_MODEL)
Avaliação do modelo TFLite.
O TFX Evaluator oferece a capacidade de analisar modelos treinados para entender sua qualidade em uma ampla variedade de métricas. Além de analisar SavedModels, o TFX Evaluator agora também pode analisar modelos TFLite.
O trecho de código a seguir (reproduzido do pipeline MNIST ) mostra como configurar um avaliador que analisa um modelo TFLite.
# Informs the evaluator that the model is a TFLite model.
eval_config_lite.model_specs[0].model_type = 'tf_lite'
...
# Uses TFMA to compute the evaluation statistics over features of a TFLite
# model.
model_analyzer_lite = Evaluator(
examples=example_gen.outputs['examples'],
model=trainer_lite.outputs['model'],
eval_config=eval_config_lite,
).with_id('mnist_lite')
Como mostrado acima, a única mudança que precisamos fazer é definir o campo model_type
para tf_lite
. Nenhuma outra alteração de configuração é necessária para analisar o modelo TFLite. Independentemente de um modelo TFLite ou um SavedModel ser analisado, a saída do Evaluator
terá exatamente a mesma estrutura.
No entanto, observe que o Avaliador assume que o modelo TFLite é salvo em um arquivo chamado tflite
dentro do trainer_lite.outputs['model'].