Questa pagina è stata tradotta dall'API Cloud Translation.
Switch to English

Migliori prestazioni con la funzione tf

Visualizza su TensorFlow.org Esegui in Google Colab Visualizza sorgente su GitHub Scarica notebook

In TensorFlow 2, l'esecuzione desiderosa è attivata per impostazione predefinita. L'interfaccia utente è intuitiva e flessibile (l'esecuzione di operazioni una tantum è molto più semplice e veloce), ma ciò può andare a scapito delle prestazioni e della distribuzione.

Puoi usare tf.function per creare grafici dai tuoi programmi. È uno strumento di trasformazione che crea grafici di flussi di dati indipendenti da Python dal codice Python. Questo ti aiuterà a creare modelli performanti e portatili, ed è necessario utilizzare SavedModel .

Questa guida ti aiuterà a concettualizzare come funziona tf.function sotto il cofano in modo da poterlo utilizzare in modo efficace.

I principali suggerimenti e consigli sono:

  • Eseguire il debug in modalità eager, quindi decorare con la funzione @tf.function ..
  • Non fare affidamento sugli effetti collaterali di Python come la mutazione di oggetti o l'aggiunta di elenchi.
  • tf.function funziona meglio con le operazioni TensorFlow; Le chiamate NumPy e Python vengono convertite in costanti.

Impostare

import tensorflow as tf

Definisci una funzione di supporto per dimostrare i tipi di errori che potresti incontrare:

import traceback
import contextlib

# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
  try:
    yield
  except error_class as e:
    print('Caught expected exception \n  {}:'.format(error_class))
    traceback.print_exc(limit=2)
  except Exception as e:
    raise e
  else:
    raise Exception('Expected {} to be raised but no error was raised!'.format(
        error_class))

Nozioni di base

Utilizzo

Una Function che definisci è proprio come un'operazione di base di TensorFlow: puoi eseguirla con entusiasmo; puoi calcolare gradienti; e così via.

@tf.function
def add(a, b):
  return a + b

add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
       [2., 2.]], dtype=float32)>
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
  result = add(v, 1.0)
tape.gradient(result, v)
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

È possibile utilizzare le Function all'interno di altre Function .

@tf.function
def dense_layer(x, w, b):
  return add(tf.matmul(x, w), b)

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
       [3., 3.],
       [3., 3.]], dtype=float32)>

Function possono essere più veloci del codice desideroso, specialmente per i grafici con molte piccole operazioni. Ma per i grafici con poche operazioni costose (come le convoluzioni), potresti non vedere molta velocità.

import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)

@tf.function
def conv_fn(image):
  return conv_layer(image)

image = tf.zeros([1, 200, 200, 100])
# warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")

Eager conv: 0.002407395000091128
Function conv: 0.004000883000117028
Note how there's not much difference in performance for convolutions

Tracciamento

La digitazione dinamica di Python significa che puoi chiamare funzioni con una varietà di tipi di argomenti e Python può fare qualcosa di diverso in ogni scenario.

Tuttavia, per creare un grafico TensorFlow, sono necessari dtypes statici e dimensioni della forma. tf.function colma questa lacuna avvolgendo una funzione Python per creare un oggetto Function . In base agli input forniti, la Function seleziona il grafico appropriato per gli input dati, ripercorrendo la funzione Python secondo necessità. Una volta capito perché e quando viene tf.function tracciamento, è molto più facile usare tf.function modo efficace!

Puoi chiamare una Function con argomenti di diversi tipi per vedere questo comportamento polimorfico in azione.

@tf.function
def double(a):
  print("Tracing with", a)
  return a + a

print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()

Tracing with Tensor("a:0", shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Tracing with Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)

Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'aa', shape=(), dtype=string)


Nota che se chiami ripetutamente una Function con lo stesso tipo di argomento, TensorFlow riutilizzerà un grafico tracciato in precedenza, poiché il grafico generato sarebbe identico.

# This doesn't print 'Tracing with ...'
print(double(tf.constant("b")))
tf.Tensor(b'bb', shape=(), dtype=string)

Puoi usare pretty_printed_concrete_signatures() per vedere tutte le tracce disponibili:

