Ta strona została przetłumaczona przez Cloud Translation API.
Switch to English

tf.function

Wersja TensorFlow 1 Wyświetl źródło na GitHub

Kompiluje funkcję w wywoływalny wykres TensorFlow.

Używany w notebookach

Używany w przewodniku Używany w samouczkach

tf.function konstruuje wywoływaną funkcję, która wykonuje wykres TensorFlow ( tf.Graph ) utworzony przez kompilację śledzenia operacji TensorFlow w func , skutecznie wykonując func jako wykres TensorFlow.

Przykładowe użycie:

@tf.function
def f(x, y):
  return x ** 2 + y
x = tf.constant([2, 3])
y = tf.constant([3, -2])
f(x, y)
<tf.Tensor: ... numpy=array([7, 7], ...)>

cechy

func może używać przepływu sterowania zależnego od danych, w tym instrukcji if , for , while break , continue i return :

@tf.function
def f(x):
  if tf.reduce_sum(x) > 0:
    return x * x
  else:
    return -x // 2
f(tf.constant(-2))
<tf.Tensor: ... numpy=1>

Zamknięcie func może zawierać obiekty tf.Tensor i tf.Variable :

@tf.function
def f():
  return x ** 2 + y
x = tf.constant([-2, -3])
y = tf.Variable([3, -2])
f()
<tf.Tensor: ... numpy=array([7, 7], ...)>

func może również używać operacji z efektami ubocznymi, takimi jak tf.print , tf.Variable i inne:

v = tf.Variable(1)
@tf.function
def f(x):
  for i in tf.range(x):
    v.assign_add(i)
f(3)
v
<tf.Variable ... numpy=4>
l = []
@tf.function
def f(x):
  for i in x:
    l.append(i + 1)    # Caution! Will only happen once when tracing
f(tf.constant([1, 2, 3]))
l
[<tf.Tensor ...>]

Zamiast tego użyj kolekcji tf.TensorArray takich jak tf.TensorArray :

@tf.function
def f(x):
  ta = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
  for i in range(len(x)):
    ta = ta.write(i, x[i] + 1)
  return ta.stack()
f(tf.constant([1, 2, 3]))
<tf.Tensor: ..., numpy=array([2, 3, 4], ...)>

tf.function jest polimorficzny

Wewnętrznie tf.function może budować więcej niż jeden wykres, aby wspierać argumenty z różnymi typami danych lub kształtami, ponieważ TensorFlow może tworzyć bardziej wydajne wykresy, które są wyspecjalizowane w kształtach i dtypach. tf.function również traktuje każdą czystą wartość Pythona jako nieprzezroczyste obiekty i tworzy osobny wykres dla każdego napotkanego zestawu argumentów Pythona.

Aby uzyskać indywidualny wykres, użyj metody get_concrete_function obiektu wywoływanego utworzonego przez tf.function . Można go wywołać z tymi samymi argumentami, co func i zwraca specjalny obiekt tf.Graph :

@tf.function
def f(x):
  return x + 1
isinstance(f.get_concrete_function(1).graph, tf.Graph)
True
@tf.function
def f(x):
  return tf.abs(x)
f1 = f.get_concrete_function(1)
f2 = f.get_concrete_function(2)  # Slow - builds new graph
f1 is f2
False
f1 = f.get_concrete_function(tf.constant(1))
f2 = f.get_concrete_function(tf.constant(2))  # Fast - reuses f1
f1 is f2
True

Argumenty numeryczne Pythona powinny być używane tylko wtedy, gdy przyjmują kilka różnych wartości, takich jak hiperparametry, takie jak liczba warstw w sieci neuronowej.

Podpisy wejściowe

W przypadku argumentów Tensor tf.function instancję oddzielnego wykresu dla każdego unikalnego zestawu kształtów wejściowych i typów danych. Poniższy przykład tworzy dwa oddzielne wykresy, z których każdy ma inny kształt:

@tf.function
def f(x):
  return x + 1
vector = tf.constant([1.0, 1.0])
matrix = tf.constant([[3.0]])
f.get_concrete_function(vector) is f.get_concrete_function(ma