Cette page a été traduite par l'API Cloud Translation.
Switch to English

Une meilleure performance avec tf.function

Voir sur TensorFlow.org Exécuter dans Google Colab Voir la source sur GitHub Bloc - notes Télécharger

Dans tensorflow 2, l'exécution est impatient activée par défaut. L'interface utilisateur est intuitive et flexible (en cours d'exécution unique des opérations est beaucoup plus facile et plus rapide), mais cela peut se faire au détriment de la performance et la déployabilité.

Vous pouvez utiliser tf.function pour faire des graphiques de vos programmes. Il est un outil de transformation qui crée des graphiques de flux de données Python indépendants de votre code Python. Cela vous aidera à créer des modèles et portables performants, et il est nécessaire d'utiliser SavedModel .

Ce guide vous aidera à conceptualiser comment tf.function fonctionne sous le capot afin que vous puissiez l' utiliser efficacement.

Les principaux plats à emporter et recommandations sont les suivantes:

  • Mise au point en mode hâte, puis décorer avec @tf.function .
  • Ne vous fiez pas sur les effets secondaires Python comme la mutation d'ajouter ses objets ou liste.
  • tf.function fonctionne mieux avec ops tensorflow; appels numpy et Python sont converties en constantes.

Installer

 import tensorflow as tf
 

Définir une fonction d'aide pour démontrer les types d'erreurs que vous pourriez rencontrer:

 import traceback
import contextlib

# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
  try:
    yield
  except error_class as e:
    print('Caught expected exception \n  {}:'.format(error_class))
    traceback.print_exc(limit=2)
  except Exception as e:
    raise e
  else:
    raise Exception('Expected {} to be raised but no error was raised!'.format(
        error_class))
 

Notions de base

Usage

A Function vous définissez est comme une opération de tensorflow de base: Vous pouvez l' exécuter avec impatience; vous pouvez calculer des gradients; etc.

 @tf.function
def add(a, b):
  return a + b

add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]
 
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
       [2., 2.]], dtype=float32)>
 v = tf.Variable(1.0)
with tf.GradientTape() as tape:
  result = add(v, 1.0)
tape.gradient(result, v)
 
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

Vous pouvez utiliser la Function de l' intérieur des autres Function s.

 @tf.function
def dense_layer(x, w, b):
  return add(tf.matmul(x, w), b)

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
 
<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
       [3., 3.],
       [3., 3.]], dtype=float32)>

Function s peut être plus rapide que le code désireux, en particulier pour les graphiques avec de nombreuses petites opérations. Mais pour les graphiques avec quelques opérations coûteuses (comme) convolutions, vous ne pouvez pas voir beaucoup speedup.

 import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)

@tf.function
def conv_fn(image):
  return conv_layer(image)

image = tf.zeros([1, 200, 200, 100])
# warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")

 
Eager conv: 0.0023194860004878137
Function conv: 0.0036776439992536325
Note how there's not much difference in performance for convolutions

Tracé

des moyens de frappe dynamiques de Python que vous pouvez appeler des fonctions avec une variété de types d'arguments, et Python peuvent faire quelque chose de différent dans chaque scénario.

Cependant, pour créer un graphique tensorflow, statique dtypes et les dimensions de forme sont nécessaires. tf.function comble cette lacune en enveloppant une fonction python pour créer une Function objet. Sur la base des entrées données, la Function sélectionne le graphique approprié pour les entrées données, retraçant la fonction Python nécessaire. Une fois que vous comprenez pourquoi et quand le traçage se produit, il est beaucoup plus facile à utiliser tf.function efficacement!

Vous pouvez appeler une Function avec des arguments de types différents pour voir ce comportement polymorphique en action.

 @tf.function
def double(a):
  print("Tracing with", a)
  return a + a

print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()

 
Tracing with Tensor("a:0", shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Tracing with Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)

Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'aa', shape=(), dtype=string)


Notez que si vous appelez à plusieurs reprises une Function avec le même type d'argument, tensorflow réutilisera un graphique précédemment tracé, comme le graphique généré serait identique.

 # This doesn't print 'Tracing with ...'
print(double(tf.constant("b")))
 
tf.Tensor(b'bb', shape=(), dtype=string)

(Le changement suivant est disponible dans tensorflow la nuit, et sera disponible en tensorflow 2.3.)

Vous pouvez utiliser pretty_printed_concrete_signatures() pour voir toutes les traces disponibles:

 print(double.pretty_printed_concrete_signatures())
 
