Weź udział w sympozjum Women in ML 7 grudnia Zarejestruj się teraz

Lepsza wydajność dzięki funkcji tf.

Zadbaj o dobrą organizację dzięki kolekcji Zapisuj i kategoryzuj treści zgodnie ze swoimi preferencjami.

Zobacz na TensorFlow.org Uruchom w Google Colab Wyświetl źródło na GitHubPobierz notatnik

W TensorFlow 2 szybkie wykonywanie jest domyślnie włączone. Interfejs użytkownika jest intuicyjny i elastyczny (wykonywanie jednorazowych operacji jest znacznie łatwiejsze i szybsze), ale może to nastąpić kosztem wydajności i możliwości wdrożenia.

Możesz użyć tf.function do tworzenia wykresów ze swoich programów. Jest to narzędzie do transformacji, które tworzy niezależne od Pythona wykresy przepływu danych z kodu Pythona. Pomoże to w tworzeniu wydajnych i przenośnych modeli i jest wymagane do korzystania z SavedModel .

Ten przewodnik pomoże Ci zrozumieć, jak tf.function działa pod maską, abyś mógł z niego efektywnie korzystać.

Główne dania na wynos i rekomendacje to:

  • Debuguj w trybie przyspieszonym, a następnie udekoruj za pomocą @tf.function .
  • Nie polegaj na efektach ubocznych Pythona, takich jak mutacje obiektów lub dołączanie list.
  • tf.function działa najlepiej z operacjami TensorFlow; Wywołania NumPy i Python są konwertowane na stałe.

Ustawiać

import tensorflow as tf

Zdefiniuj funkcję pomocniczą, aby zademonstrować rodzaje błędów, które możesz napotkać:

import traceback
import contextlib

# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
  try:
    yield
  except error_class as e:
    print('Caught expected exception \n  {}:'.format(error_class))
    traceback.print_exc(limit=2)
  except Exception as e:
    raise e
  else:
    raise Exception('Expected {} to be raised but no error was raised!'.format(
        error_class))

Podstawy

Stosowanie

Function , którą definiujesz (na przykład przez zastosowanie dekoratora @tf.function ) jest jak podstawowa operacja TensorFlow: możesz ją wykonać z zapałem; możesz obliczyć gradienty; i tak dalej.

@tf.function  # The decorator converts `add` into a `Function`.
def add(a, b):
  return a + b

add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
       [2., 2.]], dtype=float32)>
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
  result = add(v, 1.0)
tape.gradient(result, v)
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

Możesz używać Function wewnątrz innych Function .

@tf.function
def dense_layer(x, w, b):
  return add(tf.matmul(x, w), b)

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
       [3., 3.],
       [3., 3.]], dtype=float32)>

Function s może być szybsza niż gorliwy kod, szczególnie w przypadku wykresów z wieloma małymi operacjami. Ale w przypadku wykresów z kilkoma kosztownymi operacjami (takich jak sploty) możesz nie zauważyć dużego przyspieszenia.

import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)

@tf.function
def conv_fn(image):
  return conv_layer(image)

image = tf.zeros([1, 200, 200, 100])
# Warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")
Eager conv: 0.006058974999177735
Function conv: 0.005791576000774512
Note how there's not much difference in performance for convolutions

Rysunek kalkowy

W tej sekcji opisano, jak Function działa pod maską, w tym szczegóły implementacji, które mogą ulec zmianie w przyszłości . Jednak po zrozumieniu, dlaczego i kiedy dzieje się śledzenie, znacznie łatwiej jest efektywnie korzystać z tf.function !

Co to jest „śledzenie”?

Function uruchamia program na wykresie TensorFlow . Jednak tf.Graph nie może reprezentować wszystkich rzeczy, które napisałbyś w chętnym programie TensorFlow. Na przykład Python obsługuje polimorfizm, ale tf.Graph wymaga, aby dane wejściowe miały określony typ danych i wymiar. Możesz też wykonywać zadania poboczne, takie jak odczytywanie argumentów wiersza poleceń, zgłaszanie błędów lub praca z bardziej złożonym obiektem Pythona; żadna z tych rzeczy nie może działać w tf.Graph .

Function wypełnia tę lukę, dzieląc kod na dwa etapy:

1) W pierwszym etapie, określanym jako „ śledzenie ”, Function tworzy nowy tf.Graph . Kod Pythona działa normalnie, ale wszystkie operacje TensorFlow (takie jak dodawanie dwóch Tensorów) są odroczone : są przechwytywane przez tf.Graph i nie są uruchamiane.

2) W drugim etapie uruchamiany jest tf.Graph , który zawiera wszystko, co zostało odroczone w pierwszym etapie. Ten etap jest znacznie szybszy niż etap śledzenia.

W zależności od danych wejściowych Function nie zawsze uruchomi pierwszy etap po wywołaniu. Zobacz „Zasady śledzenia” poniżej, aby lepiej zrozumieć, w jaki sposób dokonuje tego ustalenia. Pominięcie pierwszego etapu i wykonanie tylko drugiego etapu zapewnia wysoką wydajność TensorFlow.

Gdy Function zdecyduje się na śledzenie, po etapie śledzenia następuje natychmiast drugi etap, więc wywołanie Function zarówno tworzy, jak i uruchamia tf.Graph . Później zobaczysz, jak możesz uruchomić tylko etap śledzenia za pomocą get_concrete_function .

Kiedy przekazujesz argumenty różnych typów do Function , uruchamiane są oba etapy:

@tf.function
def double(a):
  print("Tracing with", a)
  return a + a

print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()
Tracing with Tensor("a:0", shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Tracing with Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)

Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'aa', shape=(), dtype=string)

Zauważ, że jeśli wielokrotnie wywołasz Function z tym samym typem argumentu, TensorFlow pominie etap śledzenia i ponownie użyje wcześniej wyśledzonego wykresu, ponieważ wygenerowany wykres byłby identyczny.

# This doesn't print 'Tracing with ...'
print(double(tf.constant("b")))
tf.Tensor(b'bb', shape=(), dtype=string)

Możesz użyć pretty_printed_concrete_signatures() , aby zobaczyć wszystkie dostępne ślady:

