Migliori prestazioni con tf.function

Visualizza su TensorFlow.org Esegui in Google Colab Visualizza la fonte su GitHub Scarica il taccuino

In TensorFlow 2, l' esecuzione anticipata è attivata per impostazione predefinita. L'interfaccia utente è intuitiva e flessibile (l'esecuzione di operazioni una tantum è molto più semplice e veloce), ma ciò può andare a scapito delle prestazioni e della distribuzione.

Puoi usare tf.function per creare grafici dai tuoi programmi. È uno strumento di trasformazione che crea grafici di flussi di dati indipendenti da Python dal tuo codice Python. Questo ti aiuterà a creare modelli performanti e portatili ed è necessario usare SavedModel .

Questa guida ti aiuterà a concettualizzare come funziona tf.function sotto il cofano, in modo da poterlo utilizzare in modo efficace.

Le principali indicazioni e raccomandazioni sono:

  • Eseguire il debug in modalità desideroso, quindi decorare con @tf.function .
  • Non fare affidamento sugli effetti collaterali di Python come la mutazione degli oggetti o l'aggiunta di elenchi.
  • tf.function funziona meglio con le operazioni TensorFlow; Le chiamate NumPy e Python vengono convertite in costanti.

Impostare

import tensorflow as tf

Definisci una funzione di supporto per dimostrare i tipi di errori che potresti incontrare:

import traceback
import contextlib

# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
  try:
    yield
  except error_class as e:
    print('Caught expected exception \n  {}:'.format(error_class))
    traceback.print_exc(limit=2)
  except Exception as e:
    raise e
  else:
    raise Exception('Expected {} to be raised but no error was raised!'.format(
        error_class))

Nozioni di base

Utilizzo

Una Function che definisci (ad esempio applicando il decoratore @tf.function ) è proprio come un'operazione TensorFlow di base: puoi eseguirla con entusiasmo; puoi calcolare i gradienti; e così via.

@tf.function  # The decorator converts `add` into a `Function`.
def add(a, b):
  return a + b

add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
       [2., 2.]], dtype=float32)>
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
  result = add(v, 1.0)
tape.gradient(result, v)
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

È possibile utilizzare le Function all'interno di altre Function .

@tf.function
def dense_layer(x, w, b):
  return add(tf.matmul(x, w), b)

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
       [3., 3.],
       [3., 3.]], dtype=float32)>

Function possono essere più veloci del codice desideroso, specialmente per i grafici con molte piccole operazioni. Ma per i grafici con poche operazioni costose (come le convoluzioni), potresti non vedere molta accelerazione.

import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)

@tf.function
def conv_fn(image):
  return conv_layer(image)

image = tf.zeros([1, 200, 200, 100])
# Warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")
Eager conv: 0.23302616399996623
Function conv: 0.21780993200013654
Note how there's not much difference in performance for convolutions

Tracciamento

Questa sezione illustra il Function nascosto di Function , inclusi i dettagli di implementazione che potrebbero cambiare in futuro . Tuttavia, una volta compreso il motivo e il momento in cui viene tf.function traccia, è molto più semplice utilizzare tf.function modo efficace!

Che cos'è "tracciare"?

Una Function esegue il programma in un grafico TensorFlow . Tuttavia, un tf.Graph non può rappresentare tutte le cose che scriveresti in un appassionato programma TensorFlow. Ad esempio, Python supporta il polimorfismo, ma tf.Graph richiede che i suoi input abbiano un tipo di dati e una dimensione specificati. Oppure puoi eseguire attività secondarie come leggere argomenti della riga di comando, sollevare un errore o lavorare con un oggetto Python più complesso; nessuna di queste cose può essere eseguita in un tf.Graph .

Function colma questa lacuna separando il codice in due fasi:

1) Nella prima fase, denominata " tracing ", Function crea un nuovo tf.Graph . Il codice Python viene eseguito normalmente, ma tutte le operazioni di TensorFlow (come l'aggiunta di due Tensor) vengono posticipate : vengono catturate dal tf.Graph e non vengono eseguite.

2) Nella seconda fase viene tf.Graph un tf.Graph che contiene tutto ciò che è stato rinviato nella prima fase. Questa fase è molto più veloce della fase di tracciamento.

A seconda dei suoi input, Function non eseguirà sempre la prima fase quando viene chiamata. Vedere "Regole di tracciamento" di seguito per avere un'idea migliore di come determina tale determinazione. Saltare la prima fase ed eseguire solo la seconda è ciò che offre le elevate prestazioni di TensorFlow.

Quando la Function decide di tracciare, la fase di tracciamento è immediatamente seguita dalla seconda fase, quindi la chiamata alla Function crea ed esegue sia il tf.Graph . Più avanti vedrai come puoi eseguire solo la fase di tracciamento con get_concrete_function .

Quando passiamo argomenti di tipo diverso in una Function , vengono eseguite entrambe le fasi:

@tf.function
def double(a):
  print("Tracing with", a)
  return a + a

print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()
Tracing with Tensor("a:0", shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Tracing with Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)

Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'aa', shape=(), dtype=string)