print(double.pretty_printed_concrete_signatures())
double(a)
  Args:
    a: int32 Tensor, shape=()
  Returns:
    int32 Tensor, shape=()

double(a)
  Args:
    a: float32 Tensor, shape=()
  Returns:
    float32 Tensor, shape=()

double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

Finora, hai visto che tf.function crea un livello di invio dinamico memorizzato nella cache sulla logica di tracciamento del grafico di TensorFlow. Per essere più specifici sulla terminologia:

  • Un tf.Graph è la rappresentazione grezza, indipendente dal linguaggio e portabile del tuo calcolo.
  • Una ConcreteFunction è un wrapper che esegue con impazienza un tf.Graph .
  • Una Function gestisce una cache di ConcreteFunction e sceglie quella giusta per i tuoi input.
  • tf.function avvolge una funzione Python, restituendo un oggetto Function .

Ottenere funzioni concrete

Ogni volta che viene tracciata una funzione, viene creata una nuova funzione concreta. È possibile ottenere direttamente una funzione concreta, utilizzando get_concrete_function .

print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.constant("a"))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))

Obtaining concrete trace
Executing traced function
tf.Tensor(b'aa', shape=(), dtype=string)
tf.Tensor(b'bb', shape=(), dtype=string)

# You can also call get_concrete_function on an InputSpec
double_strings_from_inputspec = double.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.string))
print(double_strings_from_inputspec(tf.constant("c")))
Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'cc', shape=(), dtype=string)

(La seguente modifica è disponibile in TensorFlow ogni notte e sarà disponibile in TensorFlow 2.3.)

La stampa di una ConcreteFunction visualizza un riepilogo dei suoi argomenti di input (con i tipi) e il suo tipo di output.

print(double_strings)
ConcreteFunction double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

Puoi anche recuperare direttamente la firma di una funzione concreta.

print(double_strings.structured_input_signature)
print(double_strings.structured_outputs)
((TensorSpec(shape=(), dtype=tf.string, name='a'),), {})
Tensor("Identity:0", shape=(), dtype=string)

L'utilizzo di una traccia concreta con tipi incompatibili genererà un errore

with assert_raises(tf.errors.InvalidArgumentError):
  double_strings(tf.constant(1))
Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>:

Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-e4e2860a4364>", line 2, in <module>
    double_strings(tf.constant(1))
tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_168 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_168]

Si può notare che gli argomenti Python ricevono un trattamento speciale nella firma di input di una funzione concreta. Prima di TensorFlow 2.3, gli argomenti Python venivano semplicemente rimossi dalla firma della funzione concreta. A partire da TensorFlow 2.3, gli argomenti Python rimangono nella firma, ma sono vincolati a prendere il valore impostato durante la traccia.

@tf.function
def pow(a, b):
  return a ** b

square = pow.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)
print(square)
ConcreteFunction pow(a, b=2)
  Args:
    a: float32 Tensor, shape=<unknown>
  Returns:
    float32 Tensor, shape=<unknown>

assert square(tf.constant(10.0)) == 100

with assert_raises(TypeError):
  square(tf.constant(10.0), b=3)
Caught expected exception 
  <class 'TypeError'>:

Traceback (most recent call last):
  File "/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1669, in _call_impl
    cancellation_manager)
  File "/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1714, in _call_with_flat_signature
    self._flat_signature_summary(), ", ".join(sorted(kwargs))))
TypeError: pow(a) got unexpected keyword arguments: b.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-d163f3d206cb>", line 4, in <module>
    square(tf.constant(10.0), b=3)
TypeError: ConcreteFunction pow(a, b) was constructed with int value 2 in b, but was called with int value 3

Ottenere grafici

Ogni funzione concreta è un involucro richiamabile attorno a un tf.Graph . Sebbene il recupero dell'oggetto tf.Graph effettivo non sia qualcosa che normalmente devi fare, puoi ottenerlo facilmente da qualsiasi funzione concreta.

graph = double_strings.graph
for node in graph.as_graph_def().node:
  print(f'{node.input} -> {node.name}')

[] -> a
['a', 'a'] -> add
['add'] -> Identity

Debug

