Modelli di checkpoint

La capacità di salvare e ripristinare lo stato di un modello è vitale per numerose applicazioni, ad esempio nel trasferimento di apprendimento o per eseguire inferenze utilizzando modelli preaddestrati. Salvare i parametri di un modello (pesi, bias, ecc.) in un file o in una directory di checkpoint è un modo per raggiungere questo obiettivo.

Questo modulo fornisce un'interfaccia di alto livello per caricare e salvare checkpoint in formato TensorFlow v2 , nonché componenti di livello inferiore che scrivono e leggono da questo formato di file.

Caricamento e salvataggio di modelli semplici

Conformandosi al protocollo Checkpointable , molti modelli semplici possono essere serializzati su checkpoint senza alcun codice aggiuntivo:

import Checkpoints
import ImageClassificationModels

extension LeNet: Checkpointable {}

var model = LeNet()

...

try model.writeCheckpoint(to: directory, name: "LeNet")

e quindi lo stesso checkpoint può essere letto utilizzando:

try model.readCheckpoint(from: directory, name: "LeNet")

Questa implementazione predefinita per il caricamento e il salvataggio del modello utilizzerà uno schema di denominazione basato sul percorso per ogni tensore nel modello basato sui nomi delle proprietà all'interno delle strutture del modello. Ad esempio, i pesi e i bias all'interno della prima convoluzione nel modello LeNet-5 verranno salvati rispettivamente con i nomi conv1/filter e conv1/bias . Durante il caricamento, il lettore del checkpoint cercherà tensori con questi nomi.

Personalizzazione del caricamento e del salvataggio del modello

Se desideri avere un maggiore controllo su quali tensori vengono salvati e caricati o sulla denominazione di tali tensori, il protocollo Checkpointable offre alcuni punti di personalizzazione.

Per ignorare le proprietà su determinati tipi, puoi fornire un'implementazione di ignoredTensorPaths sul tuo modello che restituisce un set di stringhe sotto forma di Type.property . Ad esempio, per ignorare la proprietà scale su ogni livello Attenzione, potresti restituire ["Attention.scale"] .

Per impostazione predefinita, viene utilizzata una barra per separare ciascun livello più profondo in un modello. Questo può essere personalizzato implementando checkpointSeparator sul tuo modello e fornendo una nuova stringa da utilizzare per questo separatore.

Infine, per il massimo grado di personalizzazione nella denominazione dei tensori, è possibile implementare tensorNameMap e fornire una funzione che esegua la mappatura dal nome di stringa predefinito generato per un tensore nel modello al nome di stringa desiderato nel checkpoint. Più comunemente, questo verrà utilizzato per interagire con checkpoint generati con altri framework, ognuno dei quali ha le proprie convenzioni di denominazione e strutture di modello. Una funzione di mappatura personalizzata offre il massimo grado di personalizzazione per il modo in cui vengono denominati questi tensori.

Vengono fornite alcune funzioni di supporto standard, come la CheckpointWriter.identityMap predefinita (che utilizza semplicemente il nome del percorso tensore generato automaticamente per i checkpoint) o la funzione CheckpointWriter.lookupMap(table:) , che può creare una mappatura da un dizionario.

Per un esempio di come è possibile realizzare la mappatura personalizzata, consulta il modello GPT-2 , che utilizza una funzione di mappatura per corrispondere all'esatto schema di denominazione utilizzato per i checkpoint di OpenAI.

I componenti CheckpointReader e CheckpointWriter

Per la scrittura del checkpoint, l'estensione fornita dal protocollo Checkpointable utilizza la riflessione e i percorsi chiave per eseguire l'iterazione sulle proprietà di un modello e generare un dizionario che mappa i percorsi dei tensori delle stringhe sui valori dei tensori. Questo dizionario viene fornito a un CheckpointWriter sottostante, insieme a una directory in cui scrivere il checkpoint. Quel CheckpointWriter gestisce l'attività di generare il checkpoint su disco da quel dizionario.

Il processo inverso è la lettura, in cui a CheckpointReader viene assegnata la posizione di una directory di checkpoint su disco. Quindi legge da quel checkpoint e forma un dizionario che mappa i nomi dei tensori all'interno del checkpoint con i loro valori salvati. Questo dizionario viene utilizzato per sostituire i tensori correnti in un modello con quelli presenti in questo dizionario.

Sia per il caricamento che per il salvataggio, il protocollo Checkpointable mappa i percorsi delle stringhe ai tensori sui corrispondenti nomi dei tensori su disco utilizzando la funzione di mappatura sopra descritta.

