Aide à protéger la Grande barrière de corail avec tensorflow sur Kaggle Rejoignez Défi

Introduction aux graphes et tf.function

Voir sur TensorFlow.org Exécuter dans Google Colab Voir la source sur GitHub Télécharger le cahier

Aperçu

Ce guide va sous la surface de TensorFlow et Keras pour montrer comment fonctionne TensorFlow. Si vous voulez au lieu de commencer immédiatement avec Keras, consultez la collection de guides KERAS .

Dans ce guide, vous apprendrez comment TensorFlow vous permet d'apporter des modifications simples à votre code pour obtenir des graphiques, comment les graphiques sont stockés et représentés, et comment vous pouvez les utiliser pour accélérer vos modèles.

Ceci est un aperçu grand-image qui couvre la façon dont tf.function vous permet de passer d'exécution désireux d'exécution graphique. Pour une spécification plus complète de tf.function , allez à la tf.function Guide .

Que sont les graphiques ?

Dans les trois précédents guides, vous avez couru tensorflow avec impatience. Cela signifie que les opérations TensorFlow sont exécutées par Python, opération par opération, et renvoient les résultats à Python.

Alors que l'exécution rapide présente plusieurs avantages uniques, l'exécution graphique permet la portabilité en dehors de Python et a tendance à offrir de meilleures performances. Graphique des moyens d'exécution que les calculs de tenseurs sont exécutées sous forme de graphique tensorflow, parfois appelé tf.Graph ou simplement un « graphique ».

Les graphiques sont des structures de données qui contiennent un ensemble de tf.Operation objets, qui représentent des unités de calcul; et tf.Tensor objets, qui représentent les unités de données qui circulent entre les opérations. Ils sont définis dans un tf.Graph contexte. Étant donné que ces graphiques sont des structures de données, ils peuvent être enregistrés, exécutés et restaurés sans le code Python d'origine.

Voici à quoi ressemble un graphique TensorFlow représentant un réseau de neurones à deux couches lorsqu'il est visualisé dans TensorBoard.

Un simple graphique TensorFlow

Les avantages des graphiques

Avec un graphique, vous avez une grande flexibilité. Vous pouvez utiliser votre graphique TensorFlow dans des environnements qui n'ont pas d'interpréteur Python, comme les applications mobiles, les appareils intégrés et les serveurs principaux. Tensorflow utilise des graphiques comme le format des modèles enregistrés quand il les exporte à partir de Python.

Les graphiques sont également facilement optimisés, permettant au compilateur d'effectuer des transformations telles que :

  • En déduire la valeur statique de tenseurs par pliage noeuds de constants dans le calcul ( « pliage constant »).
  • Séparez les sous-parties d'un calcul qui sont indépendantes et répartissez-les entre des threads ou des périphériques.
  • Simplifiez les opérations arithmétiques en éliminant les sous-expressions courantes.

Il existe un système d'optimisation ensemble, Grappler , pour réaliser cela et d' autres speedups.

En bref, les graphiques sont extrêmement utiles et laissez votre tensorflow courir vite, courir en parallèle et fonctionnent efficacement sur plusieurs appareils.

Cependant, vous souhaitez toujours définir vos modèles d'apprentissage automatique (ou d'autres calculs) en Python pour plus de commodité, puis construire automatiquement des graphiques lorsque vous en avez besoin.

Installer

import tensorflow as tf
import timeit
from datetime import datetime

Tirer parti des graphiques

Vous créez et exécutez un graphique dans tensorflow en utilisant tf.function , que ce soit comme un appel direct ou comme décorateur. tf.function prend une fonction régulière en entrée et renvoie une Function . Une Function est un appelable Python qui construit des graphiques tensorflow de la fonction Python. Vous utilisez une Function de la même manière que son équivalent Python.

# Define a Python function.
def a_regular_function(x, y, b):
  x = tf.matmul(x, y)
  x = x + b
  return x

# `a_function_that_uses_a_graph` is a TensorFlow `Function`.
a_function_that_uses_a_graph = tf.function(a_regular_function)

# Make some tensors.
x1 = tf.constant([[1.0, 2.0]])
y1 = tf.constant([[2.0], [3.0]])
b1 = tf.constant(4.0)

orig_value = a_regular_function(x1, y1, b1).numpy()
# Call a `Function` like a Python function.
tf_function_value = a_function_that_uses_a_graph(x1, y1, b1).numpy()
assert(orig_value == tf_function_value)

