Salva la data! Google I / O ritorna dal 18 al 20 maggio Registrati ora
Questa pagina è stata tradotta dall'API Cloud Translation.
Switch to English

Migrare il codice TensorFlow 1 a TensorFlow 2

Visualizza su TensorFlow.org Esegui in Google Colab Visualizza sorgente su GitHub Scarica taccuino

Questa guida è per gli utenti di API TensorFlow di basso livello. Se stai utilizzando le API di alto livello ( tf.keras ), potrebbe essere necessario eseguire poche o nessuna azione per rendere il tuo codice completamente compatibile con TensorFlow 2.x:

È ancora possibile eseguire codice 1.x, non modificato (ad eccezione di contrib ), in TensorFlow 2.x:

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

Tuttavia, questo non ti consente di sfruttare molti dei miglioramenti apportati in TensorFlow 2.x. Questa guida ti aiuterà ad aggiornare il tuo codice, rendendolo più semplice, più performante e più facile da mantenere.

Script di conversione automatica

Il primo passaggio, prima di tentare di implementare le modifiche descritte in questa guida, è provare a eseguire lo script di aggiornamento .

Questo eseguirà un passaggio iniziale all'aggiornamento del codice a TensorFlow 2.x ma non può rendere il tuo codice idiomatico alla v2. Il codice può ancora utilizzare tf.compat.v1 endpoint tf.compat.v1 per accedere a segnaposto, sessioni, raccolte e altre funzionalità in stile 1.x.

Cambiamenti comportamentali di primo livello

Se il tuo codice funziona in TensorFlow 2.x utilizzando tf.compat.v1.disable_v2_behavior , tf.compat.v1.disable_v2_behavior esserci ancora cambiamenti comportamentali globali che potresti dover affrontare. I principali cambiamenti sono:

  • Esecuzione desiderosa, v1.enable_eager_execution() : qualsiasi codice che utilizza implicitamente un tf.Graph fallirà. Assicurati di racchiudere questo codice in un with tf.Graph().as_default() .

  • Variabili risorsa, v1.enable_resource_variables() : parte del codice può dipendere da comportamenti non deterministici abilitati dalle variabili di riferimento TensorFlow. Le variabili delle risorse sono bloccate durante la scrittura e quindi forniscono garanzie di coerenza più intuitive.

    • Questo può cambiare il comportamento nei casi limite.
    • Ciò potrebbe creare copie aggiuntive e richiedere un maggiore utilizzo della memoria.
    • Questo può essere disabilitato passando use_resource=False al costruttore tf.Variable .
  • Forme tensoriali, v1.enable_v2_tensorshape() : TensorFlow 2.x semplifica il comportamento delle forme tensoriali. Invece di t.shape[0].value puoi dire t.shape[0] . Questi cambiamenti dovrebbero essere piccoli e ha senso risolverli immediatamente. Fare riferimento alla sezione TensorShape per gli esempi.

  • Flusso di controllo, v1.enable_control_flow_v2() : l'implementazione del flusso di controllo di TensorFlow 2.x è stata semplificata e quindi produce diverse rappresentazioni grafiche. Si prega di segnalare bug per qualsiasi problema.

Crea codice per TensorFlow 2.x

Questa guida illustrerà diversi esempi di conversione del codice TensorFlow 1.x in TensorFlow 2.x. Queste modifiche consentiranno al codice di sfruttare le ottimizzazioni delle prestazioni e le chiamate API semplificate.

In ogni caso, il modello è:

1. Sostituire v1.Session.run chiamate v1.Session.run

Ogni chiamata v1.Session.run dovrebbe essere sostituita da una funzione Python.

  • I feed_dict e v1.placeholder diventano argomenti della funzione.
  • I fetches diventano il valore restituito dalla funzione.
  • Durante la conversione, l'esecuzione desiderosa consente un facile debug con strumenti Python standard come pdb .

Dopodiché, aggiungi un decoratore tf.function per farlo funzionare in modo efficiente nel grafico. Consulta la guida agli autografi per ulteriori informazioni su come funziona.

Notare che:

  • A differenza di v1.Session.run , una funzione tf.function ha una firma di ritorno fissa e restituisce sempre tutti gli output. Se ciò causa problemi di prestazioni, creare due funzioni separate.

  • Non c'è bisogno di tf.control_dependencies o operazioni simili: una funzione tf.function si comporta come se fosse eseguita nell'ordine scritto. tf.Variable assegnazioni tf.Variable e tf.assert , ad esempio, vengono eseguite automaticamente.

La sezione dei modelli di conversione contiene un esempio funzionante di questo processo di conversione.

2. Usa oggetti Python per tenere traccia di variabili e perdite

Tutto il tracciamento delle variabili basato sul nome è fortemente sconsigliato in TensorFlow 2.x. Usa gli oggetti Python per tenere traccia delle variabili.

Usa tf.Variable invece di v1.get_variable .

Ogni v1.variable_scope dovrebbe essere convertito in un oggetto Python. In genere questo sarà uno di:

Se è necessario aggregare elenchi di variabili (come tf.Graph.get_collection(tf.GraphKeys.VARIABLES) ), utilizzare gli attributi .variables e .trainable_variables degli oggetti Layer e Model .

Queste classi Layer e Model implementano diverse altre proprietà che eliminano la necessità di raccolte globali. La loro proprietà .losses può sostituire l'utilizzo della raccolta tf.GraphKeys.LOSSES .

Fare riferimento alle guide di Keras per maggiori dettagli.

3. Aggiorna i tuoi cicli di allenamento

Utilizza l'API di livello più alto che funziona per il tuo caso d'uso. Preferisci tf.keras.Model.fit costruire i tuoi loop di allenamento.

Queste funzioni di alto livello gestiscono molti dettagli di basso livello che potrebbero essere facili da perdere se scrivi il tuo ciclo di allenamento. Ad esempio, raccolgono automaticamente le perdite di regolarizzazione e impostano l'argomento training=True quando si chiama il modello.

4. Aggiornare le pipeline di input dei dati

Utilizzare i set di dati tf.data per l'immissione dei dati. Questi oggetti sono efficienti, espressivi e si integrano bene con tensorflow.

Possono essere passati direttamente al metodo tf.keras.Model.fit .

model.fit(dataset, epochs=5)

Possono essere iterati direttamente su Python standard:

for example_batch, label_batch in dataset:
    break

5. Migrare i simboli compat.v1

Il modulo tf.compat.v1 contiene l'API TensorFlow 1.x completa, con la sua semantica originale.

Lo script di aggiornamento di TensorFlow 2.x convertirà i simboli nei loro equivalenti v2 se tale conversione è sicura, ovvero se può determinare che il comportamento della versione di TensorFlow 2.x è esattamente equivalente (ad esempio, rinominerà v1.arg_max a tf.argmax , poiché sono la stessa funzione).

