Cette page a été traduite par l'API Cloud Translation.
Switch to English

Meilleures performances avec tf.function

Voir sur TensorFlow.org Exécuter dans Google Colab Afficher la source sur GitHub Télécharger le carnet

Dans TensorFlow 2, l'exécution hâtive est activée par défaut. L'interface utilisateur est intuitive et flexible (l'exécution d'opérations ponctuelles est beaucoup plus facile et plus rapide), mais cela peut se faire au détriment des performances et de la déployabilité.

Vous pouvez utiliser tf.function pour créer des graphiques à partir de vos programmes. C'est un outil de transformation qui crée des graphiques de flux de données indépendants de Python à partir de votre code Python. Cela vous aidera à créer des modèles performants et portables, et il est nécessaire d'utiliser SavedModel .

Ce guide vous aidera à conceptualiser le fonctionnement de tf.function sous le capot afin que vous puissiez l'utiliser efficacement.

Les principaux points à retenir et recommandations sont:

  • Déboguez en mode impatient, puis décorez avec @tf.function .
  • Ne vous fiez pas aux effets secondaires Python comme la mutation d'objet ou les ajouts de liste.
  • tf.function fonctionne mieux avec les opérations TensorFlow; Les appels NumPy et Python sont convertis en constantes.

Installer

import tensorflow as tf

Définissez une fonction d'assistance pour illustrer les types d'erreurs que vous pourriez rencontrer:

import traceback
import contextlib

# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
  try:
    yield
  except error_class as e:
    print('Caught expected exception \n  {}:'.format(error_class))
    traceback.print_exc(limit=2)
  except Exception as e:
    raise e
  else:
    raise Exception('Expected {} to be raised but no error was raised!'.format(
        error_class))

Basiques

Usage

Une Function vous définissez est comme une opération TensorFlow principale: vous pouvez l'exécuter avec empressement; vous pouvez calculer des dégradés; etc.

@tf.function
def add(a, b):
  return a + b

add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
       [2., 2.]], dtype=float32)>
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
  result = add(v, 1.0)
tape.gradient(result, v)
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

Vous pouvez utiliser des Function dans d'autres Function .

@tf.function
def dense_layer(x, w, b):
  return add(tf.matmul(x, w), b)

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
       [3., 3.],
       [3., 3.]], dtype=float32)>

Function peuvent être plus rapides que le code avide, en particulier pour les graphiques avec de nombreuses petites opérations. Mais pour les graphiques avec quelques opérations coûteuses (comme les convolutions), vous ne verrez peut-être pas beaucoup d'accélération.

import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)

@tf.function
def conv_fn(image):
  return conv_layer(image)

image = tf.zeros([1, 200, 200, 100])
# warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")

Eager conv: 0.002407395000091128
Function conv: 0.004000883000117028
Note how there's not much difference in performance for convolutions

Tracé

Le typage dynamique de Python signifie que vous pouvez appeler des fonctions avec une variété de types d'arguments, et Python peut faire quelque chose de différent dans chaque scénario.

Pourtant, pour créer un graphe TensorFlow, des dtypes statiques et des dimensions de forme sont nécessaires. tf.function comble cette lacune en tf.function une fonction Python pour créer un objet Function . Sur la base des entrées données, la Function sélectionne le graphe approprié pour les entrées données, retraçant la fonction Python si nécessaire. Une fois que vous avez compris pourquoi et quand le traçage se produit, il est beaucoup plus facile d'utiliser efficacement tf.function !

Vous pouvez appeler une Function avec des arguments de types différents pour voir ce comportement polymorphe en action.

@tf.function
def double(a):
  print("Tracing with", a)
  return a + a

print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()

Tracing with Tensor("a:0", shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Tracing with Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)

Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'aa', shape=(), dtype=string)


Notez que si vous appelez à plusieurs reprises une Function avec le même type d'argument, TensorFlow réutilisera un graphe précédemment tracé, car le graphe généré serait identique.

# This doesn't print 'Tracing with ...'
print(double(tf.constant("b")))
tf.Tensor(b'bb', shape=(), dtype=string)

Vous pouvez utiliser pretty_printed_concrete_signatures() pour voir toutes les traces disponibles:

