ML Community Day è il 9 novembre! Unisciti a noi per gli aggiornamenti da tensorflow, JAX, e più Per saperne di più

Introduzione ai grafici e tf.function

Visualizza su TensorFlow.org Esegui in Google Colab Visualizza la fonte su GitHub Scarica taccuino

Panoramica

Questa guida va sotto la superficie di TensorFlow e Keras per dimostrare come funziona TensorFlow. Se si desidera invece immediatamente iniziare con Keras, controlla la raccolta di guide Keras .

In questa guida imparerai come TensorFlow ti consente di apportare semplici modifiche al tuo codice per ottenere grafici, come i grafici sono archiviati e rappresentati e come puoi usarli per accelerare i tuoi modelli.

Questa è una panoramica grande quadro che spiega come tf.function consente di passare dall'esecuzione desiderosi di esecuzione grafico. Per una specificazione più completa di tf.function , andare al tf.function guida .

Cosa sono i grafici?

Nei precedenti tre guide, è stato eseguito tensorflow avidamente. Ciò significa che le operazioni TensorFlow vengono eseguite da Python, operazione per operazione, e restituiscono i risultati a Python.

Sebbene l'esecuzione desiderosa abbia diversi vantaggi unici, l'esecuzione del grafico consente la portabilità al di fuori di Python e tende a offrire prestazioni migliori. Grafico mezzi di esecuzione che calcoli tensoriali vengono eseguiti come grafico tensorflow, talvolta indicato come un tf.Graph o semplicemente un "grafico".

I grafici sono strutture di dati che contengono una serie di tf.Operation oggetti, che rappresentano le unità di calcolo; e tf.Tensor oggetti, che rappresentano le unità di dati che scorrono tra le operazioni. Essi sono definiti in un tf.Graph contesto. Poiché questi grafici sono strutture di dati, possono essere salvati, eseguiti e ripristinati senza il codice Python originale.

Questo è l'aspetto di un grafico TensorFlow che rappresenta una rete neurale a due livelli quando viene visualizzato in TensorBoard.

Un semplice grafico TensorFlow

I vantaggi dei grafici

Con un grafico, hai una grande flessibilità. Puoi usare il tuo grafico TensorFlow in ambienti che non dispongono di un interprete Python, come applicazioni mobili, dispositivi incorporati e server di backend. Tensorflow utilizza grafici come formato per i modelli salvati quando li esporta da Python.

I grafici sono anche facilmente ottimizzati, consentendo al compilatore di eseguire trasformazioni come:

  • Staticamente dedurre il valore dei tensori piegando nodi costanti nel calcolo ( "folding costante").
  • Separare le sottoparti di un calcolo indipendenti e suddividerle tra thread o dispositivi.
  • Semplifica le operazioni aritmetiche eliminando le sottoespressioni comuni.

C'è un intero sistema di ottimizzazione, Grappler , per eseguire questo e altri incrementi nella velocità.

In breve, i grafici sono estremamente utili e lasciate che il vostro tensorflow correre veloce, in parallelo, ed eseguire in modo efficiente su più dispositivi.

Tuttavia, vuoi comunque definire i tuoi modelli di apprendimento automatico (o altri calcoli) in Python per comodità e quindi costruire automaticamente i grafici quando ne hai bisogno.

Impostare

import tensorflow as tf
import timeit
from datetime import datetime

Sfruttando i grafici

È possibile creare ed eseguire un grafico in tensorflow utilizzando tf.function , sia come una chiamata diretta o come un decoratore. tf.function prende una funzione regolare come input e restituisce un Function . Una Function è richiamabile Python che costruisce grafici tensorflow dalla funzione Python. Si utilizza un Function nello stesso modo come il suo equivalente Python.

# Define a Python function.
def a_regular_function(x, y, b):
  x = tf.matmul(x, y)
  x = x + b
  return x

# `a_function_that_uses_a_graph` is a TensorFlow `Function`.
a_function_that_uses_a_graph = tf.function(a_regular_function)

# Make some tensors.
x1 = tf.constant([[1.0, 2.0]])
y1 = tf.constant([[2.0], [3.0]])
b1 = tf.constant(4.0)

orig_value = a_regular_function(x1, y1, b1).numpy()
# Call a `Function` like a Python function.
tf_function_value = a_function_that_uses_a_graph(x1, y1, b1).numpy()
assert(orig_value == tf_function_value)

