Types d'extensions

Voir sur TensorFlow.org Exécuter dans Google Colab Voir la source sur GitHubTélécharger le cahier

Installer

!pip install -q tf_nightly
import tensorflow as tf
import numpy as np
from typing import Tuple, List, Mapping, Union, Optional
import tempfile

Types d'extensions

Les types définis par l'utilisateur peuvent rendre les projets plus lisibles, modulaires et maintenables. Cependant, la plupart des API TensorFlow ont une prise en charge très limitée des types Python définis par l'utilisateur. Cela inclut à la fois les API de haut niveau (telles que Keras , tf.function , tf.SavedModel ) et les API de niveau inférieur (telles que tf.while_loop et tf.concat ). Les types d'extension TensorFlow peuvent être utilisés pour créer des types orientés objet définis par l'utilisateur qui fonctionnent de manière transparente avec les API de TensorFlow. Pour créer un type d'extension, définissez simplement une classe Python avec tf.experimental.ExtensionType comme base et utilisez des annotations de type pour spécifier le type de chaque champ.

class TensorGraph(tf.experimental.ExtensionType):
  """A collection of labeled nodes connected by weighted edges."""
  edge_weights: tf.Tensor               # shape=[num_nodes, num_nodes]
  node_labels: Mapping[str, tf.Tensor]  # shape=[num_nodes]; dtype=any

class MaskedTensor(tf.experimental.ExtensionType):
  """A tensor paired with a boolean mask, indicating which values are valid."""
  values: tf.Tensor
  mask: tf.Tensor       # shape=values.shape; false for missing/invalid values.

class CSRSparseMatrix(tf.experimental.ExtensionType):
  """Compressed sparse row matrix (https://en.wikipedia.org/wiki/Sparse_matrix)."""
  values: tf.Tensor     # shape=[num_nonzero]; dtype=any
  col_index: tf.Tensor  # shape=[num_nonzero]; dtype=int64
  row_index: tf.Tensor  # shape=[num_rows+1]; dtype=int64

La classe de base tf.experimental.ExtensionType fonctionne de la même manière que typing.NamedTuple et @dataclasses.dataclass de la bibliothèque Python standard. En particulier, il ajoute automatiquement un constructeur et des méthodes spéciales (telles que __repr__ et __eq__ ) basées sur les annotations de type de champ.