double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

double(a)
  Args:
    a: int32 Tensor, shape=()
  Returns:
    int32 Tensor, shape=()

double(a)
  Args:
    a: float32 Tensor, shape=()
  Returns:
    float32 Tensor, shape=()

Jusqu'à présent, vous avez vu que tf.function crée un cache, couche d'expédition dynamique sur la logique de traçage du graphe de tensorflow. Pour être plus précis sur la terminologie:

  • Un tf.Graph est la première, la langue agnostique, représentation portable de votre calcul.
  • Un ConcreteFunction est une enveloppe avec impatience-exécution autour d' un tf.Graph .
  • Une Function gère un cache de ConcreteFunction s et prend la bonne pour vos entrées.
  • tf.function encapsule une fonction Python, retournant une Function objet.

L'obtention de fonctions de béton

Chaque fois qu'une fonction est tracée, une nouvelle fonction est créée en béton. Vous pouvez directement obtenir une fonction de béton, à l'aide get_concrete_function .

 print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.constant("a"))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))

 
Obtaining concrete trace
Executing traced function
tf.Tensor(b'aa', shape=(), dtype=string)
tf.Tensor(b'bb', shape=(), dtype=string)

 # You can also call get_concrete_function on an InputSpec
double_strings_from_inputspec = double.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.string))
print(double_strings_from_inputspec(tf.constant("c")))
 
Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'cc', shape=(), dtype=string)

(Le changement suivant est disponible dans tensorflow la nuit, et sera disponible en tensorflow 2.3.)

Impression d' un ConcreteFunction affiche un résumé de ses arguments d'entrée (avec des types) et son type de sortie.

 print(double_strings)
 
ConcreteFunction double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

Vous pouvez également récupérer directement la signature d'une fonction de béton.

 print(double_strings.structured_input_signature)
print(double_strings.structured_outputs)
 
((TensorSpec(shape=(), dtype=tf.string, name='a'),), {})
Tensor("Identity:0", shape=(), dtype=string)

En utilisant une trace de béton avec des types incompatibles renvoie une erreur

 with assert_raises(tf.errors.InvalidArgumentError):
  double_strings(tf.constant(1))
 
Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>:

Traceback (most recent call last):
  File "<ipython-input-3-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-15-e4e2860a4364>", line 2, in <module>
    double_strings(tf.constant(1))
tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_168 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_168]

Vous pouvez remarquer que les arguments Python bénéficient d'un traitement spécial dans la signature d'entrée d'une fonction de béton. Avant tensorflow 2.3, Python arguments ont été tout simplement retirés de la signature de la fonction concrète. A partir de tensorflow 2.3, Python arguments restent dans la signature, mais sont contraints de prendre l'ensemble de la valeur au cours de traçage.

 @tf.function
def pow(a, b):
  return a ** b

square = pow.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)
print(square)
 
ConcreteFunction pow(a, b=2)
  Args:
    a: float32 Tensor, shape=<unknown>
  Returns:
    float32 Tensor, shape=<unknown>

 assert square(tf.constant(10.0)) == 100

with assert_raises(TypeError):
  square(tf.constant(10.0), b=3)
 
Caught expected exception 
  <class 'TypeError'>:

Traceback (most recent call last):
  File "/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1669, in _call_impl
    cancellation_manager)
  File "/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1714, in _call_with_flat_signature
    self._flat_signature_summary(), ", ".join(sorted(kwargs))))
TypeError: pow(a) got unexpected keyword arguments: b.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<ipython-input-3-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-17-d163f3d206cb>", line 4, in <module>
    square(tf.constant(10.0), b=3)
TypeError: ConcreteFunction pow(a, b) was constructed with int value 2 in b, but was called with int value 3

graphiques obtenir

Chaque fonction de béton est une enveloppe appelable autour d' un tf.Graph . Bien que la récupération de la réelle tf.Graph objet n'est pas quelque chose que vous aurez normalement besoin de le faire, vous pouvez l' obtenir facilement de toute fonction de béton.

 graph = double_strings.graph
for node in graph.as_graph_def().node:
  print(f'{node.input} -> {node.name}')

 
[] -> a
['a', 'a'] -> add
['add'] -> Identity

débogage