Se il protocollo Checkpointable non dispone delle funzionalità necessarie o si desidera un maggiore controllo sul processo di caricamento e salvataggio del checkpoint, le classi CheckpointReader e CheckpointWriter possono essere utilizzate da sole.

Il formato del checkpoint TensorFlow v2

Il formato checkpoint TensorFlow v2, come brevemente descritto in questa intestazione , è il formato di seconda generazione per i checkpoint del modello TensorFlow. Questo formato di seconda generazione è in uso dalla fine del 2016 e presenta numerosi miglioramenti rispetto al formato checkpoint v1. I TensorFlow SavedModels utilizzano checkpoint v2 al loro interno per salvare i parametri del modello.

Un checkpoint TensorFlow v2 è costituito da una directory con una struttura simile alla seguente:

checkpoint/modelname.index
checkpoint/modelname.data-00000-of-00002
checkpoint/modelname.data-00001-of-00002

dove il primo file memorizza i metadati per il checkpoint e i file rimanenti sono frammenti binari che contengono i parametri serializzati per il modello.

Il file di metadati dell'indice contiene i tipi, le dimensioni, le posizioni e i nomi delle stringhe di tutti i tensori serializzati contenuti negli shard. Quel file di indice è la parte strutturalmente più complessa del checkpoint ed è basato su tensorflow::table , che a sua volta è basato su SSTable/LevelDB. Questo file di indice è composto da una serie di coppie chiave-valore, dove le chiavi sono stringhe e i valori sono buffer di protocollo. Le stringhe vengono ordinate e compresse con prefisso. Ad esempio: se la prima voce è conv1/weight e la successiva conv1/bias , la seconda voce utilizza solo la parte bias .

Questo file di indice generale viene talvolta compresso utilizzando la compressione Snappy . Il file SnappyDecompression.swift fornisce un'implementazione Swift nativa della decompressione Snappy da un'istanza Data compressa.

I metadati dell'intestazione dell'indice e i metadati del tensore sono codificati come buffer di protocollo e codificati/decodificati direttamente tramite Swift Protobuf .

Le classi CheckpointIndexReader e CheckpointIndexWriter gestiscono il caricamento e il salvataggio di questi file di indice come parte delle classi generali CheckpointReader e CheckpointWriter . Questi ultimi utilizzano i file indice come base per determinare cosa leggere e scrivere sui frammenti binari strutturalmente più semplici che contengono i dati tensoriali.

,

La capacità di salvare e ripristinare lo stato di un modello è vitale per numerose applicazioni, ad esempio nel trasferimento di apprendimento o per eseguire inferenze utilizzando modelli preaddestrati. Salvare i parametri di un modello (pesi, bias, ecc.) in un file o in una directory di checkpoint è un modo per raggiungere questo obiettivo.

Questo modulo fornisce un'interfaccia di alto livello per caricare e salvare checkpoint in formato TensorFlow v2 , nonché componenti di livello inferiore che scrivono e leggono da questo formato di file.

Caricamento e salvataggio di modelli semplici

Conformandosi al protocollo Checkpointable , molti modelli semplici possono essere serializzati su checkpoint senza alcun codice aggiuntivo:

import Checkpoints
import ImageClassificationModels

extension LeNet: Checkpointable {}

var model = LeNet()

...

try model.writeCheckpoint(to: directory, name: "LeNet")

e quindi lo stesso checkpoint può essere letto utilizzando:

try model.readCheckpoint(from: directory, name: "LeNet")

Questa implementazione predefinita per il caricamento e il salvataggio del modello utilizzerà uno schema di denominazione basato sul percorso per ogni tensore nel modello basato sui nomi delle proprietà all'interno delle strutture del modello. Ad esempio, i pesi e i bias all'interno della prima convoluzione nel modello LeNet-5 verranno salvati rispettivamente con i nomi conv1/filter e conv1/bias . Durante il caricamento, il lettore del checkpoint cercherà tensori con questi nomi.

Personalizzazione del caricamento e del salvataggio del modello

Se desideri avere un maggiore controllo su quali tensori vengono salvati e caricati o sulla denominazione di tali tensori, il protocollo Checkpointable offre alcuni punti di personalizzazione.

Per ignorare le proprietà su determinati tipi, puoi fornire un'implementazione di ignoredTensorPaths sul tuo modello che restituisce un set di stringhe sotto forma di Type.property . Ad esempio, per ignorare la proprietà scale su ogni livello Attenzione, potresti restituire ["Attention.scale"] .

Per impostazione predefinita, viene utilizzata una barra per separare ciascun livello più profondo in un modello. Questo può essere personalizzato implementando checkpointSeparator sul tuo modello e fornendo una nuova stringa da utilizzare per questo separatore.