print(double.pretty_printed_concrete_signatures())
double(a)
  Args:
    a: int32 Tensor, shape=()
  Returns:
    int32 Tensor, shape=()

double(a)
  Args:
    a: float32 Tensor, shape=()
  Returns:
    float32 Tensor, shape=()

double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

Jak dotąd widziałeś, że tf.function tworzy buforowaną, dynamiczną warstwę wysyłania nad logiką śledzenia wykresów TensorFlow. Aby być bardziej szczegółowym o terminologii:

  • tf.Graph to surowa, niezależna od języka, przenośna reprezentacja obliczeń TensorFlow.
  • ConcreteFunction otacza tf.Graph .
  • Function zarządza pamięcią podręczną ConcreteFunction s i wybiera właściwą dla danych wejściowych.
  • tf.function otacza funkcję Pythona, zwracając obiekt Function .
  • Tracing tworzy tf.Graph i zawija go w ConcreteFunction , znany również jako ślad.

Zasady śledzenia

Function określa, czy ponownie użyć śledzonej funkcji ConcreteFunction , obliczając klucz pamięci podręcznej z args i kwargs danych wejściowych. Klucz pamięci podręcznej to klucz, który identyfikuje funkcję ConcreteFunction na podstawie argumentów wejściowych i kwarg wywołania Function , zgodnie z następującymi regułami (które mogą się zmieniać):

  • Kluczem wygenerowanym dla tf.Tensor jest jego kształt i dtype.
  • Klucz wygenerowany dla tf.Variable to unikalny identyfikator zmiennej.
  • Kluczem wygenerowanym dla prymitywu Pythona (takiego jak int , float , str ) jest jego wartość.
  • Klucz wygenerowany dla zagnieżdżonych dict s, list s, tuple s, namedtuple s i attr s jest spłaszczoną krotką kluczy-liści (patrz nest.flatten ). (W wyniku tego spłaszczenia wywołanie konkretnej funkcji z inną strukturą zagnieżdżenia niż ta użyta podczas śledzenia spowoduje wystąpienie TypeError).
  • Dla wszystkich innych typów Pythona klucz jest unikalny dla obiektu. W ten sposób funkcja lub metoda jest śledzona niezależnie dla każdej instancji, z którą jest wywoływana.

Kontrolowanie cofania

Retracing, czyli gdy Function tworzy więcej niż jeden ślad, pomaga zapewnić, że TensorFlow generuje poprawne wykresy dla każdego zestawu danych wejściowych. Jednak śledzenie to kosztowna operacja! Jeśli Function odtworzy nowy wykres dla każdego wywołania, przekonasz się, że Twój kod jest wykonywany wolniej, niż gdybyś nie używał tf.function .

Aby kontrolować zachowanie śledzenia, możesz użyć następujących technik:

  • Określ input_signature w tf.function , aby ograniczyć śledzenie.
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
  print("Tracing with", x)
  return tf.where(x % 2 == 0, x // 2, 3 * x + 1)

print(next_collatz(tf.constant([1, 2])))
# You specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([[1, 2], [3, 4]]))

# You specified an int32 dtype in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([1.0, 2.0]))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([4 1], shape=(2,), dtype=int32)
Caught expected exception 
  <class 'ValueError'>:
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/1851403433.py", line 9, in <module>
    next_collatz(tf.constant([[1, 2], [3, 4]]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor(
[[1 2]
 [3 4]], shape=(2, 2), dtype=int32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None)).
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/1851403433.py", line 13, in <module>
    next_collatz(tf.constant([1.0, 2.0]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor([1. 2.], shape=(2,), dtype=float32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None)).
  • Określ wymiar [Brak] w tf.TensorSpec , aby zapewnić elastyczność ponownego wykorzystania śledzenia.

    Ponieważ TensorFlow dopasowuje tensory na podstawie ich kształtu, użycie wymiaru None jako symbolu wieloznacznego umożliwi Function s ponowne wykorzystanie śladów do wprowadzania danych o zmiennej wielkości. Dane wejściowe o zmiennej wielkości mogą wystąpić, jeśli masz sekwencje o różnej długości lub obrazy o różnych rozmiarach dla każdej partii (zobacz na przykład samouczki Transformer i Deep Dream ).

@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def g(x):
  print('Tracing with', x)
  return x

# No retrace!
print(g(tf.constant([1, 2, 3])))
print(g(tf.constant([1, 2, 3, 4, 5])))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([1 2 3], shape=(3,), dtype=int32)
tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)
  • Przesyłaj argumenty Pythona do tensorów, aby ograniczyć cofanie się.

    Często argumenty Pythona są używane do kontrolowania hiperparametrów i konstrukcji wykresów — na przykład num_layers=10 lub training=True lub nonlinearity='relu' . Tak więc, jeśli argument Pythona ulegnie zmianie, sensowne jest ponowne prześledzenie wykresu.

    Jednak jest możliwe, że argument Pythona nie jest używany do kontrolowania konstrukcji grafu. W takich przypadkach zmiana wartości Pythona może wywołać niepotrzebne cofanie. Weźmy na przykład tę pętlę treningową, którą AutoGraph będzie dynamicznie rozwijał. Pomimo wielu śladów wygenerowany wykres jest w rzeczywistości identyczny, więc ponowne śledzenie nie jest konieczne.

def train_one_step():
  pass

@tf.function
def train(num_steps):
  print("Tracing with num_steps = ", num_steps)
  tf.print("Executing with num_steps = ", num_steps)
  for _ in tf.range(num_steps):
    train_one_step()

print("Retracing occurs for different Python arguments.")
train(num_steps=10)
train(num_steps=20)

print()
print("Traces are reused for Tensor arguments.")
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))
Retracing occurs for different Python arguments.
Tracing with num_steps =  10
Executing with num_steps =  10
Tracing with num_steps =  20
Executing with num_steps =  20

Traces are reused for Tensor arguments.
Tracing with num_steps =  Tensor("num_steps:0", shape=(), dtype=int32)
Executing with num_steps =  10
Executing with num_steps =  20

Jeśli chcesz wymusić odtworzenie, utwórz nową Function . Oddzielne obiekty Function gwarantują, że nie będą udostępniać śladów.

def f():
  print('Tracing!')
  tf.print('Executing')