Nota che se chiami ripetutamente una Function con lo stesso tipo di argomento, TensorFlow salterà la fase di tracciamento e riutilizzerà un grafico tracciato in precedenza, poiché il grafico generato sarebbe identico.

# This doesn't print 'Tracing with ...'
print(double(tf.constant("b")))
tf.Tensor(b'bb', shape=(), dtype=string)

Puoi usare pretty_printed_concrete_signatures() per vedere tutte le tracce disponibili:

print(double.pretty_printed_concrete_signatures())
double(a)
  Args:
    a: int32 Tensor, shape=()
  Returns:
    int32 Tensor, shape=()

double(a)
  Args:
    a: float32 Tensor, shape=()
  Returns:
    float32 Tensor, shape=()

double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

Finora, hai visto che tf.function crea un livello di invio dinamico memorizzato nella cache sulla logica di tracciamento del grafico di TensorFlow. Per essere più precisi sulla terminologia:

  • Un tf.Graph è la rappresentazione grezza, indipendente dal linguaggio e portatile di un calcolo TensorFlow.
  • Una ConcreteFunction avvolge un tf.Graph .
  • Una Function gestisce una cache di ConcreteFunction sceglie quella giusta per i tuoi input.
  • tf.function avvolge una funzione Python, restituendo un oggetto Function .
  • Tracing crea un tf.Graph e lo avvolge in una ConcreteFunction , nota anche come trace.

Regole di tracciamento

Una Function determina se riutilizzare una ConcreteFunction tracciata calcolando una chiave di cache da args e kwargs di un input. Una chiave cache è una chiave che identifica una ConcreteFunction base agli args e kwargs di input della chiamata alla Function , secondo le seguenti regole (che possono cambiare):

  • La chiave generata per un tf.Tensor è la sua forma e il suo dtype.
  • La chiave generata per un tf.Variable è un ID variabile univoco.
  • La chiave generata per una primitiva Python (come int , float , str ) è il suo valore.
  • La chiave generata per dict s, list s, tuple , namedtuple s e attr s nidificati è la tupla appiattita delle chiavi foglia (vederenest.flatten ). (Come risultato di questo appiattimento, la chiamata a una funzione concreta con una struttura di annidamento diversa da quella utilizzata durante la traccia risulterà in un TypeError).
  • Per tutti gli altri tipi Python la chiave è univoca per l'oggetto. In questo modo una funzione o un metodo viene tracciato indipendentemente per ogni istanza con cui viene chiamato.

Controllo del ritracciamento

Il ritracciamento, ovvero quando la Function crea più di una traccia, aiuta a garantire che TensorFlow generi grafici corretti per ogni set di input. Tuttavia, il tracciamento è un'operazione costosa! Se la tua Function ripercorre un nuovo grafico per ogni chiamata, scoprirai che il tuo codice viene eseguito più lentamente che se non tf.function usato tf.function .

Per controllare il comportamento di traccia, è possibile utilizzare le seguenti tecniche:

  • Specifica input_signature in tf.function per limitare la traccia.
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
  print("Tracing with", x)
  return tf.where(x % 2 == 0, x // 2, 3 * x + 1)

print(next_collatz(tf.constant([1, 2])))
# You specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([[1, 2], [3, 4]]))

# You specified an int32 dtype in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([1.0, 2.0]))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([4 1], shape=(2,), dtype=int32)
Caught expected exception 
  <class 'ValueError'>:
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-14ebce7b7ee8>", line 9, in <module>
    next_collatz(tf.constant([[1, 2], [3, 4]]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor(
[[1 2]
 [3 4]], shape=(2, 2), dtype=int32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None))
Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-14ebce7b7ee8>", line 13, in <module>
    next_collatz(tf.constant([1.0, 2.0]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor([1. 2.], shape=(2,), dtype=float32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None))
  • Specificare una dimensione [Nessuno] in tf.TensorSpec per consentire flessibilità nel riutilizzo della traccia.

    Poiché TensorFlow corrisponde ai tensori in base alla loro forma, l'uso di una dimensione None come carattere jolly consentirà a Function s di riutilizzare le tracce per input di dimensioni variabili. L'input di dimensioni variabili può verificarsi se si dispone di sequenze di lunghezza diversa o immagini di dimensioni diverse per ciascun batch (vedere ad esempio i tutorial Transformer e Deep Dream ).

@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def g(x):
  print('Tracing with', x)
  return x

# No retrace!
print(g(tf.constant([1, 2, 3])))
print(g(tf.constant([1, 2, 3, 4, 5])))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([1 2 3], shape=(3,), dtype=int32)
tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)
  • Trasmetti argomenti Python ai tensori per ridurre il ritracciamento.

    Spesso, gli argomenti Python vengono utilizzati per controllare iperparametri e costruzioni di grafici, ad esempio num_layers=10 o training=True o nonlinearity='relu' . Quindi, se l'argomento Python cambia, ha senso che tu debba ripercorrere il grafico.

    Tuttavia, è possibile che un argomento Python non venga utilizzato per controllare la costruzione del grafico. In questi casi, un cambiamento nel valore Python può innescare inutili ritracciamenti. Prendi, ad esempio, questo ciclo di allenamento, che AutoGraph srotolerà dinamicamente. Nonostante le tracce multiple, il grafico generato è in realtà identico, quindi non è necessario ritracciare.

def train_one_step():
  pass

@tf.function
def train(num_steps):
  print("Tracing with num_steps = ", num_steps)
  tf.print("Executing with num_steps = ", num_steps)
  for _ in tf.range(num_steps):
    train_one_step()

print("Retracing occurs for different Python arguments.")
train(num_steps=10)
train(num_steps=20)

print()
print("Traces are reused for Tensor arguments.")
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))
Retracing occurs for different Python arguments.
Tracing with num_steps =  10
Executing with num_steps =  10
Tracing with num_steps =  20
Executing with num_steps =  20