En règle générale, les types d'extension ont tendance à appartenir à l'une des deux catégories suivantes :

  • Structures de données , qui regroupent une collection de valeurs associées et peuvent fournir des opérations utiles basées sur ces valeurs. Les structures de données peuvent être assez générales (comme l'exemple TensorGraph ci-dessus) ; ou ils peuvent être hautement personnalisés pour un modèle spécifique.

  • Types de type Tensor , qui spécialisent ou étendent le concept de "Tensor". Les types de cette catégorie ont un rank , une shape et généralement un dtype ; et il est logique de les utiliser avec des opérations Tensor (telles que tf.stack , tf.add ou tf.matmul ). MaskedTensor et CSRSparseMatrix sont des exemples de types de type tenseur.

API prises en charge

Les types d'extension sont compatibles avec les API TensorFlow suivantes :

  • Keras : les types d'extension peuvent être utilisés comme entrées et sorties pour les Models et Layers Keras.
  • tf.data.Dataset : les types d'extension peuvent être inclus dans des Datasets de données et renvoyés par des Iterators d'ensemble de données.
  • Hub Tensorflow : les types d'extension peuvent être utilisés comme entrées et sorties pour les modules tf.hub .
  • SavedModel : les types d'extension peuvent être utilisés comme entrées et sorties pour les fonctions SavedModel .
  • tf.function : les types d'extension peuvent être utilisés comme arguments et valeurs de retour pour les fonctions enveloppées avec le décorateur @tf.function .
  • boucles while : les types d'extension peuvent être utilisés comme variables de boucle dans tf.while_loop , et peuvent être utilisés comme arguments et valeurs de retour pour le corps de la boucle while.
  • conditionals : les types d'extension peuvent être sélectionnés de manière conditionnelle à l'aide tf.cond et tf.case .
  • py_function : les types d'extension peuvent être utilisés comme arguments et valeurs de retour pour l'argument func de tf.py_function .
  • Opérations Tensor : les types d'extension peuvent être étendus pour prendre en charge la plupart des opérations TensorFlow qui acceptent les entrées Tensor (par exemple, tf.matmul , tf.gather et tf.reduce_sum ). Voir la section " Expédition " ci-dessous pour plus d'informations.
  • stratégie de distribution : les types d'extension peuvent être utilisés comme valeurs par réplica.

Pour plus de détails, consultez la section "API TensorFlow prenant en charge les ExtensionTypes" ci-dessous.

Exigences

Types de champs

Tous les champs (ou variables d'instance) doivent être déclarés et une annotation de type doit être fournie pour chaque champ. Les annotations de type suivantes sont prises en charge :

Taper Exemple
Entiers Python i: int
Flotteurs Python f: float
Chaînes Python s: str
Booléens Python b: bool
Python Aucun n: None
Formes tenseurs shape: tf.TensorShape
Types de tenseur dtype: tf.DType
Tenseurs t: tf.Tensor
Types d'extensions mt: MyMaskedTensor
Tenseurs en lambeaux rt: tf.RaggedTensor
Tenseurs clairsemés st: tf.SparseTensor
Tranches indexées s: tf.IndexedSlices
Tenseurs optionnels o: tf.experimental.Optional
Tapez les unions int_or_float: typing.Union[int, float]
Tuples params: typing.Tuple[int, float, tf.Tensor, int]
Tuples de longueur var lengths: typing.Tuple[int, ...]
Mappages tags: typing.Mapping[str, tf.Tensor]
Valeurs facultatives weight: typing.Optional[tf.Tensor]

Mutabilité

Les types d'extension doivent être immuables. Cela garantit qu'ils peuvent être correctement suivis par les mécanismes de traçage de graphes de TensorFlow. Si vous souhaitez muter une valeur de type d'extension, envisagez plutôt de définir des méthodes qui transforment les valeurs. Par exemple, plutôt que de définir une méthode set_mask pour muter un MaskedTensor , vous pouvez définir une méthode replace_mask qui renvoie un nouveau MaskedTensor :

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  def replace_mask(self, new_mask):
      self.values.shape.assert_is_compatible_with(new_mask.shape)
      return MaskedTensor(self.values, new_mask)

Fonctionnalité ajoutée par ExtensionType

La classe de base ExtensionType fournit les fonctionnalités suivantes :

  • Un constructeur ( __init__ ).
  • Une méthode de représentation imprimable ( __repr__ ).
  • Opérateurs d'égalité et d'inégalité ( __eq__ ).
  • Une méthode de validation ( __validate__ ).
  • Immutabilité forcée.
  • Un TypeSpec imbriqué.
  • Prise en charge de l'envoi de l'API Tensor.

Consultez la section "Personnalisation des types d'extension" ci-dessous pour plus d'informations sur la personnalisation de cette fonctionnalité.

Constructeur

Le constructeur ajouté par ExtensionType prend chaque champ comme argument nommé (dans l'ordre dans lequel ils ont été répertoriés dans la définition de classe). Ce constructeur vérifiera le type de chaque paramètre et les convertira si nécessaire. En particulier, les champs Tensor sont convertis à l'aide tf.convert_to_tensor ; Les champs de Tuple sont convertis en tuple s ; et les champs de Mapping sont convertis en dicts immuables.

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

# Constructor takes one parameter for each field.
mt = MaskedTensor(values=[[1, 2, 3], [4, 5, 6]],
                  mask=[[True, True, False], [True, False, True]])

# Fields are type-checked and converted to the declared types.
# E.g., mt.values is converted to a Tensor.
print(mt.values)
tf.Tensor(
[[1 2 3]
 [4 5 6]], shape=(2, 3), dtype=int32)

Le constructeur lève une TypeError si une valeur de champ ne peut pas être convertie en son type déclaré :

try:
  MaskedTensor([1, 2, 3], None)
except TypeError as e:
  print(f"Got expected TypeError: {e}")
Got expected TypeError: mask: expected a Tensor, got None

La valeur par défaut d'un champ peut être spécifiée en définissant sa valeur au niveau de la classe :

class Pencil(tf.experimental.ExtensionType):
  color: str = "black"
  has_erasor: bool = True
  length: tf.Tensor = 1.0

Pencil()
Pencil(color='black', has_erasor=True, length=<tf.Tensor: shape=(), dtype=float32, numpy=1.0>)
Pencil(length=0.5, color="blue")
Pencil(color='blue', has_erasor=True, length=<tf.Tensor: shape=(), dtype=float32, numpy=0.5>)

Représentation imprimable

ExtensionType ajoute une méthode de représentation imprimable par défaut ( __repr__ ) qui inclut le nom de la classe et la valeur de chaque champ :

print(MaskedTensor(values=[1, 2, 3], mask=[True, True, False]))
MaskedTensor(values=<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], dtype=int32)>, mask=<tf.Tensor: shape=(3,), dtype=bool, numpy=array([ True,  True, False])>)

Opérateurs d'égalité

ExtensionType ajoute des opérateurs d'égalité par défaut ( __eq__ et __ne__ ) qui considèrent deux valeurs égales si elles ont le même type et que tous leurs champs sont égaux. Les champs tensoriels sont considérés comme égaux s'ils ont la même forme et sont égaux élément par élément pour tous les éléments.

a = MaskedTensor([1, 2], [True, False])
b = MaskedTensor([[3, 4], [5, 6]], [[False, True], [True, True]])
print(f"a == a: {a==a}")
print(f"a == b: {a==b}")
print(f"a == a.values: {a==a.values}")
a == a: True
a == b: False
a == a.values: False

Méthode de validation

ExtensionType ajoute une méthode __validate__ , qui peut être remplacée pour effectuer des contrôles de validation sur les champs. Il est exécuté après l'appel du constructeur et après que les champs ont été vérifiés et convertis en leurs types déclarés, de sorte qu'il peut supposer que tous les champs ont leurs types déclarés.

L'exemple suivant met à jour MaskedTensor pour valider les shape s et dtype s de ses champs :

class MaskedTensor(tf.experimental.ExtensionType):
  """A tensor paired with a boolean mask, indicating which values are valid."""
  values: tf.Tensor
  mask: tf.Tensor
  def __validate__(self):
    self.values.shape.assert_is_compatible_with(self.mask.shape)
    assert self.mask.dtype.is_bool, 'mask.dtype must be bool'
try:
  MaskedTensor([1, 2, 3], [0, 1, 0])  # wrong dtype for mask.
except AssertionError as e:
  print(f"Got expected AssertionError: {e}")
Got expected AssertionError: mask.dtype must be bool
try:
  MaskedTensor([1, 2, 3], [True, False])  # shapes don't match.
except ValueError as e:
  print(f"Got expected ValueError: {e}")
Got expected ValueError: Shapes (3,) and (2,) are incompatible

Immutabilité forcée

ExtensionType remplace les méthodes __setattr__ et __delattr__ pour empêcher la mutation, garantissant que les valeurs de type d'extension sont immuables.

mt = MaskedTensor([1, 2, 3], [True, False, True])
try:
  mt.mask = [True, True, True]
except AttributeError as e:
  print(f"Got expected AttributeError: {e}")
Got expected AttributeError: Cannot mutate attribute `mask` outside the custom constructor of ExtensionType.
try:
  mt.mask[0] = False
except TypeError as e:
  print(f"Got expected TypeError: {e}")
Got expected TypeError: 'tensorflow.python.framework.ops.EagerTensor' object does not support item assignment
try:
  del mt.mask
except AttributeError as e:
  print(f"Got expected AttributeError: {e}")
Got expected AttributeError: Cannot mutate attribute `mask` outside the custom constructor of ExtensionType.

TypeSpec imbriqué

Chaque classe ExtensionType a une classe TypeSpec correspondante, qui est créée automatiquement et stockée sous <extension_type_name>.Spec .

Cette classe capture toutes les informations d'une valeur, à l' exception des valeurs des tenseurs imbriqués. En particulier, le TypeSpec d'une valeur est créé en remplaçant tout Tensor, ExtensionType ou CompositeTensor imbriqué par son TypeSpec .

class Player(tf.experimental.ExtensionType):
  name: tf.Tensor
  attributes: Mapping[str, tf.Tensor]

anne = Player("Anne", {"height": 8.3, "speed": 28.1})
anne_spec = tf.type_spec_from_value(anne)
print(anne_spec.name)  # Records dtype and shape, but not the string value.
print(anne_spec.attributes)  # Records keys and TensorSpecs for values.
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class 'tensorflow.python.framework.immutable_dict.ImmutableDict'>
TensorSpec(shape=(), dtype=tf.string, name=None)
ImmutableDict({'height': TensorSpec(shape=(), dtype=tf.float32, name=None), 'speed': TensorSpec(shape=(), dtype=tf.float32, name=None)})

Les valeurs TypeSpec peuvent être construites explicitement, ou elles peuvent être construites à partir d'une valeur ExtensionType en utilisant tf.type_spec_from_value :

spec1 = Player.Spec(name=tf.TensorSpec([], tf.float32), attributes={})
spec2 = tf.type_spec_from_value(anne)

TypeSpec sont utilisées par TensorFlow pour diviser les valeurs en un composant statique et un composant dynamique :

  • Le composant statique (qui est fixé au moment de la construction du graphe) est encodé avec un tf.TypeSpec .
  • La composante dynamique (qui peut varier à chaque exécution du graphe) est encodée sous la forme d'une liste de tf.Tensor s.

Par exemple, tf.function retrace sa fonction encapsulée chaque fois qu'un argument a une TypeSpec inédite :

@tf.function
def anonymize_player(player):
  print("<<TRACING>>")
  return Player("<anonymous>", player.attributes)
# Function gets traced (first time the function has been called):
anonymize_player(Player("Anne", {"height": 8.3, "speed": 28.1}))
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class 'tensorflow.python.framework.immutable_dict.ImmutableDict'>
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class 'tensorflow.python.framework.immutable_dict.ImmutableDict'>
<<TRACING>>
Player(name=<tf.Tensor: shape=(), dtype=string, numpy=b'<anonymous>'>, attributes=ImmutableDict({'height': <tf.Tensor: shape=(), dtype=float32, numpy=8.3>, 'speed': <tf.Tensor: shape=(), dtype=float32, numpy=28.1>}))
# Function does NOT get traced (same TypeSpec: just tensor values changed)
anonymize_player(Player("Bart", {"height": 8.1, "speed": 25.3}))
Player(name=<tf.Tensor: shape=(), dtype=string, numpy=b'<anonymous>'>, attributes=ImmutableDict({'height': <tf.Tensor: shape=(), dtype=float32, numpy=8.1>, 'speed': <tf.Tensor: shape=(), dtype=float32, numpy=25.3>}))
# Function gets traced (new TypeSpec: keys for attributes changed):
anonymize_player(Player("Chuck", {"height": 11.0, "jump": 5.3}))
<<TRACING>>
Player(name=<tf.Tensor: shape=(), dtype=string, numpy=b'<anonymous>'>, attributes=ImmutableDict({'height': <tf.Tensor: shape=(), dtype=float32, numpy=11.0>, 'jump': <tf.Tensor: shape=(), dtype=float32, numpy=5.3>}))

Pour plus d'informations, reportez-vous au Guide des fonctions tf .

Personnalisation des types d'extension

En plus de simplement déclarer les champs et leurs types, les types d'extension peuvent :

  • Remplacer la représentation imprimable par défaut ( __repr__ ).
  • Définir les méthodes.
  • Définir les méthodes de classe et les méthodes statiques.
  • Définir les propriétés.
  • Remplacez le constructeur par défaut ( __init__ ).
  • Remplacez l'opérateur d'égalité par défaut ( __eq__ ).
  • Définissez des opérateurs (tels que __add__ et __lt__ ).
  • Déclarez les valeurs par défaut des champs.
  • Définir des sous-classes.

Remplacer la représentation imprimable par défaut

Vous pouvez remplacer cet opérateur de conversion de chaîne par défaut pour les types d'extension. L'exemple suivant met à jour la classe MaskedTensor pour générer une représentation sous forme de chaîne plus lisible lorsque les valeurs sont imprimées en mode Eager.

class MaskedTensor(tf.experimental.ExtensionType):
  """A tensor paired with a boolean mask, indicating which values are valid."""
  values: tf.Tensor
  mask: tf.Tensor       # shape=values.shape; false for invalid values.

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

def masked_tensor_str(values, mask):
  if isinstance(values, tf.Tensor):
    if hasattr(values, 'numpy') and hasattr(mask, 'numpy'):
      return f'<MaskedTensor {masked_tensor_str(values.numpy(), mask.numpy())}>'
    else:
      return f'MaskedTensor(values={values}, mask={mask})'
  if len(values.shape) == 1:
    items = [repr(v) if m else '_' for (v, m) in zip(values, mask)]
  else:
    items = [masked_tensor_str(v, m) for (v, m) in zip(values, mask)]
  return '[%s]' % ', '.join(items)

mt = MaskedTensor(values=[[1, 2, 3], [4, 5, 6]],
                  mask=[[True, True, False], [True, False, True]])
print(mt)
<MaskedTensor [[1, 2, _], [4, _, 6]]>

Définir des méthodes

Les types d'extension peuvent définir des méthodes, comme n'importe quelle classe Python normale. Par exemple, le type MaskedTensor pourrait définir une méthode with_default qui renvoie une copie de self avec des valeurs masquées remplacées par une valeur default donnée. Les méthodes peuvent éventuellement être annotées avec le décorateur @tf.function .

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  def with_default(self, default):
    return tf.where(self.mask, self.values, default)

MaskedTensor([1, 2, 3], [True, False, True]).with_default(0)
<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 0, 3], dtype=int32)>