tf.function(f)()
tf.function(f)()
Tracing!
Executing
Tracing!
Executing

Uzyskanie konkretnych funkcji

Za każdym razem, gdy funkcja jest śledzona, tworzona jest nowa konkretna funkcja. Możesz bezpośrednio uzyskać konkretną funkcję, używając get_concrete_function .

print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.constant("a"))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))
Obtaining concrete trace
Executing traced function
tf.Tensor(b'aa', shape=(), dtype=string)
tf.Tensor(b'bb', shape=(), dtype=string)
# You can also call get_concrete_function on an InputSpec
double_strings_from_inputspec = double.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.string))
print(double_strings_from_inputspec(tf.constant("c")))
tf.Tensor(b'cc', shape=(), dtype=string)

Wydrukowanie ConcreteFunction wyświetla podsumowanie jej argumentów wejściowych (z typami) i typu wyjściowego.

print(double_strings)
ConcreteFunction double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

Możesz również bezpośrednio pobrać sygnaturę konkretnej funkcji.

print(double_strings.structured_input_signature)
print(double_strings.structured_outputs)
((TensorSpec(shape=(), dtype=tf.string, name='a'),), {})
Tensor("Identity:0", shape=(), dtype=string)

Użycie konkretnego śladu z niekompatybilnymi typami spowoduje błąd

with assert_raises(tf.errors.InvalidArgumentError):
  double_strings(tf.constant(1))
Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/3196284684.py", line 2, in <module>
    double_strings(tf.constant(1))
tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_162 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_162]

Możesz zauważyć, że argumenty Pythona są traktowane w specjalny sposób w sygnaturze wejściowej konkretnej funkcji. Przed TensorFlow 2.3 argumenty Pythona były po prostu usuwane z podpisu konkretnej funkcji. Począwszy od TensorFlow 2.3, argumenty Pythona pozostają w sygnaturze, ale są ograniczone do przyjęcia wartości ustawionej podczas śledzenia.

@tf.function
def pow(a, b):
  return a ** b

square = pow.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)
print(square)
ConcreteFunction pow(a, b=2)
  Args:
    a: float32 Tensor, shape=<unknown>
  Returns:
    float32 Tensor, shape=<unknown>
assert square(tf.constant(10.0)) == 100

with assert_raises(TypeError):
  square(tf.constant(10.0), b=3)
Caught expected exception 
  <class 'TypeError'>:
Traceback (most recent call last):
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1721, in _call_impl
    cancellation_manager)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1765, in _call_with_flat_signature
    raise TypeError(f"{self._flat_signature_summary()} got unexpected "
TypeError: pow(a) got unexpected keyword arguments: b.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/2310937119.py", line 4, in <module>
    square(tf.constant(10.0), b=3)
TypeError: ConcreteFunction pow(a, b) was constructed with int value 2 in b, but was called with int value 3.

Uzyskiwanie wykresów

Każda konkretna funkcja jest wywoływanym opakowaniem wokół tf.Graph . Chociaż pobieranie rzeczywistego obiektu tf.Graph nie jest czymś, co normalnie trzeba robić, można go łatwo uzyskać z dowolnej konkretnej funkcji.

graph = double_strings.graph
for node in graph.as_graph_def().node:
  print(f'{node.input} -> {node.name}')
[] -> a
['a', 'a'] -> add
['add'] -> Identity

Debugowanie

Ogólnie rzecz biorąc, debugowanie kodu jest łatwiejsze w trybie przyspieszonym niż w tf.function . Powinieneś upewnić się, że twój kod wykonuje się bez błędów w trybie przyspieszonym przed dekorowaniem za pomocą tf.function . Aby pomóc w procesie debugowania, możesz wywołać tf.config.run_functions_eagerly(True) , aby globalnie wyłączyć i ponownie tf.function .

Oto kilka wskazówek dotyczących śledzenia problemów, które pojawiają się tylko w tf.function :

  • Zwykłe, stare wywołania print w Pythonie są wykonywane tylko podczas śledzenia, pomagając ci wyśledzić, kiedy twoja funkcja zostanie (ponownie) śledzona.
  • tf.print będą wykonywane za każdym razem i mogą pomóc w śledzeniu wartości pośrednich podczas wykonywania.
  • tf.debugging.enable_check_numerics to łatwy sposób na wyśledzenie, gdzie są tworzone NaN i Inf.
  • pdb ( debugger Pythona ) może pomóc w zrozumieniu, co się dzieje podczas śledzenia. (Zastrzeżenie: pdb wrzuci Cię do kodu źródłowego przekształconego w AutoGraph).

Transformacje autografu

AutoGraph to biblioteka, która jest domyślnie włączona w tf.function i przekształca podzbiór gorliwego kodu Pythona w operacje TensorFlow zgodne z wykresami. Obejmuje to sterowanie przepływem, takie jak if , for , while .

Operacje TensorFlow, takie jak tf.cond i tf.while_loop , nadal działają, ale przepływ sterowania jest często łatwiejszy do napisania i zrozumienia w Pythonie.

# A simple loop

@tf.function
def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

f(tf.random.uniform([5]))
[0.666458249 0.713946581 0.723879576 0.330758929 0.184087753]
[0.582645297 0.613145649 0.619306684 0.319202513 0.182036072]
[0.524585426 0.546337605 0.550645113 0.308785647 0.18005164]
[0.481231302 0.497770309 0.501003504 0.299331933 0.178130865]
[0.447229207 0.460361809 0.462906033 0.290701121 0.176270396]
[0.419618756 0.430379033 0.432449728 0.282779962 0.174467146]
[0.396609187 0.405638 0.407366514 0.275476 0.172718227]
[0.377043903 0.384762734 0.386234313 0.268712848 0.17102097]
[0.360137492 0.366836458 0.368109286 0.262426734 0.169372901]
[0.345335096 0.351221472 0.352336824 0.256563932 0.167771652]
[0.332231969 0.337458342 0.338446289 0.251078814 0.166215062]
[0.320524871 0.325206399 0.326089561 0.24593246 0.164701089]
[0.309981436 0.314206958 0.31500268 0.241091311 0.163227797]
[0.300420195 0.304259449 0.304981351 0.236526251 0.161793426]
[0.291697085 0.295205742 0.295864582 0.232211992 0.160396278]
[0.283696055 0.286919087 0.287523568 0.228126258 0.159034774]
[0.276322395 0.279296666 0.27985391 0.224249557 0.157707423]
[0.269497961 0.272254 0.272769839 0.220564634 0.15641281]
[0.263157606 0.265720904 0.266200244 0.21705614 0.155149609]
[0.257246554 0.259638608 0.260085613 0.213710397 0.153916568]
[0.251718313 0.25395745 0.254375577 0.210515186 0.152712509]
[0.246533215 0.248635098 0.249027327 0.207459539 0.151536316]
[0.241657034 0.243635193 0.244004101 0.204533577 0.15038693]
[0.237060249 0.238926381 0.239274174 0.201728329 0.149263337]
[0.232717097 0.234481394 0.234810054 0.199035719 0.148164615]
[0.228605017 0.230276451 0.230587661 0.196448416 0.147089839]
[0.224704206 0.226290658 0.22658591 0.193959698 0.14603813]
[0.220997125 0.222505584 0.222786173 0.191563457 0.145008713]
<tf.Tensor: shape=(5,), dtype=float32, numpy=
array([0.21746822, 0.21890487, 0.21917202, 0.18925412, 0.14400077],
      dtype=float32)>

Jeśli jesteś ciekawy, możesz sprawdzić kod generowany przez autograf.

print(tf.autograph.to_code(f.python_function))
def tf__f(x):
    with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        do_return = False
        retval_ = ag__.UndefinedReturnValue()

        def get_state():
            return (x,)

        def set_state(vars_):
            nonlocal x
            (x,) = vars_

        def loop_body():
            nonlocal x
            ag__.converted_call(ag__.ld(tf).print, (ag__.ld(x),), None, fscope)
            x = ag__.converted_call(ag__.ld(tf).tanh, (ag__.ld(x),), None, fscope)

        def loop_test():
            return (ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) > 1)
        ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {})
        try:
            do_return = True
            retval_ = ag__.ld(x)
        except:
            do_return = False
            raise
        return fscope.ret(retval_, do_return)

