Bessere Leistung mit tf.function

Auf TensorFlow.org ansehen In Google Colab ausführen Quelle auf GitHub anzeigen Notizbuch herunterladen

In TensorFlow 2 ist die Eager-Ausführung standardmäßig aktiviert. Die Benutzeroberfläche ist intuitiv und flexibel (die Ausführung einmaliger Vorgänge ist viel einfacher und schneller), aber dies kann zu Lasten der Leistung und Bereitstellung gehen.

Mit tf.function können tf.function aus Ihren Programmen Grafiken tf.function . Es ist ein Transformationstool, das Python-unabhängige Datenflussdiagramme aus Ihrem Python-Code erstellt. Dies hilft Ihnen beim Erstellen leistungsfähiger und tragbarer Modelle, und es ist erforderlich, SavedModel zu verwenden.

Dieser Leitfaden wird Ihnen helfen, sich zu tf.function wie tf.function unter der Haube funktioniert, damit Sie es effektiv nutzen können.

Die wichtigsten Erkenntnisse und Empfehlungen sind:

  • Debuggen Sie im Eager-Modus und dekorieren Sie dann mit @tf.function .
  • Verlassen Sie sich nicht auf Python-Nebeneffekte wie Objektmutation oder Listenanhänge.
  • tf.function funktioniert am besten mit TensorFlow-Operationen; NumPy- und Python-Aufrufe werden in Konstanten umgewandelt.

Einrichten

import tensorflow as tf

Definieren Sie eine Hilfsfunktion, um die möglichen Fehlerarten zu demonstrieren:

import traceback
import contextlib

# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
  try:
    yield
  except error_class as e:
    print('Caught expected exception \n  {}:'.format(error_class))
    traceback.print_exc(limit=2)
  except Exception as e:
    raise e
  else:
    raise Exception('Expected {} to be raised but no error was raised!'.format(
        error_class))

Grundlagen

Verwendung

Eine Function Sie definieren (zum Beispiel durch Anwenden des Dekorators @tf.function ) ist wie eine @tf.function TensorFlow: Sie können sie eifrig ausführen; Sie können Gradienten berechnen; und so weiter.

@tf.function  # The decorator converts `add` into a `Function`.
def add(a, b):
  return a + b

add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
       [2., 2.]], dtype=float32)>
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
  result = add(v, 1.0)
tape.gradient(result, v)
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

Sie können Function s innerhalb anderer Function s verwenden.

@tf.function
def dense_layer(x, w, b):
  return add(tf.matmul(x, w), b)

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
       [3., 3.],
       [3., 3.]], dtype=float32)>

Function s kann schneller sein als Eager-Code, insbesondere bei Graphen mit vielen kleinen Operationen. Aber bei Graphen mit ein paar teuren Operationen (wie Faltungen) sehen Sie möglicherweise nicht viel Geschwindigkeit.

import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)

@tf.function
def conv_fn(image):
  return conv_layer(image)

image = tf.zeros([1, 200, 200, 100])
# Warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")
Eager conv: 0.23302616399996623
Function conv: 0.21780993200013654
Note how there's not much difference in performance for convolutions

Verfolgung

In diesem Abschnitt wird erläutert, wie Function unter der Haube Function , einschließlich Implementierungsdetails, die sich in Zukunft ändern können . Wenn Sie jedoch verstehen, warum und wann das Tracing stattfindet, ist es viel einfacher, tf.function effektiv zu verwenden!

Was ist "Nachverfolgung"?

Eine Function führt Ihr Programm in einem TensorFlow-Graph aus . Ein tf.Graph kann jedoch nicht alle Dinge darstellen, die Sie in ein eifriges TensorFlow-Programm schreiben würden. Python unterstützt beispielsweise Polymorphismus, aber tf.Graph erfordert, dass seine Eingaben einen bestimmten Datentyp und eine bestimmte Dimension haben. Oder Sie können Nebenaufgaben wie das Lesen von Befehlszeilenargumenten, das Auslösen eines Fehlers oder das Arbeiten mit einem komplexeren Python-Objekt ausführen; keines dieser Dinge kann in einem tf.Graph .

Function schließt diese Lücke, indem sie Ihren Code in zwei Phasen aufteilt:

1) In der ersten Phase, die als " Tracing " bezeichnet wird, erstellt Function einen neuen tf.Graph . Python - Code läuft normal, aber alle TensorFlow Operationen (wie das Hinzufügen von zwei Tensoren) werden abgegrenzt: sie durch die erfasst werden tf.Graph und nicht ausgeführt.

2) In der zweiten Stufe wird ein tf.Graph der alles enthält, was in der ersten Stufe zurückgestellt wurde. Diese Phase ist viel schneller als die Verfolgungsphase.

Abhängig von seinen Eingaben wird Function nicht immer die erste Stufe ausführen, wenn sie aufgerufen wird. Siehe "Regeln für die Rückverfolgung" unten, um ein besseres Gefühl dafür zu bekommen, wie diese Bestimmung getroffen wird. Wenn Sie die erste Stufe überspringen und nur die zweite Stufe ausführen, erhalten Sie die hohe Leistung von TensorFlow.

Wenn Function beschließt, tf.Graph die tf.Graph unmittelbar die zweite Phase, sodass beim Aufrufen der Function tf.Graph erstellt und tf.Graph . Später werden Sie sehen, wie Sie mit get_concrete_function nur die Tracing-Phase get_concrete_function .

Wenn wir Argumente unterschiedlichen Typs an eine Function , werden beide Phasen ausgeführt:

@tf.function
def double(a):
  print("Tracing with", a)
  return a + a

print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()
Tracing with Tensor("a:0", shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Tracing with Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)

Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'aa', shape=(), dtype=string)

Beachten Sie, dass TensorFlow beim wiederholten Aufrufen einer Function mit demselben Argumenttyp die Verfolgungsphase überspringt und einen zuvor verfolgten Graphen wiederverwendet, da der generierte Graph identisch wäre.

# This doesn't print 'Tracing with ...'
print(double(tf.constant("b")))
tf.Tensor(b'bb', shape=(), dtype=string)

Sie können pretty_printed_concrete_signatures() , um alle verfügbaren Spuren pretty_printed_concrete_signatures() :

print(double.pretty_printed_concrete_signatures())
double(a)
  Args:
    a: int32 Tensor, shape=()
  Returns:
    int32 Tensor, shape=()

double(a)
  Args:
    a: float32 Tensor, shape=()
  Returns:
    float32 Tensor, shape=()

double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

Bisher haben Sie gesehen, dass tf.function eine zwischengespeicherte, dynamische Dispatch-Schicht über der Graphenverfolgungslogik von TensorFlow erstellt. Um die Terminologie genauer zu beschreiben:

  • Ein tf.Graph ist die rohe, tf.Graph , portable Darstellung einer TensorFlow-Berechnung.
  • Eine ConcreteFunction tf.Graph eine tf.Graph .
  • Eine Function verwaltet einen Cache von ConcreteFunction s und wählt den richtigen für Ihre Eingaben aus.
  • tf.function eine Python-Funktion und gibt ein Function Objekt zurück.
  • Tracing erstellt einen tf.Graph und verpackt ihn in eine ConcreteFunction , auch als Trace bezeichnet.

Regeln für die Rückverfolgung

Eine Function bestimmt, ob eine verfolgte ConcreteFunction wiederverwendet werden soll, indem sie einen Cache-Schlüssel aus den args und kwargs einer Eingabe berechnet. Ein Cache-Schlüssel ist ein Schlüssel, der eine ConcreteFunction basierend auf den Eingabeargumenten und -kwargs des Function Aufrufs gemäß den folgenden Regeln identifiziert (die sich ändern können):

  • Der für einen tf.Tensor generierte Schlüssel ist seine Form und sein dtype.
  • Der für eine tf.Variable generierte Schlüssel ist eine eindeutige Variablen-ID.
  • Der für ein Python-Primitiven (wie int , float , str ) generierte Schlüssel ist sein Wert.
  • Der Schlüssel erzeugt für verschachtelte dict s, list s, tuple s, namedtuple s, und attr s ist die abgeflachte Tupel von blattTasten (siehenest.flatten ). (Infolge dieser Reduzierung führt der Aufruf einer konkreten Funktion mit einer anderen Verschachtelungsstruktur als der beim Tracing verwendeten zu einem TypeError).
  • Bei allen anderen Python-Typen ist der Schlüssel für das Objekt eindeutig. Auf diese Weise wird eine Funktion oder Methode für jede Instanz, mit der sie aufgerufen wird, unabhängig verfolgt.

Rückverfolgung steuern

Die Rückverfolgung, bei der Ihre Function mehr als eine Spur erstellt, hilft sicherzustellen, dass TensorFlow für jeden Eingabesatz korrekte Diagramme generiert. Tracing ist jedoch eine teure Operation! Wenn Ihre Function bei jedem Aufruf einen neuen Graphen tf.function , werden Sie feststellen, dass Ihr Code langsamer ausgeführt wird, als wenn Sie tf.function nicht verwenden.

Um das Ablaufverfolgungsverhalten zu steuern, können Sie die folgenden Techniken verwenden:

@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
  print("Tracing with", x)
  return tf.where(x % 2 == 0, x // 2, 3 * x + 1)

print(next_collatz(tf.constant([1, 2])))
# You specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([[1, 2], [3, 4]]))

# You specified an int32 dtype in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([1.0, 2.0]))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([4 1], shape=(2,), dtype=int32)
Caught expected exception 
  <class 'ValueError'>:
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-14ebce7b7ee8>", line 9, in <module>
    next_collatz(tf.constant([[1, 2], [3, 4]]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor(
[[1 2]
 [3 4]], shape=(2, 2), dtype=int32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None))
Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-14ebce7b7ee8>", line 13, in <module>
    next_collatz(tf.constant([1.0, 2.0]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor([1. 2.], shape=(2,), dtype=float32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None))
  • Geben Sie eine [Keine] tf.TensorSpec in tf.TensorSpec an, um Flexibilität bei der Wiederverwendung von tf.TensorSpec zu ermöglichen.

    Da TensorFlow Tensoren basierend auf ihrer Form abgleicht, ermöglicht die Verwendung einer None Dimension als Platzhalter der Function s, Spuren für Eingaben unterschiedlicher Größe wiederzuverwenden. Eingaben unterschiedlicher Größe können auftreten, wenn Sie Sequenzen unterschiedlicher Länge oder Bilder unterschiedlicher Größe für jeden Stapel haben (siehe zum Beispiel die Tutorials zu Transformer und Deep Dream ).

@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def g(x):
  print('Tracing with', x)
  return x

# No retrace!
print(g(tf.constant([1, 2, 3])))
print(g(tf.constant([1, 2, 3, 4, 5])))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([1 2 3], shape=(3,), dtype=int32)
tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)
  • Wandeln Sie Python-Argumente in Tensoren um, um die Rückverfolgung zu reduzieren.

    Häufig werden Python-Argumente verwendet, um Hyperparameter und num_layers=10 zu steuern - zum Beispiel num_layers=10 oder training=True oder nonlinearity='relu' . Wenn sich also das Python-Argument ändert, ist es sinnvoll, dass Sie den Graphen zurückverfolgen müssen.

    Es ist jedoch möglich, dass ein Python-Argument nicht verwendet wird, um die Diagrammkonstruktion zu steuern. In diesen Fällen kann eine Änderung des Python-Werts eine unnötige Rückverfolgung auslösen. Nehmen Sie zum Beispiel diese Trainingsschleife, die AutoGraph dynamisch ausrollt. Trotz der mehrfachen Traces ist der erzeugte Graph tatsächlich identisch, sodass ein erneutes Tracen nicht erforderlich ist.

def train_one_step():
  pass

@tf.function
def train(num_steps):
  print("Tracing with num_steps = ", num_steps)
  tf.print("Executing with num_steps = ", num_steps)
  for _ in tf.range(num_steps):
    train_one_step()

print("Retracing occurs for different Python arguments.")
train(num_steps=10)
train(num_steps=20)

print()
print("Traces are reused for Tensor arguments.")
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))
Retracing occurs for different Python arguments.
Tracing with num_steps =  10
Executing with num_steps =  10
Tracing with num_steps =  20
Executing with num_steps =  20