Sulla parte esterna, una Function appare come una normale funzione si scrive utilizzando le operazioni tensorflow. Sotto , però, è molto diversa. Una Function incapsula diversi tf.Graph s dietro un'API . È così che Function è in grado di darvi i vantaggi di esecuzione grafico , come la velocità e la schierabilità.

tf.function applica a una funzione e tutte le altre funzioni chiamate:

def inner_function(x, y, b):
  x = tf.matmul(x, y)
  x = x + b
  return x

# Use the decorator to make `outer_function` a `Function`.
@tf.function
def outer_function(x):
  y = tf.constant([[2.0], [3.0]])
  b = tf.constant(4.0)

  return inner_function(x, y, b)

# Note that the callable will create a graph that
# includes `inner_function` as well as `outer_function`.
outer_function(tf.constant([[1.0, 2.0]])).numpy()
array([[12.]], dtype=float32)

Se avete usato tensorflow 1.x, si noterà che in nessun momento è necessario definire un Placeholder o tf.Session .

Conversione di funzioni Python in grafici

Ogni funzione si scrive con tensorflow conterrà una miscela di built-in operazioni di TF e la logica Python, come if-then clausole, loop, break , return , continue , e altro ancora. Mentre le operazioni tensorflow sono facilmente catturati da un tf.Graph , Python-specifiche esigenze logici a subire un ulteriore passaggio per diventare parte del grafico. tf.function utilizza una libreria chiamata AutoGraph ( tf.autograph ) per convertire codice Python in codice grafico generatrici.

def simple_relu(x):
  if tf.greater(x, 0):
    return x
  else:
    return 0

# `tf_simple_relu` is a TensorFlow `Function` that wraps `simple_relu`.
tf_simple_relu = tf.function(simple_relu)

print("First branch, with graph:", tf_simple_relu(tf.constant(1)).numpy())
print("Second branch, with graph:", tf_simple_relu(tf.constant(-1)).numpy())
First branch, with graph: 1
Second branch, with graph: 0

Sebbene sia improbabile che sia necessario visualizzare direttamente i grafici, è possibile ispezionare gli output per verificare i risultati esatti. Questi non sono facili da leggere, quindi non c'è bisogno di guardare troppo attentamente!

