Esta página foi traduzida pela API Cloud Translation.
Switch to English

Melhor desempenho com tf.function

Ver em TensorFlow.org Executar no Google Colab Ver fonte no GitHub Download do caderno

No TensorFlow 2, a execução rápida é ativada por padrão. A interface do usuário é intuitiva e flexível (a execução de operações pontuais é muito mais fácil e rápida), mas isso pode prejudicar o desempenho e a capacidade de implantação.

Você pode usar tf.function para criar gráficos de seus programas. É uma ferramenta de transformação que cria gráficos de fluxo de dados independentes do Python a partir do seu código Python. Isso o ajudará a criar modelos portáteis e de SavedModel desempenho e é necessário o uso do SavedModel .

Este guia o ajudará a conceituar como o tf.function funciona sob o capô para que você possa usá-lo com eficiência.

Os principais tópicos e recomendações são:

  • Depure no modo ansioso e decore com @tf.function .
  • Não confie nos efeitos colaterais do Python, como mutação de objeto ou anexos à lista.
  • tf.function funciona melhor com as operações do TensorFlow; As chamadas NumPy e Python são convertidas em constantes.

Configuração

 import tensorflow as tf
 

Defina uma função auxiliar para demonstrar os tipos de erros que você pode encontrar:

 import traceback
import contextlib

# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
  try:
    yield
  except error_class as e:
    print('Caught expected exception \n  {}:'.format(error_class))
    traceback.print_exc(limit=2)
  except Exception as e:
    raise e
  else:
    raise Exception('Expected {} to be raised but no error was raised!'.format(
        error_class))
 

Fundamentos

Uso

Uma Function você define é como uma operação principal do TensorFlow: você pode executá-la com entusiasmo; você pode calcular gradientes; e assim por diante.

 @tf.function
def add(a, b):
  return a + b

add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]
 
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
       [2., 2.]], dtype=float32)>
 v = tf.Variable(1.0)
with tf.GradientTape() as tape:
  result = add(v, 1.0)
tape.gradient(result, v)
 
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

Você pode usar a Function s dentro de outras Function s.

 @tf.function
def dense_layer(x, w, b):
  return add(tf.matmul(x, w), b)

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
 
<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
       [3., 3.],
       [3., 3.]], dtype=float32)>

Function s podem ser mais rápidas que o código ansioso, especialmente para gráficos com muitas operações pequenas. Mas para gráficos com algumas operações caras (como convoluções), talvez você não veja muita aceleração.

 import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)

@tf.function
def conv_fn(image):
  return conv_layer(image)

image = tf.zeros([1, 200, 200, 100])
# warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")

 
Eager conv: 0.0023194860004878137
Function conv: 0.0036776439992536325
Note how there's not much difference in performance for convolutions

Rastreamento

A digitação dinâmica do Python significa que você pode chamar funções com vários tipos de argumentos, e o Python pode fazer algo diferente em cada cenário.

No entanto, para criar um gráfico TensorFlow, são necessários dtypes estáticos e dimensões de forma. tf.function preenche essa lacuna envolvendo uma função Python para criar um objeto Function . Com base nas entradas fornecidas, a Function seleciona o gráfico apropriado para as entradas fornecidas, refazendo a função Python conforme necessário. Depois de entender por que e quando o rastreamento acontece, é muito mais fácil usar o tf.function eficiência!

Você pode chamar uma Function com argumentos de diferentes tipos para ver esse comportamento polimórfico em ação.

 @tf.function
def double(a):
  print("Tracing with", a)
  return a + a

print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()

 
Tracing with Tensor("a:0", shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Tracing with Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)

Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'aa', shape=(), dtype=string)


Observe que, se você chamar repetidamente uma Function com o mesmo tipo de argumento, o TensorFlow reutilizará um gráfico rastreado anteriormente, pois o gráfico gerado seria idêntico.

 # This doesn't print 'Tracing with ...'
print(double(tf.constant("b")))
 
tf.Tensor(b'bb', shape=(), dtype=string)

(A alteração a seguir está disponível no TensorFlow todas as noites e estará disponível no TensorFlow 2.3.)

Você pode usar pretty_printed_concrete_signatures() para ver todos os rastreamentos disponíveis:

 print(double.pretty_printed_concrete_signatures())
 
double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