A l'extérieur, une Function ressemble à une fonction régulière vous écrire en utilisant des opérations de tensorflow. Au- dessous , cependant, il est très différent. Une Function encapsule plusieurs tf.Graph s derrière une API . C'est ainsi que la Function est en mesure de vous donner les avantages de l' exécution graphique , comme la vitesse et la déployabilité.

tf.function applique à une fonction et toutes les autres fonctions qu'elle appelle:

def inner_function(x, y, b):
  x = tf.matmul(x, y)
  x = x + b
  return x

# Use the decorator to make `outer_function` a `Function`.
@tf.function
def outer_function(x):
  y = tf.constant([[2.0], [3.0]])
  b = tf.constant(4.0)

  return inner_function(x, y, b)

# Note that the callable will create a graph that
# includes `inner_function` as well as `outer_function`.
outer_function(tf.constant([[1.0, 2.0]])).numpy()
array([[12.]], dtype=float32)

Si vous avez utilisé 1.x tensorflow, vous remarquerez que , à aucun moment vous devez définir un Placeholder ou tf.Session .

Conversion de fonctions Python en graphiques

Toutes les fonctions que vous écrivez avec tensorflow contiendra un mélange de intégré dans les opérations de TF et de la logique Python, comme if-then des clauses, des boucles, break , return , continue , et plus encore. Alors que les opérations tensorflow sont facilement capturés par un tf.Graph besoins, logiques spécifiques Python pour subir une étape supplémentaire afin de faire partie du graphique. tf.function utilise une bibliothèque appelée AutoGraph ( tf.autograph ) pour convertir le code Python en code générant le graphique.

def simple_relu(x):
  if tf.greater(x, 0):
    return x
  else:
    return 0

# `tf_simple_relu` is a TensorFlow `Function` that wraps `simple_relu`.
tf_simple_relu = tf.function(simple_relu)

print("First branch, with graph:", tf_simple_relu(tf.constant(1)).numpy())
print("Second branch, with graph:", tf_simple_relu(tf.constant(-1)).numpy())
First branch, with graph: 1
Second branch, with graph: 0

Bien qu'il soit peu probable que vous ayez besoin de visualiser les graphiques directement, vous pouvez inspecter les sorties pour vérifier les résultats exacts. Ce ne sont pas faciles à lire, donc pas besoin de trop regarder !