Traces are reused for Tensor arguments.
Tracing with num_steps =  Tensor("num_steps:0", shape=(), dtype=int32)
Executing with num_steps =  10
Executing with num_steps =  20

Se è necessario forzare il ritracciamento, creare una nuova Function . È garantito che gli oggetti Function separati non condividano le tracce.

def f():
  print('Tracing!')
  tf.print('Executing')

tf.function(f)()
tf.function(f)()
Tracing!
Executing
Tracing!
Executing

Ottenere funzioni concrete

Ogni volta che viene tracciata una funzione, viene creata una nuova funzione concreta. Puoi ottenere direttamente una funzione concreta, usando get_concrete_function .

print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.constant("a"))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))
Obtaining concrete trace
Executing traced function
tf.Tensor(b'aa', shape=(), dtype=string)
tf.Tensor(b'bb', shape=(), dtype=string)
# You can also call get_concrete_function on an InputSpec
double_strings_from_inputspec = double.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.string))
print(double_strings_from_inputspec(tf.constant("c")))
Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'cc', shape=(), dtype=string)

La stampa di una ConcreteFunction visualizza un riepilogo dei suoi argomenti di input (con tipi) e del suo tipo di output.

print(double_strings)
ConcreteFunction double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

Puoi anche recuperare direttamente la firma di una funzione concreta.

print(double_strings.structured_input_signature)
print(double_strings.structured_outputs)
((TensorSpec(shape=(), dtype=tf.string, name='a'),), {})
Tensor("Identity:0", shape=(), dtype=string)

L'uso di una traccia concreta con tipi incompatibili genererà un errore

with assert_raises(tf.errors.InvalidArgumentError):
  double_strings(tf.constant(1))
Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>:
Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-e4e2860a4364>", line 2, in <module>
    double_strings(tf.constant(1))
tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_162 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_162]

Potresti notare che gli argomenti Python ricevono un trattamento speciale nella firma di input di una funzione concreta. Prima di TensorFlow 2.3, gli argomenti Python venivano semplicemente rimossi dalla firma della funzione concreta. A partire da TensorFlow 2.3, gli argomenti Python rimangono nella firma, ma sono vincolati a prendere il valore impostato durante la traccia.

@tf.function
def pow(a, b):
  return a ** b

square = pow.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)
print(square)
ConcreteFunction pow(a, b=2)
  Args:
    a: float32 Tensor, shape=<unknown>
  Returns:
    float32 Tensor, shape=<unknown>
assert square(tf.constant(10.0)) == 100

with assert_raises(TypeError):
  square(tf.constant(10.0), b=3)
Caught expected exception 
  <class 'TypeError'>:
Traceback (most recent call last):
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1725, in _call_impl
    cancellation_manager)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1770, in _call_with_flat_signature
    self._flat_signature_summary(), ", ".join(sorted(kwargs))))
TypeError: pow(a) got unexpected keyword arguments: b.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-d163f3d206cb>", line 4, in <module>
    square(tf.constant(10.0), b=3)
TypeError: ConcreteFunction pow(a, b) was constructed with int value 2 in b, but was called with int value 3

Ottenere grafici

Ogni funzione concreta è un wrapper richiamabile attorno a un tf.Graph . Sebbene recuperare l'oggetto tf.Graph effettivo non sia qualcosa che normalmente devi fare, puoi ottenerlo facilmente da qualsiasi funzione concreta.

graph = double_strings.graph
for node in graph.as_graph_def().node:
  print(f'{node.input} -> {node.name}')
[] -> a
['a', 'a'] -> add
['add'] -> Identity

Debug

In generale, il debug del codice è più semplice in modalità desideroso che all'interno di tf.function . Dovresti assicurarti che il tuo codice venga eseguito senza errori in modalità desideroso prima di decorare con tf.function . Per assistere nel processo di debug, puoi chiamare tf.config.run_functions_eagerly(True) per disabilitare e tf.function globalmente tf.function .