Infine, per il massimo grado di personalizzazione nella denominazione dei tensori, è possibile implementare tensorNameMap e fornire una funzione che esegua la mappatura dal nome di stringa predefinito generato per un tensore nel modello al nome di stringa desiderato nel checkpoint. Più comunemente, questo verrà utilizzato per interagire con checkpoint generati con altri framework, ognuno dei quali ha le proprie convenzioni di denominazione e strutture di modello. Una funzione di mappatura personalizzata offre il massimo grado di personalizzazione per il modo in cui vengono denominati questi tensori.

Vengono fornite alcune funzioni di supporto standard, come la CheckpointWriter.identityMap predefinita (che utilizza semplicemente il nome del percorso tensore generato automaticamente per i checkpoint) o la funzione CheckpointWriter.lookupMap(table:) , che può creare una mappatura da un dizionario.

Per un esempio di come è possibile realizzare la mappatura personalizzata, consulta il modello GPT-2 , che utilizza una funzione di mappatura per corrispondere all'esatto schema di denominazione utilizzato per i checkpoint di OpenAI.

I componenti CheckpointReader e CheckpointWriter

Per la scrittura del checkpoint, l'estensione fornita dal protocollo Checkpointable utilizza la riflessione e i percorsi chiave per eseguire l'iterazione sulle proprietà di un modello e generare un dizionario che mappa i percorsi dei tensori delle stringhe sui valori dei tensori. Questo dizionario viene fornito a un CheckpointWriter sottostante, insieme a una directory in cui scrivere il checkpoint. Quel CheckpointWriter gestisce l'attività di generare il checkpoint su disco da quel dizionario.

Il processo inverso è la lettura, in cui a CheckpointReader viene assegnata la posizione di una directory di checkpoint su disco. Quindi legge da quel checkpoint e forma un dizionario che mappa i nomi dei tensori all'interno del checkpoint con i loro valori salvati. Questo dizionario viene utilizzato per sostituire i tensori correnti in un modello con quelli presenti in questo dizionario.

Sia per il caricamento che per il salvataggio, il protocollo Checkpointable mappa i percorsi delle stringhe ai tensori sui corrispondenti nomi dei tensori su disco utilizzando la funzione di mappatura sopra descritta.

Se il protocollo Checkpointable non dispone delle funzionalità necessarie o si desidera un maggiore controllo sul processo di caricamento e salvataggio del checkpoint, le classi CheckpointReader e CheckpointWriter possono essere utilizzate da sole.

Il formato del checkpoint TensorFlow v2

Il formato checkpoint TensorFlow v2, come brevemente descritto in questa intestazione , è il formato di seconda generazione per i checkpoint del modello TensorFlow. Questo formato di seconda generazione è in uso dalla fine del 2016 e presenta numerosi miglioramenti rispetto al formato checkpoint v1. I TensorFlow SavedModels utilizzano checkpoint v2 al loro interno per salvare i parametri del modello.

Un checkpoint TensorFlow v2 è costituito da una directory con una struttura simile alla seguente:

checkpoint/modelname.index
checkpoint/modelname.data-00000-of-00002
checkpoint/modelname.data-00001-of-00002

dove il primo file memorizza i metadati per il checkpoint e i file rimanenti sono frammenti binari che contengono i parametri serializzati per il modello.

Il file di metadati dell'indice contiene i tipi, le dimensioni, le posizioni e i nomi delle stringhe di tutti i tensori serializzati contenuti negli shard. Quel file di indice è la parte strutturalmente più complessa del checkpoint ed è basato su tensorflow::table , che a sua volta è basato su SSTable/LevelDB. Questo file di indice è composto da una serie di coppie chiave-valore, dove le chiavi sono stringhe e i valori sono buffer di protocollo. Le stringhe vengono ordinate e compresse con prefisso. Ad esempio: se la prima voce è conv1/weight e la successiva conv1/bias , la seconda voce utilizza solo la parte bias .

Questo file di indice generale viene talvolta compresso utilizzando la compressione Snappy . Il file SnappyDecompression.swift fornisce un'implementazione Swift nativa della decompressione Snappy da un'istanza Data compressa.

I metadati dell'intestazione dell'indice e i metadati del tensore sono codificati come buffer di protocollo e codificati/decodificati direttamente tramite Swift Protobuf .

Le classi CheckpointIndexReader e CheckpointIndexWriter gestiscono il caricamento e il salvataggio di questi file di indice come parte delle classi generali CheckpointReader e CheckpointWriter . Questi ultimi utilizzano i file indice come base per determinare cosa leggere e scrivere sui frammenti binari strutturalmente più semplici che contengono i dati tensoriali.