In generale, il codice di debug è più facile in modalità eager che all'interno di tf.function . Dovresti assicurarti che il tuo codice venga eseguito senza errori in modalità eager prima di decorare con tf.function . Per assistere nel processo di debug, puoi chiamare tf.config.run_functions_eagerly(True) per disabilitare e tf.function globalmente tf.function .

Quando si rintracciano problemi che compaiono solo all'interno di tf.function , ecco alcuni suggerimenti:

  • Le semplici vecchie chiamate di print Python vengono eseguite solo durante la traccia, aiutandoti a rintracciare quando la tua funzione viene (ri) tracciata.
  • tf.print chiamate tf.print verranno eseguite ogni volta e possono aiutarti a rintracciare i valori intermedi durante l'esecuzione.
  • tf.debugging.enable_check_numerics è un modo semplice per rintracciare dove vengono creati NaN e Inf.
  • pdb può aiutarti a capire cosa sta succedendo durante la traccia. (Avvertenza: PDB ti porterà nel codice sorgente trasformato da AutoGraph.)

Tracciare la semantica

Regole chiave della cache

Una Function determina se riutilizzare una funzione concreta tracciata calcolando una chiave di cache da args e kwargs di un input.

  • La chiave generata per un argomento tf.Tensor è la sua forma e dtype.
  • A partire da TensorFlow 2.3, la chiave generata per un argomento tf.Variable è il suo id() .
  • La chiave generata per una primitiva Python è il suo valore. La chiave generata per dict nidificati, list s, tuple s, namedtuple s e attr s è la tupla appiattita. (Come risultato di questo appiattimento, la chiamata a una funzione concreta con una struttura di annidamento diversa da quella utilizzata durante la traccia risulterà in un TypeError).
  • Per tutti gli altri tipi di Python, le chiavi sono basate sull'oggetto id() modo che i metodi vengano tracciati indipendentemente per ogni istanza di una classe.

Controllo del tracciamento

Il tracciamento aiuta a garantire che TensorFlow generi grafici corretti per ogni set di input. Tuttavia, il tracciamento è un'operazione costosa! Se la tua Function traccia un nuovo grafico per ogni chiamata, scoprirai che il tuo codice viene eseguito più lentamente rispetto a se non tf.function usato tf.function .

Per controllare il comportamento di tracciamento, puoi utilizzare le seguenti tecniche:

  • Specificare input_signature in tf.function per limitare la traccia.
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
  print("Tracing with", x)
  return tf.where(x % 2 == 0, x // 2, 3 * x + 1)

print(next_collatz(tf.constant([1, 2])))
# We specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([[1, 2], [3, 4]]))

# We specified an int32 dtype in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([1.0, 2.0]))

Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([4 1], shape=(2,), dtype=int32)
Caught expected exception 
  <class 'ValueError'>:
Caught expected exception 
  <class 'ValueError'>:

Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-20f544b8adbf>", line 9, in <module>
    next_collatz(tf.constant([[1, 2], [3, 4]]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor(
[[1 2]
 [3 4]], shape=(2, 2), dtype=int32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None))
Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-20f544b8adbf>", line 13, in <module>
    next_collatz(tf.constant([1.0, 2.0]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor([1. 2.], shape=(2,), dtype=float32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None))

  • Specificare una dimensione [Nessuno] in tf.TensorSpec per consentire flessibilità nel riutilizzo delle tracce.

    Poiché TensorFlow corrisponde ai tensori in base alla loro forma, l'utilizzo di una dimensione None come carattere jolly consentirà a Function s di riutilizzare le tracce per input di dimensioni variabili. L'input di dimensioni variabili può verificarsi se si dispone di sequenze di lunghezza diversa o immagini di dimensioni diverse per ciascun batch (vedere i tutorial su Transformer e Deep Dream ad esempio).

@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def g(x):
  print('Tracing with', x)
  return x

# No retrace!
print(g(tf.constant([1, 2, 3])))
print(g(tf.constant([1, 2, 3, 4, 5])))

Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([1 2 3], shape=(3,), dtype=int32)
tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)

  • Trasmetti argomenti Python a Tensors per ridurre il ritracciamento.

    Spesso, gli argomenti Python vengono utilizzati per controllare gli iperparametri e le costruzioni di grafici, ad esempio num_layers=10 o training=True o nonlinearity='relu' . Quindi, se l'argomento Python cambia, ha senso che tu debba ripercorrere il grafico.

    Tuttavia, è possibile che un argomento Python non venga utilizzato per controllare la costruzione del grafico. In questi casi, una modifica del valore di Python può attivare un inutile ritracciamento. Prendi, ad esempio, questo ciclo di addestramento, che AutoGraph svolgerà dinamicamente. Nonostante le tracce multiple, il grafico generato è in realtà identico, quindi non è necessario risalire.

def train_one_step():
  pass

@tf.function
def train(num_steps):
  print("Tracing with num_steps = ", num_steps)
  tf.print("Executing with num_steps = ", num_steps)
  for _ in tf.range(num_steps):
    train_one_step()

print("Retracing occurs for different Python arguments.")
train(num_steps=10)
train(num_steps=20)

print()
print("Traces are reused for Tensor arguments.")
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))
Retracing occurs for different Python arguments.
Tracing with num_steps =  10
Executing with num_steps =  10
Tracing with num_steps =  20
Executing with num_steps =  20