double(a)
  Args:
    a: int32 Tensor, shape=()
  Returns:
    int32 Tensor, shape=()

double(a)
  Args:
    a: float32 Tensor, shape=()
  Returns:
    float32 Tensor, shape=()

Até agora, você viu que o tf.function cria uma camada de despacho dinâmico em cache sobre a lógica de rastreamento de gráficos do TensorFlow. Para ser mais específico sobre a terminologia:

  • Um tf.Graph é a representação portátil, tf.Graph idioma e portátil, de sua computação.
  • Um ConcreteFunction é um invólucro que executa ansiosamente em torno de um tf.Graph .
  • Uma Function gerencia um cache de ConcreteFunction escolhe o caminho certo para suas entradas.
  • tf.function uma função Python, retornando um objeto Function .

Obtendo funções concretas

Sempre que uma função é rastreada, uma nova função concreta é criada. Você pode obter diretamente uma função concreta usando get_concrete_function .

 print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.constant("a"))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))

 
Obtaining concrete trace
Executing traced function
tf.Tensor(b'aa', shape=(), dtype=string)
tf.Tensor(b'bb', shape=(), dtype=string)

 # You can also call get_concrete_function on an InputSpec
double_strings_from_inputspec = double.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.string))
print(double_strings_from_inputspec(tf.constant("c")))
 
Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'cc', shape=(), dtype=string)

(A alteração a seguir está disponível no TensorFlow todas as noites e estará disponível no TensorFlow 2.3.)

A impressão de uma ConcreteFunction exibe um resumo de seus argumentos de entrada (com tipos) e seu tipo de saída.

 print(double_strings)
 
ConcreteFunction double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

Você também pode recuperar diretamente a assinatura de uma função concreta.

 print(double_strings.structured_input_signature)
print(double_strings.structured_outputs)
 
((TensorSpec(shape=(), dtype=tf.string, name='a'),), {})
Tensor("Identity:0", shape=(), dtype=string)

Usar um rastreamento concreto com tipos incompatíveis gerará um erro

 with assert_raises(tf.errors.InvalidArgumentError):
  double_strings(tf.constant(1))
 
Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>:

Traceback (most recent call last):
  File "<ipython-input-3-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-15-e4e2860a4364>", line 2, in <module>
    double_strings(tf.constant(1))
tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_168 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_168]

Você pode perceber que os argumentos do Python recebem tratamento especial na assinatura de entrada de uma função concreta. Antes do TensorFlow 2.3, os argumentos do Python eram simplesmente removidos da assinatura da função concreta. Começando com o TensorFlow 2.3, os argumentos Python permanecem na assinatura, mas são limitados a assumir o valor definido durante o rastreamento.

 @tf.function
def pow(a, b):
  return a ** b

square = pow.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)
print(square)
 
ConcreteFunction pow(a, b=2)
  Args:
    a: float32 Tensor, shape=<unknown>
  Returns:
    float32 Tensor, shape=<unknown>

 assert square(tf.constant(10.0)) == 100

with assert_raises(TypeError):
  square(tf.constant(10.0), b=3)
 
Caught expected exception 
  <class 'TypeError'>:

Traceback (most recent call last):
  File "/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1669, in _call_impl
    cancellation_manager)
  File "/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1714, in _call_with_flat_signature
    self._flat_signature_summary(), ", ".join(sorted(kwargs))))
TypeError: pow(a) got unexpected keyword arguments: b.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<ipython-input-3-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-17-d163f3d206cb>", line 4, in <module>
    square(tf.constant(10.0), b=3)
TypeError: ConcreteFunction pow(a, b) was constructed with int value 2 in b, but was called with int value 3

Obtendo gráficos

Cada função concreta é um invólucro que pode ser tf.Graph em torno de um tf.Graph . Embora recuperar o objeto tf.Graph real não seja algo que você normalmente precise fazer, você pode obtê-lo facilmente a partir de qualquer função concreta.

 graph = double_strings.graph
for node in graph.as_graph_def().node:
  print(f'{node.input} -> {node.name}')

 
[] -> a
['a', 'a'] -> add
['add'] -> Identity

Depuração

Em geral, o código de depuração é mais fácil no modo ansioso do que no tf.function . Você deve garantir que seu código seja executado sem erros no modo ansioso antes de decorar com tf.function . Para ajudar no processo de depuração, você pode chamar tf.config.run_functions_eagerly(True) para desativar e reativar globalmente a função tf.function .