# This is the graph-generating output of AutoGraph.
print(tf.autograph.to_code(simple_relu))
def tf__simple_relu(x):
    with ag__.FunctionScope('simple_relu', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        do_return = False
        retval_ = ag__.UndefinedReturnValue()

        def get_state():
            return (do_return, retval_)

        def set_state(vars_):
            nonlocal retval_, do_return
            (do_return, retval_) = vars_

        def if_body():
            nonlocal retval_, do_return
            try:
                do_return = True
                retval_ = ag__.ld(x)
            except:
                do_return = False
                raise

        def else_body():
            nonlocal retval_, do_return
            try:
                do_return = True
                retval_ = 0
            except:
                do_return = False
                raise
        ag__.if_stmt(ag__.converted_call(ag__.ld(tf).greater, (ag__.ld(x), 0), None, fscope), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
        return fscope.ret(retval_, do_return)
# This is the graph itself.
print(tf_simple_relu.get_concrete_function(tf.constant(1)).graph.as_graph_def())
node {
  name: "x"
  op: "Placeholder"
  attr {
    key: "_user_specified_name"
    value {
      s: "x"
    }
  }
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "shape"
    value {
      shape {
      }
    }
  }
}
node {
  name: "Greater/y"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
        }
        int_val: 0
      }
    }
  }
}
node {
  name: "Greater"
  op: "Greater"
  input: "x"
  input: "Greater/y"
  attr {
    key: "T"
    value {
      type: DT_INT32
    }
  }
}
node {
  name: "cond"
  op: "StatelessIf"
  input: "Greater"
  input: "x"
  attr {
    key: "Tcond"
    value {
      type: DT_BOOL
    }
  }
  attr {
    key: "Tin"
    value {
      list {
        type: DT_INT32
      }
    }
  }
  attr {
    key: "Tout"
    value {
      list {
        type: DT_BOOL
        type: DT_INT32
      }
    }
  }
  attr {
    key: "_lower_using_switch_merge"
    value {
      b: true
    }
  }
  attr {
    key: "_read_only_resource_inputs"
    value {
      list {
      }
    }
  }
  attr {
    key: "else_branch"
    value {
      func {
        name: "cond_false_34"
      }
    }
  }
  attr {
    key: "output_shapes"
    value {
      list {
        shape {
        }
        shape {
        }
      }
    }
  }
  attr {
    key: "then_branch"
    value {
      func {
        name: "cond_true_33"
      }
    }
  }
}
node {
  name: "cond/Identity"
  op: "Identity"
  input: "cond"
  attr {
    key: "T"
    value {
      type: DT_BOOL
    }
  }
}
node {
  name: "cond/Identity_1"
  op: "Identity"
  input: "cond:1"
  attr {
    key: "T"
    value {
      type: DT_INT32
    }
  }
}
node {
  name: "Identity"
  op: "Identity"
  input: "cond/Identity_1"
  attr {
    key: "T"
    value {
      type: DT_INT32
    }
  }
}
library {
  function {
    signature {
      name: "cond_false_34"
      input_arg {
        name: "cond_placeholder"
        type: DT_INT32
      }
      output_arg {
        name: "cond_identity"
        type: DT_BOOL
      }
      output_arg {
        name: "cond_identity_1"
        type: DT_INT32
      }
    }
    node_def {
      name: "cond/Const"
      op: "Const"
      attr {
        key: "dtype"
        value {
          type: DT_BOOL
        }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_BOOL
            tensor_shape {
            }
            bool_val: true
          }
        }
      }
      experimental_debug_info {
        original_node_names: "cond/Const"
      }
    }
    node_def {
      name: "cond/Const_1"
      op: "Const"
      attr {
        key: "dtype"
        value {
          type: DT_BOOL
        }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_BOOL
            tensor_shape {
            }
            bool_val: true
          }
        }
      }
      experimental_debug_info {
        original_node_names: "cond/Const_1"
      }
    }
    node_def {
      name: "cond/Const_2"
      op: "Const"
      attr {
        key: "dtype"
        value {
          type: DT_INT32
        }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_INT32
            tensor_shape {
            }
            int_val: 0
          }
        }
      }
      experimental_debug_info {
        original_node_names: "cond/Const_2"
      }
    }
    node_def {
      name: "cond/Const_3"
      op: "Const"
      attr {
        key: "dtype"
        value {
          type: DT_BOOL
        }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_BOOL
            tensor_shape {
            }
            bool_val: true
          }
        }
      }
      experimental_debug_info {
        original_node_names: "cond/Const_3"
      }
    }
    node_def {
      name: "cond/Identity"
      op: "Identity"
      input: "cond/Const_3:output:0"
      attr {
        key: "T"
        value {
          type: DT_BOOL
        }
      }
      experimental_debug_info {
        original_node_names: "cond/Identity"
      }
    }
    node_def {
      name: "cond/Const_4"
      op: "Const"
      attr {
        key: "dtype"
        value {
          type: DT_INT32
        }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_INT32
            tensor_shape {
            }
            int_val: 0
          }
        }
      }
      experimental_debug_info {
        original_node_names: "cond/Const_4"
      }
    }
    node_def {
      name: "cond/Identity_1"
      op: "Identity"
      input: "cond/Const_4:output:0"
      attr {
        key: "T"
        value {
          type: DT_INT32
        }
      }
      experimental_debug_info {
        original_node_names: "cond/Identity_1"
      }
    }
    ret {
      key: "cond_identity"
      value: "cond/Identity:output:0"
    }
    ret {
      key: "cond_identity_1"
      value: "cond/Identity_1:output:0"
    }
    attr {
      key: "_construction_context"
      value {
        s: "kEagerRuntime"
      }
    }
    arg_attr {
      key: 0
      value {
        attr {
          key: "_output_shapes"
          value {
            list {
              shape {
              }
            }
          }
        }
      }
    }
  }
  function {
    signature {
      name: "cond_true_33"
      input_arg {
        name: "cond_identity_1_x"
        type: DT_INT32
      }
      output_arg {
        name: "cond_identity"
        type: DT_BOOL
      }
      output_arg {
        name: "cond_identity_1"
        type: DT_INT32
      }
    }
    node_def {
      name: "cond/Const"
      op: "Const"
      attr {
        key: "dtype"
        value {
          type: DT_BOOL
        }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_BOOL
            tensor_shape {
            }
            bool_val: true
          }
        }
      }
      experimental_debug_info {
        original_node_names: "cond/Const"
      }
    }
    node_def {
      name: "cond/Identity"
      op: "Identity"
      input: "cond/Const:output:0"
      attr {
        key: "T"
        value {
          type: DT_BOOL
        }
      }
      experimental_debug_info {
        original_node_names: "cond/Identity"
      }
    }
    node_def {
      name: "cond/Identity_1"
      op: "Identity"
      input: "cond_identity_1_x"
      attr {
        key: "T"
        value {
          type: DT_INT32
        }
      }
      experimental_debug_info {
        original_node_names: "cond/Identity_1"
      }
    }
    ret {
      key: "cond_identity"
      value: "cond/Identity:output:0"
    }
    ret {
      key: "cond_identity_1"
      value: "cond/Identity_1:output:0"
    }
    attr {
      key: "_construction_context"
      value {
        s: "kEagerRuntime"
      }
    }
    arg_attr {
      key: 0
      value {
        attr {
          key: "_output_shapes"
          value {
            list {
              shape {
              }
            }
          }
        }
      }
    }
  }
}
versions {
  producer: 808
  min_consumer: 12
}