print(double.pretty_printed_concrete_signatures())
double(a)
  Args:
    a: int32 Tensor, shape=()
  Returns:
    int32 Tensor, shape=()

double(a)
  Args:
    a: float32 Tensor, shape=()
  Returns:
    float32 Tensor, shape=()

double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

Jusqu'à présent, vous avez vu que tf.function crée une couche de répartition dynamique en cache sur la logique de traçage des graphiques de TensorFlow. Pour être plus précis sur la terminologie:

  • Un tf.Graph est la représentation brute, tf.Graph du langage et portable de votre calcul.
  • Un ConcreteFunction est un wrapper qui s'exécute avec impatience autour d'un tf.Graph .
  • Une Function gère un cache de ConcreteFunction et choisit la bonne pour vos entrées.
  • tf.function une fonction Python, renvoyant un objet Function .

Obtenir des fonctions concrètes

Chaque fois qu'une fonction est tracée, une nouvelle fonction concrète est créée. Vous pouvez obtenir directement une fonction concrète, en utilisant get_concrete_function .

print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.constant("a"))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))

Obtaining concrete trace
Executing traced function
tf.Tensor(b'aa', shape=(), dtype=string)
tf.Tensor(b'bb', shape=(), dtype=string)

# You can also call get_concrete_function on an InputSpec
double_strings_from_inputspec = double.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.string))
print(double_strings_from_inputspec(tf.constant("c")))
Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'cc', shape=(), dtype=string)

(Le changement suivant est disponible dans TensorFlow tous les soirs et sera disponible dans TensorFlow 2.3.)

L'impression d'une ConcreteFunction affiche un résumé de ses arguments d'entrée (avec types) et de son type de sortie.

print(double_strings)
ConcreteFunction double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

Vous pouvez également récupérer directement la signature d'une fonction concrète.

print(double_strings.structured_input_signature)
print(double_strings.structured_outputs)
((TensorSpec(shape=(), dtype=tf.string, name='a'),), {})
Tensor("Identity:0", shape=(), dtype=string)

L'utilisation d'une trace concrète avec des types incompatibles générera une erreur

with assert_raises(tf.errors.InvalidArgumentError):
  double_strings(tf.constant(1))
Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>:

Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-e4e2860a4364>", line 2, in <module>
    double_strings(tf.constant(1))
tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_168 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_168]

Vous remarquerez peut-être que les arguments Python reçoivent un traitement spécial dans la signature d'entrée d'une fonction concrète. Avant TensorFlow 2.3, les arguments Python étaient simplement supprimés de la signature de la fonction concrète. À partir de TensorFlow 2.3, les arguments Python restent dans la signature, mais sont contraints de prendre la valeur définie lors du traçage.

@tf.function
def pow(a, b):
  return a ** b

square = pow.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)
print(square)
ConcreteFunction pow(a, b=2)
  Args:
    a: float32 Tensor, shape=<unknown>
  Returns:
    float32 Tensor, shape=<unknown>

assert square(tf.constant(10.0)) == 100

with assert_raises(TypeError):
  square(tf.constant(10.0), b=3)
Caught expected exception 
  <class 'TypeError'>:

Traceback (most recent call last):
  File "/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1669, in _call_impl
    cancellation_manager)
  File "/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1714, in _call_with_flat_signature
    self._flat_signature_summary(), ", ".join(sorted(kwargs))))
TypeError: pow(a) got unexpected keyword arguments: b.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-d163f3d206cb>", line 4, in <module>
    square(tf.constant(10.0), b=3)
TypeError: ConcreteFunction pow(a, b) was constructed with int value 2 in b, but was called with int value 3

Obtention de graphiques

Chaque fonction concrète est un wrapper appelable autour d'un tf.Graph . Bien que la récupération de l'objet tf.Graph réel ne soit pas quelque chose que vous devez normalement faire, vous pouvez l'obtenir facilement à partir de n'importe quelle fonction concrète.

graph = double_strings.graph
for node in graph.as_graph_def().node:
  print(f'{node.input} -> {node.name}')

[] -> a
['a', 'a'] -> add
['add'] -> Identity

Débogage