Quando si rintracciano problemi che appaiono solo all'interno di tf.function , ecco alcuni suggerimenti:

  • Le vecchie chiamate di print Python vengono eseguite solo durante la traccia, aiutandoti a rintracciare quando la tua funzione viene (ri)tracciata.
  • tf.print chiamate tf.print verranno eseguite ogni volta e possono aiutarti a rintracciare i valori intermedi durante l'esecuzione.
  • tf.debugging.enable_check_numerics è un modo semplice per rintracciare dove vengono creati NaN e Inf.
  • pdb (il debugger Python ) può aiutarti a capire cosa sta succedendo durante la traccia. (Avvertenza: pdb ti lascerà cadere nel codice sorgente trasformato da AutoGraph.)

Trasformazioni di AutoGraph

AutoGraph è una libreria attiva per impostazione predefinita in tf.function e trasforma un sottoinsieme di codice avido di Python in operazioni TensorFlow compatibili con i grafici. Ciò include il flusso di controllo come if , for , while .

Le operazioni di TensorFlow come tf.cond e tf.while_loop continuano a funzionare, ma il flusso di controllo è spesso più facile da scrivere e comprendere se scritto in Python.

# A simple loop

@tf.function
def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

f(tf.random.uniform([5]))
[0.710546374 0.327660799 0.393230557 0.545059443 0.666661739]
[0.611019373 0.316417336 0.374141902 0.496808201 0.582779706]
[0.54484427 0.306263864 0.357609242 0.45960331 0.52468282]
[0.496646136 0.297034383 0.343106419 0.429760844 0.481306106]
[0.459475428 0.288596332 0.330247819 0.405121386 0.44728902]
[0.429656595 0.280842364 0.318743408 0.384322464 0.419668049]
[0.405034214 0.273684502 0.308370233 0.366455346 0.396650732]
[0.384248167 0.267049909 0.298953682 0.350887358 0.377079546]
[0.366391063 0.260877609 0.290354759 0.337162226 0.360168517]
[0.350830942 0.255116194 0.282461286 0.324941576 0.345362455]
[0.337112248 0.2497219 0.275181532 0.313968241 0.332256317]
[0.324896872 0.244657204 0.268439621 0.304042816 0.320546716]
[0.313927948 0.239889801 0.262172282 0.295007944 0.310001194]
[0.304006279 0.235391632 0.256326199 0.286737591 0.300438195]
[0.294974595 0.231138244 0.250856102 0.279129326 0.291713566]
[0.286706954 0.227108166 0.245723218 0.272099048 0.283711195]
[0.279101074 0.223282441 0.240894228 0.265576899 0.276336372]
[0.272072881 0.219644368 0.23634018 0.259504348 0.269510925]
[0.26555258 0.216179058 0.232035935 0.253831863 0.263169676]
[0.259481668 0.212873235 0.227959365 0.24851726 0.257257849]
[0.253810644 0.209715009 0.224091038 0.243524343 0.251728892]
[0.248497337 0.206693679 0.220413819 0.238821834 0.246543139]
[0.243505597 0.20379965 0.216912434 0.2343826 0.241666421]
[0.238804176 0.20102416 0.213573262 0.230182901 0.237069115]
[0.234365895 0.198359385 0.210384145 0.226201892 0.232725471]
[0.230167076 0.195798069 0.207334161 0.222421184 0.228612974]
[0.226186857 0.19333373 0.204413429 0.218824506 0.224711776]
[0.222406894 0.190960392 0.201613098 0.215397388 0.221004337]
[0.218810901 0.188672557 0.198925063 0.212126866 0.217475086]
[0.215384394 0.186465234 0.196342021 0.209001362 0.214110211]
[0.212114424 0.184333771 0.193857312 0.206010461 0.210897282]
<tf.Tensor: shape=(5,), dtype=float32, numpy=
array([0.20898944, 0.18227392, 0.19146483, 0.2031447 , 0.20782518],
      dtype=float32)>

Se sei curioso puoi controllare il codice generato dall'autografo.

print(tf.autograph.to_code(f.python_function))
def tf__f(x):
    with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        do_return = False
        retval_ = ag__.UndefinedReturnValue()

        def get_state():
            return (x,)

        def set_state(vars_):
            nonlocal x
            (x,) = vars_

        def loop_body():
            nonlocal x
            ag__.converted_call(ag__.ld(tf).print, (ag__.ld(x),), None, fscope)
            x = ag__.converted_call(ag__.ld(tf).tanh, (ag__.ld(x),), None, fscope)

        def loop_test():
            return (ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) > 1)
        ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {})
        try:
            do_return = True
            retval_ = ag__.ld(x)
        except:
            do_return = False
            raise
        return fscope.ret(retval_, do_return)

Condizionali

AutoGraph convertirà alcune istruzioni if <condition> nelle chiamate equivalenti tf.cond . Questa sostituzione viene eseguita se <condition> è un tensore. In caso contrario, l'istruzione if viene eseguita come condizionale Python.

Un condizionale Python viene eseguito durante la traccia, quindi al grafico verrà aggiunto esattamente un ramo del condizionale. Senza AutoGraph, questo grafico tracciato non sarebbe in grado di prendere il ramo alternativo se è presente un flusso di controllo dipendente dai dati.