Définition des méthodes de classe et des méthodes statiques

Les types d'extension peuvent définir des méthodes à l'aide des décorateurs @classmethod et @staticmethod . Par exemple, le type MaskedTensor pourrait définir une méthode de fabrique qui masque tout élément avec une valeur donnée :

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

  @staticmethod
  def from_tensor_and_value_to_mask(values, value_to_mask):
    return MaskedTensor(values, values == value_to_mask)

x = tf.constant([[1, 0, 2], [3, 0, 0]])
MaskedTensor.from_tensor_and_value_to_mask(x, 0)
<MaskedTensor [[_, 0, _], [_, 0, 0]]>

Définir les propriétés

Les types d'extension peuvent définir des propriétés à l'aide du décorateur @property , comme n'importe quelle classe Python normale. Par exemple, le type MaskedTensor pourrait définir une propriété dtype qui est un raccourci pour le dtype des valeurs :

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  @property
  def dtype(self):
    return self.values.dtype

MaskedTensor([1, 2, 3], [True, False, True]).dtype
tf.int32

Remplacer le constructeur par défaut

Vous pouvez remplacer le constructeur par défaut pour les types d'extension. Les constructeurs personnalisés doivent définir une valeur pour chaque champ déclaré ; et après le retour du constructeur personnalisé, tous les champs seront vérifiés et les valeurs seront converties comme décrit ci-dessus.

class Toy(tf.experimental.ExtensionType):
  name: str
  price: tf.Tensor
  def __init__(self, name, price, discount=0):
    self.name = name
    self.price = price * (1 - discount)

print(Toy("ball", 5.0, discount=0.2))  # On sale -- 20% off!
Toy(name='ball', price=<tf.Tensor: shape=(), dtype=float32, numpy=4.0>)

Vous pouvez également envisager de laisser le constructeur par défaut tel quel, mais d'ajouter une ou plusieurs méthodes de fabrique. Par exemple:

class Toy(tf.experimental.ExtensionType):
  name: str
  price: tf.Tensor

  @staticmethod
  def new_toy_with_discount(name, price, discount):
    return Toy(name, price * (1 - discount))

print(Toy.new_toy_with_discount("ball", 5.0, discount=0.2))
Toy(name='ball', price=<tf.Tensor: shape=(), dtype=float32, numpy=4.0>)

Remplacement de l'opérateur d'égalité par défaut ( __eq__ )

Vous pouvez remplacer l'opérateur __eq__ par défaut pour les types d'extension. L'exemple suivant met à jour MaskedTensor pour ignorer les éléments masqués lors de la comparaison d'égalité.

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

  def __eq__(self, other):
    result = tf.math.equal(self.values, other.values)
    result = result | ~(self.mask & other.mask)
    return tf.reduce_all(result)

x = MaskedTensor([1, 2, 3, 4], [True, True, False, True])
y = MaskedTensor([5, 2, 0, 4], [False, True, False, True])
print(x == y)
tf.Tensor(True, shape=(), dtype=bool)