La plupart du temps, tf.function fonctionnera sans considérations particulières. Cependant, il y a des mises en garde et le guide tf.function peut aider ici, ainsi que la référence complète Autograph

Polymorphisme: une Function , de nombreux graphiques

Un tf.Graph est spécialisé pour un type spécifique d'entrées (par exemple, avec un tenseurs spécifique dtype ou des objets avec le même id() ).

Chaque fois que vous invoquez une Function avec de nouveaux dtypes et des formes dans ses arguments, la Function crée une nouvelle tf.Graph pour les nouveaux arguments. Les dtypes et les formes d'une tf.Graph entrées de » sont connus comme une signature d'entrée ou juste une signature.

Les Function stocke le tf.Graph correspondant à cette signature dans un ConcreteFunction . Un ConcreteFunction est une enveloppe autour d' un tf.Graph .

@tf.function
def my_relu(x):
  return tf.maximum(0., x)

# `my_relu` creates new graphs as it observes more signatures.
print(my_relu(tf.constant(5.5)))
print(my_relu([1, -1]))
print(my_relu(tf.constant([3., -3.])))
tf.Tensor(5.5, shape=(), dtype=float32)
tf.Tensor([1. 0.], shape=(2,), dtype=float32)
tf.Tensor([3. 0.], shape=(2,), dtype=float32)

Si la Function a déjà été appelé avec cette signature, la Function ne crée pas une nouvelle tf.Graph .

# These two calls do *not* create new graphs.
print(my_relu(tf.constant(-2.5))) # Signature matches `tf.constant(5.5)`.
print(my_relu(tf.constant([-1., 1.]))) # Signature matches `tf.constant([3., -3.])`.
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor([0. 1.], shape=(2,), dtype=float32)

Parce qu'il est soutenu par plusieurs graphiques, une Function est polymorphes. Cela permet de soutenir les types plus qu'une seule entrée tf.Graph pourrait représenter, ainsi que pour optimiser chaque tf.Graph pour une meilleure performance.

# There are three `ConcreteFunction`s (one for each graph) in `my_relu`.
# The `ConcreteFunction` also knows the return type and shape!
print(my_relu.pretty_printed_concrete_signatures())
my_relu(x)
  Args:
    x: float32 Tensor, shape=()
  Returns:
    float32 Tensor, shape=()

my_relu(x=[1, -1])
  Returns:
    float32 Tensor, shape=(2,)

my_relu(x)
  Args:
    x: float32 Tensor, shape=(2,)
  Returns:
    float32 Tensor, shape=(2,)

L' utilisation tf.function

Jusqu'à présent, vous avez appris comment convertir une fonction Python dans un graphique en utilisant simplement tf.function comme un décorateur ou un emballage. Mais dans la pratique, se tf.function au travail peut être délicat correctement! Dans les sections suivantes, vous apprendrez comment vous pouvez faire votre travail de code comme prévu avec tf.function .

Exécution graphique vs exécution avide

Le code dans une Function peut être exécutée à la fois avec impatience et comme un graphique. Par défaut, la Function exécute son code sous forme de graphique:

@tf.function
def get_MSE(y_true, y_pred):
  sq_diff = tf.pow(y_true - y_pred, 2)
  return tf.reduce_mean(sq_diff)