tf.cond traccia e aggiunge entrambi i rami del condizionale al grafico, selezionando dinamicamente un ramo al momento dell'esecuzione. Il tracciamento può avere effetti collaterali indesiderati; controlla gli effetti di traccia di AutoGraph per ulteriori informazioni.

@tf.function
def fizzbuzz(n):
  for i in tf.range(1, n + 1):
    print('Tracing for loop')
    if i % 15 == 0:
      print('Tracing fizzbuzz branch')
      tf.print('fizzbuzz')
    elif i % 3 == 0:
      print('Tracing fizz branch')
      tf.print('fizz')
    elif i % 5 == 0:
      print('Tracing buzz branch')
      tf.print('buzz')
    else:
      print('Tracing default branch')
      tf.print(i)

fizzbuzz(tf.constant(5))
fizzbuzz(tf.constant(20))
Tracing for loop
Tracing fizzbuzz branch
Tracing fizz branch
Tracing buzz branch
Tracing default branch
1
2
fizz
4
buzz
1
2
fizz
4
buzz
fizz
7
8
fizz
buzz
11
fizz
13
14
fizzbuzz
16
17
fizz
19
buzz

Consultare la documentazione di riferimento per ulteriori restrizioni sulle istruzioni if ​​convertite in AutoGraph.

loop

AutoGraph convertirà alcune istruzioni for e while nelle operazioni di looping TensorFlow equivalenti, come tf.while_loop . Se non viene convertito, il ciclo for o while viene eseguito come ciclo Python.

Questa sostituzione viene effettuata nelle seguenti situazioni:

Un ciclo Python viene eseguito durante la traccia, aggiungendo ulteriori operazioni al tf.Graph per ogni iterazione del ciclo.

Un ciclo TensorFlow traccia il corpo del ciclo e seleziona dinamicamente il numero di iterazioni da eseguire in fase di esecuzione. Il corpo del ciclo appare solo una volta nel tf.Graph generato.

Consultare la documentazione di riferimento per ulteriori restrizioni sulle istruzioni for e while convertite in AutoGraph.

Ciclo su dati Python

Un errore comune è quello di eseguire il loop sui dati Python/NumPy all'interno di un tf.function . Questo ciclo verrà eseguito durante il processo di tracciamento, aggiungendo una copia del modello al tf.Graph per ogni iterazione del ciclo.

Se vuoi avvolgere l'intero ciclo di addestramento in tf.function , il modo più sicuro per farlo è avvolgere i tuoi dati come untf.data.Dataset modo che AutoGraph srotoli dinamicamente il ciclo di addestramento.

def measure_graph_size(f, *args):
  g = f.get_concrete_function(*args).graph
  print("{}({}) contains {} nodes in its graph".format(
      f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))

@tf.function
def train(dataset):
  loss = tf.constant(0)
  for x, y in dataset:
    loss += tf.abs(y - x) # Some dummy computation.
  return loss

small_data = [(1, 1)] * 3
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)

measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: big_data, (tf.int32, tf.int32)))
train([(1, 1), (1, 1), (1, 1)]) contains 11 nodes in its graph
train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 10 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 10 nodes in its graph

Quando si avvolgono i dati Python/NumPy in un tf.data.Dataset.from_generator di dati, fare tf.data.Dataset.from_generator rispetto a tf.data.Dataset.from_tensors . Il primo manterrà i dati in Python e li tf.py_function tramite tf.py_function che può avere implicazioni sulle prestazioni, mentre il secondo raggrupperà una copia dei dati come un grande nodo tf.constant() nel grafico, che può avere implicazioni sulla memoria.

La lettura dei dati dai file tramite TFRecordDataset , CsvDataset , ecc. è il modo più efficace per consumare i dati, poiché TensorFlow stesso può gestire il caricamento asincrono e il precaricamento dei dati, senza dover coinvolgere Python. Per ulteriori informazioni, vedere la guida alle pipeline di input tf.data : Build TensorFlow .

Accumulare valori in un ciclo

Un modello comune consiste nell'accumulare valori intermedi da un ciclo. Normalmente, ciò si ottiene aggiungendo a un elenco Python o aggiungendo voci a un dizionario Python. Tuttavia, poiché si tratta di effetti collaterali di Python, non funzioneranno come previsto in un ciclo srotolato dinamicamente. Utilizzare tf.TensorArray per accumulare risultati da un ciclo svolto dinamicamente.

batch_size = 2
seq_len = 3
feature_size = 4

def rnn_step(inp, state):
  return inp + state

@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
  # [batch, time, features] -> [time, batch, features]
  input_data = tf.transpose(input_data, [1, 0, 2])
  max_seq_len = input_data.shape[0]

  states = tf.TensorArray(tf.float32, size=max_seq_len)
  state = initial_state
  for i in tf.range(max_seq_len):
    state = rnn_step(input_data[i], state)
    states = states.write(i, state)
  return tf.transpose(states.stack(), [1, 0, 2])

