Modelo de puntos de control

La capacidad de guardar y restaurar el estado de un modelo es vital para varias aplicaciones, como el aprendizaje por transferencia o la realización de inferencias utilizando modelos previamente entrenados. Guardar los parámetros de un modelo (pesos, sesgos, etc.) en un archivo o directorio de puntos de control es una forma de lograrlo.

Este módulo proporciona una interfaz de alto nivel para cargar y guardar puntos de control del formato TensorFlow v2 , así como componentes de nivel inferior que escriben y leen en este formato de archivo.

Cargar y guardar modelos simples

Al cumplir con el protocolo Checkpointable , muchos modelos simples se pueden serializar en puntos de control sin ningún código adicional:

import Checkpoints
import ImageClassificationModels

extension LeNet: Checkpointable {}

var model = LeNet()

...

try model.writeCheckpoint(to: directory, name: "LeNet")

y luego ese mismo punto de control se puede leer usando:

try model.readCheckpoint(from: directory, name: "LeNet")

Esta implementación predeterminada para cargar y guardar modelos utilizará un esquema de nomenclatura basado en rutas para cada tensor en el modelo que se basa en los nombres de las propiedades dentro de las estructuras del modelo. Por ejemplo, los pesos y sesgos dentro de la primera convolución en el modelo LeNet-5 se guardarán con los nombres conv1/filter y conv1/bias , respectivamente. Al cargar, el lector de puntos de control buscará tensores con estos nombres.

Personalización de la carga y guardado del modelo

Si desea tener un mayor control sobre qué tensores se guardan y cargan, o el nombre de esos tensores, el protocolo Checkpointable ofrece algunos puntos de personalización.

Para ignorar propiedades en ciertos tipos, puede proporcionar una implementación de ignoredTensorPaths en su modelo que devuelva un conjunto de cadenas en el formato Type.property . Por ejemplo, para ignorar la propiedad scale en cada capa de Atención, puede devolver ["Attention.scale"] .

De forma predeterminada, se utiliza una barra diagonal para separar cada nivel más profundo de un modelo. Esto se puede personalizar implementando checkpointSeparator en su modelo y proporcionando una nueva cadena para usar con este separador.

Finalmente, para lograr el mayor grado de personalización en la denominación de tensores, puede implementar tensorNameMap y proporcionar una función que asigne el nombre de cadena predeterminado generado para un tensor en el modelo a un nombre de cadena deseado en el punto de control. Por lo general, esto se utilizará para interoperar con puntos de control generados con otros marcos, cada uno de los cuales tiene sus propias convenciones de nomenclatura y estructuras de modelo. Una función de mapeo personalizado brinda el mayor grado de personalización sobre cómo se nombran estos tensores.

Se proporcionan algunas funciones auxiliares estándar, como la CheckpointWriter.identityMap predeterminada (que simplemente usa el nombre de ruta del tensor generado automáticamente para los puntos de control) o la función CheckpointWriter.lookupMap(table:) , que puede crear un mapeo a partir de un diccionario.

Para ver un ejemplo de cómo se puede lograr un mapeo personalizado, consulte el modelo GPT-2 , que utiliza una función de mapeo para coincidir con el esquema de nombres exacto utilizado para los puntos de control de OpenAI.

Los componentes CheckpointReader y CheckpointWriter

Para la escritura de puntos de control, la extensión proporcionada por el protocolo Checkpointable utiliza reflexión y rutas clave para iterar sobre las propiedades de un modelo y generar un diccionario que asigna rutas de tensor de cadena a valores de tensor. Este diccionario se proporciona a un CheckpointWriter subyacente, junto con un directorio en el que escribir el punto de control. Ese CheckpointWriter maneja la tarea de generar el punto de control en el disco a partir de ese diccionario.

Lo contrario de este proceso es la lectura, donde a un CheckpointReader se le proporciona la ubicación de un directorio de puntos de control en el disco. Luego lee desde ese punto de control y forma un diccionario que asigna los nombres de los tensores dentro del punto de control con sus valores guardados. Este diccionario se utiliza para reemplazar los tensores actuales en un modelo por los de este diccionario.

Tanto para cargar como para guardar, el protocolo Checkpointable asigna las rutas de cadena a los tensores a los nombres de tensores correspondientes en el disco utilizando la función de mapeo descrita anteriormente.

Si el protocolo Checkpointable carece de la funcionalidad necesaria, o si se desea tener más control sobre el proceso de carga y guardado de puntos de control, las clases CheckpointReader y CheckpointWriter se pueden usar por sí mismas.

El formato de punto de control de TensorFlow v2

El formato de punto de control de TensorFlow v2, como se describe brevemente en este encabezado , es el formato de segunda generación para los puntos de control del modelo TensorFlow. Este formato de segunda generación se utiliza desde finales de 2016 y tiene una serie de mejoras con respecto al formato de punto de control v1. Los TensorFlow SavedModels utilizan puntos de control v2 dentro de ellos para guardar los parámetros del modelo.

Un punto de control de TensorFlow v2 consta de un directorio con una estructura como la siguiente:

checkpoint/modelname.index
checkpoint/modelname.data-00000-of-00002
checkpoint/modelname.data-00001-of-00002

donde el primer archivo almacena los metadatos del punto de control y los archivos restantes son fragmentos binarios que contienen los parámetros serializados del modelo.

El archivo de metadatos del índice contiene los tipos, tamaños, ubicaciones y nombres de cadenas de todos los tensores serializados contenidos en los fragmentos. Ese archivo de índice es la parte estructuralmente más compleja del punto de control y se basa en tensorflow::table , que a su vez se basa en SSTable/LevelDB. Este archivo de índice se compone de una serie de pares clave-valor, donde las claves son cadenas y los valores son búferes de protocolo. Las cadenas están ordenadas y comprimidas por prefijo. Por ejemplo: si la primera entrada es conv1/weight y la siguiente conv1/bias , la segunda entrada solo usa la parte bias .

Este archivo de índice general a veces se comprime usando la compresión Snappy . El archivo SnappyDecompression.swift proporciona una implementación Swift nativa de la descompresión Snappy desde una instancia de datos comprimida.

Los metadatos del encabezado del índice y los metadatos del tensor se codifican como buffers de protocolo y se codifican/decodifican directamente a través de Swift Protobuf .

Las clases CheckpointIndexReader y CheckpointIndexWriter manejan la carga y el guardado de estos archivos de índice como parte de las clases generales CheckpointReader y CheckpointWriter . Estos últimos utilizan los archivos de índice como base para determinar qué leer y escribir en los fragmentos binarios estructuralmente más simples que contienen los datos del tensor.