Traces are reused for Tensor arguments.
Tracing with num_steps =  Tensor("num_steps:0", shape=(), dtype=int32)
Executing with num_steps =  10
Executing with num_steps =  20

Se è necessario forzare il ritracciamento, creare una nuova Function . È garantito che gli oggetti Function separati non condividano le tracce.

def f():
  print('Tracing!')
  tf.print('Executing')

tf.function(f)()
tf.function(f)()
Tracing!
Executing
Tracing!
Executing

Effetti collaterali di Python

Gli effetti collaterali di Python come la stampa, l'aggiunta a elenchi e la modifica di valori globali si verificano solo la prima volta che chiami una Function con un insieme di input. Successivamente, il tf.Graph tracciato viene rieseguito, senza eseguire il codice Python.

La regola generale è usare solo gli effetti collaterali di Python per eseguire il debug delle tracce. In caso contrario, le operazioni di TensorFlow come tf.Variable.assign , tf.print e tf.summary sono il modo migliore per garantire che il codice venga tracciato ed eseguito dal runtime TensorFlow a ogni chiamata.

@tf.function
def f(x):
  print("Traced with", x)
  tf.print("Executed with", x)

f(1)
f(1)
f(2)

Traced with 1
Executed with 1
Executed with 1
Traced with 2
Executed with 2

Molte funzionalità di Python, come i generatori e gli iteratori, si basano sul runtime di Python per tenere traccia dello stato. In generale, mentre questi costrutti funzionano come previsto in modalità eager, molte cose inaspettate possono accadere all'interno di una Function .

Per fare un esempio, l'avanzamento dello stato dell'iteratore è un effetto collaterale di Python e quindi si verifica solo durante la traccia.

external_var = tf.Variable(0)
@tf.function
def buggy_consume_next(iterator):
  external_var.assign_add(next(iterator))
  tf.print("Value of external_var:", external_var)