# This is the graph-generating output of AutoGraph.
print(tf.autograph.to_code(simple_relu))
def tf__simple_relu(x):
    with ag__.FunctionScope('simple_relu', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        do_return = False
        retval_ = ag__.UndefinedReturnValue()

        def get_state():
            return (do_return, retval_)

        def set_state(vars_):
            nonlocal retval_, do_return
            (do_return, retval_) = vars_

        def if_body():
            nonlocal retval_, do_return
            try:
                do_return = True
                retval_ = ag__.ld(x)
            except:
                do_return = False
                raise

        def else_body():
            nonlocal retval_, do_return
            try:
                do_return = True
                retval_ = 0
            except:
                do_return = False
                raise
        ag__.if_stmt(ag__.converted_call(ag__.ld(tf).greater, (ag__.ld(x), 0), None, fscope), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
        return fscope.ret(retval_, do_return)
# This is the graph itself.
print(tf_simple_relu.get_concrete_function(tf.constant(1)).graph.as_graph_def())
node {
  name: "x"
  op: "Placeholder"
  attr {
    key: "_user_specified_name"
    value {
      s: "x"
    }
  }
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "shape"
    value {
      shape {
      }
    }
  }
}
node {
  name: "Greater/y"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
        }
        int_val: 0
      }
    }
  }
}
node {
  name: "Greater"
  op: "Greater"
  input: "x"
  input: "Greater/y"
  attr {
    key: "T"
    value {
      type: DT_INT32
    }
  }
}
node {
  name: "cond"
  op: "StatelessIf"
  input: "Greater"
  input: "x"
  attr {
    key: "Tcond"
    value {
      type: DT_BOOL
    }
  }
  attr {
    key: "Tin"
    value {
      list {
        type: DT_INT32
      }
    }
  }
  attr {
    key: "Tout"
    value {
      list {
        type: DT_BOOL
        type: DT_INT32
      }
    }
  }
  attr {
    key: "_lower_using_switch_merge"
    value {
      b: true
    }
  }
  attr {
    key: "_read_only_resource_inputs"
    value {
      list {
      }
    }
  }
  attr {
    key: "else_branch"
    value {
      func {
        name: "cond_false_34"
      }
    }
  }
  attr {
    key: "output_shapes"
    value {
      list {
        shape {
        }
        shape {
        }
      }
    }
  }
  attr {
    key: "then_branch"
    value {
      func {
        name: "cond_true_33"
      }
    }
  }
}
node {
  name: "cond/Identity"
  op: "Identity"
  input: "cond"
  attr {
    key: "T"
    value {
      type: DT_BOOL
    }
  }
}
node {
  name: "cond/Identity_1"
  op: "Identity"
  input: "cond:1"
  attr {
    key: "T"
    value {
      type: DT_INT32
    }
  }
}
node {
  name: "Identity"
  op: "Identity"
  input: "cond/Identity_1"
  attr {
    key: "T"
    value {
      type: DT_INT32
    }
  }
}
library {
  function {
    signature {
      name: "cond_false_34"
      input_arg {
        name: "cond_placeholder"
        type: DT_INT32
      }
      output_arg {
        name: "cond_identity"
        type: DT_BOOL
      }
      output_arg {
        name: "cond_identity_1"
        type: DT_INT32
      }
    }
    node_def {
      name: "cond/Const"
      op: "Const"
      attr {
        key: "dtype"
        value {
          type: DT_BOOL
        }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_BOOL
            tensor_shape {
            }
            bool_val: true
          }
        }
      }
      experimental_debug_info {
        original_node_names: "cond/Const"
      }
    }
    node_def {
      name: "cond/Const_1"
      op: "Const"
      attr {
        key: "dtype"
        value {
          type: DT_BOOL
        }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_BOOL
            tensor_shape {
            }
            bool_val: true
          }
        }
      }
      experimental_debug_info {
        original_node_names: "cond/Const_1"
      }
    }
    node_def {
      name: "cond/Const_2"
      op: "Const"
      attr {
        key: "dtype"
        value {
          type: DT_INT32
        }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_INT32
            tensor_shape {
            }
            int_val: 0
          }
        }
      }
      experimental_debug_info {
        original_node_names: "cond/Const_2"
      }
    }
    node_def {
      name: "cond/Const_3"
      op: "Const"
      attr {
        key: "dtype"
        value {
          type: DT_BOOL
        }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_BOOL
            tensor_shape {
            }
            bool_val: true
          }
        }
      }
      experimental_debug_info {
        original_node_names: "cond/Const_3"
      }
    }
    node_def {
      name: "cond/Identity"
      op: "Identity"
      input: "cond/Const_3:output:0"
      attr {
        key: "T"
        value {
          type: DT_BOOL
        }
      }
      experimental_debug_info {
        original_node_names: "cond/Identity"
      }
    }
    node_def {
      name: "cond/Const_4"
      op: "Const"
      attr {
        key: "dtype"
        value {
          type: DT_INT32
        }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_INT32
            tensor_shape {
            }
            int_val: 0
          }
        }
      }
      experimental_debug_info {
        original_node_names: "cond/Const_4"
      }
    }
    node_def {
      name: "cond/Identity_1"
      op: "Identity"
      input: "cond/Const_4:output:0"
      attr {
        key: "T"
        value {
          type: DT_INT32
        }
      }
      experimental_debug_info {
        original_node_names: "cond/Identity_1"
      }
    }
    ret {
      key: "cond_identity"
      value: "cond/Identity:output:0"
    }
    ret {
      key: "cond_identity_1"
      value: "cond/Identity_1:output:0"
    }
    attr {
      key: "_construction_context"
      value {
        s: "kEagerRuntime"
      }
    }
    arg_attr {
      key: 0
      value {
        attr {
          key: "_output_shapes"
          value {
            list {
              shape {
              }
            }
          }
        }
      }
    }
  }
  function {
    signature {
      name: "cond_true_33"
      input_arg {
        name: "cond_identity_1_x"
        type: DT_INT32
      }
      output_arg {
        name: "cond_identity"
        type: DT_BOOL
      }
      output_arg {
        name: "cond_identity_1"
        type: DT_INT32
      }
    }
    node_def {
      name: "cond/Const"
      op: "Const"
      attr {
        key: "dtype"
        value {
          type: DT_BOOL
        }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_BOOL
            tensor_shape {
            }
            bool_val: true
          }
        }
      }
      experimental_debug_info {
        original_node_names: "cond/Const"
      }
    }
    node_def {
      name: "cond/Identity"
      op: "Identity"
      input: "cond/Const:output:0"
      attr {
        key: "T"
        value {
          type: DT_BOOL
        }
      }
      experimental_debug_info {
        original_node_names: "cond/Identity"
      }
    }
    node_def {
      name: "cond/Identity_1"
      op: "Identity"
      input: "cond_identity_1_x"
      attr {
        key: "T"
        value {
          type: DT_INT32
        }
      }
      experimental_debug_info {
        original_node_names: "cond/Identity_1"
      }
    }
    ret {
      key: "cond_identity"
      value: "cond/Identity:output:0"
    }
    ret {
      key: "cond_identity_1"
      value: "cond/Identity_1:output:0"
    }
    attr {
      key: "_construction_context"
      value {
        s: "kEagerRuntime"
      }
    }
    arg_attr {
      key: 0
      value {
        attr {
          key: "_output_shapes"
          value {
            list {
              shape {
              }
            }
          }
        }
      }
    }
  }
}
versions {
  producer: 808
  min_consumer: 12
}