Warunkowe

AutoGraph przekonwertuje niektóre instrukcje if <condition> na równoważne wywołania tf.cond . To podstawienie jest wykonywane, jeśli <condition> jest tensorem. W przeciwnym razie instrukcja if jest wykonywana jako warunek warunkowy Pythona.

Warunek Pythona jest wykonywany podczas śledzenia, więc do wykresu zostanie dodana dokładnie jedna gałąź warunku. Bez AutoGraph ten śledzony wykres nie byłby w stanie wykonać alternatywnej gałęzi, jeśli istnieje przepływ sterowania zależny od danych.

tf.cond śledzi i dodaje obie gałęzie warunku do grafu, dynamicznie wybierając gałąź w czasie wykonania. Śledzenie może mieć niezamierzone skutki uboczne; sprawdź efekty śledzenia Autografu, aby uzyskać więcej informacji.

@tf.function
def fizzbuzz(n):
  for i in tf.range(1, n + 1):
    print('Tracing for loop')
    if i % 15 == 0:
      print('Tracing fizzbuzz branch')
      tf.print('fizzbuzz')
    elif i % 3 == 0:
      print('Tracing fizz branch')
      tf.print('fizz')
    elif i % 5 == 0:
      print('Tracing buzz branch')
      tf.print('buzz')
    else:
      print('Tracing default branch')
      tf.print(i)

fizzbuzz(tf.constant(5))
fizzbuzz(tf.constant(20))
Tracing for loop
Tracing fizzbuzz branch
Tracing fizz branch
Tracing buzz branch
Tracing default branch
1
2
fizz
4
buzz
1
2
fizz
4
buzz
fizz
7
8
fizz
buzz
11
fizz
13
14
fizzbuzz
16
17
fizz
19
buzz

Zapoznaj się z dokumentacją referencyjną, aby uzyskać dodatkowe ograniczenia dotyczące instrukcji if konwertowanych za pomocą funkcji AutoGraph.

Pętle

AutoGraph przekonwertuje niektóre instrukcje for i while na równoważne operacje pętli TensorFlow, takie jak tf.while_loop . Jeśli nie zostanie przekonwertowana, pętla for lub while jest wykonywana jako pętla Pythona.

Ta zamiana jest dokonywana w następujących sytuacjach:

  • for x in y : jeśli y jest tensorem, przekonwertuj na tf.while_loop . W szczególnym przypadku, gdy y jest tf.data.Dataset , generowana jest kombinacja operacji tf.data.Dataset .
  • while <condition> : jeśli <condition> jest tensorem, przekonwertuj na tf.while_loop .

Pętla Pythona jest wykonywana podczas śledzenia, dodając dodatkowe operacje do tf.Graph dla każdej iteracji pętli.

Pętla TensorFlow śledzi treść pętli i dynamicznie wybiera liczbę iteracji do uruchomienia w czasie wykonywania. Treść pętli pojawia się tylko raz w wygenerowanym tf.Graph .

Zapoznaj się z dokumentacją referencyjną, aby uzyskać dodatkowe ograniczenia dotyczące instrukcji konwertowanych przez AutoGraph for i while .

Zapętlanie danych Pythona

Częstą pułapką jest zapętlenie danych Pythona/NumPy w tf.function . Ta pętla zostanie wykonana podczas procesu śledzenia, dodając kopię modelu do tf.Graph dla każdej iteracji pętli.

Jeśli chcesz opakować całą pętlę uczącą w tf.function , najbezpieczniejszym sposobem jest zawinięcie danych jako tf.data.Dataset , tak aby AutoGraph dynamicznie rozwijał pętlę uczącą.

def measure_graph_size(f, *args):
  g = f.get_concrete_function(*args).graph
  print("{}({}) contains {} nodes in its graph".format(
      f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))

@tf.function
def train(dataset):
  loss = tf.constant(0)
  for x, y in dataset:
    loss += tf.abs(y - x) # Some dummy computation.
  return loss

small_data = [(1, 1)] * 3
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)

measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: big_data, (tf.int32, tf.int32)))
train([(1, 1), (1, 1), (1, 1)]) contains 11 nodes in its graph
train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 6 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 6 nodes in its graph