y_true = tf.random.uniform([5], maxval=10, dtype=tf.int32)
y_pred = tf.random.uniform([5], maxval=10, dtype=tf.int32)
print(y_true)
print(y_pred)
tf.Tensor([6 1 7 8 0], shape=(5,), dtype=int32)
tf.Tensor([6 0 1 8 6], shape=(5,), dtype=int32)
get_MSE(y_true, y_pred)
<tf.Tensor: shape=(), dtype=int32, numpy=14>

Pour vérifier que votre Function « graphe de fait le même calcul que son équivalent fonction Python, vous pouvez le faire exécuter avec enthousiasme avec tf.config.run_functions_eagerly(True) . Ceci est un interrupteur qui désactive la Function la capacité de créer et exécuter des graphiques, au lieu d' exécuter le code normalement.

tf.config.run_functions_eagerly(True)
get_MSE(y_true, y_pred)
<tf.Tensor: shape=(), dtype=int32, numpy=14>
# Don't forget to set it back when you are done.
tf.config.run_functions_eagerly(False)

Cependant, la Function peut se comporter différemment dans le graphique et l' exécution avide. Python print fonction est un exemple de la façon dont ces deux modes différents. Soit un chèque de ce qui se passe lorsque vous insérez une print déclaration à votre fonction et l' appeler à plusieurs reprises.

@tf.function
def get_MSE(y_true, y_pred):
  print("Calculating MSE!")
  sq_diff = tf.pow(y_true - y_pred, 2)
  return tf.reduce_mean(sq_diff)

Observez ce qui est imprimé :

error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)
Calculating MSE!

Le rendu est-il surprenant ? get_MSE seulement imprimé une fois , même si elle a été appelé trois fois.

Pour expliquer, l' print instruction est exécutée lorsque la Function exécute le code d' origine afin de créer le graphique dans un processus connu sous le nom « traçage » . Tracing capture les opérations tensorflow dans un graphique et l' print n'est pas capturé dans le graphique. Ce graphique est alors exécuté pour les trois appels sans jamais courir à nouveau le code Python.

Pour vérifier l'intégrité, désactivons l'exécution du graphique pour comparer :

# Now, globally set everything to run eagerly to force eager execution.
tf.config.run_functions_eagerly(True)
# Observe what is printed below.
error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)
Calculating MSE!
Calculating MSE!
Calculating MSE!
tf.config.run_functions_eagerly(False)

print est un effet secondaire Python, et il y a d' autres différences que vous devriez être au courant lors de la conversion d' une fonction en Function .

Exécution non stricte

L'exécution du graphe n'exécute que les opérations nécessaires pour produire les effets observables, ce qui comprend :

  • La valeur de retour de la fonction
  • Effets secondaires bien connus documentés tels que :
    • Les opérations d' entrée / sortie, comme tf.print
    • Des opérations de déboguage, telles que les fonctions assert dans tf.debugging
    • Mutations de tf.Variable

Ce comportement est généralement connu sous le nom d'"exécution non stricte" et diffère de l'exécution rapide, qui passe par toutes les opérations du programme, nécessaires ou non.

En particulier, la vérification des erreurs d'exécution ne compte pas comme un effet observable. Si une opération est ignorée parce qu'elle est inutile, elle ne peut pas générer d'erreurs d'exécution.

Dans l'exemple suivant, l'opération « inutile » tf.gather est ignorée lors de l' exécution graphique, donc l'erreur d'exécution InvalidArgumentError ne se pose pas , car il serait en exécution avide. Ne vous fiez pas à une erreur générée lors de l'exécution d'un graphique.

def unused_return_eager(x):
  # Get index 1 will fail when `len(x) == 1`
  tf.gather(x, [1]) # unused 
  return x

try:
  print(unused_return_eager(tf.constant([0.0])))
except tf.errors.InvalidArgumentError as e:
  # All operations are run during eager execution so an error is raised.
  print(f'{type(e).__name__}: {e}')
tf.Tensor([0.], shape=(1,), dtype=float32)
@tf.function
def unused_return_graph(x):
  tf.gather(x, [1]) # unused
  return x

# Only needed operations are run during graph exection. The error is not raised.
print(unused_return_graph(tf.constant([0.0])))
tf.Tensor([0.], shape=(1,), dtype=float32)

tf.function les meilleures pratiques