Dopo che lo script di aggiornamento è stato eseguito con un pezzo di codice, è probabile che ci siano molte menzioni di compat.v1 . Vale la pena leggere il codice e convertirli manualmente nell'equivalente v2 (dovrebbe essere menzionato nel registro se ce n'è uno).

Conversione di modelli

Variabili di basso livello ed esecuzione dell'operatore

Esempi di utilizzo dell'API di basso livello includono:

Prima della conversione

Ecco come possono apparire questi modelli nel codice utilizzando TensorFlow 1.x.

import tensorflow as tf
import tensorflow.compat.v1 as v1

import tensorflow_datasets as tfds
g = v1.Graph()

with g.as_default():
  in_a = v1.placeholder(dtype=v1.float32, shape=(2))
  in_b = v1.placeholder(dtype=v1.float32, shape=(2))

  def forward(x):
    with v1.variable_scope("matmul", reuse=v1.AUTO_REUSE):
      W = v1.get_variable("W", initializer=v1.ones(shape=(2,2)),
                          regularizer=lambda x:tf.reduce_mean(x**2))
      b = v1.get_variable("b", initializer=v1.zeros(shape=(2)))
      return W * x + b

  out_a = forward(in_a)
  out_b = forward(in_b)
  reg_loss=v1.losses.get_regularization_loss(scope="matmul")

with v1.Session(graph=g) as sess:
  sess.run(v1.global_variables_initializer())
  outs = sess.run([out_a, out_b, reg_loss],
                feed_dict={in_a: [1, 0], in_b: [0, 1]})

print(outs[0])
print()
print(outs[1])
print()
print(outs[2])
[[1. 0.]
 [1. 0.]]

[[0. 1.]
 [0. 1.]]

1.0

Dopo la conversione

Nel codice convertito:

  • Le variabili sono oggetti Python locali.
  • La funzione forward definisce ancora il calcolo.
  • La chiamata Session.run viene sostituita da una chiamata da forward .
  • Il decoratore tf.function opzionale può essere aggiunto per le prestazioni.
  • Le regolarizzazioni vengono calcolate manualmente, senza fare riferimento ad alcuna raccolta globale.
  • Nessun utilizzo di sessioni o segnaposto .
W = tf.Variable(tf.ones(shape=(2,2)), name="W")
b = tf.Variable(tf.zeros(shape=(2)), name="b")

@tf.function
def forward(x):
  return W * x + b

out_a = forward([1,0])
print(out_a)
tf.Tensor(
[[1. 0.]
 [1. 0.]], shape=(2, 2), dtype=float32)
out_b = forward([0,1])

regularizer = tf.keras.regularizers.l2(0.04)
reg_loss=regularizer(W)

Modelli basati su tf.layers

Il modulo v1.layers viene utilizzato per contenere funzioni di livello che si basavano su v1.variable_scope per definire e riutilizzare le variabili.

Prima della conversione

def model(x, training, scope='model'):
  with v1.variable_scope(scope, reuse=v1.AUTO_REUSE):
    x = v1.layers.conv2d(x, 32, 3, activation=v1.nn.relu,
          kernel_regularizer=lambda x:0.004*tf.reduce_mean(x**2))
    x = v1.layers.max_pooling2d(x, (2, 2), 1)
    x = v1.layers.flatten(x)
    x = v1.layers.dropout(x, 0.1, training=training)
    x = v1.layers.dense(x, 64, activation=v1.nn.relu)
    x = v1.layers.batch_normalization(x, training=training)
    x = v1.layers.dense(x, 10)
    return x
train_data = tf.ones(shape=(1, 28, 28, 1))
test_data = tf.ones(shape=(1, 28, 28, 1))

train_out = model(train_data, training=True)
test_out = model(test_data, training=False)