Podczas pakowania danych Pythona/NumPy w zestaw danych należy pamiętać o tf.data.Dataset.from_generator i tf.data.Dataset.from_tensors . Pierwsza z nich zachowa dane w Pythonie i pobierze je za pomocą tf.py_function , co może mieć wpływ na wydajność, podczas gdy druga utworzy kopię danych jako jeden duży węzeł tf.constant() na wykresie, co może mieć wpływ na pamięć.

Odczytywanie danych z plików za pośrednictwem TFRecordDataset , CsvDataset itp. jest najskuteczniejszym sposobem na wykorzystanie danych, ponieważ sam TensorFlow może zarządzać asynchronicznym ładowaniem i wstępnym pobieraniem danych, bez konieczności angażowania Pythona. Aby dowiedzieć się więcej, zobacz przewodnik tf.data : Build TensorFlow input pipelines guide.

Kumulacja wartości w pętli

Powszechnym wzorcem jest akumulowanie wartości pośrednich z pętli. Zwykle osiąga się to poprzez dodawanie do listy Pythona lub dodawanie wpisów do słownika Pythona. Jednakże, ponieważ są to efekty uboczne Pythona, nie będą działać zgodnie z oczekiwaniami w dynamicznie rozwijanej pętli. Użyj tf.TensorArray , aby zgromadzić wyniki z dynamicznie rozwijanej pętli.

batch_size = 2
seq_len = 3
feature_size = 4

def rnn_step(inp, state):
  return inp + state

@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
  # [batch, time, features] -> [time, batch, features]
  input_data = tf.transpose(input_data, [1, 0, 2])
  max_seq_len = input_data.shape[0]

  states = tf.TensorArray(tf.float32, size=max_seq_len)
  state = initial_state
  for i in tf.range(max_seq_len):
    state = rnn_step(input_data[i], state)
    states = states.write(i, state)
  return tf.transpose(states.stack(), [1, 0, 2])

dynamic_rnn(rnn_step,
            tf.random.uniform([batch_size, seq_len, feature_size]),
            tf.zeros([batch_size, feature_size]))
<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[0.06309307, 0.9938811 , 0.90789986, 0.42136216],
        [0.44997275, 1.9107027 , 1.0716251 , 0.717237  ],
        [0.6026064 , 2.1622117 , 1.4164022 , 1.4153863 ]],

       [[0.04946005, 0.69127274, 0.56848884, 0.22406638],
        [0.8148316 , 1.0278493 , 0.6207781 , 1.1935129 ],
        [0.9178308 , 1.320889  , 0.989761  , 2.0120025 ]]], dtype=float32)>

Ograniczenia

Function TensorFlow ma kilka ograniczeń projektowych, o których należy pamiętać podczas konwertowania funkcji Pythona na Function .

Wywoływanie efektów ubocznych Pythona

Efekty uboczne, takie jak drukowanie, dołączanie do list i mutowanie globalnych, mogą zachowywać się nieoczekiwanie wewnątrz Function , czasami wykonując dwa razy lub nie wszystkie. Zdarzają się tylko przy pierwszym wywołaniu Function z zestawem danych wejściowych. Następnie wyśledzony tf.Graph jest ponownie wykonywany, bez wykonywania kodu Pythona.

Ogólną zasadą jest unikanie polegania na efektach ubocznych Pythona w swojej logice i używanie ich tylko do debugowania śladów. W przeciwnym razie interfejsy API TensorFlow, takie jak tf.data , tf.print , tf.summary , tf.Variable.assign i tf.TensorArray , to najlepszy sposób na zapewnienie wykonania kodu przez środowisko wykonawcze TensorFlow przy każdym wywołaniu.

@tf.function
def f(x):
  print("Traced with", x)
  tf.print("Executed with", x)

f(1)
f(1)
f(2)
Traced with 1
Executed with 1
Executed with 1
Traced with 2
Executed with 2

Jeśli chcesz wykonać kod Pythona podczas każdego wywołania Function , tf.py_function jest kreskowaniem wyjścia. Wadą tf.py_function jest to, że nie jest przenośny ani szczególnie wydajny, nie można go zapisać za pomocą SavedModel i nie działa dobrze w konfiguracjach rozproszonych (wiele GPU, TPU). Ponadto, ponieważ tf.py_function musi być podłączone do grafu, rzutuje wszystkie wejścia/wyjścia na tensory.

Zmiana globalnych i wolnych zmiennych Pythona

Zmiana globalnych i wolnych zmiennych Pythona liczy się jako efekt uboczny Pythona, więc dzieje się to tylko podczas śledzenia.

external_list = []

@tf.function
def side_effect(x):
  print('Python side effect')
  external_list.append(x)

side_effect(1)
side_effect(1)
side_effect(1)
# The list append only happened once!
assert len(external_list) == 1
Python side effect

Czasami nieoczekiwane zachowania są bardzo trudne do zauważenia. W poniższym przykładzie counter ma na celu zabezpieczenie przyrostu zmiennej. Jednak ponieważ jest to liczba całkowita Pythona, a nie obiekt TensorFlow, jego wartość jest przechwytywana podczas pierwszego śledzenia. Gdy tf.function jest funkcja tf., parametr assign_add zostanie zarejestrowany bezwarunkowo na wykresie bazowym. Dlatego v będzie wzrastać o 1 za każdym razem, gdy tf.function jest funkcja tf. Ten problem jest powszechny wśród użytkowników, którzy próbują migrować swój kod Tensorflow w trybie Grpah do Tensorflow 2 za pomocą dekoratorów tf.function , gdy efekty uboczne Pythona ( counter w przykładzie) są używane do określenia, które operacje mają zostać uruchomione (w przykładzie assign_add ). Zwykle użytkownicy zdają sobie z tego sprawę dopiero po zobaczeniu podejrzanych wyników liczbowych lub znacznie niższej wydajności niż oczekiwano (np. jeśli strzeżona operacja jest bardzo kosztowna).

class Model(tf.Module):
  def __init__(self):
    self.v = tf.Variable(0)
    self.counter = 0

  @tf.function
  def __call__(self):
    if self.counter == 0:
      # A python side-effect
      self.counter += 1
      self.v.assign_add(1)

    return self.v