Traces are reused for Tensor arguments.
Tracing with num_steps =  Tensor("num_steps:0", shape=(), dtype=int32)
Executing with num_steps =  10
Executing with num_steps =  20

Wenn Sie die Rückverfolgung erzwingen müssen, erstellen Sie eine neue Function . Separate Function Objekte teilen garantiert keine Traces.

def f():
  print('Tracing!')
  tf.print('Executing')

tf.function(f)()
tf.function(f)()
Tracing!
Executing
Tracing!
Executing

Erhalten konkreter Funktionen

Jedes Mal, wenn eine Funktion verfolgt wird, wird eine neue konkrete Funktion erstellt. Sie können eine konkrete Funktion direkt get_concrete_function , indem Sie get_concrete_function .

print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.constant("a"))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))
Obtaining concrete trace
Executing traced function
tf.Tensor(b'aa', shape=(), dtype=string)
tf.Tensor(b'bb', shape=(), dtype=string)
# You can also call get_concrete_function on an InputSpec
double_strings_from_inputspec = double.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.string))
print(double_strings_from_inputspec(tf.constant("c")))
Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'cc', shape=(), dtype=string)

Beim Drucken einer ConcreteFunction wird eine Zusammenfassung ihrer Eingabeargumente (mit Typen) und ihres Ausgabetyps angezeigt.

print(double_strings)
ConcreteFunction double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

Sie können die Signatur einer konkreten Funktion auch direkt abrufen.

print(double_strings.structured_input_signature)
print(double_strings.structured_outputs)
((TensorSpec(shape=(), dtype=tf.string, name='a'),), {})
Tensor("Identity:0", shape=(), dtype=string)

Die Verwendung eines konkreten Trace mit inkompatiblen Typen führt zu einem Fehler

with assert_raises(tf.errors.InvalidArgumentError):
  double_strings(tf.constant(1))
Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>:
Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-e4e2860a4364>", line 2, in <module>
    double_strings(tf.constant(1))
tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_162 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_162]

Sie werden vielleicht feststellen, dass Python-Argumente in der Eingabesignatur einer konkreten Funktion besonders behandelt werden. Vor TensorFlow 2.3 wurden Python-Argumente einfach aus der Signatur der konkreten Funktion entfernt. Ab TensorFlow 2.3 verbleiben Python-Argumente in der Signatur, müssen jedoch den während der Ablaufverfolgung festgelegten Wert annehmen.

@tf.function
def pow(a, b):
  return a ** b

square = pow.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)
print(square)
ConcreteFunction pow(a, b=2)
  Args:
    a: float32 Tensor, shape=<unknown>
  Returns:
    float32 Tensor, shape=<unknown>
assert square(tf.constant(10.0)) == 100

with assert_raises(TypeError):
  square(tf.constant(10.0), b=3)
Caught expected exception 
  <class 'TypeError'>:
Traceback (most recent call last):
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1725, in _call_impl
    cancellation_manager)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1770, in _call_with_flat_signature
    self._flat_signature_summary(), ", ".join(sorted(kwargs))))
TypeError: pow(a) got unexpected keyword arguments: b.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-d163f3d206cb>", line 4, in <module>
    square(tf.constant(10.0), b=3)
TypeError: ConcreteFunction pow(a, b) was constructed with int value 2 in b, but was called with int value 3

Abrufen von Grafiken

Jede konkrete Funktion ist ein aufrufbarer Wrapper um einen tf.Graph . Obwohl das Abrufen des tatsächlichen tf.Graph Objekts normalerweise nicht erforderlich ist, können Sie es leicht von jeder konkreten Funktion tf.Graph .

graph = double_strings.graph
for node in graph.as_graph_def().node:
  print(f'{node.input} -> {node.name}')
[] -> a
['a', 'a'] -> add
['add'] -> Identity

Debuggen

Im Allgemeinen ist das Debuggen von Code im Eager-Modus einfacher als in tf.function . Sie sollten sicherstellen, dass Ihr Code im Eager-Modus fehlerfrei ausgeführt wird, bevor Sie ihn mit tf.function . Um den Debugging-Prozess zu unterstützen, können Sie tf.config.run_functions_eagerly(True) aufrufen, um tf.config.run_functions_eagerly(True) global zu deaktivieren und wieder zu tf.function .