La maggior parte del tempo, tf.function funzionerà senza considerazioni speciali. Tuttavia, ci sono alcuni avvertimenti, e la guida tf.function può aiutare qui, così come il riferimento AutoGraph completa

Polimorfismo: una Function , molti grafici

Un tf.Graph è specializzato per un tipo specifico di ingressi (ad esempio, tensori con una determinata dtype o oggetti con lo stesso id() ).

Ogni volta che si richiama una Function con i nuovi dtypes e forme nei suoi argomenti, Function crea un nuovo tf.Graph per i nuovi argomenti. I dtypes e le forme di una tf.Graph ingressi 's sono noti come una firma ingresso o solo una firma.

Le Function memorizza il tf.Graph corrispondente a quella firma in un ConcreteFunction . Un ConcreteFunction è un wrapper per un tf.Graph .

@tf.function
def my_relu(x):
  return tf.maximum(0., x)

# `my_relu` creates new graphs as it observes more signatures.
print(my_relu(tf.constant(5.5)))
print(my_relu([1, -1]))
print(my_relu(tf.constant([3., -3.])))
tf.Tensor(5.5, shape=(), dtype=float32)
tf.Tensor([1. 0.], shape=(2,), dtype=float32)
tf.Tensor([3. 0.], shape=(2,), dtype=float32)

Se la Function è già stato chiamato con quella firma, Function non crea un nuovo tf.Graph .

# These two calls do *not* create new graphs.
print(my_relu(tf.constant(-2.5))) # Signature matches `tf.constant(5.5)`.
print(my_relu(tf.constant([-1., 1.]))) # Signature matches `tf.constant([3., -3.])`.
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor([0. 1.], shape=(2,), dtype=float32)

Perché è sostenuta da più grafici, una Function è polimorfica. Che consente di supportare più tipi di input di un singolo tf.Graph potrebbe rappresentare, nonché per ottimizzare ogni tf.Graph per migliorare le prestazioni.

# There are three `ConcreteFunction`s (one for each graph) in `my_relu`.
# The `ConcreteFunction` also knows the return type and shape!
print(my_relu.pretty_printed_concrete_signatures())
my_relu(x)
  Args:
    x: float32 Tensor, shape=()
  Returns:
    float32 Tensor, shape=()

my_relu(x=[1, -1])
  Returns:
    float32 Tensor, shape=(2,)

my_relu(x)
  Args:
    x: float32 Tensor, shape=(2,)
  Returns:
    float32 Tensor, shape=(2,)

utilizzando tf.function

Finora, avete imparato come convertire una funzione Python in un grafico semplicemente utilizzando tf.function come un decoratore o wrapper. Ma in pratica, ottenendo tf.function per funzionare correttamente può essere difficile! Nelle sezioni seguenti, imparerete come si può fare il vostro lavoro di codice come previsto con tf.function .

Esecuzione del grafico vs. esecuzione ansiosa

Il codice in una Function può essere eseguita sia avidamente e come grafico. Per impostazione predefinita, Function esegue il suo codice sotto forma di grafico:

@tf.function
def get_MSE(y_true, y_pred):
  sq_diff = tf.pow(y_true - y_pred, 2)
  return tf.reduce_mean(sq_diff)
y_true = tf.random.uniform([5], maxval=10, dtype=tf.int32)
y_pred = tf.random.uniform([5], maxval=10, dtype=tf.int32)
print(y_true)
print(y_pred)
tf.Tensor([6 4 3 1 5], shape=(5,), dtype=int32)
tf.Tensor([6 7 9 9 1], shape=(5,), dtype=int32)
get_MSE(y_true, y_pred)
<tf.Tensor: shape=(), dtype=int32, numpy=25>

Per verificare che la Function grafico s' sta facendo lo stesso calcolo come la sua funzione Python equivalente, si può fare eseguire con entusiasmo con tf.config.run_functions_eagerly(True) . Questo è un interruttore che spegne Function capacità s' per creare ed eseguire grafici, invece l'esecuzione del codice normalmente.

tf.config.run_functions_eagerly(True)
get_MSE(y_true, y_pred)
<tf.Tensor: shape=(), dtype=int32, numpy=25>
# Don't forget to set it back when you are done.
tf.config.run_functions_eagerly(False)

Tuttavia, Function può comportarsi diversamente, nell'ambito grafico ed esecuzione ansioso. Il pitone print funzione è un esempio di come questi due modi diversi. Assegno di far uscire quello che succede quando si inserisce una print dichiarazione per la funzione e chiamare ripetutamente.

@tf.function
def get_MSE(y_true, y_pred):
  print("Calculating MSE!")
  sq_diff = tf.pow(y_true - y_pred, 2)
  return tf.reduce_mean(sq_diff)

Osserva cosa viene stampato:

error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)
Calculating MSE!

L'output è sorprendente? get_MSE stampato solo una volta, anche se è stato chiamato tre volte.

Per spiegare, la print istruzione viene eseguita quando Function viene eseguito il codice originale al fine di creare il grafico in un processo noto come "tracciabilità" . Tracciare acquisisce le operazioni tensorflow in un grafico e print non viene catturata nel grafico. Questo grafico viene quindi eseguito per tutte e tre le chiamate senza mai eseguire nuovamente il codice Python.

Come controllo di integrità, disattiviamo l'esecuzione del grafico per confrontare:

# Now, globally set everything to run eagerly to force eager execution.
tf.config.run_functions_eagerly(True)
# Observe what is printed below.
error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)
Calculating MSE!
Calculating MSE!
Calculating MSE!
tf.config.run_functions_eagerly(False)

print è un effetto collaterale di Python, e ci sono altre differenze che si dovrebbe essere a conoscenza di quando la conversione di una funzione in una Function .

Esecuzione non rigorosa

L'esecuzione del grafico esegue solo le operazioni necessarie per produrre gli effetti osservabili, che includono:

  • Il valore di ritorno della funzione
  • Effetti collaterali noti e documentati come:

Questo comportamento è generalmente noto come "esecuzione non rigorosa" e differisce dall'esecuzione ansiosa, che passa attraverso tutte le operazioni del programma, necessarie o meno.

In particolare, il controllo degli errori di runtime non conta come un effetto osservabile. Se un'operazione viene ignorata perché non necessaria, non può generare errori di runtime.

Nel seguente esempio, l'operazione "inutile" tf.gather è ignorato durante l'esecuzione grafico, quindi l'errore runtime InvalidArgumentError non viene generato come sarebbe in esecuzione ansiosi. Non fare affidamento su un errore generato durante l'esecuzione di un grafico.

def unused_return_eager(x):
  # Get index 1 will fail when `len(x) == 1`
  tf.gather(x, [1]) # unused 
  return x

try:
  print(unused_return_eager(tf.constant([0.0])))
except tf.errors.InvalidArgumentError as e:
  # All operations are run during eager execution so an error is raised.
  print(f'{type(e).__name__}: {e}')
tf.Tensor([0.], shape=(1,), dtype=float32)
@tf.function
def unused_return_graph(x):
  tf.gather(x, [1]) # unused
  return x

# Only needed operations are run during graph exection. The error is not raised.
print(unused_return_graph(tf.constant([0.0])))
tf.Tensor([0.], shape=(1,), dtype=float32)

tf.function migliori pratiche