Utiliser des références directes

Si le type d'un champ n'a pas encore été défini, vous pouvez utiliser une chaîne contenant le nom du type à la place. Dans l'exemple suivant, la chaîne "Node" est utilisée pour annoter le champ children car le type de Node n'a pas encore été (entièrement) défini.

class Node(tf.experimental.ExtensionType):
  value: tf.Tensor
  children: Tuple["Node", ...] = ()

Node(3, [Node(5), Node(2)])
Node(value=<tf.Tensor: shape=(), dtype=int32, numpy=3>, children=(Node(value=<tf.Tensor: shape=(), dtype=int32, numpy=5>, children=()), Node(value=<tf.Tensor: shape=(), dtype=int32, numpy=2>, children=())))

Définition des sous-classes

Les types d'extension peuvent être sous-classés en utilisant la syntaxe Python standard. Les sous-classes de type d'extension peuvent ajouter de nouveaux champs, méthodes et propriétés ; et peut remplacer le constructeur, la représentation imprimable et l'opérateur d'égalité. L'exemple suivant définit une classe TensorGraph de base qui utilise trois champs Tensor pour encoder un ensemble d'arêtes entre les nœuds. Il définit ensuite une sous-classe qui ajoute un champ Tensor pour enregistrer une "valeur de caractéristique" pour chaque nœud. La sous-classe définit également une méthode pour propager les valeurs des caractéristiques le long des bords.

class TensorGraph(tf.experimental.ExtensionType):
  num_nodes: tf.Tensor
  edge_src: tf.Tensor   # edge_src[e] = index of src node for edge e.
  edge_dst: tf.Tensor   # edge_dst[e] = index of dst node for edge e.

class TensorGraphWithNodeFeature(TensorGraph):
  node_features: tf.Tensor  # node_features[n] = feature value for node n.

  def propagate_features(self, weight=1.0) -> 'TensorGraphWithNodeFeature':
    updates = tf.gather(self.node_features, self.edge_src) * weight
    new_node_features = tf.tensor_scatter_nd_add(
        self.node_features, tf.expand_dims(self.edge_dst, 1), updates)
    return TensorGraphWithNodeFeature(
        self.num_nodes, self.edge_src, self.edge_dst, new_node_features)

g = TensorGraphWithNodeFeature(  # Edges: 0->1, 4->3, 2->2, 2->1
    num_nodes=5, edge_src=[0, 4, 2, 2], edge_dst=[1, 3, 2, 1],
    node_features=[10.0, 0.0, 2.0, 5.0, -1.0, 0.0])

print("Original features:", g.node_features)
print("After propagating:", g.propagate_features().node_features)
Original features: tf.Tensor([10.  0.  2.  5. -1.  0.], shape=(6,), dtype=float32)
After propagating: tf.Tensor([10. 12.  4.  4. -1.  0.], shape=(6,), dtype=float32)

Définir des champs privés

Les champs d'un type d'extension peuvent être marqués comme privés en les préfixant d'un trait de soulignement (suivant les conventions Python standard). Cela n'a aucune incidence sur la manière dont TensorFlow traite les champs ; mais sert simplement de signal à tous les utilisateurs du type d'extension que ces champs sont privés.

Personnalisation de TypeSpec d'ExtensionType

Chaque classe ExtensionType a une classe TypeSpec correspondante, qui est créée automatiquement et stockée sous <extension_type_name>.Spec . Pour plus d'informations, consultez la section "TypeSpec imbriqué" ci-dessus.

Pour personnaliser le TypeSpec , définissez simplement votre propre classe imbriquée nommée Spec , et ExtensionType l'utilisera comme base pour le TypeSpec construit automatiquement. Vous pouvez personnaliser la classe Spec en :

  • Remplacement de la représentation imprimable par défaut.
  • Remplacement du constructeur par défaut.
  • Définir des méthodes, des méthodes de classe, des méthodes statiques et des propriétés.

L'exemple suivant personnalise la classe MaskedTensor.Spec pour en faciliter l'utilisation :

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  shape = property(lambda self: self.values.shape)
  dtype = property(lambda self: self.values.dtype)

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

  def with_values(self, new_values):
    return MaskedTensor(new_values, self.mask)

  class Spec:
    def __init__(self, shape, dtype=tf.float32):
      self.values = tf.TensorSpec(shape, dtype)
      self.mask = tf.TensorSpec(shape, tf.bool)

    def __repr__(self):
      return f"MaskedTensor.Spec(shape={self.shape}, dtype={self.dtype})"

    shape = property(lambda self: self.values.shape)
    dtype = property(lambda self: self.values.dtype)

Envoi de l'API Tensor

Les types d'extension peuvent être "de type tenseur", dans le sens où ils spécialisent ou étendent l'interface définie par le type tf.Tensor . Des exemples de types d'extensions de type tenseur incluent RaggedTensor , SparseTensor et MaskedTensor . Les décorateurs de répartition peuvent être utilisés pour remplacer le comportement par défaut des opérations TensorFlow lorsqu'elles sont appliquées à des types d'extension de type tenseur. TensorFlow définit actuellement trois décorateurs de répartition :

Dispatch pour une seule API

Le décorateur tf.experimental.dispatch_for_api remplace le comportement par défaut d'une opération TensorFlow spécifiée lorsqu'elle est appelée avec la signature spécifiée. Par exemple, vous pouvez utiliser ce décorateur pour spécifier comment tf.stack doit traiter les valeurs MaskedTensor :

@tf.experimental.dispatch_for_api(tf.stack)
def masked_stack(values: List[MaskedTensor], axis = 0):
  return MaskedTensor(tf.stack([v.values for v in values], axis),
                      tf.stack([v.mask for v in values], axis))