m = Model()
for n in range(3):
  print(m().numpy()) # prints 1, 2, 3
1
2
3

Obejściem pozwalającym na osiągnięcie oczekiwanego zachowania jest użycie tf.init_scope do usunięcia operacji poza wykres funkcji. Gwarantuje to, że przyrost zmiennej jest wykonywany tylko raz w czasie śledzenia. Należy zauważyć, że init_scope ma inne skutki uboczne, w tym wyczyszczony przepływ sterowania i taśmę gradientu. Czasami użycie init_scope może stać się zbyt skomplikowane, aby zarządzać realistycznie.

class Model(tf.Module):
  def __init__(self):
    self.v = tf.Variable(0)
    self.counter = 0

  @tf.function
  def __call__(self):
    if self.counter == 0:
      # Lifts ops out of function-building graphs
      with tf.init_scope():
        self.counter += 1
        self.v.assign_add(1)

    return self.v

m = Model()
for n in range(3):
  print(m().numpy()) # prints 1, 1, 1
1
1
1

Podsumowując, należy unikać mutowania obiektów Pythona, takich jak liczby całkowite lub kontenery, takie jak listy, które znajdują się poza Function . Zamiast tego użyj argumentów i obiektów TF. Na przykład sekcja „Akumulacja wartości w pętli” zawiera jeden przykład implementacji operacji podobnych do listy.

W niektórych przypadkach możesz przechwytywać i manipulować stanem, jeśli jest to tf.Variable . W ten sposób wagi modeli Keras są aktualizowane za pomocą powtarzających się wywołań tej samej funkcji ConcreteFunction .

Korzystanie z iteratorów i generatorów Pythona

Wiele funkcji Pythona, takich jak generatory i iteratory, opiera się na środowisku wykonawczym Pythona do śledzenia stanu. Ogólnie rzecz biorąc, chociaż te konstrukcje działają zgodnie z oczekiwaniami w trybie przyspieszonym, są przykładami efektów ubocznych Pythona i dlatego występują tylko podczas śledzenia.

@tf.function
def buggy_consume_next(iterator):
  tf.print("Value:", next(iterator))