Si può richiedere un certo tempo per abituarsi al comportamento della Function . Per iniziare rapidamente, prima volta gli utenti devono giocare con la decorazione funzioni giocattolo con @tf.function per fare esperienza con l'andare da desiderosi di esecuzione grafico.

Progettare per tf.function potrebbe essere la soluzione migliore per la scrittura di programmi tensorflow grafico-compatibili. Ecco alcuni suggerimenti:

  • Commutazione tra esecuzione ansioso e grafico presto e spesso con tf.config.run_functions_eagerly per individuare se / quando le due modalità divergono.
  • Creare tf.Variable s fuori della funzione Python e modificarli al suo interno. Lo stesso vale per gli oggetti che utilizzano tf.Variable , come keras.layers , keras.Model S e tf.optimizers .
  • Evitare di funzioni che la scrittura dipendono da variabili esterne Python , esclusi tf.Variable s ed oggetti Keras.
  • Preferisci scrivere funzioni che accettano tensori e altri tipi di TensorFlow come input. È possibile passare in altri tipi di oggetti, ma state attenti !
  • Includere il maggior numero possibile di calcolo sotto un tf.function per massimizzare il guadagno di prestazioni. Ad esempio, decora un'intera fase di allenamento o l'intero ciclo di allenamento.

Vedendo l'accelerazione

tf.function solito migliora le prestazioni del vostro codice, ma la quantità di speed-up dipende dal tipo di calcolo si esegue. I piccoli calcoli possono essere dominati dal sovraccarico di chiamare un grafico. Puoi misurare la differenza di prestazioni in questo modo:

x = tf.random.uniform(shape=[10, 10], minval=-1, maxval=2, dtype=tf.dtypes.int32)

def power(x, y):
  result = tf.eye(10, dtype=tf.dtypes.int32)
  for _ in range(y):
    result = tf.matmul(x, result)
  return result
print("Eager execution:", timeit.timeit(lambda: power(x, 100), number=1000))
Eager execution: 2.286959676999686
power_as_graph = tf.function(power)
print("Graph execution:", timeit.timeit(lambda: power_as_graph(x, 100), number=1000))
Graph execution: 0.6817171659999985

tf.function è comunemente utilizzato per accelerare i cicli di formazione, e si può imparare di più su di esso in scrittura un ciclo di formazione da zero con Keras.

Prestazioni e compromessi

I grafici possono velocizzare il codice, ma il processo di creazione ha un sovraccarico. Per alcune funzioni, la creazione del grafico richiede più tempo dell'esecuzione del grafico. Questo investimento viene in genere rapidamente ripagato con l'aumento delle prestazioni delle esecuzioni successive, ma è importante essere consapevoli che i primi passaggi di qualsiasi addestramento di modelli di grandi dimensioni possono essere più lenti a causa della traccia.

Non importa quanto sia grande il tuo modello, vuoi evitare di tracciare frequentemente. I tf.function guida mostra come le specifiche di ingresso fissi e argomenti uso tensore per evitare ripercorrendo. Se ti accorgi di ottenere prestazioni insolitamente scarse, è una buona idea controllare se stai ritracciando accidentalmente.

Quando una Function tracciando?

Per capire quando la vostra Function sta tracciando, aggiungere una print dichiarazione al suo codice. Come regola generale, Function eseguirà la print dichiarazione ogni volta che tracce.

@tf.function
def a_function_with_python_side_effect(x):
  print("Tracing!") # An eager-only side effect.
  return x * x + tf.constant(2)

# This is traced the first time.
print(a_function_with_python_side_effect(tf.constant(2)))
# The second time through, you won't see the side effect.
print(a_function_with_python_side_effect(tf.constant(3)))
Tracing!
tf.Tensor(6, shape=(), dtype=int32)
tf.Tensor(11, shape=(), dtype=int32)
# This retraces each time the Python argument changes,
# as a Python argument could be an epoch count or other
# hyperparameter.
print(a_function_with_python_side_effect(2))
print(a_function_with_python_side_effect(3))
Tracing!
tf.Tensor(6, shape=(), dtype=int32)
Tracing!
tf.Tensor(11, shape=(), dtype=int32)

I nuovi argomenti Python attivano sempre la creazione di un nuovo grafico, quindi la traccia aggiuntiva.

Prossimi passi

È possibile saperne di più su tf.function sulla pagina di riferimento API e seguendo la prestazione migliore con tf.function guida.