Cela remplace l'implémentation par défaut de tf.stack chaque fois qu'il est appelé avec une liste de valeurs MaskedTensor (puisque l'argument values est annoté avec typing.List[MaskedTensor] ):

x = MaskedTensor([1, 2, 3], [True, True, False])
y = MaskedTensor([4, 5, 6], [False, True, True])
tf.stack([x, y])
<MaskedTensor [[1, 2, _], [_, 5, 6]]>

Pour permettre à tf.stack de gérer des listes de valeurs mixtes MaskedTensor et Tensor , vous pouvez affiner l'annotation de type pour le paramètre values et mettre à jour le corps de la fonction de manière appropriée :

tf.experimental.unregister_dispatch_for(masked_stack)

def convert_to_masked_tensor(x):
  if isinstance(x, MaskedTensor):
    return x
  else:
    return MaskedTensor(x, tf.ones_like(x, tf.bool))

@tf.experimental.dispatch_for_api(tf.stack)
def masked_stack_v2(values: List[Union[MaskedTensor, tf.Tensor]], axis = 0):
  values = [convert_to_masked_tensor(v) for v in values]
  return MaskedTensor(tf.stack([v.values for v in values], axis),
                      tf.stack([v.mask for v in values], axis))
x = MaskedTensor([1, 2, 3], [True, True, False])
y = tf.constant([4, 5, 6])
tf.stack([x, y, x])
<MaskedTensor [[1, 2, _], [4, 5, 6], [1, 2, _]]>

Pour obtenir une liste des API pouvant être remplacées, consultez la documentation de l'API pour tf.experimental.dispatch_for_api .

Dispatch pour toutes les API élémentaires unaires

Le décorateur tf.experimental.dispatch_for_unary_elementwise_apis remplace le comportement par défaut de toutes les opérations élémentaires unaires (telles que tf.math.cos ) chaque fois que la valeur du premier argument (généralement nommé x ) correspond à l'annotation de type x_type . La fonction décorée doit prendre deux arguments :

  • api_func : une fonction qui prend un seul paramètre et effectue l'opération élément par élément (par exemple, tf.abs ).
  • x : Le premier argument de l'opération élément par élément.

L'exemple suivant met à jour toutes les opérations élémentaires unaires pour gérer le type MaskedTensor :

@tf.experimental.dispatch_for_unary_elementwise_apis(MaskedTensor)
 def masked_tensor_unary_elementwise_api_handler(api_func, x):
   return MaskedTensor(api_func(x.values), x.mask)

Cette fonction sera désormais utilisée chaque fois qu'une opération élément par élément unaire est appelée sur un MaskedTensor .

x = MaskedTensor([1, -2, -3], [True, False, True])
 print(tf.abs(x))
<MaskedTensor [1, _, 3]>
print(tf.ones_like(x, dtype=tf.float32))
<MaskedTensor [1.0, _, 1.0]>

Envoi pour les API binaires toutes par éléments

De même, tf.experimental.dispatch_for_binary_elementwise_apis peut être utilisé pour mettre à jour toutes les opérations élémentaires binaires pour gérer le type MaskedTensor :

@tf.experimental.dispatch_for_binary_elementwise_apis(MaskedTensor, MaskedTensor)
def masked_tensor_binary_elementwise_api_handler(api_func, x, y):
  return MaskedTensor(api_func(x.values, y.values), x.mask & y.mask)
x = MaskedTensor([1, -2, -3], [True, False, True])
y = MaskedTensor([[4], [5]], [[True], [False]])
tf.math.add(x, y)
<MaskedTensor [[5, _, 1], [_, _, _]]>

Pour obtenir la liste des API élément par élément qui sont remplacées, consultez la documentation de l'API pour tf.experimental.dispatch_for_unary_elementwise_apis et tf.experimental.dispatch_for_binary_elementwise_apis .

Types d'extensions batchables

Un ExtensionType est batchable si une seule instance peut être utilisée pour représenter un lot de valeurs. En règle générale, cela se fait en ajoutant des dimensions de lot à tous les Tensor s imbriqués. Les API TensorFlow suivantes exigent que toutes les entrées de type d'extension puissent être traitées par lots :

Par défaut, BatchableExtensionType crée des valeurs par lots en regroupant tous les Tensor s, CompositeTensor s et ExtensionType s imbriqués. Si cela ne convient pas à votre classe, vous devrez utiliser tf.experimental.ExtensionTypeBatchEncoder pour remplacer ce comportement par défaut. Par exemple, il ne serait pas approprié de créer un lot de valeurs tf.SparseTensor en empilant simplement les champs values , indices et dense_shape individuels des tenseurs clairsemés -- dans la plupart des cas, vous ne pouvez pas empiler ces tenseurs, car ils ont des formes incompatibles ; et même si vous le pouviez, le résultat ne serait pas un SparseTensor valide.

Exemple BatchableExtensionType : Réseau

Par exemple, considérons une classe Network simple utilisée pour l'équilibrage de charge, qui suit la quantité de travail restant à faire à chaque nœud et la quantité de bande passante disponible pour déplacer le travail entre les nœuds :

class Network(tf.experimental.ExtensionType):  # This version is not batchable.
  work: tf.Tensor       # work[n] = work left to do at node n
  bandwidth: tf.Tensor  # bandwidth[n1, n2] = bandwidth from n1->n2

net1 = Network([5., 3, 8], [[0., 2, 0], [2, 0, 3], [0, 3, 0]])
net2 = Network([3., 4, 2], [[0., 2, 2], [2, 0, 2], [2, 2, 0]])

Pour rendre ce type utilisable par lot, remplacez le type de base par BatchableExtensionType et ajustez la forme de chaque champ pour inclure des dimensions de lot facultatives. L'exemple suivant ajoute également un champ de shape pour garder une trace de la forme du lot. Ce champ de shape n'est pas requis par tf.data.Dataset ou tf.map_fn , mais il est requis par tf.Keras .

class Network(tf.experimental.BatchableExtensionType):
  shape: tf.TensorShape  # batch shape.  A single network has shape=[].
  work: tf.Tensor        # work[*shape, n] = work left to do at node n
  bandwidth: tf.Tensor   # bandwidth[*shape, n1, n2] = bandwidth from n1->n2

  def __init__(self, work, bandwidth):
    self.work = tf.convert_to_tensor(work)
    self.bandwidth = tf.convert_to_tensor(bandwidth)
    work_batch_shape = self.work.shape[:-1]
    bandwidth_batch_shape = self.bandwidth.shape[:-2]
    self.shape = work_batch_shape.merge_with(bandwidth_batch_shape)

  def __repr__(self):
    return network_repr(self)

def network_repr(network):
  work = network.work
  bandwidth = network.bandwidth
  if hasattr(work, 'numpy'):
    work = ' '.join(str(work.numpy()).split())
  if hasattr(bandwidth, 'numpy'):
    bandwidth = ' '.join(str(bandwidth.numpy()).split())
  return (f"<Network shape={network.shape} work={work} bandwidth={bandwidth}>")
net1 = Network([5., 3, 8], [[0., 2, 0], [2, 0, 3], [0, 3, 0]])
net2 = Network([3., 4, 2], [[0., 2, 2], [2, 0, 2], [2, 2, 0]])
batch_of_networks = Network(
    work=tf.stack([net1.work, net2.work]),
    bandwidth=tf.stack([net1.bandwidth, net2.bandwidth]))
print(f"net1={net1}")
print(f"net2={net2}")
print(f"batch={batch_of_networks}")
net1=<Network shape=() work=[5. 3. 8.] bandwidth=[[0. 2. 0.] [2. 0. 3.] [0. 3. 0.]]>
net2=<Network shape=() work=[3. 4. 2.] bandwidth=[[0. 2. 2.] [2. 0. 2.] [2. 2. 0.]]>
batch=<Network shape=(2,) work=[[5. 3. 8.] [3. 4. 2.]] bandwidth=[[[0. 2. 0.] [2. 0. 3.] [0. 3. 0.]] [[0. 2. 2.] [2. 0. 2.] [2. 2. 0.]]]>

Vous pouvez ensuite utiliser tf.data.Dataset pour parcourir un lot de réseaux :

dataset = tf.data.Dataset.from_tensor_slices(batch_of_networks)
for i, network in enumerate(dataset):
  print(f"Batch element {i}: {network}")
Batch element 0: <Network shape=() work=[5. 3. 8.] bandwidth=[[0. 2. 0.] [2. 0. 3.] [0. 3. 0.]]>
Batch element 1: <Network shape=() work=[3. 4. 2.] bandwidth=[[0. 2. 2.] [2. 0. 2.] [2. 2. 0.]]>

Et vous pouvez également utiliser map_fn pour appliquer une fonction à chaque élément du lot :

def balance_work_greedy(network):
  delta = (tf.expand_dims(network.work, -1) - tf.expand_dims(network.work, -2))
  delta /= 4
  delta = tf.maximum(tf.minimum(delta, network.bandwidth), -network.bandwidth)
  new_work = network.work + tf.reduce_sum(delta, -1)
  return Network(new_work, network.bandwidth)

tf.map_fn(balance_work_greedy, batch_of_networks)
<Network shape=(2,) work=[[5.5 1.25 9.25] [3. 4.75 1.25]] bandwidth=[[[0. 2. 0.] [2. 0. 3.] [0. 3. 0.]] [[0. 2. 2.] [2. 0. 2.] [2. 2. 0.]]]>

API TensorFlow prenant en charge les ExtensionTypes

@tf.fonction

tf.function est un décorateur qui précalcule les graphiques TensorFlow pour les fonctions Python, ce qui peut considérablement améliorer les performances de votre code TensorFlow. Les valeurs de type d'extension peuvent être utilisées de manière transparente avec les fonctions @tf.function .

class Pastry(tf.experimental.ExtensionType):
  sweetness: tf.Tensor  # 2d embedding that encodes sweetness
  chewiness: tf.Tensor  # 2d embedding that encodes chewiness

@tf.function
def combine_pastry_features(x: Pastry):
  return (x.sweetness + x.chewiness) / 2

cookie = Pastry(sweetness=[1.2, 0.4], chewiness=[0.8, 0.2])
combine_pastry_features(cookie)
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([1. , 0.3], dtype=float32)>

Si vous souhaitez spécifier explicitement le input_signature pour tf.function , vous pouvez le faire en utilisant TypeSpec du type d'extension.

pastry_spec = Pastry.Spec(tf.TensorSpec([2]), tf.TensorSpec(2))

@tf.function(input_signature=[pastry_spec])
def increase_sweetness(x: Pastry, delta=1.0):
  return Pastry(x.sweetness + delta, x.chewiness)

increase_sweetness(cookie)
Pastry(sweetness=<tf.Tensor: shape=(2,), dtype=float32, numpy=array([2.2, 1.4], dtype=float32)>, chewiness=<tf.Tensor: shape=(2,), dtype=float32, numpy=array([0.8, 0.2], dtype=float32)>)

Fonctions concrètes

Les fonctions concrètes encapsulent des graphes tracés individuels qui sont construits par tf.function . Les types d'extension peuvent être utilisés de manière transparente avec des fonctions concrètes.

cf = combine_pastry_features.get_concrete_function(pastry_spec)
cf(cookie)
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([1. , 0.3], dtype=float32)>

Opérations de flux de contrôle

Les types d'extension sont compatibles avec les opérations de flux de contrôle de TensorFlow :

# Example: using tf.cond to select between two MaskedTensors.  Note that the
# two MaskedTensors don't need to have the same shape.
a = MaskedTensor([1., 2, 3], [True, False, True])
b = MaskedTensor([22., 33, 108, 55], [True, True, True, False])
condition = tf.constant(True)
print(tf.cond(condition, lambda: a, lambda: b))
<MaskedTensor [1.0, _, 3.0]>
# Example: using tf.while_loop with MaskedTensor.
cond = lambda i, _: i < 10
def body(i, mt):
  return i + 1, mt.with_values(mt.values + 3 / 7)
print(tf.while_loop(cond, body, [0, b])[1])
<MaskedTensor [26.285717, 37.285698, 112.285736, _]>

Flux de contrôle des autographes

Les types d'extension sont également pris en charge par les instructions de flux de contrôle dans tf.function (en utilisant autograph). Dans l'exemple suivant, les instructions if et for sont automatiquement converties en opérations tf.cond et tf.while_loop , qui prennent en charge les types d'extension.

@tf.function
def fn(x, b):
  if b:
    x = MaskedTensor(x, tf.less(x, 0))
  else:
    x = MaskedTensor(x, tf.greater(x, 0))
  for i in tf.range(5 if b else 7):
    x = x.with_values(x.values + 1 / 2)
  return x

print(fn(tf.constant([1., -2, 3]), tf.constant(True)))
print(fn(tf.constant([1., -2, 3]), tf.constant(False)))
<MaskedTensor [_, 0.5, _]>
<MaskedTensor [4.5, _, 6.5]>

Keras

tf.keras est l'API de haut niveau de TensorFlow pour la création et la formation de modèles d'apprentissage en profondeur. Les types d'extension peuvent être transmis en tant qu'entrées à un modèle Keras, transmis entre les couches Keras et renvoyés par les modèles Keras. Keras impose actuellement deux exigences aux types d'extension :

  • Ils doivent être batchables (voir "Batchable ExtensionTypes" ci-dessus).
  • Le doit avoir un champ ou une propriété nommé shape . shape[0] est supposé être la dimension du lot.

Les deux sous-sections suivantes donnent des exemples montrant comment les types d'extension peuvent être utilisés avec Keras.

Exemple Keras : Network

Pour le premier exemple, considérez la classe Network définie dans la section "Batchable ExtensionTypes" ci-dessus, qui peut être utilisée pour le travail d'équilibrage de charge entre les nœuds. Sa définition est reprise ici :

class Network(tf.experimental.BatchableExtensionType):
  shape: tf.TensorShape  # batch shape.  A single network has shape=[].
  work: tf.Tensor        # work[*shape, n] = work left to do at node n
  bandwidth: tf.Tensor   # bandwidth[*shape, n1, n2] = bandwidth from n1->n2

  def __init__(self, work, bandwidth):
    self.work = tf.convert_to_tensor(work)
    self.bandwidth = tf.convert_to_tensor(bandwidth)
    work_batch_shape = self.work.shape[:-1]
    bandwidth_batch_shape = self.bandwidth.shape[:-2]
    self.shape = work_batch_shape.merge_with(bandwidth_batch_shape)

  def __repr__(self):
    return network_repr(self)
single_network = Network(  # A single network w/ 4 nodes.
    work=[8.0, 5, 12, 2],
    bandwidth=[[0.0, 1, 2, 2], [1, 0, 0, 2], [2, 0, 0, 1], [2, 2, 1, 0]])

batch_of_networks = Network(  # Batch of 2 networks, each w/ 2 nodes.
    work=[[8.0, 5], [3, 2]],
    bandwidth=[[[0.0, 1], [1, 0]], [[0, 2], [2, 0]]])

Vous pouvez définir une nouvelle couche Keras qui traite les Network s.

class BalanceNetworkLayer(tf.keras.layers.Layer):
  """Layer that balances work between nodes in a network.

  Shifts work from more busy nodes to less busy nodes, constrained by bandwidth.
  """
  def call(self, inputs):
    # This function is defined above, in "Batchable ExtensionTypes" section.
    return balance_work_greedy(inputs)

Vous pouvez ensuite utiliser ces calques pour créer un modèle simple. Pour alimenter un ExtensionType dans un modèle, vous pouvez utiliser une couche tf.keras.layer.Input avec type_spec défini sur TypeSpec du type d'extension. Si le modèle Keras sera utilisé pour traiter des lots, alors le type_spec doit inclure la dimension du lot.

input_spec = Network.Spec(shape=None,
                          work=tf.TensorSpec(None, tf.float32),
                          bandwidth=tf.TensorSpec(None, tf.float32))
model = tf.keras.Sequential([
    tf.keras.layers.Input(type_spec=input_spec),
    BalanceNetworkLayer(),
    ])

Enfin, vous pouvez appliquer le modèle à un seul réseau et à un lot de réseaux.

model(single_network)
<Network shape=() work=[ 9.25 5. 14. -1.25] bandwidth=[[0. 1. 2. 2.] [1. 0. 0. 2.] [2. 0. 0. 1.] [2. 2. 1. 0.]]>
model(batch_of_networks)
<Network shape=(2,) work=[[8.75 4.25] [3.25 1.75]] bandwidth=[[[0. 1.] [1. 0.]] [[0. 2.] [2. 0.]]]>

Exemple Keras : MaskedTensor

Dans cet exemple, MaskedTensor est étendu pour prendre en charge Keras . shape est définie comme une propriété calculée à partir du champ de values . Keras nécessite que vous ajoutiez cette propriété à la fois au type d'extension et à son TypeSpec . MaskedTensor définit également une variable __name__ , qui sera requise pour la sérialisation SavedModel (ci-dessous).

class MaskedTensor(tf.experimental.BatchableExtensionType):
  # __name__ is required for serialization in SavedModel; see below for details.
  __name__ = 'extension_type_colab.MaskedTensor'

  values: tf.Tensor
  mask: tf.Tensor

  shape = property(lambda self: self.values.shape)
  dtype = property(lambda self: self.values.dtype)

  def with_default(self, default):
    return tf.where(self.mask, self.values, default)

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

  class Spec:
    def __init__(self, shape, dtype=tf.float32):
      self.values = tf.TensorSpec(shape, dtype)
      self.mask = tf.TensorSpec(shape, tf.bool)

    shape = property(lambda self: self.values.shape)
    dtype = property(lambda self: self.values.dtype)

    def with_shape(self):
      return MaskedTensor.Spec(tf.TensorSpec(shape, self.values.dtype),
                               tf.TensorSpec(shape, self.mask.dtype))

Ensuite, les décorateurs de répartition sont utilisés pour remplacer le comportement par défaut de plusieurs API TensorFlow. Étant donné que ces API sont utilisées par les couches Keras standard (telles que la couche Dense ), leur remplacement nous permettra d'utiliser ces couches avec MaskedTensor . Pour les besoins de cet exemple, matmul pour les tenseurs masqués est défini pour traiter les valeurs masquées comme des zéros (c'est-à-dire pour ne pas les inclure dans le produit).

