Esta página foi traduzida pela API Cloud Translation.
Switch to English

tf.function

Versão TensorFlow 1 Ver fonte no GitHub

Compila uma função em um gráfico TensorFlow que pode ser chamado.

Usado nos notebooks

Utilizado no guia Usado nos tutoriais

tf.function constrói uma chamada que executa um gráfico TensorFlow ( tf.Graph ) criado pela compilação de rastreamento das operações TensorFlow em func , executando efetivamente func como um gráfico TensorFlow.

Exemplo de uso:

@tf.function
def f(x, y):
  return x ** 2 + y
x = tf.constant([2, 3])
y = tf.constant([3, -2])
f(x, y)
<tf.Tensor: ... numpy=array([7, 7], ...)>

Recursos

func pode usar o fluxo de controle dependente de dados, incluindo instruções if , for , while break , continue e return :

@tf.function
def f(x):
  if tf.reduce_sum(x) > 0:
    return x * x
  else:
    return -x // 2
f(tf.constant(-2))
<tf.Tensor: ... numpy=1>

func encerramento 's podem incluir tf.Tensor e tf.Variable objetos:

@tf.function
def f():
  return x ** 2 + y
x = tf.constant([-2, -3])
y = tf.Variable([3, -2])
f()
<tf.Tensor: ... numpy=array([7, 7], ...)>

func também pode usar ops com efeitos colaterais, como tf.print , tf.Variable e outros:

v = tf.Variable(1)
@tf.function
def f(x):
  for i in tf.range(x):
    v.assign_add(i)
f(3)
v
<tf.Variable ... numpy=4>
l = []
@tf.function
def f(x):
  for i in x:
    l.append(i + 1)    # Caution! Will only happen once when tracing
f(tf.constant([1, 2, 3]))
l
[<tf.Tensor ...>]

Em vez disso, use coleções tf.TensorArray como tf.TensorArray :

@tf.function
def f(x):
  ta = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
  for i in range(len(x)):
    ta = ta.write(i, x[i] + 1)
  return ta.stack()
f(tf.constant([1, 2, 3]))
<tf.Tensor: ..., numpy=array([2, 3, 4], ...)>

tf.function é polimórfica

Internamente, o tf.function pode criar mais de um gráfico, para suportar argumentos com diferentes tipos ou formas de dados, pois o TensorFlow pode criar gráficos mais eficientes, especializados em formas e tipos. tf.function também trata qualquer valor puro do Python como objetos opacos e cria um gráfico separado para cada conjunto de argumentos do Python que encontrar.

Para obter um gráfico individual, use o método get_concrete_function da chamada criada por tf.function . Ele pode ser chamado com os mesmos argumentos que func e retorna um objeto tf.Graph especial:

@tf.function
def f(x):
  return x + 1
isinstance(f.get_concrete_function(1).graph, tf.Graph)
True
@tf.function
def f(x):
  return tf.abs(x)
f1 = f.get_concrete_function(1)
f2 = f.get_concrete_function(2)  # Slow - builds new graph
f1 is f2
False
f1 = f.get_concrete_function(tf.constant(1))
f2 = f.get_concrete_function(tf.constant(2))  # Fast - reuses f1
f1 is f2
True

Os argumentos numéricos do Python devem ser usados ​​apenas quando eles recebem poucos valores distintos, como hiperparâmetros, como o número de camadas em uma rede neural.

Assinaturas de entrada

Para argumentos do tensor, tf.function instancia um gráfico separado para cada conjunto exclusivo de formas e tipos de dados de entrada. O exemplo abaixo cria dois gráficos separados, cada um especializado em uma forma diferente:

@tf.function
def f(x):
  return x + 1
vector = tf.constant([1.0, 1.0])
matrix = tf.constant([[3.0]])
f.get_concrete_function(vector) is f.get_concrete_function(matrix)
False

Uma "assinatura de entrada" pode ser opcionalmente fornecida para tf.function para controlar os gráficos rastreados. A assinatura de entrada especifica a forma e o tipo de cada argumento Tensor da função usando um objeto tf.TensorSpec . Formas mais gerais podem ser usadas. Isso é útil para evitar a criação de vários gráficos quando os tensores tiverem formas dinâmicas. Também restringe a forma e o tipo de dados dos tensores que podem ser usados:

@tf.function(
    input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
def f(x):
  return x + 1
vector = tf.constant([1.0, 1.0])
matrix = tf.constant([[3.0]])
f.get_concrete_function(vector) is f.get_concrete_function(matrix)
True

Variáveis ​​podem ser criadas apenas uma vez

tf.function só permite criar novos objetos