iterator = iter([0, 1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)

Value of external_var: 0
Value of external_var: 0
Value of external_var: 0

Alcuni costrutti di iterazione sono supportati tramite AutoGraph. Vedere la sezione sulle trasformazioni AutoGraph per una panoramica.

Se desideri eseguire codice Python durante ogni invocazione di una Function , tf.py_function è un tratteggio di uscita. Lo svantaggio di tf.py_function è che non è portatile o particolarmente performante, né funziona bene in configurazioni distribuite (multi-GPU, TPU). Inoltre, poiché tf.py_function deve essere collegato al grafico, tf.py_function il cast di tutti gli input / output ai tensori.

API come tf.gather , tf.stack e tf.TensorArray possono aiutarti a implementare modelli di loop comuni in TensorFlow nativo.

external_list = []

def side_effect(x):
  print('Python side effect')
  external_list.append(x)

@tf.function
def f(x):
  tf.py_function(side_effect, inp=[x], Tout=[])

f(1)
f(1)
f(1)
# The list append happens all three times!
assert len(external_list) == 3
# The list contains tf.constant(1), not 1, because py_function casts everything to tensors.
assert external_list[0].numpy() == 1

Python side effect
Python side effect
Python side effect

Variabili

È possibile che si tf.Variable un errore durante la creazione di una nuova tf.Variable in una funzione. Questo errore protegge dalla divergenza del comportamento su chiamate ripetute: in modalità eager, una funzione crea una nuova variabile con ogni chiamata, ma in una Function , una nuova variabile potrebbe non essere creata a causa del riutilizzo della traccia.

@tf.function
def f(x):
  v = tf.Variable(1.0)
  v.assign_add(x)
  return v

with assert_raises(ValueError):
  f(1.0)
Caught expected exception 
  <class 'ValueError'>:

Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-73e410646579>", line 8, in <module>
    f(1.0)
ValueError: in user code:

    <ipython-input-1-73e410646579>:3 f  *
        v = tf.Variable(1.0)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:262 __call__  **
        return cls._variable_v2_call(*args, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:256 _variable_v2_call
        shape=shape)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:67 getter
        return captured_getter(captured_previous, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py:702 invalid_creator_scope
        "tf.function-decorated function tried to create "

    ValueError: tf.function-decorated function tried to create variables on non-first call.


È possibile creare variabili all'interno di una Function purché tali variabili vengano create solo la prima volta che la funzione viene eseguita.

class Count(tf.Module):
  def __init__(self):
    self.count = None

  @tf.function
  def __call__(self):
    if self.count is None:
      self.count = tf.Variable(0)
    return self.count.assign_add(1)

c = Count()
print(c())
print(c())
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Un altro errore che potresti riscontrare è una variabile raccolta dai rifiuti. A differenza delle normali funzioni Python, le funzioni concrete mantengono i WeakRef solo alle variabili su cui si chiudono, quindi è necessario mantenere un riferimento a qualsiasi variabile.

external_var = tf.Variable(3)
@tf.function
def f(x):
  return x * external_var

traced_f = f.get_concrete_function(4)
print("Calling concrete function...")
print(traced_f(4))

del external_var
print()
print("Calling concrete function after garbage collecting its closed Variable...")
with assert_raises(tf.errors.FailedPreconditionError):
  traced_f(4)
Calling concrete function...
tf.Tensor(12, shape=(), dtype=int32)

Calling concrete function after garbage collecting its closed Variable...
Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.FailedPreconditionError'>:

Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-304a18524b57>", line 14, in <module>
    traced_f(4)
tensorflow.python.framework.errors_impl.FailedPreconditionError: 2 root error(s) found.
  (0) Failed precondition:  Error while reading resource variable _AnonymousVar4 from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/_AnonymousVar4/N10tensorflow3VarE does not exist.
     [[node ReadVariableOp (defined at <ipython-input-1-304a18524b57>:4) ]]
  (1) Failed precondition:  Error while reading resource variable _AnonymousVar4 from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/_AnonymousVar4/N10tensorflow3VarE does not exist.
     [[node ReadVariableOp (defined at <ipython-input-1-304a18524b57>:4) ]]
     [[ReadVariableOp/_2]]
0 successful operations.
0 derived errors ignored. [Op:__inference_f_514]

Function call stack:
f -> f


Trasformazioni AutoGraph

AutoGraph è una libreria attiva per impostazione predefinita in tf.function e trasforma un sottoinsieme di codice desideroso Python in operazioni TensorFlow compatibili con i grafici. Ciò include il flusso di controllo come if , for , while .

Le operazioni di TensorFlow come tf.cond e tf.while_loop continuano a funzionare, ma il flusso di controllo è spesso più facile da scrivere e capire quando viene scritto in Python.

# Simple loop

@tf.function
def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

f(tf.random.uniform([5]))
[0.224704742 0.895507693 0.0398198366 0.98112452 0.278468847]
[0.220997646 0.71410346 0.0397988036 0.753552318 0.271487355]
[0.217468739 0.61324358 0.0397778042 0.637263417 0.265008271]
[0.214104146 0.546406269 0.0397568382 0.563033342 0.258973926]
[0.210891485 0.497821957 0.0397359058 0.510224521 0.253335565]
[0.207819641 0.460402519 0.0397150069 0.470120102 0.248051569]
[0.204878598 0.430412233 0.0396941416 0.438296348 0.243086234]
[0.202059314 0.405665785 0.039673306 0.412231296 0.2384087]
[0.199353606 0.384786367 0.039652504 0.39036563 0.23399213]
[0.196754038 0.366856933 0.0396317355 0.371675402 0.229813099]
[0.194253832 0.351239443 0.039611 0.355456293 0.225851]
[0.191846803 0.337474287 0.0395902954 0.341205537 0.222087651]
[0.189527303 0.325220674 0.0395696238 0.3285532 0.218506947]
[0.187290132 0.314219803 0.0395489857 0.317220151 0.215094551]
[0.185130537 0.304271102 0.0395283774 0.30699119 0.211837649]
[0.183044136 0.295216352 0.0395078026 0.297697395 0.208724767]
[0.181026861 0.286928833 0.0394872613 0.289204 0.205745578]

<tf.Tensor: shape=(5,), dtype=float32, numpy=
array([0.17907499, 0.27930567, 0.03946675, 0.281402  , 0.20289075],
      dtype=float32)>

Se sei curioso puoi controllare il codice generato dall'autografo.

print(tf.autograph.to_code(f.python_function))
​​
def tf__f(x):
    with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        do_return = False
        retval_ = ag__.UndefinedReturnValue()

        def get_state():
            return (x,)

        def set_state(vars_):
            nonlocal x
            (x,) = vars_

        def loop_body():
            nonlocal x
            ag__.converted_call(ag__.ld(tf).print, (ag__.ld(x),), None, fscope)
            x = ag__.converted_call(ag__.ld(tf).tanh, (ag__.ld(x),), None, fscope)

        def loop_test():
            return (ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) > 1)
        ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {})
        try:
            do_return = True
            retval_ = ag__.ld(x)
        except:
            do_return = False
            raise
        return fscope.ret(retval_, do_return)