@tf.experimental.dispatch_for_unary_elementwise_apis(MaskedTensor)
def unary_elementwise_op_handler(op, x):
 return MaskedTensor(op(x.values), x.mask)

@tf.experimental.dispatch_for_binary_elementwise_apis(
    Union[MaskedTensor, tf.Tensor],
    Union[MaskedTensor, tf.Tensor])
def binary_elementwise_op_handler(op, x, y):
  x = convert_to_masked_tensor(x)
  y = convert_to_masked_tensor(y)
  return MaskedTensor(op(x.values, y.values), x.mask & y.mask)

@tf.experimental.dispatch_for_api(tf.matmul)
def masked_matmul(a: MaskedTensor, b,
                  transpose_a=False, transpose_b=False,
                  adjoint_a=False, adjoint_b=False,
                  a_is_sparse=False, b_is_sparse=False,
                  output_type=None):
  if isinstance(a, MaskedTensor):
    a = a.with_default(0)
  if isinstance(b, MaskedTensor):
    b = b.with_default(0)
  return tf.matmul(a, b, transpose_a, transpose_b, adjoint_a,
                   adjoint_b, a_is_sparse, b_is_sparse, output_type)

Vous pouvez ensuite construire un modèle Keras qui accepte les entrées MaskedTensor , en utilisant des couches Keras standard :