Ao rastrear problemas que só aparecem no tf.function , aqui estão algumas dicas:

  • Chamadas simples de print antiga do Python são executadas apenas durante o rastreamento, ajudando a rastrear quando sua função é (re) rastreada.
  • tf.print chamadas tf.print são executadas sempre e podem ajudá-lo a rastrear valores intermediários durante a execução.
  • tf.debugging.enable_check_numerics é uma maneira fácil de rastrear onde os NaNs e Inf são criados.
  • pdb pode ajudar você a entender o que está acontecendo durante o rastreamento. (Advertência: o PDB o levará ao código-fonte transformado pelo AutoGraph.)

Rastreando semântica

Regras de chave de cache

Uma Function determina se deve reutilizar uma função concreta rastreada computando uma chave de cache dos argumentos e kwargs de uma entrada.

  • A chave gerada para um argumento tf.Tensor é sua forma e tipo.
  • Iniciando no TensorFlow 2.3, a chave gerada para um argumento tf.Variable é seu id() .
  • A chave gerada para uma primitiva Python é seu valor. A chave gerada para dict aninhados, list s, tuple s, namedtuple s e attr s é a tupla achatada. (Como resultado desse achatamento, chamar uma função concreta com uma estrutura de aninhamento diferente daquela usada durante o rastreamento resultará em um TypeError).
  • Para todos os outros tipos de Python, as chaves são baseadas no id() do objeto id() para que os métodos sejam rastreados independentemente para cada instância de uma classe.

Controlando a retraçagem

A retração ajuda a garantir que o TensorFlow gere gráficos corretos para cada conjunto de entradas. No entanto, o rastreamento é uma operação cara! Se sua Function refazer um novo gráfico para cada chamada, você descobrirá que seu código é executado mais lentamente do que se você não usasse tf.function .

Para controlar o comportamento de rastreamento, você pode usar as seguintes técnicas:

  • Especifique input_signature em tf.function para limitar o rastreio.
 @tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
  print("Tracing with", x)
  return tf.where(x % 2 == 0, x // 2, 3 * x + 1)

print(next_collatz(tf.constant([1, 2])))
# We specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([[1, 2], [3, 4]]))

# We specified an int32 dtype in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([1.0, 2.0]))

 
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([4 1], shape=(2,), dtype=int32)
Caught expected exception 
  <class 'ValueError'>:
Caught expected exception 
  <class 'ValueError'>:

Traceback (most recent call last):
  File "<ipython-input-3-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-19-20f544b8adbf>", line 9, in <module>
    next_collatz(tf.constant([[1, 2], [3, 4]]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor(
[[1 2]
 [3 4]], shape=(2, 2), dtype=int32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None))
Traceback (most recent call last):
  File "<ipython-input-3-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-19-20f544b8adbf>", line 13, in <module>
    next_collatz(tf.constant([1.0, 2.0]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor([1. 2.], shape=(2,), dtype=float32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None))

  • Especifique uma dimensão [Nenhuma] em tf.TensorSpec para permitir flexibilidade na reutilização de rastreio.

    Como o TensorFlow corresponde aos tensores com base em sua forma, o uso de uma dimensão None como curinga permitirá que a Function s reutilize traços para entrada de tamanho variável. A entrada de tamanho variável pode ocorrer se você tiver seqüências de diferentes tamanhos ou imagens de tamanhos diferentes para cada lote (consulte os tutoriais do Transformer e Deep Dream, por exemplo).

 @tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def g(x):
  print('Tracing with', x)
  return x

# No retrace!
print(g(tf.constant([1, 2, 3])))
print(g(tf.constant([1, 2, 3, 4, 5])))

 
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([1 2 3], shape=(3,), dtype=int32)
tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)

  • Lance argumentos do Python para os tensores para reduzir a refazer.

    Geralmente, os argumentos do Python são usados ​​para controlar hiperparâmetros e construções de gráficos - por exemplo, num_layers=10 ou training=True ou nonlinearity='relu' - nonlinearity='relu' . Portanto, se o argumento Python mudar, faz sentido que você precise refazer o gráfico.

    No entanto, é possível que um argumento Python não esteja sendo usado para controlar a construção do gráfico. Nesses casos, uma alteração no valor do Python pode desencadear retrocessos desnecessários. Veja, por exemplo, esse ciclo de treinamento, que o AutoGraph desenrolará dinamicamente. Apesar dos vários rastreamentos, o gráfico gerado é realmente idêntico, portanto, o retrocesso é desnecessário.

 def train_one_step():
  pass

@tf.function
def train(num_steps):
  print("Tracing with num_steps = ", num_steps)
  tf.print("Executing with num_steps = ", num_steps)
  for _ in tf.range(num_steps):
    train_one_step()

print("Retracing occurs for different Python arguments.")
train(num_steps=10)
train(num_steps=20)

print()
print("Traces are reused for Tensor arguments.")
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))
 