dynamic_rnn(rnn_step,
            tf.random.uniform([batch_size, seq_len, feature_size]),
            tf.zeros([batch_size, feature_size]))
<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[0.60458577, 0.3308612 , 0.7878152 , 0.3223114 ],
        [0.9110272 , 1.0819752 , 1.7657743 , 1.2409766 ],
        [1.7235098 , 1.5416101 , 2.2929285 , 1.9181627 ]],

       [[0.89487076, 0.22811687, 0.342862  , 0.5752872 ],
        [1.0133923 , 0.28650808, 0.9558767 , 1.0829899 ],
        [1.9280962 , 1.1437279 , 0.9857702 , 1.4834155 ]]], dtype=float32)>

Limitazioni

La Function TensorFlow ha alcune limitazioni di progettazione di cui dovresti essere a conoscenza quando converti una funzione Python in una Function .

Esecuzione degli effetti collaterali di Python

Gli effetti collaterali, come la stampa, l'aggiunta agli elenchi e la modifica dei globali, possono comportarsi in modo imprevisto all'interno di una Function , a volte eseguendo due volte o non tutti. Si verificano solo la prima volta che chiami una Function con un insieme di input. Successivamente, il tf.Graph tracciato viene rieseguito, senza eseguire il codice Python.

La regola generale è evitare di fare affidamento sugli effetti collaterali di Python nella logica e usarli solo per eseguire il debug delle tracce. Altrimenti, le API di TensorFlow come tf.data , tf.print , tf.summary , tf.Variable.assign e tf.TensorArray sono il modo migliore per garantire che il codice venga eseguito dal runtime di TensorFlow con ogni chiamata.

@tf.function
def f(x):
  print("Traced with", x)
  tf.print("Executed with", x)

f(1)
f(1)
f(2)
Traced with 1
Executed with 1
Executed with 1
Traced with 2
Executed with 2

Se desideri eseguire il codice Python durante ogni tf.py_function di una Function , tf.py_function è un tratteggio di uscita. Lo svantaggio di tf.py_function è che non è portatile o particolarmente performante, non può essere salvato con SavedModel e non funziona bene nelle configurazioni distribuite (multi-GPU, TPU). Inoltre, poiché tf.py_function deve essere collegato al grafico, tf.py_function tutti gli input/output ai tensori.

Modifica delle variabili globali e libere di Python

La modifica delle variabili globali e gratuite di Python conta come un effetto collaterale di Python, quindi avviene solo durante la traccia.

external_list = []

@tf.function
def side_effect(x):
  print('Python side effect')
  external_list.append(x)

side_effect(1)
side_effect(1)
side_effect(1)
# The list append only happened once!
assert len(external_list) == 1
Python side effect

Dovresti evitare di modificare contenitori come elenchi, dicts, altri oggetti che vivono al di fuori di Function . Utilizzare invece argomenti e oggetti TF. Ad esempio, la sezione "Accumulare valori in un ciclo" ha un esempio di come possono essere implementate operazioni di tipo elenco.

È possibile, in alcuni casi, acquisire e manipolare lo stato se è un tf.Variable . Ecco come vengono aggiornati i pesi dei modelli Keras con ripetute chiamate alla stessa ConcreteFunction .

Utilizzo di iteratori e generatori Python

Molte funzionalità di Python, come generatori e iteratori, si basano sul runtime di Python per tenere traccia dello stato. In generale, sebbene questi costrutti funzionino come previsto in modalità desideroso, sono esempi di effetti collaterali di Python e quindi si verificano solo durante la traccia.

@tf.function
def buggy_consume_next(iterator):
  tf.print("Value:", next(iterator))