En général, le code de débogage est plus facile en mode hâte que l' intérieur tf.function . Vous devez vous assurer que votre code est exécuté sans erreur en mode hâte avant de décorer avec tf.function . Pour aider dans le processus de débogage, vous pouvez appeler tf.config.run_functions_eagerly(True) à l' échelle mondiale et désactiver réactivez tf.function .

Lorsque traquer les problèmes qui apparaissent uniquement dans tf.function , voici quelques conseils:

  • Plain Old Python print appels exécutent seulement pendant le suivi, vous aider à traquer lorsque votre fonction obtient (re) tracé.
  • tf.print appels exécuter à chaque fois, et peut vous aider à traquer les valeurs intermédiaires lors de l' exécution.
  • tf.debugging.enable_check_numerics est un moyen facile de traquer où NaN et Inf sont créés.
  • pdb peut vous aider à comprendre ce qui se passe pendant le suivi. (Caveat: PDB vous déposera dans le code source transformé AutoGraph.)

Tracing sémantique

Cache règles clés

Une Function détermine si la réutilisation d' une fonction concrète tracée par le calcul d' une clé de mémoire cache à partir des arguments et kwargs d'une entrée.

  • La clé générée pour un tf.Tensor argument est sa forme et DTYPE.
  • A partir de tensorflow 2.3, la clé générée pour un tf.Variable argument est son id() .
  • La clé générée pour une est sa valeur primitive python. La clé générée pour imbriquée dict s, list s, tuple s, namedtuple s et attr s est tuple aplaties. (En raison de cet aplatissement, appelant une fonction de béton avec une structure d'emboîtement différent de celui utilisé pendant le suivi entraînera une TypeError).
  • Pour tous les autres types python, les touches sont basées sur l'objet id() de sorte que les procédés sont tracées de façon indépendante pour chaque instance d'une classe.

contrôle de retraçage

Retracing aide assure que tensorflow génère des graphiques corrects pour chaque ensemble d'entrées. Toutefois, le suivi est une opération coûteuse! Si votre Function retrace un nouveau graphique pour chaque appel, vous verrez que votre code exécute plus lentement que si vous ne pas utiliser tf.function .

Pour contrôler le comportement de traçage, vous pouvez utiliser les techniques suivantes:

  • Spécifiez input_signature dans tf.function au traçage des limites.
 @tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
  print("Tracing with", x)
  return tf.where(x % 2 == 0, x // 2, 3 * x + 1)

print(next_collatz(tf.constant([1, 2])))
# We specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([[1, 2], [3, 4]]))

# We specified an int32 dtype in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([1.0, 2.0]))

 
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([4 1], shape=(2,), dtype=int32)
Caught expected exception 
  <class 'ValueError'>:
Caught expected exception 
  <class 'ValueError'>:

Traceback (most recent call last):
  File "<ipython-input-3-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-19-20f544b8adbf>", line 9, in <module>
    next_collatz(tf.constant([[1, 2], [3, 4]]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor(
[[1 2]
 [3 4]], shape=(2, 2), dtype=int32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None))
Traceback (most recent call last):
  File "<ipython-input-3-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-19-20f544b8adbf>", line 13, in <module>
    next_collatz(tf.constant([1.0, 2.0]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor([1. 2.], shape=(2,), dtype=float32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None))

  • Spécifiez une dimension [Aucun] dans tf.TensorSpec pour permettre une certaine souplesse dans la réutilisation des traces.

    Depuis tensorflow correspond tenseurs en fonction de leur forme, en utilisant une None dimension comme joker permettra la Function s aux traces de réutilisation pour l' entrée taille variable. Variablement taille d' entrée peut se produire si vous avez des séquences de longueurs différentes, ou des images de différentes tailles pour chaque lot (Voir transformateur et rêve profond tutoriels par exemple).

 @tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def g(x):
  print('Tracing with', x)
  return x

# No retrace!
print(g(tf.constant([1, 2, 3])))
print(g(tf.constant([1, 2, 3, 4, 5])))

 
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([1 2 3], shape=(3,), dtype=int32)
tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)

  • Cast arguments Python à tenseurs pour réduire retracing.

    Souvent, les arguments Python sont utilisés pour hyperparam'etres de contrôle et constructions graphiques - par exemple, num_layers=10 ou training=True ou nonlinearity='relu' - nonlinearity='relu' . Donc, si l'argument Python change, il est logique que vous auriez à revenir sur le graphique.

    Cependant, il est possible qu'un argument de Python n'est pas utilisé pour contrôler la construction graphique. Dans ces cas, un changement de la valeur Python peut déclencher retraçage inutile. Prenons, par exemple, cette boucle de formation, qui sera AutoGraph Déroulez dynamiquement. Malgré les multiples traces, le graphique généré est en fait identique, donc retraçage est inutile.

 def train_one_step():
  pass

@tf.function
def train(num_steps):
  print("Tracing with num_steps = ", num_steps)
  tf.print("Executing with num_steps = ", num_steps)
  for _ in tf.range(num_steps):
    train_one_step()

print("Retracing occurs for different Python arguments.")
train(num_steps=10)
train(num_steps=20)

print()
print("Traces are reused for Tensor arguments.")
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))
 
Retracing occurs for different Python arguments.
Tracing with num_steps =  10
Executing with num_steps =  10
Tracing with num_steps =  20
Executing with num_steps =  20

Traces are reused for Tensor arguments.
Tracing with num_steps =  Tensor("num_steps:0", shape=(), dtype=int32)
Executing with num_steps =  10
Executing with num_steps =  20

Si vous avez besoin de la force retraçage, créer une nouvelle Function . Séparés Function objets sont garantis pas des traces d'actions.

 def f():
  print('Tracing!')
  tf.print('Executing')

tf.function(f)()
tf.function(f)()
 
Tracing!
Executing
Tracing!
Executing

effets secondaires Python

Effets secondaires Python comme l' impression, annexant aux listes, et ne se produisent que muter GLOBALS la première fois que vous appelez une Function avec un ensemble d'entrées. Par la suite, le tracé tf.Graph est réexécutée, sans exécuter le code Python.

La règle générale est d'utiliser uniquement des effets secondaires Python pour déboguer vos traces. Dans le cas contraire, tensorflow ops comme tf.Variable.assign , tf.print et tf.summary sont la meilleure façon de vous assurer que votre code sera tracé et exécuté par le moteur d' exécution de tensorflow à chaque appel.

 @tf.function
def f(x):
  print("Traced with", x)
  tf.print("Executed with", x)

f(1)
f(1)
f(2)

 
Traced with 1
Executed with 1
Executed with 1
Traced with 2
Executed with 2

De nombreuses fonctionnalités Python, tels que les générateurs et itérateurs, comptent sur le temps d'exécution Python pour garder la trace de l'état. En général, alors que ces constructions fonctionnent comme prévu en mode hâte, beaucoup de choses inattendues peuvent se produire à l' intérieur d' une Function .

Pour donner un exemple, l'avancement iterator état est un effet secondaire Python, et donc ne se produit que pendant le suivi.

 external_var = tf.Variable(0)
@tf.function
def buggy_consume_next(iterator):
  external_var.assign_add(next(iterator))
  tf.print("Value of external_var:", external_var)

iterator = iter([0, 1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)

 
Value of external_var: 0
Value of external_var: 0
Value of external_var: 0

Certaines constructions d'itération sont pris en charge par AutoGraph. Voir la section sur les transformations Autograph pour un aperçu.

Si vous souhaitez exécuter du code Python lors de chaque appel d'une Function , tf.py_function est une trappe de sortie. L'inconvénient de tf.py_function est que ce n'est pas portable ou particulièrement performant, elle ne fonctionne pas bien distribué (multi-GPU, TPU) configurations. En outre, étant donné que tf.py_function doit être branché sur le graphique, il jette toutes les entrées / sorties tenseurs.

API comme tf.gather , tf.stack et tf.TensorArray peuvent vous aider à mettre en œuvre des modèles communs dans tensorflow boucle native.

 external_list = []

def side_effect(x):
  print('Python side effect')
  external_list.append(x)

@tf.function
def f(x):
  tf.py_function(side_effect, inp=[x], Tout=[])

f(1)
f(1)
f(1)
# The list append happens all three times!
assert len(external_list) == 3
# The list contains tf.constant(1), not 1, because py_function casts everything to tensors.
assert external_list[0].numpy() == 1

 
Python side effect
Python side effect
Python side effect

Variables

Vous pouvez rencontrer une erreur lors de la création d' une nouvelle tf.Variable dans une fonction. Ce gardes d'erreur contre la divergence de comportement sur les appels répétés: En mode hâte, une fonction crée une nouvelle variable à chaque appel, mais dans une Function , une nouvelle variable ne peuvent pas être créés en raison de la réutilisation des traces.

 @tf.function
def f(x):
  v = tf.Variable(1.0)
  v.assign_add(x)
  return v

with assert_raises(ValueError):
  f(1.0)
 
Caught expected exception 
  <class 'ValueError'>:

Traceback (most recent call last):
  File "<ipython-input-3-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-26-73e410646579>", line 8, in <module>
    f(1.0)
ValueError: in user code:

    <ipython-input-26-73e410646579>:3 f  *
        v = tf.Variable(1.0)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:262 __call__  **
        return cls._variable_v2_call(*args, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:256 _variable_v2_call
        shape=shape)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:67 getter
        return captured_getter(captured_previous, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py:702 invalid_creator_scope
        "tf.function-decorated function tried to create "

    ValueError: tf.function-decorated function tried to create variables on non-first call.


Vous pouvez créer des variables dans une Function aussi longtemps que ces variables ne sont créées que la première fois que la fonction est exécutée.

 class Count(tf.Module):
  def __init__(self):
    self.count = None
  
  @tf.function
  def __call__(self):
    if self.count is None:
      self.count = tf.Variable(0)
    return self.count.assign_add(1)

c = Count()
print(c())
print(c())
 
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Une autre erreur que vous pouvez rencontrer est une variable ramasse-miettes. Contrairement aux fonctions Python normales, les fonctions concrètes ne conservent que WeakRefs aux variables qu'ils ferment plus, vous devez conserver une référence à toutes les variables.

 external_var = tf.Variable(3)
@tf.function
def f(x):
  return x * external_var

traced_f = f.get_concrete_function(4)
print("Calling concrete function...")
print(traced_f(4))

del external_var
print()
print("Calling concrete function after garbage collecting its closed Variable...")
with assert_raises(tf.errors.FailedPreconditionError):
  traced_f(4)
 
Calling concrete function...
tf.Tensor(12, shape=(), dtype=int32)

Calling concrete function after garbage collecting its closed Variable...
Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.FailedPreconditionError'>:

Traceback (most recent call last):
  File "<ipython-input-3-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-28-304a18524b57>", line 14, in <module>
    traced_f(4)
tensorflow.python.framework.errors_impl.FailedPreconditionError: 2 root error(s) found.
  (0) Failed precondition:  Error while reading resource variable _AnonymousVar4 from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/_AnonymousVar4/N10tensorflow3VarE does not exist.
     [[node ReadVariableOp (defined at <ipython-input-28-304a18524b57>:4) ]]
     [[ReadVariableOp/_2]]
  (1) Failed precondition:  Error while reading resource variable _AnonymousVar4 from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/_AnonymousVar4/N10tensorflow3VarE does not exist.
     [[node ReadVariableOp (defined at <ipython-input-28-304a18524b57>:4) ]]
0 successful operations.
0 derived errors ignored. [Op:__inference_f_514]

Function call stack:
f -> f


Autograph Transformations

Autograph est une bibliothèque qui est activé par défaut dans tf.function , et transforme un sous - ensemble de codes désireux python dans ops tensorflow graphique compatible. Cela inclut le flux de contrôle comme if , for , while .

Tensorflow ops comme tf.cond et tf.while_loop continuent à travailler, mais le flux de contrôle est souvent plus facile d'écrire et de comprendre quand écrit en Python.

 # Simple loop

@tf.function
def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

f(tf.random.uniform([5]))
 
[0.448926926 0.896036148 0.703306437 0.446930766 0.20440042]
[0.421016544 0.714362323 0.6064623 0.419372857 0.201600626]
[0.397786468 0.613405049 0.541632056 0.396401972 0.198913112]
[0.378053397 0.546519518 0.494222373 0.376866162 0.196330562]
[0.361015767 0.497907132 0.457561225 0.359982818 0.1938463]
[0.346108437 0.460469633 0.428094476 0.3451989 0.191454232]
[0.332919776 0.43046692 0.403727621 0.332110822 0.189148799]
[0.321141869 0.405711472 0.383133948 0.320416152 0.18692489]
[0.310539037 0.384825289 0.365426034 0.309883147 0.184777796]
[0.300927401 0.366890609 0.349984437 0.300330788 0.182703182]
[0.292161077 0.351268977 0.336361736 0.291615278 0.180697069]
[0.284122646 0.337500453 0.324225426 0.283620834 0.178755745]
[0.276716352 0.325244069 0.313322544 0.276252925 0.176875815]
[0.269863278 0.314240903 0.303456694 0.269433528 0.175054088]
[0.263497591 0.304290265 0.294472754 0.263097644 0.17328763]
[0.257564 0.295233846 0.2862463 0.257190555 0.171573699]
[0.25201565 0.286944896 0.278676242 0.25166589 0.169909731]
[0.246812463 0.279320478 0.271679461 0.246483982 0.168293342]
[0.24192 0.272276044 0.265186876 0.241610721 0.166722313]
[0.237308443 0.265741408 0.259140551 0.237016559 0.165194541]
[0.23295185 0.25965777 0.253491491 0.232675791 0.163708091]
[0.228827521 0.253975391 0.248197898 0.228565902 0.162261128]
[0.224915475 0.248651937 0.243223906 0.224667087 0.160851941]
[0.221198082 0.243651047 0.238538548 0.220961839 0.159478888]
[0.217659682 0.238941342 0.23411487 0.217434615 0.158140466]
[0.214286327 0.23449555 0.229929343 0.214071587 0.156835243]
[0.211065561 0.230289876 0.225961298 0.210860386 0.155561864]
[0.207986191 0.226303399 0.222192511 0.207789883 0.154319063]
[0.20503816 0.222517684 0.2186068 0.204850093 0.153105617]

<tf.Tensor: shape=(5,), dtype=float32, numpy=
array([0.20221236, 0.2189164 , 0.21518978, 0.20203198, 0.15192041],
      dtype=float32)>

Si vous êtes curieux, vous pouvez vérifier l'autographe de code génère.

 print(tf.autograph.to_code(f.python_function))
 
def tf__f(x):
    with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        do_return = False
        retval_ = ag__.UndefinedReturnValue()

        def get_state():
            return (x,)

        def set_state(vars_):
            nonlocal x
            (x,) = vars_

        def loop_body():
            nonlocal x
            ag__.converted_call(ag__.ld(tf).print, (ag__.ld(x),), None, fscope)
            x = ag__.converted_call(ag__.ld(tf).tanh, (ag__.ld(x),), None, fscope)

        def loop_test():
            return (ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) > 1)
        ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {})
        try:
            do_return = True
            retval_ = ag__.ld(x)
        except:
            do_return = False
            raise
        return fscope.ret(retval_, do_return)


conditionals

AutoGraph convertira certains if <condition> déclarations en équivalent tf.cond appels. Cette substitution est faite si <condition> est un Tensor. Dans le cas contraire, la if instruction est exécutée comme condition Python.

Un Python conditionnel exécute pendant le suivi, si exactement une branche du conditionnel seront ajoutés au graphique. Sans AutoGraph, ce graphique tracé serait incapable de prendre la branche alternative en cas de débit de contrôle dépendant des données.

tf.cond traces et ajoute les deux branches du conditionnel au graphe, en sélectionnant de façon dynamique une branche au moment de l' exécution. Le traçage peut avoir des effets secondaires indésirables; voir les effets de traçage AutoGraph pour plus.

 @tf.function
def fizzbuzz(n):
  for i in tf.range(1, n + 1):
    print('Tracing for loop')
    if i % 15 == 0:
      print('Tracing fizzbuzz branch')
      tf.print('fizzbuzz')
    elif i % 3 == 0:
      print('Tracing fizz branch')
      tf.print('fizz')
    elif i % 5 == 0:
      print('Tracing buzz branch')
      tf.print('buzz')
    else:
      print('Tracing default branch')
      tf.print(i)

fizzbuzz(tf.constant(5))
fizzbuzz(tf.constant(20))
 
Tracing for loop
Tracing fizzbuzz branch
Tracing fizz branch
Tracing buzz branch
Tracing default branch
1
2
fizz
4
buzz
1
2
fizz
4
buzz
fizz
7
8
fizz
buzz
11
fizz
13
14
fizzbuzz
16
17
fizz
19
buzz

Voir la documentation de référence des restrictions supplémentaires sur AutoGraph-converties si les déclarations.

Boucles

AutoGraph convertira certains for et while déclarations dans le ops en boucle tensorflow équivalent, comme tf.while_loop . Si non converti, le for ou while boucle est exécutée en boucle Python.

Cette substitution se fait dans les situations suivantes:

  • for x in y : si y est un Tensor, converti à tf.while_loop . Dans le cas particulier où y est un tf.data.Dataset , une combinaison de tf.data.Dataset ops sont générés.
  • while <condition> : si <condition> est un Tensor, converti à tf.while_loop .

Une boucle Python exécute pendant le traçage, l' ajout d' opérations supplémentaires au tf.Graph pour chaque itération de la boucle.

Une boucle tensorflow trace le corps de la boucle, et dynamiquement sélectionne le nombre d'itérations à exécuter au moment de l'exécution. Le corps de la boucle apparaît une seule fois dans le produit tf.Graph .

Voir la documentation de référence des restrictions supplémentaires sur AutoGraph convertis for et while déclarations.

Bouclez données Python

Un piège commun est de boucle sur les données Python / NumPy dans un tf.function . Cette boucle exécutera au cours du processus de traçage, en ajoutant une copie de votre modèle à la tf.Graph pour chaque itération de la boucle.

Si vous voulez envelopper la boucle de formation dans toute tf.function , la meilleure façon de le faire est d'envelopper vos données en tant que tf.data.Dataset afin que AutoGraph sera Déroulez dynamiquement la boucle de formation.

 def measure_graph_size(f, *args):
  g = f.get_concrete_function(*args).graph
  print("{}({}) contains {} nodes in its graph".format(
      f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))

@tf.function
def train(dataset):
  loss = tf.constant(0)
  for x, y in dataset:
    loss += tf.abs(y - x) # Some dummy computation.
  return loss

small_data = [(1, 1)] * 3
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)

measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: big_data, (tf.int32, tf.int32)))
 
train([(1, 1), (1, 1), (1, 1)]) contains 11 nodes in its graph
train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 8 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 8 nodes in its graph

Quand vous emballez des données Python / NumPy dans un Dataset, garder à l' esprit tf.data.Dataset.from_generator par rapport tf.data.Dataset.from_tensors . L'ancien gardera les données dans le python et la récupération via tf.py_function qui peut avoir des conséquences sur les performances, alors que ce dernier regrouper une copie des données comme une grande tf.constant() noeud dans le graphe, qui peut avoir des implications de mémoire.

La lecture des données à partir de fichiers via TFRecordDataset / CsvDataset / etc. est le moyen le plus efficace de consommer des données, comme alors tensorflow lui-même peut gérer le chargement asynchrone et préchargement des données, sans avoir à impliquer Python. Pour en savoir plus, consultez le guide de tf.data .

Les valeurs d'accumulation en boucle

Un motif commun est d'accumuler des valeurs intermédiaires à partir d'une boucle. Normalement, cela se fait en ajoutant à une liste Python ou l'ajout d'entrées à un dictionnaire Python. Cependant, comme ce sont des effets secondaires Python, ils ne fonctionnent pas comme prévu dans une boucle dynamique déroulée. Utilisation tf.TensorArray pour accumuler les résultats d'une boucle dynamique déroulée.

 batch_size = 2