Retracing occurs for different Python arguments.
Tracing with num_steps =  10
Executing with num_steps =  10
Tracing with num_steps =  20
Executing with num_steps =  20

Traces are reused for Tensor arguments.
Tracing with num_steps =  Tensor("num_steps:0", shape=(), dtype=int32)
Executing with num_steps =  10
Executing with num_steps =  20

Se você precisar forçar a refazer, crie uma nova Function . Os objetos Function separados têm garantia de não compartilhar traços.

 def f():
  print('Tracing!')
  tf.print('Executing')

tf.function(f)()
tf.function(f)()
 
Tracing!
Executing
Tracing!
Executing

Efeitos colaterais do Python

Os efeitos colaterais do Python, como impressão, anexação a listas e mutação de globais, só acontecem na primeira vez em que você chama uma Function com um conjunto de entradas. Posteriormente, o tf.Graph rastreado é reexecutado, sem executar o código Python.

A regra geral é usar apenas efeitos colaterais do Python para depurar seus rastreamentos. Caso contrário, as operações do TensorFlow como tf.Variable.assign , tf.print e tf.summary são a melhor maneira de garantir que seu código seja rastreado e executado pelo tempo de execução do TensorFlow a cada chamada.

 @tf.function
def f(x):
  print("Traced with", x)
  tf.print("Executed with", x)

f(1)
f(1)
f(2)

 
Traced with 1
Executed with 1
Executed with 1
Traced with 2
Executed with 2

Muitos recursos do Python, como geradores e iteradores, dependem do tempo de execução do Python para acompanhar o estado. Em geral, enquanto essas construções funcionam como esperado no modo ansioso, muitas coisas inesperadas podem acontecer dentro de uma Function .

Para dar um exemplo, o avanço do estado do iterador é um efeito colateral do Python e, portanto, ocorre apenas durante o rastreamento.

 external_var = tf.Variable(0)
@tf.function
def buggy_consume_next(iterator):
  external_var.assign_add(next(iterator))
  tf.print("Value of external_var:", external_var)