Condizionali

AutoGraph convertirà alcune istruzioni if <condition> nelle chiamate tf.cond equivalenti. Questa sostituzione viene eseguita se <condition> è un tensore. Altrimenti, l'istruzione if viene eseguita come condizionale Python.

Un condizionale Python viene eseguito durante la traccia, quindi al grafico verrà aggiunto esattamente un ramo del condizionale. Senza AutoGraph, questo grafico tracciato non sarebbe in grado di prendere il ramo alternativo se c'è un flusso di controllo dipendente dai dati.

tf.cond traccia e aggiunge entrambi i rami del condizionale al grafico, selezionando dinamicamente un ramo al momento dell'esecuzione. Il tracciamento può avere effetti collaterali indesiderati; vedere Effetti di traccia AutoGraph per ulteriori informazioni.

@tf.function
def fizzbuzz(n):
  for i in tf.range(1, n + 1):
    print('Tracing for loop')
    if i % 15 == 0:
      print('Tracing fizzbuzz branch')
      tf.print('fizzbuzz')
    elif i % 3 == 0:
      print('Tracing fizz branch')
      tf.print('fizz')
    elif i % 5 == 0:
      print('Tracing buzz branch')
      tf.print('buzz')
    else:
      print('Tracing default branch')
      tf.print(i)

fizzbuzz(tf.constant(5))
fizzbuzz(tf.constant(20))
Tracing for loop
Tracing fizzbuzz branch
Tracing fizz branch
Tracing buzz branch
Tracing default branch
1
2
fizz
4
buzz
1
2
fizz
4
buzz
fizz
7
8
fizz
buzz
11
fizz
13
14
fizzbuzz
16
17
fizz
19
buzz

Vedere la documentazione di riferimento per ulteriori restrizioni sulle istruzioni if ​​convertite da AutoGraph.

Loop

AutoGraph convertirà alcune istruzioni for e while nelle operazioni di ciclo TensorFlow equivalenti, come tf.while_loop . Se non convertito, il ciclo for o while viene eseguito come un ciclo Python.

Questa sostituzione viene effettuata nelle seguenti situazioni:

Un ciclo Python viene eseguito durante la traccia, aggiungendo operazioni aggiuntive a tf.Graph per ogni iterazione del ciclo.