Wenn Sie Probleme aufspüren, die nur in tf.function , finden Sie hier einige Tipps:

  • Einfache alte Python- print werden nur während der Ablaufverfolgung ausgeführt und helfen Ihnen dabei, aufzuspüren, wann Ihre Funktion (erneut) verfolgt wird.
  • tf.print Aufrufe werden jedes Mal ausgeführt und können Ihnen dabei helfen, Zwischenwerte während der Ausführung aufzuspüren.
  • tf.debugging.enable_check_numerics ist eine einfache Möglichkeit, herauszufinden, wo NaNs und Inf erstellt werden.
  • pdb (der Python-Debugger ) kann Ihnen helfen zu verstehen, was während der Ablaufverfolgung vor sich geht. ( pdb : pdb bringt Sie in den AutoGraph-transformierten Quellcode.)

AutoGraph-Transformationen

AutoGraph ist eine Bibliothek, die standardmäßig in tf.function ist und eine Teilmenge des Python-Eager-Codes in tf.function TensorFlow- tf.function umwandelt. Dazu gehört auch die Ablaufsteuerung wie if , for , while .

TensorFlow- tf.cond wie tf.cond und tf.while_loop weiterhin, aber der Kontrollfluss ist oft einfacher zu schreiben und zu verstehen, wenn er in Python geschrieben ist.

# A simple loop

@tf.function
def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

f(tf.random.uniform([5]))
[0.710546374 0.327660799 0.393230557 0.545059443 0.666661739]
[0.611019373 0.316417336 0.374141902 0.496808201 0.582779706]
[0.54484427 0.306263864 0.357609242 0.45960331 0.52468282]
[0.496646136 0.297034383 0.343106419 0.429760844 0.481306106]
[0.459475428 0.288596332 0.330247819 0.405121386 0.44728902]
[0.429656595 0.280842364 0.318743408 0.384322464 0.419668049]
[0.405034214 0.273684502 0.308370233 0.366455346 0.396650732]
[0.384248167 0.267049909 0.298953682 0.350887358 0.377079546]
[0.366391063 0.260877609 0.290354759 0.337162226 0.360168517]
[0.350830942 0.255116194 0.282461286 0.324941576 0.345362455]
[0.337112248 0.2497219 0.275181532 0.313968241 0.332256317]
[0.324896872 0.244657204 0.268439621 0.304042816 0.320546716]
[0.313927948 0.239889801 0.262172282 0.295007944 0.310001194]
[0.304006279 0.235391632 0.256326199 0.286737591 0.300438195]
[0.294974595 0.231138244 0.250856102 0.279129326 0.291713566]
[0.286706954 0.227108166 0.245723218 0.272099048 0.283711195]
[0.279101074 0.223282441 0.240894228 0.265576899 0.276336372]
[0.272072881 0.219644368 0.23634018 0.259504348 0.269510925]
[0.26555258 0.216179058 0.232035935 0.253831863 0.263169676]
[0.259481668 0.212873235 0.227959365 0.24851726 0.257257849]
[0.253810644 0.209715009 0.224091038 0.243524343 0.251728892]
[0.248497337 0.206693679 0.220413819 0.238821834 0.246543139]
[0.243505597 0.20379965 0.216912434 0.2343826 0.241666421]
[0.238804176 0.20102416 0.213573262 0.230182901 0.237069115]
[0.234365895 0.198359385 0.210384145 0.226201892 0.232725471]
[0.230167076 0.195798069 0.207334161 0.222421184 0.228612974]
[0.226186857 0.19333373 0.204413429 0.218824506 0.224711776]
[0.222406894 0.190960392 0.201613098 0.215397388 0.221004337]
[0.218810901 0.188672557 0.198925063 0.212126866 0.217475086]
[0.215384394 0.186465234 0.196342021 0.209001362 0.214110211]
[0.212114424 0.184333771 0.193857312 0.206010461 0.210897282]
<tf.Tensor: shape=(5,), dtype=float32, numpy=
array([0.20898944, 0.18227392, 0.19146483, 0.2031447 , 0.20782518],
      dtype=float32)>

Wenn Sie neugierig sind, können Sie sich den generierten Code-Autogramm ansehen.

print(tf.autograph.to_code(f.python_function))
def tf__f(x):
    with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        do_return = False
        retval_ = ag__.UndefinedReturnValue()

        def get_state():
            return (x,)

        def set_state(vars_):
            nonlocal x
            (x,) = vars_

        def loop_body():
            nonlocal x
            ag__.converted_call(ag__.ld(tf).print, (ag__.ld(x),), None, fscope)
            x = ag__.converted_call(ag__.ld(tf).tanh, (ag__.ld(x),), None, fscope)

        def loop_test():
            return (ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) > 1)
        ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {})
        try:
            do_return = True
            retval_ = ag__.ld(x)
        except:
            do_return = False
            raise
        return fscope.ret(retval_, do_return)

Bedingungen

AutoGraph konvertiert einige if <condition> tf.cond in die entsprechenden tf.cond Aufrufe. Diese Ersetzung erfolgt, wenn <condition> ein Tensor ist. Andernfalls wird die if Anweisung als Python-Bedingung ausgeführt.

Eine Python-Bedingung wird während des Tracings ausgeführt, sodass genau ein Zweig der Bedingung zum Graphen hinzugefügt wird. Ohne AutoGraph wäre dieser verfolgte Graph nicht in der Lage, den alternativen Zweig zu nehmen, wenn ein datenabhängiger Kontrollfluss vorhanden ist.