seq_len = 3
feature_size = 4

def rnn_step(inp, state):
  return inp + state

@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
  # [batch, time, features] -> [time, batch, features]
  input_data = tf.transpose(input_data, [1, 0, 2])
  max_seq_len = input_data.shape[0]

  states = tf.TensorArray(tf.float32, size=max_seq_len)
  state = initial_state
  for i in tf.range(max_seq_len):
    state = rnn_step(input_data[i], state)
    states = states.write(i, state)
  return tf.transpose(states.stack(), [1, 0, 2])
  
dynamic_rnn(rnn_step,
            tf.random.uniform([batch_size, seq_len, feature_size]),
            tf.zeros([batch_size, feature_size]))
 
<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[0.2486304 , 0.0612042 , 0.69624186, 0.28587592],
        [1.2193475 , 0.2389338 , 1.5216837 , 0.38649392],
        [1.7640524 , 1.1970762 , 2.3265643 , 0.81419575]],

       [[0.36599267, 0.41830885, 0.73540664, 0.63987565],
        [0.48354673, 1.1808103 , 1.7210082 , 0.8333106 ],
        [0.7138835 , 1.2030114 , 1.8544207 , 1.1647347 ]]], dtype=float32)>

Pour en savoir plus

Pour en savoir plus sur la façon d'exporter et de charger une Function , consultez le Guide SavedModel . Pour en savoir plus sur l' optimisation des graphiques qui sont effectuées après le suivi, voir le Guide Grappler . Pour savoir comment optimiser votre pipeline de données et le profil de votre modèle, consultez le Guide Profiler .