iterator = iter([1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)
Value: 1
Value: 1
Value: 1

Proprio come TensorFlow ha un tf.TensorArray specializzato per i costrutti di elenco, ha un tf.data.Iterator specializzato per i costrutti di iterazione. Vedere la sezione sulle trasformazioni di AutoGraph per una panoramica. Inoltre, l'API tf.data può aiutare a implementare modelli di generatori:

@tf.function
def good_consume_next(iterator):
  # This is ok, iterator is a tf.data.Iterator
  tf.print("Value:", next(iterator))

ds = tf.data.Dataset.from_tensor_slices([1, 2, 3])
iterator = iter(ds)
good_consume_next(iterator)
good_consume_next(iterator)
good_consume_next(iterator)
Value: 1
Value: 2
Value: 3

Eliminazione di tf.Variables tra chiamate di Function

Un altro errore che potresti riscontrare è una variabile raccolta nella spazzatura. ConcreteFunction conservano i WeakRefs solo per le variabili su cui si chiudono, quindi è necessario mantenere un riferimento a qualsiasi variabile.

external_var = tf.Variable(3)
@tf.function
def f(x):
  return x * external_var

traced_f = f.get_concrete_function(4)
print("Calling concrete function...")
print(traced_f(4))

# The original variable object gets garbage collected, since there are no more
# references to it.
external_var = tf.Variable(4)
print()
print("Calling concrete function after garbage collecting its closed Variable...")
with assert_raises(tf.errors.FailedPreconditionError):
  traced_f(4)
Calling concrete function...
tf.Tensor(12, shape=(), dtype=int32)

Calling concrete function after garbage collecting its closed Variable...
Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.FailedPreconditionError'>:
Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-9a93d2e07632>", line 16, in <module>
    traced_f(4)
tensorflow.python.framework.errors_impl.FailedPreconditionError:  Could not find variable _AnonymousVar3. This could mean that the variable has been deleted. In TF1, it can also mean the variable is uninitialized. Debug info: container=localhost, status=Not found: Resource localhost/_AnonymousVar3/N10tensorflow3VarE does not exist.
     [[node ReadVariableOp (defined at <ipython-input-1-9a93d2e07632>:4) ]] [Op:__inference_f_782]

Function call stack:
f

problemi conosciuti

Se la tua Function non viene valutata correttamente, l'errore potrebbe essere spiegato da questi problemi noti che dovrebbero essere risolti in futuro.

A seconda delle variabili globali e libere di Python

Function crea una nuova ConcreteFunction quando viene chiamata con un nuovo valore di un argomento Python. Tuttavia, non lo fa per la chiusura Python, i globali o i non locali di quella Function . Se il loro valore cambia tra le chiamate alla Function , la Function utilizzerà ancora i valori che aveva quando è stata tracciata. Questo è diverso da come funzionano le normali funzioni Python.

Per questo motivo, raccomandiamo uno stile di programmazione funzionale che utilizzi argomenti invece di chiudere su nomi esterni.

@tf.function
def buggy_add():
  return 1 + foo

@tf.function
def recommended_add(foo):
  return 1 + foo

foo = 1
print("Buggy:", buggy_add())
print("Correct:", recommended_add(foo))
Buggy: tf.Tensor(2, shape=(), dtype=int32)
Correct: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo = 100
print("Buggy:", buggy_add())  # Did not change!
print("Correct:", recommended_add(foo))
Updating the value of `foo` to 100!
Buggy: tf.Tensor(2, shape=(), dtype=int32)
Correct: tf.Tensor(101, shape=(), dtype=int32)

Puoi chiudere sui nomi esterni, purché non aggiorni i loro valori.

A seconda degli oggetti Python

La raccomandazione di passare oggetti Python come argomenti in tf.function ha una serie di problemi noti, che dovrebbero essere risolti in futuro. In generale, puoi fare affidamento su una traccia coerente se usi una primitiva Python o una struttura tf.nest tf.nest come argomento o passi un'istanza diversa di un oggetto in una Function . Tuttavia, Function non creerà una nuova traccia quando passerai lo stesso oggetto e ne modificherà solo gli attributi .

class SimpleModel(tf.Module):
  def __init__(self):
    # These values are *not* tf.Variables.
    self.bias = 0.
    self.weight = 2.

@tf.function
def evaluate(model, x):
  return model.weight * x + model.bias

simple_model = SimpleModel()
x = tf.constant(10.)
print(evaluate(simple_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
simple_model.bias += 5.0
print(evaluate(simple_model, x))  # Didn't change :(
Adding bias!
tf.Tensor(20.0, shape=(), dtype=float32)

L'utilizzo della stessa Function per valutare l'istanza aggiornata del modello sarà difettoso poiché il modello aggiornato ha la stessa chiave di cache del modello originale.

Per questo motivo, ti consigliamo di scrivere la tua Function per evitare di dipendere dagli attributi degli oggetti modificabili o creare nuovi oggetti.

Se ciò non è possibile, una soluzione alternativa consiste nel creare nuove Function ogni volta che si modifica l'oggetto per forzare il ritracciamento:

def evaluate(model, x):
  return model.weight * x + model.bias

new_model = SimpleModel()
evaluate_no_bias = tf.function(evaluate).get_concrete_function(new_model, x)
# Don't pass in `new_model`, `Function` already captured its state during tracing.
print(evaluate_no_bias(x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
new_model.bias += 5.0
# Create new Function and ConcreteFunction since you modified new_model.
evaluate_with_bias = tf.function(evaluate).get_concrete_function(new_model, x)
print(evaluate_with_bias(x)) # Don't pass in `new_model`.
Adding bias!
tf.Tensor(25.0, shape=(), dtype=float32)

Poiché il retrace può essere costoso , puoi usare tf.Variable s come attributi dell'oggetto, che possono essere mutati (ma non modificati, attenzione!) per un effetto simile senza bisogno di un retrace.

class BetterModel:

  def __init__(self):
    self.bias = tf.Variable(0.)
    self.weight = tf.Variable(2.)

@tf.function
def evaluate(model, x):
  return model.weight * x + model.bias

better_model = BetterModel()
print(evaluate(better_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
better_model.bias.assign_add(5.0)  # Note: instead of better_model.bias += 5
print(evaluate(better_model, x))  # This works!
Adding bias!
tf.Tensor(25.0, shape=(), dtype=float32)

Creazione di tf.Variables

Function supporta solo la creazione di variabili una volta, quando viene chiamata per la prima volta, e quindi riutilizzandole. Non è possibile creare tf.Variables in nuove tracce. La creazione di nuove variabili nelle chiamate successive non è attualmente consentita, ma lo sarà in futuro.

Esempio:

@tf.function
def f(x):
  v = tf.Variable(1.0)
  return v

with assert_raises(ValueError):
  f(1.0)
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-8a0913e250e0>", line 7, in <module>
    f(1.0)
ValueError: in user code:

    <ipython-input-1-8a0913e250e0>:3 f  *
        v = tf.Variable(1.0)
    /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/variables.py:262 __call__  **
        return cls._variable_v2_call(*args, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/variables.py:256 _variable_v2_call
        shape=shape)
    /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/variables.py:67 getter
        return captured_getter(captured_previous, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py:769 invalid_creator_scope
        "tf.function-decorated function tried to create "

    ValueError: tf.function-decorated function tried to create variables on non-first call.

È possibile creare variabili all'interno di una Function purché tali variabili vengano create solo la prima volta che la funzione viene eseguita.

class Count(tf.Module):
  def __init__(self):
    self.count = None

  @tf.function
  def __call__(self):
    if self.count is None:
      self.count = tf.Variable(0)
    return self.count.assign_add(1)

c = Count()
print(c())
print(c())
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Utilizzo con più ottimizzatori Keras

Potresti incontrare ValueError: tf.function-decorated function tried to create variables on non-first call. quando si utilizza più di un ottimizzatore Keras con tf.function . Questo errore si verifica perché gli ottimizzatori creano internamente tf.Variables quando applicano i gradienti per la prima volta.

opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)

@tf.function
def train_step(w, x, y, optimizer):
   with tf.GradientTape() as tape:
       L = tf.reduce_sum(tf.square(w*x - y))
   gradients = tape.gradient(L, [w])
   optimizer.apply_gradients(zip(gradients, [w]))

w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])

train_step(w, x, y, opt1)
print("Calling `train_step` with different optimizer...")
with assert_raises(ValueError):
  train_step(w, x, y, opt2)
Calling `train_step` with different optimizer...
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-d3d3937dbf1a>", line 18, in <module>
    train_step(w, x, y, opt2)
ValueError: in user code:

    <ipython-input-1-d3d3937dbf1a>:9 train_step  *
        optimizer.apply_gradients(zip(gradients, [w]))
    /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:636 apply_gradients  **
        self._create_all_weights(var_list)
    /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:821 _create_all_weights
        _ = self.iterations
    /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:828 __getattribute__
        return super(OptimizerV2, self).__getattribute__(name)
    /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:988 iterations
        aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA)
    /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:1194 add_weight
        aggregation=aggregation)
    /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/tracking/base.py:815 _add_variable_with_custom_getter
        **kwargs_for_getter)
    /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer_utils.py:139 make_variable
        shape=variable_shape if variable_shape else None)
    /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/variables.py:260 __call__
        return cls._variable_v1_call(*args, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/variables.py:221 _variable_v1_call
        shape=shape)
    /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/variables.py:67 getter
        return captured_getter(captured_previous, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py:769 invalid_creator_scope
        "tf.function-decorated function tried to create "

    ValueError: tf.function-decorated function tried to create variables on non-first call.

Se è necessario modificare l'ottimizzatore durante l'addestramento, una soluzione alternativa consiste nel creare una nuova Function per ogni ottimizzatore, chiamando direttamente ConcreteFunction .

opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)

# Not a tf.function.
def train_step(w, x, y, optimizer):
   with tf.GradientTape() as tape:
       L = tf.reduce_sum(tf.square(w*x - y))
   gradients = tape.gradient(L, [w])
   optimizer.apply_gradients(zip(gradients, [w]))

w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])