iterator = iter([1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)
Value: 1
Value: 1
Value: 1

Podobnie jak TensorFlow ma wyspecjalizowany tf.TensorArray dla konstrukcji list, ma wyspecjalizowany tf.data.Iterator dla konstrukcji iteracyjnych. Zapoznaj się z sekcją dotyczącą przekształceń Autografu , aby zapoznać się z omówieniem. Ponadto tf.data API może pomóc w implementacji wzorców generatora:

@tf.function
def good_consume_next(iterator):
  # This is ok, iterator is a tf.data.Iterator
  tf.print("Value:", next(iterator))

ds = tf.data.Dataset.from_tensor_slices([1, 2, 3])
iterator = iter(ds)
good_consume_next(iterator)
good_consume_next(iterator)
good_consume_next(iterator)
Value: 1
Value: 2
Value: 3

Wszystkie dane wyjściowe funkcji tf.function muszą być wartościami zwracanymi

Z wyjątkiem tf.Variable s, tf.function musi zwrócić wszystkie swoje wyjścia. Próba bezpośredniego dostępu do dowolnych tensorów z funkcji bez przechodzenia przez zwracane wartości powoduje „przecieki”.

Na przykład poniższa funkcja "przecieka" tensor a przez globalny x Pythona:

x = None

@tf.function
def leaky_function(a):
  global x
  x = a + 1  # Bad - leaks local tensor
  return a + 2

correct_a = leaky_function(tf.constant(1))

print(correct_a.numpy())  # Good - value obtained from function's returns
try:
  x.numpy()  # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
  print(expected)
3
'Tensor' object has no attribute 'numpy'

Dzieje się tak, nawet jeśli zwracana jest również wartość, która wyciekła:

@tf.function
def leaky_function(a):
  global x
  x = a + 1  # Bad - leaks local tensor
  return x  # Good - uses local tensor

correct_a = leaky_function(tf.constant(1))

print(correct_a.numpy())  # Good - value obtained from function's returns
try:
  x.numpy()  # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
  print(expected)

@tf.function
def captures_leaked_tensor(b):
  b += x  # Bad - `x` is leaked from `leaky_function`
  return b

with assert_raises(TypeError):
  captures_leaked_tensor(tf.constant(2))
2
'Tensor' object has no attribute 'numpy'
Caught expected exception 
  <class 'TypeError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/566849597.py", line 21, in <module>
    captures_leaked_tensor(tf.constant(2))
TypeError: Originated from a graph execution error.

The graph execution error is detected at a node built at (most recent call last):
>>>  File /usr/lib/python3.7/runpy.py, line 193, in _run_module_as_main
>>>  File /usr/lib/python3.7/runpy.py, line 85, in _run_code
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel_launcher.py, line 16, in <module>
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/traitlets/config/application.py, line 846, in launch_instance
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelapp.py, line 677, in start
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tornado/platform/asyncio.py, line 199, in start
>>>  File /usr/lib/python3.7/asyncio/base_events.py, line 534, in run_forever
>>>  File /usr/lib/python3.7/asyncio/base_events.py, line 1771, in _run_once
>>>  File /usr/lib/python3.7/asyncio/events.py, line 88, in _run
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 457, in dispatch_queue
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 446, in process_one
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 353, in dispatch_shell
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 648, in execute_request
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/ipkernel.py, line 353, in do_execute
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/zmqshell.py, line 533, in run_cell
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 2902, in run_cell
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 2947, in _run_cell
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/async_helpers.py, line 68, in _pseudo_sync_runner
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 3173, in run_cell_async
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 3364, in run_ast_nodes
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 3444, in run_code
>>>  File /tmp/ipykernel_26244/566849597.py, line 7, in <module>
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/traceback_utils.py, line 150, in error_handler
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 910, in __call__
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 958, in _call
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 781, in _initialize
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py, line 3157, in _get_concrete_function_internal_garbage_collected
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py, line 3557, in _maybe_define_function
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py, line 3402, in _create_graph_function
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py, line 1143, in func_graph_from_py_func
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 672, in wrapped_fn
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py, line 1125, in autograph_handler
>>>  File /tmp/ipykernel_26244/566849597.py, line 4, in leaky_function
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/traceback_utils.py, line 150, in error_handler
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py, line 1383, in binary_op_wrapper
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/traceback_utils.py, line 150, in error_handler
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py, line 1096, in op_dispatch_handler
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py, line 1737, in _add_dispatch
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/gen_math_ops.py, line 476, in add_v2
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py, line 746, in _apply_op_helper
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py, line 691, in _create_op_internal
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/ops.py, line 3705, in _create_op_internal
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/ops.py, line 2101, in __init__

Error detected in node 'add' defined at: File "/tmp/ipykernel_26244/566849597.py", line 4, in leaky_function

TypeError: tf.Graph captured an external symbolic tensor. The symbolic tensor 'add:0' created by node 'add' is captured by the tf.Graph being executed as an input. But a tf.Graph is not allowed to take symbolic tensors from another graph as its inputs. Make sure all captured inputs of the executing tf.Graph are not symbolic tensors. Use return values, explicit Python locals or TensorFlow collections to access it. Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information.

Zwykle takie przecieki występują, gdy używasz instrukcji lub struktur danych Pythona. Oprócz wycieku niedostępnych tensorów, takie instrukcje są również prawdopodobnie błędne, ponieważ liczą się jako efekty uboczne Pythona i nie są gwarantowane do wykonania przy każdym wywołaniu funkcji.

Typowe sposoby wyciekania lokalnych tensorów obejmują również mutowanie zewnętrznej kolekcji Pythona lub obiektu:

class MyClass:

  def __init__(self):
    self.field = None

external_list = []
external_object = MyClass()

def leaky_function():
  a = tf.constant(1)
  external_list.append(a)  # Bad - leaks tensor
  external_object.field = a  # Bad - leaks tensor

Rekurencyjne funkcje tf. nie są obsługiwane

Function rekurencyjne nie są obsługiwane i mogą powodować nieskończone pętle. Na przykład,

@tf.function
def recursive_fn(n):
  if n > 0:
    return recursive_fn(n - 1)
  else:
    return 1

with assert_raises(Exception):
  recursive_fn(tf.constant(5))  # Bad - maximum recursion error.
Caught expected exception 
  <class 'Exception'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/2233998312.py", line 9, in <module>
    recursive_fn(tf.constant(5))  # Bad - maximum recursion error.
tensorflow.python.autograph.impl.api.StagingError: in user code:

    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/usr/lib/python3.7/abc.py", line 139, in __instancecheck__
        return _abc_instancecheck(cls, instance)

    RecursionError: maximum recursion depth exceeded while calling a Python object

Nawet jeśli Function rekurencyjna wydaje się działać, funkcja Pythona będzie śledzona wiele razy i może mieć wpływ na wydajność. Na przykład,

@tf.function
def recursive_fn(n):
  if n > 0:
    print('tracing')
    return recursive_fn(n - 1)
  else:
    return 1

recursive_fn(5)  # Warning - multiple tracings
tracing
tracing
tracing
tracing
tracing
<tf.Tensor: shape=(), dtype=int32, numpy=1>

Znane problemy

Jeśli twoja Function nie ocenia poprawnie, błąd może być wyjaśniony przez te znane problemy, które planuje się naprawić w przyszłości.

W zależności od zmiennych globalnych i wolnych Pythona

Function tworzy nową funkcję ConcreteFunction po wywołaniu z nową wartością argumentu Pythona. Jednak nie robi tego w przypadku zamknięcia Pythona, zmiennych globalnych lub nielokalnych tej Function . Jeśli ich wartość zmienia się między wywołaniami Function , Function nadal będzie używać wartości, które miały podczas śledzenia. Różni się to od działania zwykłych funkcji Pythona.

Z tego powodu powinieneś postępować zgodnie z funkcjonalnym stylem programowania, który używa argumentów zamiast zamykać nazwy zewnętrzne.

@tf.function
def buggy_add():
  return 1 + foo

@tf.function
def recommended_add(foo):
  return 1 + foo

foo = 1
print("Buggy:", buggy_add())
print("Correct:", recommended_add(foo))
Buggy: tf.Tensor(2, shape=(), dtype=int32)
Correct: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo = 100
print("Buggy:", buggy_add())  # Did not change!
print("Correct:", recommended_add(foo))
Updating the value of `foo` to 100!
Buggy: tf.Tensor(2, shape=(), dtype=int32)
Correct: tf.Tensor(101, shape=(), dtype=int32)

Innym sposobem aktualizacji wartości globalnej jest uczynienie jej tf.Variable i użycie zamiast niej metody Variable.assign .

@tf.function
def variable_add():
  return 1 + foo

foo = tf.Variable(1)
print("Variable:", variable_add())
Variable: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo.assign(100)
print("Variable:", variable_add())
Updating the value of `foo` to 100!
Variable: tf.Tensor(101, shape=(), dtype=int32)

W zależności od obiektów Pythona

Rekomendacja przekazywania obiektów Pythona jako argumentów do tf.function zawiera szereg znanych problemów, które powinny zostać naprawione w przyszłości. Ogólnie rzecz biorąc, możesz polegać na spójnym śledzeniu, jeśli jako argument lub przekażesz inną instancję obiektu do Function , użyj struktury podstawowej Pythona lub struktury zgodnej z tf.nest . Jednak Function nie utworzy nowego śladu, gdy przekażesz ten sam obiekt i zmienisz tylko jego atrybuty .

class SimpleModel(tf.Module):
  def __init__(self):
    # These values are *not* tf.Variables.
    self.bias = 0.
    self.weight = 2.

@tf.function
def evaluate(model, x):
  return model.weight * x + model.bias

simple_model = SimpleModel()
x = tf.constant(10.)
print(evaluate(simple_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
simple_model.bias += 5.0
print(evaluate(simple_model, x))  # Didn't change :(
Adding bias!
tf.Tensor(20.0, shape=(), dtype=float32)

Użycie tej samej Function do oceny zaktualizowanego wystąpienia modelu będzie błędne, ponieważ zaktualizowany model ma ten sam klucz pamięci podręcznej co model oryginalny.

Z tego powodu zaleca się napisanie Function , aby uniknąć zależności od zmiennych atrybutów obiektu lub tworzenia nowych obiektów.

Jeśli nie jest to możliwe, jednym obejściem jest tworzenie nowych Function za każdym razem, gdy modyfikujesz swój obiekt, aby wymusić odtworzenie:

def evaluate(model, x):
  return model.weight * x + model.bias

new_model = SimpleModel()
evaluate_no_bias = tf.function(evaluate).get_concrete_function(new_model, x)
# Don't pass in `new_model`, `Function` already captured its state during tracing.
print(evaluate_no_bias(x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
new_model.bias += 5.0
# Create new Function and ConcreteFunction since you modified new_model.
evaluate_with_bias = tf.function(evaluate).get_concrete_function(new_model, x)
print(evaluate_with_bias(x)) # Don't pass in `new_model`.
Adding bias!
tf.Tensor(25.0, shape=(), dtype=float32)

Ponieważ powrót może być kosztowny , możesz użyć tf.Variable s jako atrybutów obiektu, które można mutować (ale nie zmieniać, uważaj!) w celu uzyskania podobnego efektu bez konieczności powrotu.

class BetterModel:

  def __init__(self):
    self.bias = tf.Variable(0.)
    self.weight = tf.Variable(2.)

@tf.function
def evaluate(model, x):
  return model.weight * x + model.bias

better_model = BetterModel()
print(evaluate(better_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
better_model.bias.assign_add(5.0)  # Note: instead of better_model.bias += 5
print(evaluate(better_model, x))  # This works!
Adding bias!
tf.Tensor(25.0, shape=(), dtype=float32)

Tworzenie tf.Zmiennych

Function obsługuje tylko pojedyncze tf.Variable s utworzone raz przy pierwszym wywołaniu i ponownie użyte w kolejnych wywołaniach funkcji. Poniższy fragment kodu spowoduje utworzenie nowej tf.Variable w każdym wywołaniu funkcji, co spowoduje wystąpienie wyjątku ValueError .

Przykład:

@tf.function
def f(x):
  v = tf.Variable(1.0)
  return v

with assert_raises(ValueError):
  f(1.0)
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/3018268426.py", line 7, in <module>
    f(1.0)
ValueError: in user code:

    File "/tmp/ipykernel_26244/3018268426.py", line 3, in f  *
        v = tf.Variable(1.0)

    ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.

Typowym wzorcem używanym do obejścia tego ograniczenia jest rozpoczęcie od wartości None w Pythonie, a następnie warunkowe utworzenie tf.Variable , jeśli wartość to None:

class Count(tf.Module):
  def __init__(self):
    self.count = None

  @tf.function
  def __call__(self):
    if self.count is None:
      self.count = tf.Variable(0)
    return self.count.assign_add(1)

c = Count()
print(c())
print(c())
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Używanie z wieloma optymalizatorami Keras

Możesz napotkać ValueError: tf.function only supports singleton tf.Variables created on the first call. gdy używasz więcej niż jednego optymalizatora Keras z tf.function . Ten błąd występuje, ponieważ optymalizatory wewnętrznie tworzą tf.Variables , gdy stosują gradienty po raz pierwszy.

opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)

@tf.function
def train_step(w, x, y, optimizer):
   with tf.GradientTape() as tape:
       L = tf.reduce_sum(tf.square(w*x - y))
   gradients = tape.gradient(L, [w])
   optimizer.apply_gradients(zip(gradients, [w]))

w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])

train_step(w, x, y, opt1)
print("Calling `train_step` with different optimizer...")
with assert_raises(ValueError):
  train_step(w, x, y, opt2)
Calling `train_step` with different optimizer...
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/3167358578.py", line 18, in <module>
    train_step(w, x, y, opt2)
ValueError: in user code:

    File "/tmp/ipykernel_26244/3167358578.py", line 9, in train_step  *
        optimizer.apply_gradients(zip(gradients, [w]))
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 639, in apply_gradients  **
        self._create_all_weights(var_list)
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 828, in _create_all_weights
        _ = self.iterations
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 835, in __getattribute__
        return super(OptimizerV2, self).__getattribute__(name)
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 995, in iterations
        aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 1202, in add_weight
        aggregation=aggregation)
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/engine/base_layer_utils.py", line 129, in make_variable
        shape=variable_shape if variable_shape else None)

    ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.

Jeśli musisz zmienić optymalizator podczas uczenia, obejściem tego problemu jest utworzenie nowej Function dla każdego optymalizatora, wywołując bezpośrednio funkcję ConcreteFunction .

opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)

# Not a tf.function.
def train_step(w, x, y, optimizer):
   with tf.GradientTape() as tape:
       L = tf.reduce_sum(tf.square(w*x - y))
   gradients = tape.gradient(L, [w])
   optimizer.apply_gradients(zip(gradients, [w]))

w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])