En général, le code de débogage est plus facile en mode hâte qu'à l'intérieur de tf.function . Vous devez vous assurer que votre code s'exécute sans erreur en mode hâte avant de décorer avec tf.function . Pour vous aider dans le processus de débogage, vous pouvez appeler tf.config.run_functions_eagerly(True) pour désactiver et réactiver globalement tf.function .

Lorsque vous recherchez des problèmes qui n'apparaissent que dans tf.function , voici quelques conseils:

  • Les anciens appels d' print Python ne s'exécutent que pendant le traçage, ce qui vous aide à localiser lorsque votre fonction est (re) tracée.
  • tf.print appels tf.print s'exécuteront à chaque fois et peuvent vous aider à retrouver les valeurs intermédiaires pendant l'exécution.
  • tf.debugging.enable_check_numerics est un moyen facile de localiser où les NaN et Inf sont créés.
  • pdb peut vous aider à comprendre ce qui se passe pendant le traçage. (Attention: PDB vous déposera dans le code source transformé par AutoGraph.)

Tracer la sémantique

Règles de clé de cache

Une Function détermine s'il faut réutiliser une fonction concrète tracée en calculant une clé de cache à partir des args et kwargs d'une entrée.

  • La clé générée pour un argument tf.Tensor est sa forme et son type.
  • À partir de TensorFlow 2.3, la clé générée pour un argument tf.Variable est son id() .
  • La clé générée pour une primitive Python est sa valeur. La clé générée pour les dict , list s, tuple s, namedtuple s et attr s imbriqués est le tuple aplati. (À la suite de cet aplatissement, l'appel d'une fonction concrète avec une structure d'imbrication différente de celle utilisée lors du traçage entraînera un TypeError).
  • Pour tous les autres types Python, les clés sont basées sur l'objet id() afin que les méthodes soient tracées indépendamment pour chaque instance d'une classe.

Contrôle du retracement

Le retracement permet de garantir que TensorFlow génère des graphiques corrects pour chaque ensemble d'entrées. Cependant, le traçage est une opération coûteuse! Si votre Function retrace un nouveau graphique pour chaque appel, vous constaterez que votre code s'exécute plus lentement que si vous n'utilisiez pas tf.function .

Pour contrôler le comportement du traçage, vous pouvez utiliser les techniques suivantes:

  • Spécifiez input_signature dans tf.function pour limiter le traçage.
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
  print("Tracing with", x)
  return tf.where(x % 2 == 0, x // 2, 3 * x + 1)

print(next_collatz(tf.constant([1, 2])))
# We specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([[1, 2], [3, 4]]))

# We specified an int32 dtype in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([1.0, 2.0]))

Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([4 1], shape=(2,), dtype=int32)
Caught expected exception 
  <class 'ValueError'>:
Caught expected exception 
  <class 'ValueError'>:

Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-20f544b8adbf>", line 9, in <module>
    next_collatz(tf.constant([[1, 2], [3, 4]]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor(
[[1 2]
 [3 4]], shape=(2, 2), dtype=int32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None))
Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-20f544b8adbf>", line 13, in <module>
    next_collatz(tf.constant([1.0, 2.0]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor([1. 2.], shape=(2,), dtype=float32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None))

  • Spécifiez une dimension [Aucun] dans tf.TensorSpec pour permettre une flexibilité dans la réutilisation des traces.

    Étant donné que TensorFlow correspond aux tenseurs en fonction de leur forme, l'utilisation d'une dimension None comme caractère générique permettra à Function s de réutiliser des traces pour une entrée de taille variable. Des entrées de tailles variables peuvent se produire si vous avez des séquences de longueur différente ou des images de tailles différentes pour chaque lot (voir les didacticiels Transformer et Deep Dream par exemple).

@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def g(x):
  print('Tracing with', x)
  return x

# No retrace!
print(g(tf.constant([1, 2, 3])))
print(g(tf.constant([1, 2, 3, 4, 5])))

Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([1 2 3], shape=(3,), dtype=int32)
tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)

  • Convertissez les arguments Python en Tensors pour réduire le retracement

    Souvent, les arguments Python sont utilisés pour contrôler les hyperparamètres et les constructions de graphes - par exemple, num_layers=10 ou training=True ou nonlinearity='relu' . Donc, si l'argument Python change, il est logique que vous deviez retracer le graphique.

    Cependant, il est possible qu'un argument Python ne soit pas utilisé pour contrôler la construction du graphe. Dans ces cas, une modification de la valeur Python peut déclencher un retracement inutile. Prenez, par exemple, cette boucle d'entraînement, qu'AutoGraph déroulera dynamiquement. Malgré les multiples traces, le graphique généré est en fait identique, donc le retracement n'est pas nécessaire.

def train_one_step():
  pass

@tf.function
def train(num_steps):
  print("Tracing with num_steps = ", num_steps)
  tf.print("Executing with num_steps = ", num_steps)
  for _ in tf.range(num_steps):
    train_one_step()

print("Retracing occurs for different Python arguments.")
train(num_steps=10)
train(num_steps=20)

print()
print("Traces are reused for Tensor arguments.")
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))
Retracing occurs for different Python arguments.
Tracing with num_steps =  10
Executing with num_steps =  10
Tracing with num_steps =  20
Executing with num_steps =  20

Traces are reused for Tensor arguments.
Tracing with num_steps =  Tensor("num_steps:0", shape=(), dtype=int32)
Executing with num_steps =  10
Executing with num_steps =  20

Si vous devez forcer le retracement, créez une nouvelle Function . Il est garanti que les objets Function séparés ne partagent pas les traces.

def f():
  print('Tracing!')
  tf.print('Executing')

tf.function(f)()
tf.function(f)()
Tracing!
Executing
Tracing!
Executing

Effets secondaires Python

Les effets secondaires Python tels que l'impression, l'ajout aux listes et la mutation des globaux ne se produisent que la première fois que vous appelez une Function avec un ensemble d'entrées. Ensuite, le tf.Graph tracé est tf.Graph , sans exécuter le code Python.

La règle générale est de n'utiliser que les effets secondaires Python pour déboguer vos traces. Sinon, les opérations TensorFlow telles que tf.Variable.assign , tf.print et tf.summary sont le meilleur moyen de garantir que votre code sera tracé et exécuté par le runtime TensorFlow à chaque appel.

@tf.function
def f(x):
  print("Traced with", x)
  tf.print("Executed with", x)

f(1)
f(1)
f(2)

Traced with 1
Executed with 1
Executed with 1
Traced with 2
Executed with 2

De nombreuses fonctionnalités Python, telles que les générateurs et les itérateurs, s'appuient sur le runtime Python pour suivre l'état. En général, alors que ces constructions fonctionnent comme prévu en mode hâtif, de nombreuses choses inattendues peuvent se produire dans une Function .

Pour donner un exemple, l'avancement de l'état de l'itérateur est un effet secondaire Python et ne se produit donc que pendant le traçage.

external_var = tf.Variable(0)
@tf.function
def buggy_consume_next(iterator):
  external_var.assign_add(next(iterator))
  tf.print("Value of external_var:", external_var)

iterator = iter([0, 1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)

Value of external_var: 0
Value of external_var: 0
Value of external_var: 0

Certaines constructions d'itération sont prises en charge via AutoGraph. Consultez la section sur les transformations AutoGraph pour une présentation .

Si vous souhaitez exécuter du code Python lors de chaque appel d'une Function , tf.py_function est une trappe de sortie. L'inconvénient de tf.py_function est qu'il n'est ni portable ni particulièrement performant, ni ne fonctionne bien dans les configurations distribuées (multi-GPU, TPU). De plus, comme tf.py_function doit être câblé dans le graphe, il convertit toutes les entrées / sorties en tenseurs.

Les API telles que tf.gather , tf.stack et tf.TensorArray peuvent vous aider à implémenter des modèles de bouclage courants dans TensorFlow natif.

external_list = []

def side_effect(x):
  print('Python side effect')
  external_list.append(x)

@tf.function
def f(x):
  tf.py_function(side_effect, inp=[x], Tout=[])

f(1)
f(1)
f(1)
# The list append happens all three times!
assert len(external_list) == 3
# The list contains tf.constant(1), not 1, because py_function casts everything to tensors.
assert external_list[0].numpy() == 1

Python side effect
Python side effect
Python side effect

Variables

Vous pouvez rencontrer une erreur lors de la création d'un nouveau tf.Variable dans une fonction. Cette erreur protège contre la divergence de comportement lors d'appels répétés: en mode hâtif, une fonction crée une nouvelle variable à chaque appel, mais dans une Function , une nouvelle variable peut ne pas être créée en raison de la réutilisation de la trace.

@tf.function
def f(x):
  v = tf.Variable(1.0)
  v.assign_add(x)
  return v

with assert_raises(ValueError):
  f(1.0)
Caught expected exception 
  <class 'ValueError'>:

Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-73e410646579>", line 8, in <module>
    f(1.0)
ValueError: in user code:

    <ipython-input-1-73e410646579>:3 f  *
        v = tf.Variable(1.0)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:262 __call__  **
        return cls._variable_v2_call(*args, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:256 _variable_v2_call
        shape=shape)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:67 getter
        return captured_getter(captured_previous, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py:702 invalid_creator_scope
        "tf.function-decorated function tried to create "

    ValueError: tf.function-decorated function tried to create variables on non-first call.


Vous pouvez créer des variables à l'intérieur d'une Function tant que ces variables ne sont créées que la première fois que la fonction est exécutée.

class Count(tf.Module):
  def __init__(self):
    self.count = None
  
  @tf.function
  def __call__(self):
    if self.count is None:
      self.count = tf.Variable(0)
    return self.count.assign_add(1)

c = Count()
print(c())
print(c())
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Une autre erreur que vous pouvez rencontrer est une variable récupérée par la mémoire. Contrairement aux fonctions Python normales, les fonctions concrètes ne conservent que les WeakRefs aux variables sur lesquelles elles se ferment, vous devez donc conserver une référence à toutes les variables.

external_var = tf.Variable(3)
@tf.function
def f(x):
  return x * external_var

traced_f = f.get_concrete_function(4)
print("Calling concrete function...")
print(traced_f(4))

del external_var
print()
print("Calling concrete function after garbage collecting its closed Variable...")
with assert_raises(tf.errors.FailedPreconditionError):
  traced_f(4)
Calling concrete function...
tf.Tensor(12, shape=(), dtype=int32)

Calling concrete function after garbage collecting its closed Variable...
Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.FailedPreconditionError'>:

Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-304a18524b57>", line 14, in <module>
    traced_f(4)
tensorflow.python.framework.errors_impl.FailedPreconditionError: 2 root error(s) found.
  (0) Failed precondition:  Error while reading resource variable _AnonymousVar4 from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/_AnonymousVar4/N10tensorflow3VarE does not exist.
     [[node ReadVariableOp (defined at <ipython-input-1-304a18524b57>:4) ]]
  (1) Failed precondition:  Error while reading resource variable _AnonymousVar4 from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/_AnonymousVar4/N10tensorflow3VarE does not exist.
     [[node ReadVariableOp (defined at <ipython-input-1-304a18524b57>:4) ]]
     [[ReadVariableOp/_2]]
0 successful operations.
0 derived errors ignored. [Op:__inference_f_514]

Function call stack:
f -> f


Transformations AutoGraph

AutoGraph est une bibliothèque activée par défaut dans tf.function et transforme un sous-ensemble de code impatient Python en opérations TensorFlow compatibles avec les graphes. Cela inclut le flux de contrôle comme if , for , while .

Les opérations TensorFlow comme tf.cond et tf.while_loop continuent de fonctionner, mais le flux de contrôle est souvent plus facile à écrire et à comprendre lorsqu'il est écrit en Python.

# Simple loop

@tf.function
def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

f(tf.random.uniform([5]))
[0.224704742 0.895507693 0.0398198366 0.98112452 0.278468847]
[0.220997646 0.71410346 0.0397988036 0.753552318 0.271487355]
[0.217468739 0.61324358 0.0397778042 0.637263417 0.265008271]
[0.214104146 0.546406269 0.0397568382 0.563033342 0.258973926]
[0.210891485 0.497821957 0.0397359058 0.510224521 0.253335565]
[0.207819641 0.460402519 0.0397150069 0.470120102 0.248051569]
[0.204878598 0.430412233 0.0396941416 0.438296348 0.243086234]
[0.202059314 0.405665785 0.039673306 0.412231296 0.2384087]
[0.199353606 0.384786367 0.039652504 0.39036563 0.23399213]
[0.196754038 0.366856933 0.0396317355 0.371675402 0.229813099]
[0.194253832 0.351239443 0.039611 0.355456293 0.225851]
[0.191846803 0.337474287 0.0395902954 0.341205537 0.222087651]
[0.189527303 0.325220674 0.0395696238 0.3285532 0.218506947]
[0.187290132 0.314219803 0.0395489857 0.317220151 0.215094551]
[0.185130537 0.304271102 0.0395283774 0.30699119 0.211837649]
[0.183044136 0.295216352 0.0395078026 0.297697395 0.208724767]
[0.181026861 0.286928833 0.0394872613 0.289204 0.205745578]

<tf.Tensor: shape=(5,), dtype=float32, numpy=
array([0.17907499, 0.27930567, 0.03946675, 0.281402  , 0.20289075],
      dtype=float32)>

Si vous êtes curieux, vous pouvez inspecter le code généré par autographe.

print(tf.autograph.to_code(f.python_function))
def tf__f(x):
    with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        do_return = False
        retval_ = ag__.UndefinedReturnValue()

        def get_state():
            return (x,)

        def set_state(vars_):
            nonlocal x
            (x,) = vars_

        def loop_body():
            nonlocal x
            ag__.converted_call(ag__.ld(tf).print, (ag__.ld(x),), None, fscope)
            x = ag__.converted_call(ag__.ld(tf).tanh, (ag__.ld(x),), None, fscope)

        def loop_test():
            return (ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) > 1)
        ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {})
        try:
            do_return = True
            retval_ = ag__.ld(x)
        except:
            do_return = False
            raise
        return fscope.ret(retval_, do_return)


Conditionnels

AutoGraph convertira certaines instructions if <condition> en appels équivalents à tf.cond . Cette substitution est effectuée si <condition> est un Tensor. Sinon, l'instruction if est exécutée en tant que conditionnel Python.

Un conditionnel Python s'exécute pendant le traçage, donc exactement une branche du conditionnel sera ajoutée au graphique. Sans AutoGraph, ce graphique tracé ne pourrait pas prendre la branche alternative s'il existe un flux de contrôle dépendant des données.

tf.cond trace et ajoute les deux branches du conditionnel au graphique, en sélectionnant dynamiquement une branche au moment de l'exécution. Le traçage peut avoir des effets secondaires involontaires; voir Effets de traçage AutoGraph pour en savoir plus.

@tf.function
def fizzbuzz(n):
  for i in tf.range(1, n + 1):
    print('Tracing for loop')
    if i % 15 == 0:
      print('Tracing fizzbuzz branch')
      tf.print('fizzbuzz')
    elif i % 3 == 0:
      print('Tracing fizz branch')
      tf.print('fizz')
    elif i % 5 == 0:
      print('Tracing buzz branch')
      tf.print('buzz')
    else:
      print('Tracing default branch')
      tf.print(i)

fizzbuzz(tf.constant(5))
fizzbuzz(tf.constant(20))
Tracing for loop
Tracing fizzbuzz branch
Tracing fizz branch
Tracing buzz branch
Tracing default branch
1
2
fizz
4
buzz
1
2
fizz
4
buzz
fizz
7
8
fizz
buzz
11
fizz
13
14
fizzbuzz
16
17
fizz
19
buzz

Consultez la documentation de référence pour des restrictions supplémentaires sur les instructions if converties par AutoGraph.

Boucles

AutoGraph convertira certaines instructions for et while en opérations de boucle TensorFlow équivalentes, comme tf.while_loop . tf.while_loop . Si non converti, le for ou while boucle est exécutée en boucle Python.

Cette substitution est effectuée dans les situations suivantes:

  • for x in y : si y est un Tensor, convertissez en tf.while_loop . tf.while_loop . Dans le cas particulier où y est un tf.data.Dataset , une combinaison d'opérations tf.data.Dataset est générée.
  • while <condition> : si <condition> est un Tensor, convertissez en tf.while_loop . tf.while_loop .

Une boucle Python s'exécute pendant le traçage, ajoutant des opérations supplémentaires au tf.Graph pour chaque itération de la boucle.

Une boucle TensorFlow trace le corps de la boucle et sélectionne dynamiquement le nombre d'itérations à exécuter au moment de l'exécution. Le corps de la boucle n'apparaît qu'une seule fois dans le tf.Graph généré.

Consultez la documentation de référence pour connaître les restrictions supplémentaires sur les instructions for et while converties en AutoGraph.

Boucle sur les données Python

Un écueil courant est de boucler sur les données Python / Numpy dans une fonction tf.function . Cette boucle s'exécutera pendant le processus de traçage, ajoutant une copie de votre modèle au tf.Graph pour chaque itération de la boucle.

Si vous souhaitez tf.function toute la boucle d'entraînement dans tf.function , le moyen le plus sûr de le faire est d' tf.data.Dataset vos données sous la forme d'un tf.data.Dataset afin qu'AutoGraph déroule dynamiquement la boucle d'apprentissage.

def measure_graph_size(f, *args):
  g = f.get_concrete_function(*args).graph
  print("{}({}) contains {} nodes in its graph".format(
      f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))

@tf.function
def train(dataset):
  loss = tf.constant(0)
  for x, y in dataset:
    loss += tf.abs(y - x) # Some dummy computation.
  return loss

small_data = [(1, 1)] * 3
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)

measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: big_data, (tf.int32, tf.int32)))
train([(1, 1), (1, 1), (1, 1)]) contains 11 nodes in its graph
train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 8 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 8 nodes in its graph

Lorsque vous encapsulez des données Python / Numpy dans un ensemble de données, tf.data.Dataset.from_generator tf.data.Dataset.from_tensors tf.data.Dataset.from_generator par rapport à tf.data.Dataset.from_tensors . Le premier conservera les données en Python et les récupérera via tf.py_function ce qui peut avoir des implications sur les performances, tandis que le second regroupera une copie des données sous la forme d'un grand nœud tf.constant() dans le graphique, ce qui peut avoir des implications sur la mémoire.

Lecture de données à partir de fichiers via TFRecordDataset / CsvDataset / etc. est le moyen le plus efficace de consommer des données, car TensorFlow lui-même peut gérer le chargement asynchrone et la prélecture des données, sans avoir à impliquer Python. Pour en savoir plus, consultez le guide tf.data .

Accumuler des valeurs dans une boucle

Un modèle courant consiste à accumuler des valeurs intermédiaires à partir d'une boucle. Normalement, cela est accompli en ajoutant à une liste Python ou en ajoutant des entrées à un dictionnaire Python. Cependant, comme ce sont des effets secondaires Python, ils ne fonctionneront pas comme prévu dans une boucle déroulée dynamiquement. Utilisez tf.TensorArray pour accumuler les résultats d'une boucle déroulée dynamiquement.