tf.cond verfolgt und fügt beide Verzweigungen der Bedingung zum Graphen hinzu, wobei zur Ausführungszeit dynamisch eine Verzweigung ausgewählt wird. Tracing kann unbeabsichtigte Nebenwirkungen haben; Weitere Informationen finden Sie unter AutoGraph-Tracing-Effekte .

@tf.function
def fizzbuzz(n):
  for i in tf.range(1, n + 1):
    print('Tracing for loop')
    if i % 15 == 0:
      print('Tracing fizzbuzz branch')
      tf.print('fizzbuzz')
    elif i % 3 == 0:
      print('Tracing fizz branch')
      tf.print('fizz')
    elif i % 5 == 0:
      print('Tracing buzz branch')
      tf.print('buzz')
    else:
      print('Tracing default branch')
      tf.print(i)

fizzbuzz(tf.constant(5))
fizzbuzz(tf.constant(20))
Tracing for loop
Tracing fizzbuzz branch
Tracing fizz branch
Tracing buzz branch
Tracing default branch
1
2
fizz
4
buzz
1
2
fizz
4
buzz
fizz
7
8
fizz
buzz
11
fizz
13
14
fizzbuzz
16
17
fizz
19
buzz

Weitere Einschränkungen für AutoGraph-konvertierte if-Anweisungen finden Sie in der Referenzdokumentation .

Schleifen

Autograph wird einige konvertieren for und while Aussagen in die äquivalente TensorFlow ops Looping, wie tf.while_loop . Wenn sie nicht konvertiert wird, wird die for oder while Schleife als Python-Schleife ausgeführt.

Diese Ersetzung erfolgt in den folgenden Situationen:

  • for x in y : Wenn y ein Tensor ist, wandeln Sie in tf.while_loop . In dem Sonderfall, in dem y eintf.data.Dataset , wird eine Kombination vontf.data.Dataset generiert.
  • while <condition> : Wenn <condition> ein Tensor ist, konvertieren Sie in tf.while_loop .

Eine Python-Schleife wird während des tf.Graph ausgeführt und fügt dem tf.Graph für jede Iteration der Schleife zusätzliche tf.Graph .

Eine TensorFlow-Schleife verfolgt den Schleifenkörper und wählt dynamisch aus, wie viele Iterationen zur Ausführungszeit ausgeführt werden sollen. Der Schleifenkörper kommt nur einmal im generierten tf.Graph .

Finden Sie in der Referenzdokumentation für zusätzliche Einschränkungen für Autogramm konvertiert for und while Aussagen.

Schleifen über Python-Daten

Ein häufiger Fallstrick ist das Schleifen von Python/NumPy-Daten innerhalb einer tf.function . Diese Schleife wird während des Ablaufverfolgungsprozesses ausgeführt, wobei für jede Iteration der Schleife eine Kopie Ihres Modells zum tf.Graph wird.

Wenn Sie die gesamte Trainingsschleife in tf.function , ist dies am sichersten, wenn Sie Ihre Daten alstf.data.Dataset damit AutoGraph die Trainingsschleife dynamisch ausrollt.

def measure_graph_size(f, *args):
  g = f.get_concrete_function(*args).graph
  print("{}({}) contains {} nodes in its graph".format(
      f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))

@tf.function
def train(dataset):
  loss = tf.constant(0)
  for x, y in dataset:
    loss += tf.abs(y - x) # Some dummy computation.
  return loss

small_data = [(1, 1)] * 3
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)

measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: big_data, (tf.int32, tf.int32)))
train([(1, 1), (1, 1), (1, 1)]) contains 11 nodes in its graph
train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 10 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 10 nodes in its graph

tf.data.Dataset.from_generator tf.data.Dataset.from_tensors von Python-/NumPy-Daten in ein Dataset tf.data.Dataset.from_generator Vergleich zu tf.data.Dataset.from_tensors . Ersteres behält die Daten in Python und tf.py_function sie über tf.py_function was Auswirkungen auf die Leistung haben kann, während letzteres eine Kopie der Daten als einen großen tf.constant() Knoten im Diagramm tf.constant() , was Auswirkungen auf den Speicher haben kann.

Das Lesen von Daten aus Dateien über TFRecordDataset , CsvDataset usw. ist die effektivste Art, Daten zu konsumieren, da TensorFlow selbst das asynchrone Laden und Vorabrufen von Daten verwalten kann, ohne Python einbeziehen zu müssen. Weitere tf.data Leitfaden tf.data : Build TensorFlow input Pipelines .

Akkumulieren von Werten in einer Schleife

Ein übliches Muster besteht darin, Zwischenwerte aus einer Schleife zu akkumulieren. Normalerweise wird dies durch Anhängen an eine Python-Liste oder Hinzufügen von Einträgen zu einem Python-Wörterbuch erreicht. Da dies jedoch Python-Nebeneffekte sind, funktionieren sie in einer dynamisch entrollten Schleife nicht wie erwartet. Verwenden Sie tf.TensorArray , um Ergebnisse aus einer dynamisch entrollten Schleife zu sammeln.

batch_size = 2
seq_len = 3
feature_size = 4

def rnn_step(inp, state):
  return inp + state

@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
  # [batch, time, features] -> [time, batch, features]
  input_data = tf.transpose(input_data, [1, 0, 2])
  max_seq_len = input_data.shape[0]

  states = tf.TensorArray(tf.float32, size=max_seq_len)
  state = initial_state
  for i in tf.range(max_seq_len):
    state = rnn_step(input_data[i], state)
    states = states.write(i, state)
  return tf.transpose(states.stack(), [1, 0, 2])