# Make a new Function and ConcreteFunction for each optimizer.
train_step_1 = tf.function(train_step).get_concrete_function(w, x, y, opt1)
train_step_2 = tf.function(train_step).get_concrete_function(w, x, y, opt2)
for i in range(10):
  if i % 2 == 0:
    train_step_1(w, x, y) # `opt1` is not used as a parameter. 
  else:
    train_step_2(w, x, y) # `opt2` is not used as a parameter.

Używanie z wieloma modelami Keras

Możesz również napotkać ValueError: tf.function only supports singleton tf.Variables created on the first call. podczas przekazywania różnych instancji modelu do tego samego Function .

Ten błąd występuje, ponieważ modele Keras (które nie mają zdefiniowanego kształtu wejściowego ) i warstwy Keras tworzą tf.Variables przy pierwszym wywołaniu. Być może próbujesz zainicjować te zmienne wewnątrz Function , który został już wywołany. Aby uniknąć tego błędu, spróbuj wywołać model.build(input_shape) , aby zainicjować wszystkie wagi przed uczeniem modelu.

Dalsza lektura

Aby dowiedzieć się, jak eksportować i ładować Function , zobacz przewodnik SavedModel . Aby dowiedzieć się więcej o optymalizacjach wykresów wykonywanych po śledzeniu, zapoznaj się z przewodnikiem Grappler . Aby dowiedzieć się, jak zoptymalizować potok danych i profilować model, zapoznaj się z przewodnikiem Profiler .