iterator = iter([0, 1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)

 
Value of external_var: 0
Value of external_var: 0
Value of external_var: 0

Algumas construções de iteração são suportadas pelo AutoGraph. Veja a seção Transformações de AutoGraph para uma visão geral.

Se você deseja executar o código Python durante cada chamada de uma Function , tf.py_function é uma saída hachurada. A desvantagem da função tf.py_function é que ela não é portátil ou particularmente tf.py_function , nem funciona bem em configurações distribuídas (multi-GPU, TPU). Além disso, como a função tf.py_function deve ser conectada ao gráfico, ela lança todas as entradas / saídas em tensores.

APIs como tf.gather , tf.stack e tf.TensorArray podem ajudá-lo a implementar padrões de loop comuns no TensorFlow nativo.

 external_list = []

def side_effect(x):
  print('Python side effect')
  external_list.append(x)

@tf.function
def f(x):
  tf.py_function(side_effect, inp=[x], Tout=[])

f(1)
f(1)
f(1)
# The list append happens all three times!
assert len(external_list) == 3
# The list contains tf.constant(1), not 1, because py_function casts everything to tensors.
assert external_list[0].numpy() == 1

 
Python side effect
Python side effect
Python side effect

Variáveis

Você pode encontrar um erro ao criar um novo tf.Variable em uma função. Este erro protege contra divergências de comportamento em chamadas repetidas: No modo rápido, uma função cria uma nova variável a cada chamada, mas em uma Function , uma nova variável pode não ser criada devido à reutilização do rastreamento.

 @tf.function
def f(x):
  v = tf.Variable(1.0)
  v.assign_add(x)
  return v

with assert_raises(ValueError):
  f(1.0)
 
Caught expected exception 
  <class 'ValueError'>:

Traceback (most recent call last):
  File "<ipython-input-3-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-26-73e410646579>", line 8, in <module>
    f(1.0)
ValueError: in user code:

    <ipython-input-26-73e410646579>:3 f  *
        v = tf.Variable(1.0)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:262 __call__  **
        return cls._variable_v2_call(*args, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:256 _variable_v2_call
        shape=shape)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:67 getter
        return captured_getter(captured_previous, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py:702 invalid_creator_scope
        "tf.function-decorated function tried to create "

    ValueError: tf.function-decorated function tried to create variables on non-first call.


Você pode criar variáveis ​​dentro de uma Function , desde que essas variáveis ​​sejam criadas apenas na primeira vez em que a função for executada.

 class Count(tf.Module):
  def __init__(self):
    self.count = None
  
  @tf.function
  def __call__(self):
    if self.count is None:
      self.count = tf.Variable(0)
    return self.count.assign_add(1)

c = Count()
print(c())
print(c())
 
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Outro erro que você pode encontrar é uma variável coletada pelo lixo. Diferentemente das funções normais do Python, as funções concretas retêm WeakRefs apenas para as variáveis ​​fechadas, portanto, você deve manter uma referência a todas as variáveis.

 external_var = tf.Variable(3)
@tf.function
def f(x):
  return x * external_var

traced_f = f.get_concrete_function(4)
print("Calling concrete function...")
print(traced_f(4))

del external_var
print()
print("Calling concrete function after garbage collecting its closed Variable...")
with assert_raises(tf.errors.FailedPreconditionError):
  traced_f(4)
 
Calling concrete function...
tf.Tensor(12, shape=(), dtype=int32)

Calling concrete function after garbage collecting its closed Variable...
Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.FailedPreconditionError'>:

Traceback (most recent call last):
  File "<ipython-input-3-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-28-304a18524b57>", line 14, in <module>
    traced_f(4)
tensorflow.python.framework.errors_impl.FailedPreconditionError: 2 root error(s) found.
  (0) Failed precondition:  Error while reading resource variable _AnonymousVar4 from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/_AnonymousVar4/N10tensorflow3VarE does not exist.
     [[node ReadVariableOp (defined at <ipython-input-28-304a18524b57>:4) ]]
     [[ReadVariableOp/_2]]
  (1) Failed precondition:  Error while reading resource variable _AnonymousVar4 from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/_AnonymousVar4/N10tensorflow3VarE does not exist.
     [[node ReadVariableOp (defined at <ipython-input-28-304a18524b57>:4) ]]
0 successful operations.
0 derived errors ignored. [Op:__inference_f_514]

Function call stack:
f -> f


Transformações de AutoGraph

O AutoGraph é uma biblioteca que está tf.function por padrão no tf.function e transforma um subconjunto de código ansioso do Python em operações TensorFlow compatíveis com gráficos. Isso inclui o fluxo de controle como if , for while .

Operações de TensorFlow como tf.cond e tf.while_loop continuam a funcionar, mas o fluxo de controle geralmente é mais fácil de escrever e entender quando escrito em Python.

 # Simple loop

@tf.function
def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

f(tf.random.uniform([5]))
 
[0.448926926 0.896036148 0.703306437 0.446930766 0.20440042]
[0.421016544 0.714362323 0.6064623 0.419372857 0.201600626]
[0.397786468 0.613405049 0.541632056 0.396401972 0.198913112]
[0.378053397 0.546519518 0.494222373 0.376866162 0.196330562]
[0.361015767 0.497907132 0.457561225 0.359982818 0.1938463]
[0.346108437 0.460469633 0.428094476 0.3451989 0.191454232]
[0.332919776 0.43046692 0.403727621 0.332110822 0.189148799]
[0.321141869 0.405711472 0.383133948 0.320416152 0.18692489]
[0.310539037 0.384825289 0.365426034 0.309883147 0.184777796]
[0.300927401 0.366890609 0.349984437 0.300330788 0.182703182]
[0.292161077 0.351268977 0.336361736 0.291615278 0.180697069]
[0.284122646 0.337500453 0.324225426 0.283620834 0.178755745]
[0.276716352 0.325244069 0.313322544 0.276252925 0.176875815]
[0.269863278 0.314240903 0.303456694 0.269433528 0.175054088]
[0.263497591 0.304290265 0.294472754 0.263097644 0.17328763]
[0.257564 0.295233846 0.2862463 0.257190555 0.171573699]
[0.25201565 0.286944896 0.278676242 0.25166589 0.169909731]
[0.246812463 0.279320478 0.271679461 0.246483982 0.168293342]
[0.24192 0.272276044 0.265186876 0.241610721 0.166722313]
[0.237308443 0.265741408 0.259140551 0.237016559 0.165194541]
[0.23295185 0.25965777 0.253491491 0.232675791 0.163708091]
[0.228827521 0.253975391 0.248197898 0.228565902 0.162261128]
[0.224915475 0.248651937 0.243223906 0.224667087 0.160851941]
[0.221198082 0.243651047 0.238538548 0.220961839 0.159478888]
[0.217659682 0.238941342 0.23411487 0.217434615 0.158140466]
[0.214286327 0.23449555 0.229929343 0.214071587 0.156835243]
[0.211065561 0.230289876 0.225961298 0.210860386 0.155561864]
[0.207986191 0.226303399 0.222192511 0.207789883 0.154319063]
[0.20503816 0.222517684 0.2186068 0.204850093 0.153105617]

<tf.Tensor: shape=(5,), dtype=float32, numpy=
array([0.20221236, 0.2189164 , 0.21518978, 0.20203198, 0.15192041],
      dtype=float32)>

Se você estiver curioso, pode inspecionar o código que o autógrafo gera.

 print(tf.autograph.to_code(f.python_function))
 
def tf__f(x):
    with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        do_return = False
        retval_ = ag__.UndefinedReturnValue()

        def get_state():
            return (x,)

        def set_state(vars_):
            nonlocal x
            (x,) = vars_

        def loop_body():
            nonlocal x
            ag__.converted_call(ag__.ld(tf).print, (ag__.ld(x),), None, fscope)
            x = ag__.converted_call(ag__.ld(tf).tanh, (ag__.ld(x),), None, fscope)

        def loop_test():
            return (ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) > 1)
        ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {})
        try:
            do_return = True
            retval_ = ag__.ld(x)
        except:
            do_return = False
            raise
        return fscope.ret(retval_, do_return)