dynamic_rnn(rnn_step,
            tf.random.uniform([batch_size, seq_len, feature_size]),
            tf.zeros([batch_size, feature_size]))
<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[0.60458577, 0.3308612 , 0.7878152 , 0.3223114 ],
        [0.9110272 , 1.0819752 , 1.7657743 , 1.2409766 ],
        [1.7235098 , 1.5416101 , 2.2929285 , 1.9181627 ]],

       [[0.89487076, 0.22811687, 0.342862  , 0.5752872 ],
        [1.0133923 , 0.28650808, 0.9558767 , 1.0829899 ],
        [1.9280962 , 1.1437279 , 0.9857702 , 1.4834155 ]]], dtype=float32)>

Einschränkungen

TensorFlow Function weist einige Einschränkungen auf, die Sie beim Konvertieren einer Python-Funktion in eine Function beachten sollten.

Ausführen von Python-Nebeneffekten

Nebenwirkungen wie Drucken, Anhängen an Listen und mutieren von Globals können sich innerhalb einer Function unerwartet verhalten und manchmal zweimal oder nicht alle ausgeführt werden. Sie treten nur auf, wenn Sie zum ersten Mal eine Function mit einer Reihe von Eingaben aufrufen. Danach wird der verfolgte tf.Graph ausgeführt, ohne den Python-Code auszuführen.

Als allgemeine Faustregel gilt, dass Sie sich in Ihrer Logik nicht auf Python-Nebeneffekte verlassen und diese nur zum Debuggen Ihrer Traces verwenden. Andernfalls sind TensorFlow-APIs wie tf.data , tf.print , tf.summary , tf.Variable.assign und tf.TensorArray der beste Weg, um sicherzustellen, dass Ihr Code bei jedem Aufruf von der TensorFlow-Laufzeit ausgeführt wird.

@tf.function
def f(x):
  print("Traced with", x)
  tf.print("Executed with", x)

f(1)
f(1)
f(2)
Traced with 1
Executed with 1
Executed with 1
Traced with 2
Executed with 2

Wenn Sie Python-Code bei jedem Aufruf einer Function tf.py_function ist tf.py_function eine Exit-Schraffur. Der Nachteil von tf.py_function besteht darin, dass es nicht portabel oder besonders performant ist, nicht mit SavedModel gespeichert werden kann und in verteilten (Multi-GPU, TPU) Setups nicht gut funktioniert. Da tf.py_function in den Graphen verdrahtet werden muss, tf.py_function es alle Ein-/Ausgänge in Tensoren um.

Globale und freie Python-Variablen ändern

Das Ändern globaler und freier Python- Variablen zählt als Python-Nebeneffekt und geschieht daher nur während des Tracings.

external_list = []

@tf.function
def side_effect(x):
  print('Python side effect')
  external_list.append(x)

side_effect(1)
side_effect(1)
side_effect(1)
# The list append only happened once!
assert len(external_list) == 1
Python side effect

Sie sollten es vermeiden, Container wie Listen, Diktate oder andere Objekte zu mutieren, die außerhalb der Function . Verwenden Sie stattdessen Argumente und TF-Objekte. Im Abschnitt "Werte in einer Schleife akkumulieren" finden Sie beispielsweise ein Beispiel, wie listenartige Operationen implementiert werden können.

In einigen Fällen können Sie den Zustand erfassen und manipulieren, wenn es sich um eine tf.Variable . Auf diese Weise werden die Gewichtungen von Keras-Modellen bei wiederholten Aufrufen derselben ConcreteFunction aktualisiert.

Verwenden von Python-Iteratoren und -Generatoren

Viele Python-Funktionen, wie Generatoren und Iteratoren, verlassen sich auf die Python-Laufzeit, um den Status zu verfolgen. Im Allgemeinen funktionieren diese Konstrukte zwar wie erwartet im Eager-Modus, sind jedoch Beispiele für Python-Nebeneffekte und treten daher nur während des Tracings auf.

@tf.function
def buggy_consume_next(iterator):
  tf.print("Value:", next(iterator))