# Make a new Function and ConcreteFunction for each optimizer.
train_step_1 = tf.function(train_step).get_concrete_function(w, x, y, opt1)
train_step_2 = tf.function(train_step).get_concrete_function(w, x, y, opt2)
for i in range(10):
  if i % 2 == 0:
    train_step_1(w, x, y) # `opt1` is not used as a parameter. 
  else:
    train_step_2(w, x, y) # `opt2` is not used as a parameter.

Utilizzo con più modelli Keras

Potresti anche incontrare ValueError: tf.function-decorated function tried to create variables on non-first call. quando si passano istanze di modello diverse alla stessa Function .

Questo errore si verifica perché i modelli Keras (che non hanno la forma di input definita ) e i livelli Keras creano tf.Variables quando vengono chiamati per la prima volta. Potresti tentare di inizializzare quelle variabili all'interno di una Function , che è già stata chiamata. Per evitare questo errore, prova a chiamare model.build(input_shape) per inizializzare tutti i pesi prima di addestrare il modello.

Ulteriori letture

Per informazioni su come esportare e caricare una Function , vedere la guida SavedModel . Per ulteriori informazioni sulle ottimizzazioni del grafico eseguite dopo la traccia, consultare la guida Grappler . Per sapere come ottimizzare la tua pipeline di dati e profilare il tuo modello, consulta la guida di Profiler .