Condicionais

O AutoGraph converterá algumas instruções if <condition> nas chamadas tf.cond equivalentes. Essa substituição é feita se <condition> for um tensor. Caso contrário, a instrução if é executada como uma condicional Python.

Uma condicional Python é executada durante o rastreamento, portanto, exatamente uma ramificação da condicional será adicionada ao gráfico. Sem o AutoGraph, este gráfico rastreado seria incapaz de assumir a ramificação alternativa se houver um fluxo de controle dependente de dados.

tf.cond rastreia e adiciona as duas ramificações da condicional ao gráfico, selecionando dinamicamente uma ramificação no tempo de execução. O rastreamento pode ter efeitos colaterais não intencionais; consulte Efeitos de rastreamento do AutoGraph para obter mais.

 @tf.function
def fizzbuzz(n):
  for i in tf.range(1, n + 1):
    print('Tracing for loop')
    if i % 15 == 0:
      print('Tracing fizzbuzz branch')
      tf.print('fizzbuzz')
    elif i % 3 == 0:
      print('Tracing fizz branch')
      tf.print('fizz')
    elif i % 5 == 0:
      print('Tracing buzz branch')
      tf.print('buzz')
    else:
      print('Tracing default branch')
      tf.print(i)

fizzbuzz(tf.constant(5))
fizzbuzz(tf.constant(20))
 
Tracing for loop
Tracing fizzbuzz branch
Tracing fizz branch
Tracing buzz branch
Tracing default branch
1
2
fizz
4
buzz
1
2
fizz
4
buzz
fizz
7
8
fizz
buzz
11
fizz
13
14
fizzbuzz
16
17
fizz
19
buzz

Consulte a documentação de referência para obter restrições adicionais sobre instruções if convertidas em AutoGraph.

rotações

O AutoGraph converterá algumas instruções for e while nas operações de loop equivalente do TensorFlow, como tf.while_loop . tf.while_loop . Se não for convertido, o for ou while laço é executado como um loop de Python.