iterator = iter([1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)
Value: 1
Value: 1
Value: 1

Genauso wie TensorFlow über ein spezialisiertes tf.TensorArray für tf.TensorArray verfügt, verfügt es über einen spezialisierten tf.data.Iterator für Iterationskonstrukte. Eine Übersicht finden Sie im Abschnitt zu AutoGraph-Transformationen . Außerdem kann die tf.data API bei der Implementierung von Generatormustern helfen:

@tf.function
def good_consume_next(iterator):
  # This is ok, iterator is a tf.data.Iterator
  tf.print("Value:", next(iterator))

ds = tf.data.Dataset.from_tensor_slices([1, 2, 3])
iterator = iter(ds)
good_consume_next(iterator)
good_consume_next(iterator)
good_consume_next(iterator)
Value: 1
Value: 2
Value: 3

Löschen von tf.Variablen zwischen Function

Ein weiterer Fehler, auf den Sie möglicherweise stoßen, ist eine Garbage-Collector-Variable. ConcreteFunction s behalten nur WeakRefs für die Variablen bei, über denen sie schließen, daher müssen Sie eine Referenz auf alle Variablen beibehalten.

external_var = tf.Variable(3)
@tf.function
def f(x):
  return x * external_var

traced_f = f.get_concrete_function(4)
print("Calling concrete function...")
print(traced_f(4))

# The original variable object gets garbage collected, since there are no more
# references to it.
external_var = tf.Variable(4)
print()
print("Calling concrete function after garbage collecting its closed Variable...")
with assert_raises(tf.errors.FailedPreconditionError):
  traced_f(4)
Calling concrete function...
tf.Tensor(12, shape=(), dtype=int32)

Calling concrete function after garbage collecting its closed Variable...
Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.FailedPreconditionError'>:
Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-9a93d2e07632>", line 16, in <module>
    traced_f(4)
tensorflow.python.framework.errors_impl.FailedPreconditionError:  Could not find variable _AnonymousVar3. This could mean that the variable has been deleted. In TF1, it can also mean the variable is uninitialized. Debug info: container=localhost, status=Not found: Resource localhost/_AnonymousVar3/N10tensorflow3VarE does not exist.
     [[node ReadVariableOp (defined at <ipython-input-1-9a93d2e07632>:4) ]] [Op:__inference_f_782]

Function call stack:
f

Bekannte Probleme

Wenn Ihre Function nicht richtig ausgewertet wird, kann der Fehler durch diese bekannten Probleme erklärt werden, die in Zukunft behoben werden sollen.

Abhängig von Python globale und freie Variablen

Function erstellt eine neue ConcreteFunction wenn sie mit einem neuen Wert eines Python-Arguments aufgerufen wird. Dies ist jedoch nicht für die Python-Closure, Globals oder Nichtlokale dieser Function Fall. Wenn sich ihr Wert zwischen Aufrufen der Function ändert, verwendet die Function weiterhin die Werte, die sie bei der Verfolgung hatte. Dies unterscheidet sich von der Funktionsweise regulärer Python-Funktionen.

Aus diesem Grund empfehlen wir einen funktionalen Programmierstil, der Argumente verwendet, anstatt äußere Namen zu schließen.

@tf.function
def buggy_add():
  return 1 + foo

@tf.function
def recommended_add(foo):
  return 1 + foo

foo = 1
print("Buggy:", buggy_add())
print("Correct:", recommended_add(foo))
Buggy: tf.Tensor(2, shape=(), dtype=int32)
Correct: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo = 100
print("Buggy:", buggy_add())  # Did not change!
print("Correct:", recommended_add(foo))
Updating the value of `foo` to 100!
Buggy: tf.Tensor(2, shape=(), dtype=int32)
Correct: tf.Tensor(101, shape=(), dtype=int32)

Sie können über äußere Namen schließen, solange Sie deren Werte nicht aktualisieren.

Abhängig von Python-Objekten

Die Empfehlung, Python-Objekte als Argumente an tf.function zu tf.function weist eine Reihe bekannter Probleme auf, von denen erwartet wird, dass sie in Zukunft behoben werden. Im Allgemeinen können Sie sich auf eine konsistente Ablaufverfolgung verlassen, wenn Sie eine Python-Primitive oder eine tf.nest kompatible Struktur als Argument verwenden oder eine andere Instanz eines Objekts an eine Function . Function jedoch keinen neuen Trace, wenn Sie dasselbe Objekt übergeben und nur seine Attribute ändern .

class SimpleModel(tf.Module):
  def __init__(self):
    # These values are *not* tf.Variables.
    self.bias = 0.
    self.weight = 2.

@tf.function
def evaluate(model, x):
  return model.weight * x + model.bias

simple_model = SimpleModel()
x = tf.constant(10.)
print(evaluate(simple_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
simple_model.bias += 5.0
print(evaluate(simple_model, x))  # Didn't change :(
Adding bias!
tf.Tensor(20.0, shape=(), dtype=float32)

Die Verwendung derselben Function zum Auswerten der aktualisierten Instanz des Modells ist fehlerhaft, da das aktualisierte Modell denselben Cache-Schlüssel wie das Originalmodell hat.

Aus diesem Grund empfehlen wir Ihnen, Ihre Function zu schreiben, um eine Abhängigkeit von veränderlichen Function zu vermeiden oder neue Objekte zu erstellen.

Wenn dies nicht möglich ist, besteht eine Problemumgehung darin, jedes Mal neue Function s zu erstellen, wenn Sie Ihr Objekt ändern, um eine Rückverfolgung zu erzwingen:

def evaluate(model, x):
  return model.weight * x + model.bias

new_model = SimpleModel()
evaluate_no_bias = tf.function(evaluate).get_concrete_function(new_model, x)
# Don't pass in `new_model`, `Function` already captured its state during tracing.
print(evaluate_no_bias(x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
new_model.bias += 5.0
# Create new Function and ConcreteFunction since you modified new_model.
evaluate_with_bias = tf.function(evaluate).get_concrete_function(new_model, x)
print(evaluate_with_bias(x)) # Don't pass in `new_model`.
Adding bias!
tf.Tensor(25.0, shape=(), dtype=float32)

Da das Zurückverfolgen teuer sein kann , können Sie tf.Variable s als tf.Variable , die für einen ähnlichen Effekt mutiert (aber nicht geändert, Vorsicht!) werden können, ohne dass ein Zurückverfolgen erforderlich ist.

class BetterModel:

  def __init__(self):
    self.bias = tf.Variable(0.)
    self.weight = tf.Variable(2.)

@tf.function
def evaluate(model, x):
  return model.weight * x + model.bias

better_model = BetterModel()
print(evaluate(better_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
better_model.bias.assign_add(5.0)  # Note: instead of better_model.bias += 5
print(evaluate(better_model, x))  # This works!
Adding bias!
tf.Tensor(25.0, shape=(), dtype=float32)

tf.Variablen erstellen

Function unterstützt nur die einmalige Erstellung von Variablen beim ersten Aufruf und die anschließende Wiederverwendung. Sie können tf.Variables in neuen Traces erstellen. Das Erstellen neuer Variablen in nachfolgenden Aufrufen ist derzeit nicht erlaubt, wird es aber in Zukunft sein.

Beispiel:

@tf.function
def f(x):
  v = tf.Variable(1.0)
  return v

with assert_raises(ValueError):
  f(1.0)
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-8a0913e250e0>", line 7, in <module>
    f(1.0)
ValueError: in user code:

    <ipython-input-1-8a0913e250e0>:3 f  *
        v = tf.Variable(1.0)
    /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/variables.py:262 __call__  **
        return cls._variable_v2_call(*args, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/variables.py:256 _variable_v2_call
        shape=shape)
    /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/variables.py:67 getter
        return captured_getter(captured_previous, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py:769 invalid_creator_scope
        "tf.function-decorated function tried to create "

    ValueError: tf.function-decorated function tried to create variables on non-first call.

Sie können Variablen innerhalb einer Function erstellen, solange diese Variablen nur beim ersten Ausführen der Funktion erstellt werden.

class Count(tf.Module):
  def __init__(self):
    self.count = None

  @tf.function
  def __call__(self):
    if self.count is None:
      self.count = tf.Variable(0)
    return self.count.assign_add(1)

c = Count()
print(c())
print(c())
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Verwendung mit mehreren Keras-Optimierern

Möglicherweise stoßen Sie auf ValueError: tf.function-decorated function tried to create variables on non-first call. wenn Sie mehr als einen Keras-Optimierer mit einer tf.function . Dieser Fehler tritt auf, weil Optimierer intern tf.Variables erstellen, wenn sie zum ersten Mal tf.Variables anwenden.

opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)

@tf.function
def train_step(w, x, y, optimizer):
   with tf.GradientTape() as tape:
       L = tf.reduce_sum(tf.square(w*x - y))
   gradients = tape.gradient(L, [w])
   optimizer.apply_gradients(zip(gradients, [w]))

w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])

train_step(w, x, y, opt1)
print("Calling `train_step` with different optimizer...")
with assert_raises(ValueError):
  train_step(w, x, y, opt2)
Calling `train_step` with different optimizer...
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-d3d3937dbf1a>", line 18, in <module>
    train_step(w, x, y, opt2)
ValueError: in user code:

    <ipython-input-1-d3d3937dbf1a>:9 train_step  *
        optimizer.apply_gradients(zip(gradients, [w]))
    /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:636 apply_gradients  **
        self._create_all_weights(var_list)
    /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:821 _create_all_weights
        _ = self.iterations
    /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:828 __getattribute__
        return super(OptimizerV2, self).__getattribute__(name)
    /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:988 iterations
        aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA)
    /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:1194 add_weight
        aggregation=aggregation)
    /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/tracking/base.py:815 _add_variable_with_custom_getter
        **kwargs_for_getter)
    /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer_utils.py:139 make_variable
        shape=variable_shape if variable_shape else None)
    /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/variables.py:260 __call__
        return cls._variable_v1_call(*args, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/variables.py:221 _variable_v1_call
        shape=shape)
    /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/variables.py:67 getter
        return captured_getter(captured_previous, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py:769 invalid_creator_scope
        "tf.function-decorated function tried to create "

    ValueError: tf.function-decorated function tried to create variables on non-first call.

Wenn Sie den Optimierer während des Trainings ändern müssen, können Sie das Problem umgehen, indem Sie für jeden Optimierer eine neue Function erstellen und die ConcreteFunction direkt aufrufen.

opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)

# Not a tf.function.
def train_step(w, x, y, optimizer):
   with tf.GradientTape() as tape:
       L = tf.reduce_sum(tf.square(w*x - y))
   gradients = tape.gradient(L, [w])
   optimizer.apply_gradients(zip(gradients, [w]))

w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])