input_spec = MaskedTensor.Spec([None, 2], tf.float32)

masked_tensor_model = tf.keras.Sequential([
    tf.keras.layers.Input(type_spec=input_spec),
    tf.keras.layers.Dense(16, activation="relu"),
    tf.keras.layers.Dense(1)])
masked_tensor_model.compile(loss='binary_crossentropy', optimizer='rmsprop')
a = MaskedTensor([[1., 2], [3, 4], [5, 6]],
                  [[True, False], [False, True], [True, True]])
masked_tensor_model.fit(a, tf.constant([[1], [0], [1]]), epochs=3)
print(masked_tensor_model(a))
Epoch 1/3
1/1 [==============================] - 1s 955ms/step - loss: 10.2833
Epoch 2/3
1/1 [==============================] - 0s 5ms/step - loss: 10.2833
Epoch 3/3
1/1 [==============================] - 0s 5ms/step - loss: 10.2833
tf.Tensor(
[[-0.09944128]
 [-0.7225147 ]
 [-1.3020657 ]], shape=(3, 1), dtype=float32)

Modèle enregistré

Un SavedModel est un programme TensorFlow sérialisé, comprenant à la fois des pondérations et des calculs. Il peut être construit à partir d'un modèle Keras ou d'un modèle personnalisé. Dans les deux cas, les types d'extension peuvent être utilisés de manière transparente avec les fonctions et les méthodes définies par un SavedModel.

SavedModel peut enregistrer des modèles, des couches et des fonctions qui traitent des types d'extension, tant que les types d'extension ont un champ __name__ . Ce nom est utilisé pour enregistrer le type d'extension, afin qu'il puisse être localisé lorsque le modèle est chargé.

Exemple : enregistrement d'un modèle Keras

Les modèles Keras qui utilisent des types d'extension peuvent être enregistrés à l'aide SavedModel .

masked_tensor_model_path = tempfile.mkdtemp()
tf.saved_model.save(masked_tensor_model, masked_tensor_model_path)
imported_model = tf.saved_model.load(masked_tensor_model_path)
imported_model(a)
2021-11-06 01:25:14.285250: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
WARNING:absl:Function `_wrapped_model` contains input name(s) args_0 with unsupported characters which will be renamed to args_0_1 in the SavedModel.
INFO:tensorflow:Assets written to: /tmp/tmp3ceuupv9/assets
INFO:tensorflow:Assets written to: /tmp/tmp3ceuupv9/assets
<tf.Tensor: shape=(3, 1), dtype=float32, numpy=
array([[-0.09944128],
       [-0.7225147 ],
       [-1.3020657 ]], dtype=float32)>

