Cette page a été traduite par l'API Cloud Translation.
Switch to English

tf.function

Version de TensorFlow 1 Afficher la source sur GitHub

Compile une fonction dans un graphe TensorFlow appelable.

Utilisé dans les cahiers

Utilisé dans le guide Utilisé dans les tutoriels

tf.function construit un appelable qui exécute un graphe TensorFlow ( tf.Graph ) créé en compilant trace les opérations TensorFlow dans func , exécutant effectivement func comme un graphe TensorFlow.

Exemple d'utilisation:

@tf.function
def f(x, y):
  return x ** 2 + y
x = tf.constant([2, 3])
y = tf.constant([3, -2])
f(x, y)
<tf.Tensor: ... numpy=array([7, 7], ...)>

Caractéristiques

func peut utiliser un flux de contrôle dépendant des données, y compris des instructions if , for , while break , continue et return :

@tf.function
def f(x):
  if tf.reduce_sum(x) > 0:
    return x * x
  else:
    return -x // 2
f(tf.constant(-2))
<tf.Tensor: ... numpy=1>

La fermeture de func peut inclure des objets tf.Tensor et tf.Variable :

@tf.function
def f():
  return x ** 2 + y
x = tf.constant([-2, -3])
y = tf.Variable([3, -2])
f()
<tf.Tensor: ... numpy=array([7, 7], ...)>

func peut également utiliser des opérations avec des effets secondaires, tels que tf.print , tf.Variable et autres:

v = tf.Variable(1)
@tf.function
def f(x):
  for i in tf.range(x):
    v.assign_add(i)
f(3)
v
<tf.Variable ... numpy=4>
l = []
@tf.function
def f(x):
  for i in x:
    l.append(i + 1)    # Caution! Will only happen once when tracing
f(tf.constant([1, 2, 3]))
l
[<tf.Tensor ...>]

À la place, utilisez des collections TensorFlow comme tf.TensorArray :

@tf.function
def f(x):
  ta = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
  for i in range(len(x)):
    ta = ta.write(i, x[i] + 1)
  return ta.stack()
f(tf.constant([1, 2, 3]))
<tf.Tensor: ..., numpy=array([2, 3, 4], ...)>

tf.function est polymorphe

En interne, tf.function peut créer plusieurs graphiques, pour prendre en charge des arguments avec différents types de données ou formes, car TensorFlow peut créer des graphiques plus efficaces, spécialisés sur les formes et les dtypes. tf.function traite également toute valeur Python pure comme des objets opaques et crée un graphique séparé pour chaque ensemble d'arguments Python qu'il rencontre.

Pour obtenir un graphe individuel, utilisez la méthode get_concrete_function de l'appelable créé par tf.function . Il peut être appelé avec les mêmes arguments que func et retourne un objet tf.Graph spécial:

@tf.function
def f(x):
  return x + 1
isinstance(f.get_concrete_function(1).graph, tf.Graph)
True
@tf.function
def f(x):
  return tf.abs(x)
f1 = f.get_concrete_function(1)
f2 = f.get_concrete_function(2)  # Slow - builds new graph
f1 is f2
False
f1 = f.get_concrete_function(tf.constant(1))
f2 = f.get_concrete_function(tf.constant(2))  # Fast - reuses f1
f1 is f2
True

Les arguments numériques Python ne doivent être utilisés que lorsqu'ils prennent peu de valeurs distinctes, telles que des hyperparamètres comme le nombre de couches dans un réseau neuronal.

Signatures d'entrée

Pour les arguments Tensor, tf.function instancie un graphique distinct pour chaque ensemble unique de formes d'entrée et de types de données. L'exemple ci-dessous crée deux graphiques distincts, chacun spécialisé dans une forme différente:

@tf.function
def f(x):
  return x + 1
vector = tf.constant([1.0, 1.0])
matrix = tf.constant([[3.0]])
f.get_concrete_function(vector) is f.get_concrete_function(matrix)
False

Une "signature d'entrée" peut être fournie en option à tf.function pour contrôler les graphiques tracés. La signature d'entrée spécifie la forme et le type de chaque argument Tensor de la fonction à l'aide d'un objet tf.TensorSpec . Des formes plus générales peuvent être utilisées. Ceci est utile pour éviter de créer plusieurs graphiques lorsque les tenseurs ont des formes dynamiques. Il restreint également la forme et le type de données des tenseurs qui peuvent être utilisés:

@tf.function(
    input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])