Il peut prendre un certain temps pour se habituer au comportement de la Function . Pour commencer rapidement, les nouveaux utilisateurs doivent jouer avec la décoration de fonctions de jouets avec @tf.function pour obtenir de l' expérience d'aller de hâte à l' exécution graphique.

La conception de tf.function peut être votre meilleur pari pour écrire des programmes tensorflow compatible graphique. Voici quelques conseils:

  • Bascule entre l' exécution et le graphique désireux tôt et souvent avec tf.config.run_functions_eagerly à Pinpoint si / quand les deux modes diverger.
  • Créer tf.Variable de l'extérieur de la fonction Python et les modifier à l'intérieur. La même chose vaut pour les objets qui utilisent tf.Variable , comme keras.layers , keras.Model s et tf.optimizers .
  • Évitez d' écrire des fonctions qui dépendent des variables Python extérieures , à l' exclusion tf.Variable s et des objets Keras.
  • Préférez écrire des fonctions qui prennent des tenseurs et d'autres types de TensorFlow en entrée. Vous pouvez passer dans d' autres types d'objets , mais soyez prudent !
  • Inclure autant que possible le calcul sous un tf.function afin de maximiser le gain de performance. Par exemple, décorez une étape d'entraînement entière ou la boucle d'entraînement entière.

Vu l'accélération

tf.function améliore généralement les performances de votre code, mais la quantité de vitesse en fonction du type de calcul que vous exécutez. Les petits calculs peuvent être dominés par le surcoût lié à l'appel d'un graphe. Vous pouvez mesurer la différence de performance comme suit :

x = tf.random.uniform(shape=[10, 10], minval=-1, maxval=2, dtype=tf.dtypes.int32)

def power(x, y):
  result = tf.eye(10, dtype=tf.dtypes.int32)
  for _ in range(y):
    result = tf.matmul(x, result)
  return result
print("Eager execution:", timeit.timeit(lambda: power(x, 100), number=1000))
Eager execution: 2.0122516460000384
power_as_graph = tf.function(power)
print("Graph execution:", timeit.timeit(lambda: power_as_graph(x, 100), number=1000))
Graph execution: 0.6084441319999883

tf.function est couramment utilisé pour accélérer les boucles de formation, et vous pouvez en apprendre davantage à ce sujet dans L' écriture d' une boucle de formation à partir de zéro avec Keras.

Performances et compromis

Les graphiques peuvent accélérer votre code, mais le processus de création de ces derniers a une certaine surcharge. Pour certaines fonctions, la création du graphe prend plus de temps que l'exécution du graphe. Cet investissement est généralement rapidement remboursé grâce à l'amélioration des performances des exécutions ultérieures, mais il est important de savoir que les premières étapes de toute formation sur un grand modèle peuvent être plus lentes en raison du traçage.

Quelle que soit la taille de votre modèle, vous voulez éviter de tracer fréquemment. Le tf.function guide explique comment les spécifications d'entrée ensemble et arguments du tenseur d'utilisation pour éviter retracing. Si vous constatez que vous obtenez des performances inhabituellement médiocres, il est judicieux de vérifier si vous revenez accidentellement.

Quand une Function traçait?

Pour savoir si votre Function est calque, ajouter une print déclaration à son code. En règle générale, la Function exécute l' print relevé à chaque fois qu'il trace.

@tf.function
def a_function_with_python_side_effect(x):
  print("Tracing!") # An eager-only side effect.
  return x * x + tf.constant(2)

# This is traced the first time.
print(a_function_with_python_side_effect(tf.constant(2)))
# The second time through, you won't see the side effect.
print(a_function_with_python_side_effect(tf.constant(3)))
Tracing!
tf.Tensor(6, shape=(), dtype=int32)
tf.Tensor(11, shape=(), dtype=int32)
# This retraces each time the Python argument changes,
# as a Python argument could be an epoch count or other
# hyperparameter.
print(a_function_with_python_side_effect(2))
print(a_function_with_python_side_effect(3))
Tracing!
tf.Tensor(6, shape=(), dtype=int32)
Tracing!
tf.Tensor(11, shape=(), dtype=int32)

Les nouveaux arguments Python déclenchent toujours la création d'un nouveau graphe, d'où le traçage supplémentaire.

Prochaines étapes

Vous pouvez en savoir plus sur tf.function sur la page de référence de l' API et en suivant la Meilleure performance avec tf.function guide.