Un ciclo TensorFlow traccia il corpo del ciclo e seleziona dinamicamente quante iterazioni eseguire al momento dell'esecuzione. Il corpo del ciclo appare solo una volta nel tf.Graph generato.

Vedere la documentazione di riferimento per le ulteriori restrizioni su AutoGraph convertiti for e while le dichiarazioni.

Looping sui dati Python

Una trappola comune è eseguire il loop dei dati Python / Numpy all'interno di una funzione tf.function .. Questo ciclo verrà eseguito durante il processo di tracciamento, aggiungendo una copia del modello a tf.Graph per ogni iterazione del ciclo.

Se si desidera avvolgere l'intero ciclo di addestramento in tf.function , il modo più sicuro per farlo è avvolgere i dati come tf.data.Dataset modo che AutoGraph srotoli dinamicamente il ciclo di addestramento.

def measure_graph_size(f, *args):
  g = f.get_concrete_function(*args).graph
  print("{}({}) contains {} nodes in its graph".format(
      f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))

@tf.function
def train(dataset):
  loss = tf.constant(0)
  for x, y in dataset:
    loss += tf.abs(y - x) # Some dummy computation.
  return loss

small_data = [(1, 1)] * 3
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)

measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: big_data, (tf.int32, tf.int32)))
train([(1, 1), (1, 1), (1, 1)]) contains 11 nodes in its graph
train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 8 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 8 nodes in its graph

Quando si avvolgono i dati Python / Numpy in un set di dati, prestare tf.data.Dataset.from_generator rispetto a tf.data.Dataset.from_tensors . Il primo manterrà i dati in Python e lo tf.py_function tramite tf.py_function che può avere implicazioni sulle prestazioni, mentre il secondo raggrupperà una copia dei dati come un grande nodo tf.constant() nel grafico, che può avere implicazioni sulla memoria.

Lettura dei dati dai file tramite TFRecordDataset / CsvDataset / ecc. è il modo più efficace per consumare i dati, poiché TensorFlow stesso può gestire il caricamento asincrono e il precaricamento dei dati, senza dover coinvolgere Python. Per saperne di più, consulta la guida tf.data .

Accumulo di valori in un ciclo

Un modello comune consiste nell'accumulare valori intermedi da un ciclo. Normalmente, ciò si ottiene aggiungendo a un elenco Python o aggiungendo voci a un dizionario Python. Tuttavia, poiché questi sono effetti collaterali di Python, non funzioneranno come previsto in un ciclo srotolato dinamicamente. Usa tf.TensorArray per accumulare i risultati da un ciclo srotolato dinamicamente.

batch_size = 2
seq_len = 3
feature_size = 4

def rnn_step(inp, state):
  return inp + state

@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
  # [batch, time, features] -> [time, batch, features]
  input_data = tf.transpose(input_data, [1, 0, 2])
  max_seq_len = input_data.shape[0]

  states = tf.TensorArray(tf.float32, size=max_seq_len)
  state = initial_state
  for i in tf.range(max_seq_len):
    state = rnn_step(input_data[i], state)
    states = states.write(i, state)
  return tf.transpose(states.stack(), [1, 0, 2])

dynamic_rnn(rnn_step,
            tf.random.uniform([batch_size, seq_len, feature_size]),
            tf.zeros([batch_size, feature_size]))
<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[0.9854791 , 0.5162524 , 0.14062047, 0.04950547],
        [1.8820469 , 0.67421603, 0.40786874, 0.7679055 ],
        [2.8815444 , 1.1567757 , 1.0627073 , 0.8880433 ]],

       [[0.94119024, 0.19776726, 0.24890792, 0.4663092 ],
        [1.4591933 , 1.123581  , 0.35438073, 1.4392309 ],
        [2.0026946 , 1.9165647 , 0.37988353, 1.8128917 ]]], dtype=float32)>

Ulteriore lettura

Per informazioni su come esportare e caricare una Function , vedere la guida SavedModel . Per ulteriori informazioni sulle ottimizzazioni dei grafici eseguite dopo la traccia, vedere la guida Grappler . Per informazioni su come ottimizzare la pipeline di dati e profilare il tuo modello, consulta la guida Profiler .