# Make a new Function and ConcreteFunction for each optimizer.
train_step_1 = tf.function(train_step).get_concrete_function(w, x, y, opt1)
train_step_2 = tf.function(train_step).get_concrete_function(w, x, y, opt2)
for i in range(10):
  if i % 2 == 0:
    train_step_1(w, x, y) # `opt1` is not used as a parameter. 
  else:
    train_step_2(w, x, y) # `opt2` is not used as a parameter.

Verwendung mit mehreren Keras-Modellen

Sie können auch auf ValueError: tf.function-decorated function tried to create variables on non-first call. stoßen ValueError: tf.function-decorated function tried to create variables on non-first call. wenn verschiedene Modellinstanzen an dieselbe Function .

Dieser Fehler tritt auf, weil Keras-Modelle ( deren Eingabeform nicht definiert ist ) und Keras-Layer tf.Variables s erstellen, wenn sie zum ersten Mal aufgerufen werden. Möglicherweise versuchen Sie, diese Variablen innerhalb einer Function zu initialisieren, die bereits aufgerufen wurde. Um diesen Fehler zu vermeiden, versuchen Sie, model.build(input_shape) , um alle Gewichtungen zu initialisieren, bevor Sie das Modell trainieren.

Weiterlesen

Informationen zum Exportieren und Laden einer Function finden Sie im SavedModel-Handbuch . Weitere Informationen zu Diagrammoptimierungen, die nach dem Tracing durchgeführt werden, finden Sie in der Grappler-Anleitung . Informationen zum Optimieren Ihrer Datenpipeline und zum Profilieren Ihres Modells finden Sie im Profiler-Leitfaden .