batch_size = 2
seq_len = 3
feature_size = 4

def rnn_step(inp, state):
  return inp + state

@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
  # [batch, time, features] -> [time, batch, features]
  input_data = tf.transpose(input_data, [1, 0, 2])
  max_seq_len = input_data.shape[0]

  states = tf.TensorArray(tf.float32, size=max_seq_len)
  state = initial_state
  for i in tf.range(max_seq_len):
    state = rnn_step(input_data[i], state)
    states = states.write(i, state)
  return tf.transpose(states.stack(), [1, 0, 2])
  
dynamic_rnn(rnn_step,
            tf.random.uniform([batch_size, seq_len, feature_size]),
            tf.zeros([batch_size, feature_size]))
<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[0.9854791 , 0.5162524 , 0.14062047, 0.04950547],
        [1.8820469 , 0.67421603, 0.40786874, 0.7679055 ],
        [2.8815444 , 1.1567757 , 1.0627073 , 0.8880433 ]],

       [[0.94119024, 0.19776726, 0.24890792, 0.4663092 ],
        [1.4591933 , 1.123581  , 0.35438073, 1.4392309 ],
        [2.0026946 , 1.9165647 , 0.37988353, 1.8128917 ]]], dtype=float32)>

Lectures complémentaires

Pour savoir comment exporter et charger une Function , consultez le guide SavedModel . Pour en savoir plus sur les optimisations de graphes effectuées après le traçage, consultez le guide Grappler . Pour savoir comment optimiser votre pipeline de données et profiler votre modèle, consultez le guide Profiler .