print(train_out)
print()
print(test_out)
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/legacy_tf_layers/convolutional.py:414: UserWarning: `tf.layers.conv2d` is deprecated and will be removed in a future version. Please Use `tf.keras.layers.Conv2D` instead.
  warnings.warn('`tf.layers.conv2d` is deprecated and '
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py:2273: UserWarning: `layer.apply` is deprecated and will be removed in a future version. Please use `layer.__call__` method instead.
  warnings.warn('`layer.apply` is deprecated and '
tf.Tensor([[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]], shape=(1, 10), dtype=float32)

tf.Tensor(
[[ 0.379358   -0.55901194  0.48704922  0.11619566  0.23902717  0.01691487
   0.07227738  0.14556988  0.2459927   0.2501198 ]], shape=(1, 10), dtype=float32)
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/legacy_tf_layers/pooling.py:310: UserWarning: `tf.layers.max_pooling2d` is deprecated and will be removed in a future version. Please use `tf.keras.layers.MaxPooling2D` instead.
  warnings.warn('`tf.layers.max_pooling2d` is deprecated and '
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/legacy_tf_layers/core.py:329: UserWarning: `tf.layers.flatten` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Flatten` instead.
  warnings.warn('`tf.layers.flatten` is deprecated and '
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/legacy_tf_layers/core.py:268: UserWarning: `tf.layers.dropout` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Dropout` instead.
  warnings.warn('`tf.layers.dropout` is deprecated and '
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/legacy_tf_layers/core.py:171: UserWarning: `tf.layers.dense` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Dense` instead.
  warnings.warn('`tf.layers.dense` is deprecated and '
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/legacy_tf_layers/normalization.py:308: UserWarning: `tf.layers.batch_normalization` is deprecated and will be removed in a future version. Please use `tf.keras.layers.BatchNormalization` instead. In particular, `tf.control_dependencies(tf.GraphKeys.UPDATE_OPS)` should not be used (consult the `tf.keras.layers.BatchNormalization` documentation).
  '`tf.layers.batch_normalization` is deprecated and '

Dopo la conversione

La maggior parte degli argomenti è rimasta la stessa. Ma nota le differenze:

  • L'argomento di training viene passato a ogni livello dal modello quando viene eseguito.
  • Il primo argomento della funzione del model originale (l'input x ) è sparito. Questo perché i livelli degli oggetti separano la creazione del modello dalla chiamata al modello.

Notare inoltre che:

  • Se stai usando regolarizzatori o inizializzatori da tf.contrib , questi hanno più modifiche agli argomenti di altri.
  • Il codice non scrive più nelle raccolte, quindi funzioni come v1.losses.get_regularization_loss non restituiranno più questi valori, interrompendo potenzialmente i cicli di addestramento.
model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(0.04),
                           input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.1),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(10)
])

train_data = tf.ones(shape=(1, 28, 28, 1))
test_data = tf.ones(shape=(1, 28, 28, 1))
train_out = model(train_data, training=True)
print(train_out)
tf.Tensor([[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]], shape=(1, 10), dtype=float32)
test_out = model(test_data, training=False)
print(test_out)
tf.Tensor(
[[-0.2145557  -0.22979769 -0.14968733  0.01208701 -0.07569927  0.3475932
   0.10718458  0.03482988 -0.04309493 -0.10469118]], shape=(1, 10), dtype=float32)
# Here are all the trainable variables
len(model.trainable_variables)
8
# Here is the regularization loss
model.losses
[<tf.Tensor: shape=(), dtype=float32, numpy=0.08174552>]

Variabili miste e v1.layers

Il codice esistente spesso combina variabili e operazioni di TensorFlow 1.x di livello inferiore con v1.layers livello v1.layers .

Prima della conversione

def model(x, training, scope='model'):
  with v1.variable_scope(scope, reuse=v1.AUTO_REUSE):
    W = v1.get_variable(
      "W", dtype=v1.float32,
      initializer=v1.ones(shape=x.shape),
      regularizer=lambda x:0.004*tf.reduce_mean(x**2),
      trainable=True)
    if training:
      x = x + W
    else:
      x = x + W * 0.5
    x = v1.layers.conv2d(x, 32, 3, activation=tf.nn.relu)
    x = v1.layers.max_pooling2d(x, (2, 2), 1)
    x = v1.layers.flatten(x)
    return x

train_out = model(train_data, training=True)
test_out = model(test_data, training=False)

Dopo la conversione

Per convertire questo codice, seguire lo schema di mappatura dei livelli sui livelli come nell'esempio precedente.

Lo schema generale è:

  • Raccogli i parametri del livello in __init__ .
  • Costruisci le variabili in build .
  • Eseguire i calcoli nella call e restituire il risultato.

v1.variable_scope è essenzialmente un livello a sé stante. Quindi riscrivilo come tf.keras.layers.Layer . Consulta la guida Creazione di nuovi livelli e modelli tramite sottoclassi per i dettagli.

# Create a custom layer for part of the model
class CustomLayer(tf.keras.layers.Layer):
  def __init__(self, *args, **kwargs):
    super(CustomLayer, self).__init__(*args, **kwargs)

  def build(self, input_shape):
    self.w = self.add_weight(
        shape=input_shape[1:],
        dtype=tf.float32,
        initializer=tf.keras.initializers.ones(),
        regularizer=tf.keras.regularizers.l2(0.02),
        trainable=True)

  # Call method will sometimes get used in graph mode,
  # training will get turned into a tensor
  @tf.function
  def call(self, inputs, training=None):
    if training:
      return inputs + self.w
    else:
      return inputs + self.w * 0.5
custom_layer = CustomLayer()
print(custom_layer([1]).numpy())
print(custom_layer([1], training=True).numpy())
[1.5]
[2.]
train_data = tf.ones(shape=(1, 28, 28, 1))
test_data = tf.ones(shape=(1, 28, 28, 1))

# Build the model including the custom layer
model = tf.keras.Sequential([
    CustomLayer(input_shape=(28, 28, 1)),
    tf.keras.layers.Conv2D(32, 3, activation='relu'),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
])

train_out = model(train_data, training=True)
test_out = model(test_data, training=False)

Alcune cose da notare:

  • I modelli e i livelli di Keras sottoclasse devono essere eseguiti in entrambi i grafici v1 (senza dipendenze di controllo automatico) e in modalità eager:

    • Avvolgi la call in una funzione tf.function per ottenere autografi e dipendenze di controllo automatico.
  • Non dimenticare di accettare un argomento di training per call :

    • A volte è un tf.Tensor
    • A volte è un booleano Python
  • Crea le variabili del modello nel costruttore o Model.build usando `self.add_weight:

    • In Model.build hai accesso alla forma di input, quindi puoi creare pesi con forma corrispondente
    • L'utilizzo di tf.keras.layers.Layer.add_weight consente a Keras di tenere traccia delle variabili e delle perdite di regolarizzazione
  • Non tenere i tf.Tensors nei tuoi oggetti:

    • Potrebbero essere creati in una funzione tf.function o nel contesto desideroso, e questi tensori si comportano in modo diverso
    • Usa tf.Variable s per lo stato, sono sempre utilizzabili da entrambi i contesti
    • tf.Tensors I tf.Tensors sono solo per valori intermedi

Una nota su Slim e contrib.layers

Una grande quantità di codice precedente di TensorFlow 1.x utilizza la libreria Slim , che è stata confezionata con TensorFlow 1.x come tf.contrib.layers . Come modulo contrib , questo non è più disponibile in TensorFlow 2.x, anche in tf.compat.v1 . La conversione del codice utilizzando Slim in TensorFlow 2.x è più complessa della conversione di archivi che utilizzano v1.layers . In effetti, potrebbe avere senso convertire prima il tuo codice Slim in v1.layers , quindi convertirlo in Keras.

  • Rimuovi arg_scopes , tutti gli argomenti devono essere espliciti.
  • Se li usi, dividi normalizer_fn e activation_fn nei rispettivi livelli.
  • I layer conv separabili mappati a uno o più layer Keras diversi (layer Keras in profondità, in punti e separabili).
  • Slim e v1.layers hanno nomi di argomenti e valori predefiniti diversi.
  • Alcuni argomenti hanno scale diverse.
  • Se utilizzi modelli Slim pre-addestrati, prova i modelli pre-traimed di Keras da tf.keras.applications o TensorFlow 2.x SavedModels di TF Hub esportati dal codice Slim originale.

Alcuni layer tf.contrib potrebbero non essere stati spostati nel core TensorFlow, ma sono stati invece spostati nel pacchetto TensorFlow Addons .

Formazione

Esistono molti modi per fornire dati a un modello tf.keras . Accetteranno generatori Python e array Numpy come input.

Il modo consigliato per fornire dati a un modello è utilizzare il pacchetto tf.data , che contiene una raccolta di classi ad alte prestazioni per la manipolazione dei dati.

Se stai ancora usando tf.queue , questi sono ora supportati solo come strutture di dati, non come pipeline di input.

Utilizzo di set di dati TensorFlow

Il pacchetto TensorFlow Datasets ( tfds ) contiene utilità per il caricamento di set di dati predefiniti come oggettitf.data.Dataset .

Per questo esempio, puoi caricare il set di dati tfds utilizzando tfds :

datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
Downloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /home/kbuilder/tensorflow_datasets/mnist/3.0.1...
WARNING:absl:Dataset mnist is hosted on GCS. It will automatically be downloaded to your
local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead pass
`try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.
Dataset mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.

Quindi prepara i dati per l'addestramento:

  • Ridimensiona ogni immagine.
  • Mescola l'ordine degli esempi.
  • Raccogli lotti di immagini ed etichette.
BUFFER_SIZE = 10 # Use a much larger value for real code
BATCH_SIZE = 64
NUM_EPOCHS = 5


def scale(image, label):
  image = tf.cast(image, tf.float32)
  image /= 255

  return image, label

Per mantenere l'esempio breve, ritaglia il set di dati per restituire solo 5 batch:

train_data = mnist_train.map(scale).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
test_data = mnist_test.map(scale).batch(BATCH_SIZE)

STEPS_PER_EPOCH = 5

train_data = train_data.take(STEPS_PER_EPOCH)
test_data = test_data.take(STEPS_PER_EPOCH)
image_batch, label_batch = next(iter(train_data))

Usa i cicli di addestramento di Keras

Se non è necessario un controllo di basso livello del processo di allenamento, si consiglia di utilizzare i metodi di fit , evaluate e predict integrati di Keras. Questi metodi forniscono un'interfaccia uniforme per addestrare il modello indipendentemente dall'implementazione (sequenziale, funzionale o sottoclasse).

I vantaggi di questi metodi includono:

  • Accettano array Numpy, generatori Python e tf.data.Datasets .
  • Applicano automaticamente la regolarizzazione e le perdite di attivazione.
  • Supportano tf.distribute per l'addestramento multi-dispositivo .
  • Supportano chiamate arbitrarie come perdite e metriche.
  • Supportano callback come tf.keras.callbacks.TensorBoard e callback personalizzati.
  • Sono performanti, utilizzando automaticamente i grafici TensorFlow.

Di seguito è riportato un esempio di addestramento di un modello utilizzando un Dataset . (Per i dettagli su come funziona, controlla la sezione dei tutorial .)

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(0.02),
                           input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.1),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(10)
])

# Model is the full model w/o custom layers
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(train_data, epochs=NUM_EPOCHS)
loss, acc = model.evaluate(test_data)

print("Loss {}, Accuracy {}".format(loss, acc))
Epoch 1/5
5/5 [==============================] - 1s 9ms/step - loss: 2.0191 - accuracy: 0.3608
Epoch 2/5
5/5 [==============================] - 0s 9ms/step - loss: 0.4736 - accuracy: 0.9059
Epoch 3/5
5/5 [==============================] - 0s 8ms/step - loss: 0.2973 - accuracy: 0.9626
Epoch 4/5
5/5 [==============================] - 0s 9ms/step - loss: 0.2108 - accuracy: 0.9911
Epoch 5/5
5/5 [==============================] - 0s 8ms/step - loss: 0.1791 - accuracy: 0.9874
5/5 [==============================] - 0s 6ms/step - loss: 1.5504 - accuracy: 0.7500
Loss 1.5504140853881836, Accuracy 0.75

Scrivi il tuo loop

Se la fase di addestramento del modello Keras funziona per te, ma hai bisogno di più controllo al di fuori di quella fase, considera l'utilizzo del metodo tf.keras.Model.train_on_batch , nel tuo ciclo di iterazione dei dati.

Ricorda: molte cose possono essere implementate come tf.keras.callbacks.Callback .

Questo metodo presenta molti dei vantaggi dei metodi menzionati nella sezione precedente, ma fornisce all'utente il controllo del ciclo esterno.

Puoi anche utilizzare tf.keras.Model.test_on_batch o tf.keras.Model.evaluate per controllare le prestazioni durante l'allenamento.

Per continuare ad addestrare il modello precedente:

# Model is the full model w/o custom layers
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

for epoch in range(NUM_EPOCHS):
  # Reset the metric accumulators
  model.reset_metrics()

  for image_batch, label_batch in train_data:
    result = model.train_on_batch(image_batch, label_batch)
    metrics_names = model.metrics_names
    print("train: ",
          "{}: {:.3f}".format(metrics_names[0], result[0]),
          "{}: {:.3f}".format(metrics_names[1], result[1]))
  for image_batch, label_batch in test_data:
    result = model.test_on_batch(image_batch, label_batch,
                                 # Return accumulated metrics
                                 reset_metrics=False)
  metrics_names = model.metrics_names
  print("\neval: ",
        "{}: {:.3f}".format(metrics_names[0], result[0]),
        "{}: {:.3f}".format(metrics_names[1], result[1]))
train:  loss: 0.138 accuracy: 1.000
train:  loss: 0.161 accuracy: 1.000
train:  loss: 0.159 accuracy: 0.969
train:  loss: 0.241 accuracy: 0.953
train:  loss: 0.172 accuracy: 0.969

eval:  loss: 1.550 accuracy: 0.800
train:  loss: 0.086 accuracy: 1.000
train:  loss: 0.094 accuracy: 1.000
train:  loss: 0.090 accuracy: 1.000
train:  loss: 0.119 accuracy: 0.984
train:  loss: 0.099 accuracy: 1.000

eval:  loss: 1.558 accuracy: 0.841
train:  loss: 0.076 accuracy: 1.000
train:  loss: 0.068 accuracy: 1.000
train:  loss: 0.061 accuracy: 1.000
train:  loss: 0.076 accuracy: 1.000
train:  loss: 0.076 accuracy: 1.000

eval:  loss: 1.536 accuracy: 0.841
train:  loss: 0.059 accuracy: 1.000
train:  loss: 0.056 accuracy: 1.000
train:  loss: 0.058 accuracy: 1.000
train:  loss: 0.054 accuracy: 1.000
train:  loss: 0.055 accuracy: 1.000

eval:  loss: 1.497 accuracy: 0.863
train:  loss: 0.053 accuracy: 1.000
train:  loss: 0.049 accuracy: 1.000
train:  loss: 0.044 accuracy: 1.000
train:  loss: 0.049 accuracy: 1.000
train:  loss: 0.045 accuracy: 1.000

eval:  loss: 1.463 accuracy: 0.878

Personalizza la fase di formazione

Se hai bisogno di maggiore flessibilità e controllo, puoi ottenerlo implementando il tuo ciclo di allenamento. Ci sono tre passaggi:

  1. Itera su un generatore Python otf.data.Dataset per ottenere batch di esempi.
  2. Usatf.GradientTape per raccogliere i gradienti.
  3. Utilizza uno dei tf.keras.optimizers per applicare gli aggiornamenti del peso alle variabili del modello.

Ricorda:

  • Includere sempre un argomento di training sul metodo di call di livelli e modelli sottoclasse.
  • Assicurati di chiamare il modello con l'argomento di training impostato correttamente.
  • A seconda dell'utilizzo, le variabili del modello potrebbero non esistere fino a quando il modello non viene eseguito su un batch di dati.
  • È necessario gestire manualmente cose come le perdite di regolarizzazione per il modello.

Notare le semplificazioni relative alla v1:

  • Non è necessario eseguire inizializzatori di variabili. Le variabili vengono inizializzate al momento della creazione.
  • Non è necessario aggiungere dipendenze di controllo manuale. Anche nella funzione tf.function operazioni agiscono come in modalità eager.
model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(0.02),
                           input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.1),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(10)
])

optimizer = tf.keras.optimizers.Adam(0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

@tf.function
def train_step(inputs, labels):
  with tf.GradientTape() as tape:
    predictions = model(inputs, training=True)
    regularization_loss=tf.math.add_n(model.losses)
    pred_loss=loss_fn(labels, predictions)
    total_loss=pred_loss + regularization_loss

  gradients = tape.gradient(total_loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

for epoch in range(NUM_EPOCHS):
  for inputs, labels in train_data:
    train_step(inputs, labels)
  print("Finished epoch", epoch)
Finished epoch 0
Finished epoch 1
Finished epoch 2
Finished epoch 3
Finished epoch 4

Metriche e perdite di nuovo stile

In TensorFlow 2.x, le metriche e le perdite sono oggetti. Questi funzionano sia con entusiasmo che in tf.function . tf.function .

Un oggetto di perdita è richiamabile e si aspetta (y_true, y_pred) come argomenti:

cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
cce([[1, 0]], [[-1.0,3.0]]).numpy()
4.01815

Un oggetto metrica ha i seguenti metodi:

L'oggetto stesso è richiamabile. La chiamata aggiorna lo stato con nuove osservazioni, come con update_state , e restituisce il nuovo risultato della metrica.

Non è necessario inizializzare manualmente le variabili di una metrica e, poiché TensorFlow 2.x ha dipendenze di controllo automatico, non è necessario preoccuparsi nemmeno di quelle.

Il codice seguente utilizza una metrica per tenere traccia della perdita media osservata all'interno di un ciclo di addestramento personalizzato.

# Create the metrics
loss_metric = tf.keras.metrics.Mean(name='train_loss')
accuracy_metric = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

@tf.function
def train_step(inputs, labels):
  with tf.GradientTape() as tape:
    predictions = model(inputs, training=True)
    regularization_loss=tf.math.add_n(model.losses)
    pred_loss=loss_fn(labels, predictions)
    total_loss=pred_loss + regularization_loss

  gradients = tape.gradient(total_loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))
  # Update the metrics
  loss_metric.update_state(total_loss)
  accuracy_metric.update_state(labels, predictions)


for epoch in range(NUM_EPOCHS):
  # Reset the metrics
  loss_metric.reset_states()
  accuracy_metric.reset_states()

  for inputs, labels in train_data:
    train_step(inputs, labels)
  # Get the metric results
  mean_loss=loss_metric.result()
  mean_accuracy = accuracy_metric.result()

  print('Epoch: ', epoch)
  print('  loss:     {:.3f}'.format(mean_loss))
  print('  accuracy: {:.3f}'.format(mean_accuracy))
Epoch:  0
  loss:     0.139
  accuracy: 0.997
Epoch:  1
  loss:     0.116
  accuracy: 1.000
Epoch:  2
  loss:     0.105
  accuracy: 0.997
Epoch:  3
  loss:     0.089
  accuracy: 1.000
Epoch:  4
  loss:     0.078
  accuracy: 1.000

Nomi delle metriche Keras

In TensorFlow 2.x, i modelli Keras sono più coerenti nella gestione dei nomi delle metriche.

Ora, quando si passa una stringa nell'elenco delle metriche, quella stringa esatta viene utilizzata come name della metrica. Questi nomi sono visibili nell'oggetto della cronologia restituito da model.fit e nei log passati a keras.callbacks . è impostato sulla stringa che hai passato nell'elenco delle metriche.

model.compile(
    optimizer = tf.keras.optimizers.Adam(0.001),
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics = ['acc', 'accuracy', tf.keras.metrics.SparseCategoricalAccuracy(name="my_accuracy")])
history = model.fit(train_data)
5/5 [==============================] - 1s 8ms/step - loss: 0.0901 - acc: 0.9923 - accuracy: 0.9923 - my_accuracy: 0.9923
history.history.keys()
dict_keys(['loss', 'acc', 'accuracy', 'my_accuracy'])

Ciò differisce dalle versioni precedenti in cui il passaggio di metrics=["accuracy"] comporterebbe dict_keys(['loss', 'acc'])

Ottimizzatori Keras

Gli ottimizzatori in v1.train , come v1.train.AdamOptimizer e v1.train.GradientDescentOptimizer , hanno equivalenti in tf.keras.optimizers .

Converti v1.train in keras.optimizers

Ecco alcune cose da tenere a mente quando converti i tuoi ottimizzatori:

Nuove impostazioni predefinite per alcuni tf.keras.optimizers

Non ci sono modifiche per optimizers.SGD , optimizers.Adam o optimizers.RMSprop .

I seguenti tassi di apprendimento predefiniti sono cambiati:

TensorBoard

TensorFlow 2.x include modifiche significative all'API tf.summary utilizzata per scrivere dati di riepilogo per la visualizzazione in TensorBoard. Per un'introduzione generale al nuovo tf.summary , sono disponibili diversi tutorial che utilizzano l'API TensorFlow 2.x. Ciò include una guida alla migrazione di TensorBoard TensorFlow 2.x.

Salvataggio e caricamento

Compatibilità checkpoint

TensorFlow 2.x utilizza checkpoint basati su oggetti .

Se stai attento, i checkpoint basati sul nome vecchio stile possono ancora essere caricati. Il processo di conversione del codice può comportare modifiche al nome della variabile, ma esistono soluzioni alternative.

L'approccio più semplice è allineare i nomi del nuovo modello con i nomi nel checkpoint:

  • Tutte le variabili hanno ancora un argomento name che puoi impostare.
  • Modelli Keras anche prendere un name argomento come che si misero come prefisso per le loro variabili.
  • La funzione v1.name_scope può essere utilizzata per impostare i prefissi dei nomi delle variabili. Questo è molto diverso da tf.variable_scope . Ha effetto solo sui nomi e non tiene traccia delle variabili e del riutilizzo.

Se questo non funziona per il tuo caso d'uso, prova la funzione v1.train.init_from_checkpoint . Accetta un argomento assignment_map , che specifica la mappatura dai vecchi nomi ai nuovi nomi.

Il repository di TensorFlow Estimator include uno strumento di conversione per aggiornare i punti di controllo per gli stimatori premade da TensorFlow 1.xa 2.0. Può servire come esempio di come costruire uno strumento per un caso d'uso simile.

Compatibilità dei modelli salvati

Non ci sono problemi di compatibilità significativi per i modelli salvati.

  • TensorFlow 1.x saved_models funziona in TensorFlow 2.x.
  • TensorFlow 2.x saved_models funziona in TensorFlow 1.x se tutte le operazioni sono supportate.

Un Graph.pb o Graph.pbtxt

Non esiste un modo semplice per aggiornare un file Graph.pb a TensorFlow 2.x. La soluzione migliore è aggiornare il codice che ha generato il file.

Ma, se hai un "grafico congelato" (un tf.Graph cui le variabili sono state trasformate in costanti), allora è possibile convertirlo in una concrete_function usando v1.wrap_function :

def wrap_frozen_graph(graph_def, inputs, outputs):
  def _imports_graph_def():
    tf.compat.v1.import_graph_def(graph_def, name="")
  wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
  import_graph = wrapped_import.graph
  return wrapped_import.prune(
      tf.nest.map_structure(import_graph.as_graph_element, inputs),
      tf.nest.map_structure(import_graph.as_graph_element, outputs))

Ad esempio, ecco un grafico congelato per Inception v1, del 2016:

path = tf.keras.utils.get_file(
    'inception_v1_2016_08_28_frozen.pb',
    'http://storage.googleapis.com/download.tensorflow.org/models/inception_v1_2016_08_28_frozen.pb.tar.gz',
    untar=True)
Downloading data from http://storage.googleapis.com/download.tensorflow.org/models/inception_v1_2016_08_28_frozen.pb.tar.gz
24698880/24695710 [==============================] - 1s 0us/step

Carica tf.GraphDef :

graph_def = tf.compat.v1.GraphDef()
loaded = graph_def.ParseFromString(open(path,'rb').read())

Avvolgilo in una concrete_function :

inception_func = wrap_frozen_graph(
    graph_def, inputs='input:0',
    outputs='InceptionV1/InceptionV1/Mixed_3b/Branch_1/Conv2d_0a_1x1/Relu:0')

Passa un tensore come input:

input_img = tf.ones([1,224,224,3], dtype=tf.float32)
inception_func(input_img).shape
TensorShape([1, 28, 28, 96])

Stimatori

Formazione con stimatori

Gli stimatori sono supportati in TensorFlow 2.x.

Quando si utilizzano gli stimatori, è possibile utilizzare input_fn , tf.estimator.TrainSpec e tf.estimator.EvalSpec da TensorFlow 1.x.

Ecco un esempio che utilizza input_fn con le specifiche train e valuta.

Creazione delle specifiche input_fn e train / eval

# Define the estimator's input_fn
def input_fn():
  datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  BUFFER_SIZE = 10000
  BATCH_SIZE = 64

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255

    return image, label[..., tf.newaxis]

  train_data = mnist_train.map(scale).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
  return train_data.repeat()

# Define train and eval specs
train_spec = tf.estimator.TrainSpec(input_fn=input_fn,
                                    max_steps=STEPS_PER_EPOCH * NUM_EPOCHS)
eval_spec = tf.estimator.EvalSpec(input_fn=input_fn,
                                  steps=STEPS_PER_EPOCH)

Utilizzando una definizione del modello Keras

Esistono alcune differenze nel modo in cui costruire gli stimatori in TensorFlow 2.x.

Si consiglia di definire il modello utilizzando Keras, quindi utilizzare l'utilità tf.keras.estimator.model_to_estimator per trasformare il modello in uno stimatore. Il codice seguente mostra come utilizzare questa utilità durante la creazione e l'addestramento di un estimatore.

def make_model():
  return tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(0.02),
                           input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.1),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(10)
  ])
model = make_model()

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

estimator = tf.keras.estimator.model_to_estimator(
  keras_model = model
)

tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp0erq3im2
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp0erq3im2
INFO:tensorflow:Using the Keras model provided.
INFO:tensorflow:Using the Keras model provided.
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/backend.py:434: UserWarning: `tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.
  warnings.warn('`tf.keras.backend.set_learning_phase` is deprecated and '
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp0erq3im2', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp0erq3im2', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmp0erq3im2/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})
INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmp0erq3im2/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})
INFO:tensorflow:Warm-starting from: /tmp/tmp0erq3im2/keras/keras_model.ckpt
INFO:tensorflow:Warm-starting from: /tmp/tmp0erq3im2/keras/keras_model.ckpt
INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.
INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.
INFO:tensorflow:Warm-started 8 variables.
INFO:tensorflow:Warm-started 8 variables.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp0erq3im2/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp0erq3im2/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 2.4717796, step = 0
INFO:tensorflow:loss = 2.4717796, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 25...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 25...
INFO:tensorflow:Saving checkpoints for 25 into /tmp/tmp0erq3im2/model.ckpt.
INFO:tensorflow:Saving checkpoints for 25 into /tmp/tmp0erq3im2/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 25...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 25...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py:2325: UserWarning: `Model.state_updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.
  warnings.warn('`Model.state_updates` will be removed in a future version. '
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-01-06T02:31:17Z
INFO:tensorflow:Starting evaluation at 2021-01-06T02:31:17Z
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmp0erq3im2/model.ckpt-25
INFO:tensorflow:Restoring parameters from /tmp/tmp0erq3im2/model.ckpt-25
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/5]
INFO:tensorflow:Evaluation [1/5]
INFO:tensorflow:Evaluation [2/5]
INFO:tensorflow:Evaluation [2/5]
INFO:tensorflow:Evaluation [3/5]
INFO:tensorflow:Evaluation [3/5]
INFO:tensorflow:Evaluation [4/5]
INFO:tensorflow:Evaluation [4/5]
INFO:tensorflow:Evaluation [5/5]
INFO:tensorflow:Evaluation [5/5]
INFO:tensorflow:Inference Time : 0.86556s
INFO:tensorflow:Inference Time : 0.86556s
INFO:tensorflow:Finished evaluation at 2021-01-06-02:31:18
INFO:tensorflow:Finished evaluation at 2021-01-06-02:31:18
INFO:tensorflow:Saving dict for global step 25: accuracy = 0.6, global_step = 25, loss = 1.6160676
INFO:tensorflow:Saving dict for global step 25: accuracy = 0.6, global_step = 25, loss = 1.6160676
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 25: /tmp/tmp0erq3im2/model.ckpt-25
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 25: /tmp/tmp0erq3im2/model.ckpt-25
INFO:tensorflow:Loss for final step: 0.37597787.
INFO:tensorflow:Loss for final step: 0.37597787.
({'accuracy': 0.6, 'loss': 1.6160676, 'global_step': 25}, [])

Utilizzando un model_fn personalizzato

Se hai uno stimatore personalizzato esistente model_fn che devi mantenere, puoi convertire il tuo model_fn per utilizzare un modello Keras.

Tuttavia, per motivi di compatibilità, un model_fn personalizzato verrà comunque eseguito in modalità grafico in stile 1.x. Ciò significa che non c'è un'esecuzione impaziente e nessuna dipendenza dal controllo automatico.

Model_fn personalizzato con modifiche minime

Per far funzionare il tuo model_fn personalizzato in TensorFlow 2.x, se preferisci modifiche minime al codice esistente, tf.compat.v1 simboli tf.compat.v1 come optimizers e metrics .

L'utilizzo di un modello Keras in un model_fn personalizzato è simile all'utilizzo in un ciclo di addestramento personalizzato:

  • Impostare la fase di training appropriato, in base all'argomento mode .
  • Passa esplicitamente le trainable_variables del modello all'ottimizzatore.

Ma ci sono differenze importanti, relative a un ciclo personalizzato :

  • Invece di utilizzare Model.losses , estrai le perdite utilizzando Model.get_losses_for .
  • Estrai gli aggiornamenti del modello utilizzando Model.get_updates_for .

Il codice seguente crea uno stimatore da un model_fn personalizzato, illustrando tutte queste preoccupazioni.

def my_model_fn(features, labels, mode):
  model = make_model()

  optimizer = tf.compat.v1.train.AdamOptimizer()
  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

  training = (mode == tf.estimator.ModeKeys.TRAIN)
  predictions = model(features, training=training)

  if mode == tf.estimator.ModeKeys.PREDICT:
    return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

  reg_losses = model.get_losses_for(None) + model.get_losses_for(features)
  total_loss=loss_fn(labels, predictions) + tf.math.add_n(reg_losses)

  accuracy = tf.compat.v1.metrics.accuracy(labels=labels,
                                           predictions=tf.math.argmax(predictions, axis=1),
                                           name='acc_op')

  update_ops = model.get_updates_for(None) + model.get_updates_for(features)
  minimize_op = optimizer.minimize(
      total_loss,
      var_list=model.trainable_variables,
      global_step=tf.compat.v1.train.get_or_create_global_step())
  train_op = tf.group(minimize_op, update_ops)

  return tf.estimator.EstimatorSpec(
    mode=mode,
    predictions=predictions,
    loss=total_loss,
    train_op=train_op, eval_metric_ops={'accuracy': accuracy})

# Create the Estimator & Train
estimator = tf.estimator.Estimator(model_fn=my_model_fn)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpifj8mysl
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpifj8mysl
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpifj8mysl', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpifj8mysl', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpifj8mysl/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpifj8mysl/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 3.0136237, step = 0
INFO:tensorflow:loss = 3.0136237, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 25...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 25...
INFO:tensorflow:Saving checkpoints for 25 into /tmp/tmpifj8mysl/model.ckpt.
INFO:tensorflow:Saving checkpoints for 25 into /tmp/tmpifj8mysl/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 25...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 25...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-01-06T02:31:20Z
INFO:tensorflow:Starting evaluation at 2021-01-06T02:31:20Z
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpifj8mysl/model.ckpt-25
INFO:tensorflow:Restoring parameters from /tmp/tmpifj8mysl/model.ckpt-25
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/5]
INFO:tensorflow:Evaluation [1/5]
INFO:tensorflow:Evaluation [2/5]
INFO:tensorflow:Evaluation [2/5]
INFO:tensorflow:Evaluation [3/5]
INFO:tensorflow:Evaluation [3/5]
INFO:tensorflow:Evaluation [4/5]
INFO:tensorflow:Evaluation [4/5]
INFO:tensorflow:Evaluation [5/5]
INFO:tensorflow:Evaluation [5/5]
INFO:tensorflow:Inference Time : 0.97406s
INFO:tensorflow:Inference Time : 0.97406s
INFO:tensorflow:Finished evaluation at 2021-01-06-02:31:21
INFO:tensorflow:Finished evaluation at 2021-01-06-02:31:21
INFO:tensorflow:Saving dict for global step 25: accuracy = 0.59375, global_step = 25, loss = 1.6248872
INFO:tensorflow:Saving dict for global step 25: accuracy = 0.59375, global_step = 25, loss = 1.6248872
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 25: /tmp/tmpifj8mysl/model.ckpt-25
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 25: /tmp/tmpifj8mysl/model.ckpt-25
INFO:tensorflow:Loss for final step: 0.35726172.
INFO:tensorflow:Loss for final step: 0.35726172.
({'accuracy': 0.59375, 'loss': 1.6248872, 'global_step': 25}, [])

model_fn personalizzato con simboli TensorFlow 2.x.

Se desideri eliminare tutti i simboli di TensorFlow 1.x e aggiornare il tuo model_fn personalizzato a TensorFlow 2.x, devi aggiornare l'ottimizzatore e le metriche a tf.keras.optimizers e tf.keras.metrics .

Nella custom model_fn , oltre alle modifiche di cui sopra, è necessario effettuare altri aggiornamenti:

Nell'esempio precedente di my_model_fn , il codice migrato con i simboli TensorFlow 2.x viene mostrato come:

def my_model_fn(features, labels, mode):
  model = make_model()

  training = (mode == tf.estimator.ModeKeys.TRAIN)
  loss_obj = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
  predictions = model(features, training=training)

  # Get both the unconditional losses (the None part)
  # and the input-conditional losses (the features part).
  reg_losses = model.get_losses_for(None) + model.get_losses_for(features)
  total_loss=loss_obj(labels, predictions) + tf.math.add_n(reg_losses)

  # Upgrade to tf.keras.metrics.
  accuracy_obj = tf.keras.metrics.Accuracy(name='acc_obj')
  accuracy = accuracy_obj.update_state(
      y_true=labels, y_pred=tf.math.argmax(predictions, axis=1))

  train_op = None
  if training:
    # Upgrade to tf.keras.optimizers.
    optimizer = tf.keras.optimizers.Adam()
    # Manually assign tf.compat.v1.global_step variable to optimizer.iterations
    # to make tf.compat.v1.train.global_step increased correctly.
    # This assignment is a must for any `tf.train.SessionRunHook` specified in
    # estimator, as SessionRunHooks rely on global step.
    optimizer.iterations = tf.compat.v1.train.get_or_create_global_step()
    # Get both the unconditional updates (the None part)
    # and the input-conditional updates (the features part).
    update_ops = model.get_updates_for(None) + model.get_updates_for(features)
    # Compute the minimize_op.
    minimize_op = optimizer.get_updates(
        total_loss,
        model.trainable_variables)[0]
    train_op = tf.group(minimize_op, *update_ops)

  return tf.estimator.EstimatorSpec(
    mode=mode,
    predictions=predictions,
    loss=total_loss,
    train_op=train_op,
    eval_metric_ops={'Accuracy': accuracy_obj})

# Create the Estimator and train.
estimator = tf.estimator.Estimator(model_fn=my_model_fn)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpc93qfnv6
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpc93qfnv6
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpc93qfnv6', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpc93qfnv6', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpc93qfnv6/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpc93qfnv6/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 2.5293791, step = 0
INFO:tensorflow:loss = 2.5293791, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 25...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 25...
INFO:tensorflow:Saving checkpoints for 25 into /tmp/tmpc93qfnv6/model.ckpt.
INFO:tensorflow:Saving checkpoints for 25 into /tmp/tmpc93qfnv6/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 25...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 25...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-01-06T02:31:24Z
INFO:tensorflow:Starting evaluation at 2021-01-06T02:31:24Z
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpc93qfnv6/model.ckpt-25
INFO:tensorflow:Restoring parameters from /tmp/tmpc93qfnv6/model.ckpt-25
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/5]
INFO:tensorflow:Evaluation [1/5]
INFO:tensorflow:Evaluation [2/5]
INFO:tensorflow:Evaluation [2/5]
INFO:tensorflow:Evaluation [3/5]
INFO:tensorflow:Evaluation [3/5]
INFO:tensorflow:Evaluation [4/5]
INFO:tensorflow:Evaluation [4/5]
INFO:tensorflow:Evaluation [5/5]
INFO:tensorflow:Evaluation [5/5]
INFO:tensorflow:Inference Time : 0.86534s
INFO:tensorflow:Inference Time : 0.86534s
INFO:tensorflow:Finished evaluation at 2021-01-06-02:31:25
INFO:tensorflow:Finished evaluation at 2021-01-06-02:31:25
INFO:tensorflow:Saving dict for global step 25: Accuracy = 0.59375, global_step = 25, loss = 1.7570661
INFO:tensorflow:Saving dict for global step 25: Accuracy = 0.59375, global_step = 25, loss = 1.7570661
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 25: /tmp/tmpc93qfnv6/model.ckpt-25
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 25: /tmp/tmpc93qfnv6/model.ckpt-25
INFO:tensorflow:Loss for final step: 0.47094986.
INFO:tensorflow:Loss for final step: 0.47094986.
({'Accuracy': 0.59375, 'loss': 1.7570661, 'global_step': 25}, [])

Stimatori premade

Gli stimatori premade della famiglia tf.estimator.DNN* , tf.estimator.Linear* e tf.estimator.DNNLinearCombined* sono ancora supportati nell'API TensorFlow 2.x. Tuttavia, alcuni argomenti sono cambiati:

  1. input_layer_partitioner : rimosso in v2.
  2. loss_reduction : aggiornato a tf.keras.losses.Reduction invece che a tf.compat.v1.losses.Reduction . Anche il suo valore predefinito viene modificato in tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE da tf.compat.v1.losses.Reduction.SUM .
  3. optimizer , dnn_optimizer e linear_optimizer : questo argomento è stato aggiornato a tf.keras.optimizers invece che a tf.compat.v1.train.Optimizer .

Per migrare le modifiche di cui sopra:

  1. Non è necessaria alcuna migrazione per input_layer_partitioner poiché Distribution Strategy lo gestirà automaticamente in TensorFlow 2.x.
  2. Per loss_reduction , controlla tf.keras.losses.Reduction per le opzioni supportate.
  3. Per gli argomenti optimizer :
    • Se non: 1) passi optimizer , l'argomento dnn_optimizer o linear_optimizer , o 2) specifichi l'argomento optimizer come una string nel tuo codice, allora non devi cambiare nulla perché tf.keras.optimizers è usato per impostazione predefinita .
    • Altrimenti, è necessario aggiornarlo da tf.compat.v1.train.Optimizer al suo tf.keras.optimizers corrispondente.

Checkpoint Converter

La migrazione a keras.optimizers interromperà i checkpoint salvati utilizzando TensorFlow 1.x, poiché tf.keras.optimizers genera un diverso insieme di variabili da salvare nei checkpoint. Per rendere riutilizzabile il vecchio checkpoint dopo la migrazione a TensorFlow 2.x, prova lo strumento di conversione del checkpoint .

 curl -O https://raw.githubusercontent.com/tensorflow/estimator/master/tensorflow_estimator/python/estimator/tools/checkpoint_converter.py
% Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 15165  100 15165    0     0  40656      0 --:--:-- --:--:-- --:--:-- 40656

Lo strumento ha una guida integrata:

 python checkpoint_converter.py -h
2021-01-06 02:31:26.297951: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
usage: checkpoint_converter.py [-h]
                               {dnn,linear,combined} source_checkpoint
                               source_graph target_checkpoint

positional arguments:
  {dnn,linear,combined}
                        The type of estimator to be converted. So far, the
                        checkpoint converter only supports Canned Estimator.
                        So the allowed types include linear, dnn and combined.
  source_checkpoint     Path to source checkpoint file to be read in.
  source_graph          Path to source graph file to be read in.
  target_checkpoint     Path to checkpoint file to be written out.

optional arguments:
  -h, --help            show this help message and exit

TensorShape

Questa classe è stata semplificata per contenere oggetti int s, invece di tf.compat.v1.Dimension . Quindi non è necessario chiamare .value per ottenere un int .

I tf.compat.v1.Dimension oggetti tf.compat.v1.Dimension sono ancora accessibili da tf.TensorShape.dims .

Di seguito vengono illustrate le differenze tra TensorFlow 1.x e TensorFlow 2.x.

# Create a shape and choose an index
i = 0
shape = tf.TensorShape([16, None, 256])
shape
TensorShape([16, None, 256])

Se avevi questo in TensorFlow 1.x:

value = shape[i].value

Quindi fallo in TensorFlow 2.x:

value = shape[i]
value
16

Se avevi questo in TensorFlow 1.x:

for dim in shape:
    value = dim.value
    print(value)

Quindi fallo in TensorFlow 2.x:

for value in shape:
  print(value)
16
None
256

Se lo avevi in ​​TensorFlow 1.x (o hai utilizzato qualsiasi altro metodo di dimensione):

dim = shape[i]
dim.assert_is_compatible_with(other_dim)

Quindi fallo in TensorFlow 2.x:

other_dim = 16
Dimension = tf.compat.v1.Dimension

if shape.rank is None:
  dim = Dimension(None)
else:
  dim = shape.dims[i]
dim.is_compatible_with(other_dim) # or any other dimension method
True
shape = tf.TensorShape(None)

if shape:
  dim = shape.dims[i]
  dim.is_compatible_with(other_dim) # or any other dimension method

Il valore booleano di un tf.TensorShape è True se il rango è noto, False contrario.

print(bool(tf.TensorShape([])))      # Scalar
print(bool(tf.TensorShape([0])))     # 0-length vector
print(bool(tf.TensorShape([1])))     # 1-length vector
print(bool(tf.TensorShape([None])))  # Unknown-length vector
print(bool(tf.TensorShape([1, 10, 100])))       # 3D tensor
print(bool(tf.TensorShape([None, None, None]))) # 3D tensor with no known dimensions
print()
print(bool(tf.TensorShape(None)))  # A tensor with unknown rank.
True
True
True
True
True
True

False

Altre modifiche

  • Rimuovi tf.colocate_with : gli algoritmi di posizionamento dei dispositivi di TensorFlow sono migliorati in modo significativo. Questo non dovrebbe più essere necessario. Se rimuoverlo causa una riduzione delle prestazioni, segnala un bug .

  • Sostituisci v1.ConfigProto utilizzo di v1.ConfigProto con le funzioni equivalenti da tf.config .

Conclusioni

Il processo complessivo è:

  1. Esegui lo script di aggiornamento.
  2. Rimuovi i simboli contrib.
  3. Cambia i tuoi modelli in uno stile orientato agli oggetti (Keras).
  4. Usa i tf.keras formazione e valutazione tf.keras o tf.estimator dove puoi.
  5. Altrimenti, utilizza loop personalizzati, ma assicurati di evitare sessioni e raccolte.

Ci vuole un po 'di lavoro per convertire il codice in TensorFlow 2.x idiomatico, ma ogni modifica si traduce in:

  • Meno righe di codice.
  • Maggiore chiarezza e semplicità.
  • Debug più semplice.