Recortar pesos insignificantes

Este documento proporciona una descripción general sobre la poda de modelos para ayudarle a determinar cómo se adapta a su caso de uso.

Descripción general

La poda de peso basada en la magnitud reduce gradualmente a cero los pesos del modelo durante el proceso de entrenamiento para lograr la escasez del modelo. Los modelos dispersos son más fáciles de comprimir y podemos omitir los ceros durante la inferencia para mejorar la latencia.

Esta técnica aporta mejoras mediante la compresión del modelo. En el futuro, el soporte del marco para esta técnica proporcionará mejoras en la latencia. Hemos visto mejoras de hasta 6 veces en la compresión del modelo con una pérdida mínima de precisión.

La técnica se está evaluando en varias aplicaciones de voz, como el reconocimiento de voz y la conversión de texto a voz, y se ha experimentado con varios modelos de visión y traducción.

Matriz de compatibilidad de API

Los usuarios pueden aplicar la poda con las siguientes API:

  • Construcción de modelos: keras solo con modelos Secuenciales y Funcionales
  • Versiones de TensorFlow: TF 1.x para las versiones 1.14+ y 2.x.
    • No se admiten tf.compat.v1 con un paquete TF 2.X y tf.compat.v2 con un paquete TF 1.X.
  • Modo de ejecución de TensorFlow: gráfico y ansioso
  • Entrenamiento distribuido: tf.distribute con solo ejecución de gráficos

Está en nuestra hoja de ruta agregar soporte en las siguientes áreas:

Resultados

Clasificación de imágenes

Modelo Precisión Top-1 no escasa Precisión escasa aleatoria Escasez aleatoria Precisión escasa estructurada Escasez estructurada
InicioV3 78,1% 78,0% 50% 75,8% 2 por 4
76,1% 75%
74,6% 87,5%
MóvilnetV1 224 71,04% 70,84% 50% 67,35% 2 por 4
MóvilnetV2 224 71,77% 69,64% 50% 66,75% 2 por 4

Los modelos fueron probados en Imagenet.

Traducción

Modelo BLEU no disperso BLEU escaso Escasez
GNMT EN-DE 26,77 26,86 80%
26,52 85%
26.19 90%
GNMT DE-ES 29,47 29,50 80%
29.24 85%
28,81 90%

Los modelos utilizan el conjunto de datos WMT16 en alemán e inglés con news-test2013 como conjunto de desarrollo y news-test2015 como conjunto de prueba.

Modelo de detección de palabras clave

DS-CNN-L es un modelo de detección de palabras clave creado para dispositivos perimetrales. Se puede encontrar en el repositorio de ejemplos del software ARM.

Modelo Precisión no escasa Precisión dispersa estructurada (patrón de 2 por 4) Precisión dispersa aleatoria (escasez objetivo 50%)
DS-CNN-L 95,23 94,33 94,84

Ejemplos

Además del tutorial Podar con Keras , consulte los siguientes ejemplos:

  • Entrene un modelo CNN en la tarea de clasificación de dígitos escritos a mano MNIST con poda: código
  • Entrene un LSTM en la tarea de clasificación de sentimientos de IMDB con poda: código

Para obtener información general, consulte Podar o no podar: exploración de la eficacia de la poda para la compresión de modelos [ artículo ].