Essa substituição é feita nas seguintes situações:

Um loop Python é executado durante o rastreamento, adicionando ops adicionais ao tf.Graph para cada iteração do loop.

Um loop TensorFlow rastreia o corpo do loop e seleciona dinamicamente quantas iterações executar no tempo de execução. O corpo do loop aparece apenas uma vez no tf.Graph . gerado.

Consulte a documentação de referência para obter restrições adicionais sobre as instruções for e while convertidas em AutoGraph.

Loop em dados Python

Uma armadilha comum é fazer um loop sobre dados Python / Numpy em uma função tf.function Esse loop será executado durante o processo de rastreamento, adicionando uma cópia do seu modelo ao tf.Graph para cada iteração do loop.

Se você deseja tf.function todo o loop de treinamento na função tf.function , a maneira mais segura de fazer isso é tf.data.Dataset seus dados como um tf.data.Dataset para que o AutoGraph desenrole dinamicamente o loop de treinamento.

 def measure_graph_size(f, *args):
  g = f.get_concrete_function(*args).graph
  print("{}({}) contains {} nodes in its graph".format(
      f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))

@tf.function
def train(dataset):
  loss = tf.constant(0)
  for x, y in dataset:
    loss += tf.abs(y - x) # Some dummy computation.
  return loss

small_data = [(1, 1)] * 3
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)

measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: big_data, (tf.int32, tf.int32)))
 
train([(1, 1), (1, 1), (1, 1)]) contains 11 nodes in its graph
train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 8 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 8 nodes in its graph

Ao agrupar dados Python / Numpy em um conjunto de dados, tf.data.Dataset.from_generator de tf.data.Dataset.from_generator versus tf.data.Dataset.from_tensors . O primeiro manterá os dados em Python e os buscará através da função tf.py_function que pode ter implicações no desempenho, enquanto o último tf.constant() uma cópia dos dados como um nó tf.constant() grande no gráfico, o que pode ter implicações na memória.

Lendo dados de arquivos via TFRecordDataset / CsvDataset / etc. é a maneira mais eficaz de consumir dados, pois o próprio TensorFlow pode gerenciar o carregamento e a pré-busca assíncrona de dados, sem precisar envolver o Python. Para saber mais, consulte o guia tf.data .

Acumulando valores em um loop

Um padrão comum é acumular valores intermediários de um loop. Normalmente, isso é feito anexando-se a uma lista Python ou adicionando entradas a um dicionário Python. No entanto, como esses são efeitos colaterais do Python, eles não funcionarão conforme o esperado em um loop desenrolado dinamicamente. Use tf.TensorArray para acumular resultados de um loop desenrolado dinamicamente.

 batch_size = 2
seq_len = 3
feature_size = 4

def rnn_step(inp, state):
  return inp + state

@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
  # [batch, time, features] -> [time, batch, features]
  input_data = tf.transpose(input_data, [1, 0, 2])
  max_seq_len = input_data.shape[0]

  states = tf.TensorArray(tf.float32, size=max_seq_len)
  state = initial_state
  for i in tf.range(max_seq_len):
    state = rnn_step(input_data[i], state)
    states = states.write(i, state)
  return tf.transpose(states.stack(), [1, 0, 2])
  
dynamic_rnn(rnn_step,
            tf.random.uniform([batch_size, seq_len, feature_size]),
            tf.zeros([batch_size, feature_size]))
 
<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[0.2486304 , 0.0612042 , 0.69624186, 0.28587592],
        [1.2193475 , 0.2389338 , 1.5216837 , 0.38649392],
        [1.7640524 , 1.1970762 , 2.3265643 , 0.81419575]],

       [[0.36599267, 0.41830885, 0.73540664, 0.63987565],
        [0.48354673, 1.1808103 , 1.7210082 , 0.8333106 ],
        [0.7138835 , 1.2030114 , 1.8544207 , 1.1647347 ]]], dtype=float32)>

Leitura adicional

Para saber como exportar e carregar uma Function , consulte o guia SavedModel . Para saber mais sobre otimizações de gráficos que são executadas após o rastreamento, consulte o guia Grappler . Para saber como otimizar seu pipeline de dados e criar um perfil de seu modelo, consulte o guia Profiler .