Exemple : enregistrement d'un modèle personnalisé

SavedModel peut également être utilisé pour enregistrer des sous-classes tf.Module personnalisées avec des fonctions qui traitent les types d'extension.

class CustomModule(tf.Module):
  def __init__(self, variable_value):
    super().__init__()
    self.v = tf.Variable(variable_value)

  @tf.function
  def grow(self, x: MaskedTensor):
    """Increase values in `x` by multiplying them by `self.v`."""
    return MaskedTensor(x.values * self.v, x.mask)

module = CustomModule(100.0)

module.grow.get_concrete_function(MaskedTensor.Spec(shape=None,
                                                    dtype=tf.float32))
custom_module_path = tempfile.mkdtemp()
tf.saved_model.save(module, custom_module_path)
imported_model = tf.saved_model.load(custom_module_path)
imported_model.grow(MaskedTensor([1., 2, 3], [False, True, False]))
INFO:tensorflow:Assets written to: /tmp/tmp2x8zq5kb/assets
INFO:tensorflow:Assets written to: /tmp/tmp2x8zq5kb/assets
<MaskedTensor [_, 200.0, _]>

Chargement d'un SavedModel lorsque l'ExtensionType n'est pas disponible

Si vous chargez un SavedModel qui utilise un ExtensionType , mais que ExtensionType n'est pas disponible (c'est-à-dire qu'il n'a pas été importé), un avertissement s'affichera et TensorFlow reviendra à l'utilisation d'un objet "type d'extension anonyme". Cet objet aura les mêmes champs que le type d'origine, mais il n'y aura pas de personnalisation supplémentaire que vous avez ajoutée pour le type, comme des méthodes ou des propriétés personnalisées.

Utilisation d'ExtensionTypes avec le service TensorFlow

Actuellement, le service TensorFlow (et d'autres consommateurs du dictionnaire "signatures" SavedModel) nécessite que toutes les entrées et sorties soient des tenseurs bruts. Si vous souhaitez utiliser le service TensorFlow avec un modèle qui utilise des types d'extension, vous pouvez ajouter des méthodes wrapper qui composent ou décomposent les valeurs de type d'extension à partir de tenseurs. Par exemple:

class CustomModuleWrapper(tf.Module):
  def __init__(self, variable_value):
    super().__init__()
    self.v = tf.Variable(variable_value)

  @tf.function
  def var_weighted_mean(self, x: MaskedTensor):
    """Mean value of unmasked values in x, weighted by self.v."""
    x = MaskedTensor(x.values * self.v, x.mask)
    return (tf.reduce_sum(x.with_default(0)) /
            tf.reduce_sum(tf.cast(x.mask, x.dtype)))

  @tf.function()
  def var_weighted_mean_wrapper(self, x_values, x_mask):
    """Raw tensor wrapper for var_weighted_mean."""
    return self.var_weighted_mean(MaskedTensor(x_values, x_mask))

module = CustomModuleWrapper([3., 2., 8., 5.])

module.var_weighted_mean_wrapper.get_concrete_function(
    tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.bool))
custom_module_path = tempfile.mkdtemp()
tf.saved_model.save(module, custom_module_path)
imported_model = tf.saved_model.load(custom_module_path)
x = MaskedTensor([1., 2., 3., 4.], [False, True, False, True])
imported_model.var_weighted_mean_wrapper(x.values, x.mask)
INFO:tensorflow:Assets written to: /tmp/tmpxhh4zh0i/assets
INFO:tensorflow:Assets written to: /tmp/tmpxhh4zh0i/assets
<tf.Tensor: shape=(), dtype=float32, numpy=12.0>

Jeux de données

tf.data est une API qui vous permet de créer des pipelines d'entrée complexes à partir de pièces simples et réutilisables. Sa structure de données de base est tf.data.Dataset , qui représente une séquence d'éléments, dans laquelle chaque élément est constitué d'un ou plusieurs composants.

Créer des ensembles de données avec des types d'extension

Les ensembles de données peuvent être créés à partir de valeurs de type d'extension à l'aide Dataset.from_tensors , Dataset.from_tensor_slices ou Dataset.from_generator :

ds = tf.data.Dataset.from_tensors(Pastry(5, 5))
iter(ds).next()
Pastry(sweetness=<tf.Tensor: shape=(), dtype=int32, numpy=5>, chewiness=<tf.Tensor: shape=(), dtype=int32, numpy=5>)
mt = MaskedTensor(tf.reshape(range(20), [5, 4]), tf.ones([5, 4]))
ds = tf.data.Dataset.from_tensor_slices(mt)
for value in ds:
  print(value)
<MaskedTensor [0, 1, 2, 3]>
<MaskedTensor [4, 5, 6, 7]>
<MaskedTensor [8, 9, 10, 11]>
<MaskedTensor [12, 13, 14, 15]>
<MaskedTensor [16, 17, 18, 19]>
def value_gen():
  for i in range(2, 7):
    yield MaskedTensor(range(10), [j%i != 0 for j in range(10)])

ds = tf.data.Dataset.from_generator(
    value_gen, output_signature=MaskedTensor.Spec(shape=[10], dtype=tf.int32))
for value in ds:
  print(value)
<MaskedTensor [_, 1, _, 3, _, 5, _, 7, _, 9]>
<MaskedTensor [_, 1, 2, _, 4, 5, _, 7, 8, _]>
<MaskedTensor [_, 1, 2, 3, _, 5, 6, 7, _, 9]>
<MaskedTensor [_, 1, 2, 3, 4, _, 6, 7, 8, 9]>
<MaskedTensor [_, 1, 2, 3, 4, 5, _, 7, 8, 9]>

Regrouper et dissocier des ensembles de données avec des types d'extension

Les ensembles de données avec des types d'extension peuvent être groupés et non groupés à l'aide Dataset.batch et Dataset.unbatch .

batched_ds = ds.batch(2)
for value in batched_ds:
  print(value)
<MaskedTensor [[_, 1, _, 3, _, 5, _, 7, _, 9], [_, 1, 2, _, 4, 5, _, 7, 8, _]]>
<MaskedTensor [[_, 1, 2, 3, _, 5, 6, 7, _, 9], [_, 1, 2, 3, 4, _, 6, 7, 8, 9]]>
<MaskedTensor [[_, 1, 2, 3, 4, 5, _, 7, 8, 9]]>
unbatched_ds = batched_ds.unbatch()
for value in unbatched_ds:
  print(value)
<MaskedTensor [_, 1, _, 3, _, 5, _, 7, _, 9]>
<MaskedTensor [_, 1, 2, _, 4, 5, _, 7, 8, _]>
<MaskedTensor [_, 1, 2, 3, _, 5, 6, 7, _, 9]>
<MaskedTensor [_, 1, 2, 3, 4, _, 6, 7, 8, 9]>
<MaskedTensor [_, 1, 2, 3, 4, 5, _, 7, 8, 9]>