Rejoignez la communauté SIG TFX-Addons et contribuez à rendre TFX encore meilleur ! Rejoignez SIG TFX-Addons

Premiers pas avec TensorFlow Transform

Ce guide présente les concepts de base de tf.Transform et comment les utiliser. Ce sera:

  • Définissez une fonction de prétraitement , une description logique du pipeline qui transforme les données brutes en données utilisées pour entraîner un modèle d'apprentissage automatique.
  • Affichez l'implémentation d' Apache Beam utilisée pour transformer les données en convertissant la fonction de prétraitement en pipeline Beam .
  • Afficher des exemples d'utilisation supplémentaires.

Définir une fonction de prétraitement

La fonction de prétraitement est le concept le plus important de tf.Transform . La fonction de prétraitement est une description logique d'une transformation de l'ensemble de données. La fonction de prétraitement accepte et renvoie un dictionnaire de tenseurs, où un tenseur signifie Tensor ou SparseTensor . Il existe deux types de fonctions utilisées pour définir la fonction de prétraitement:

  1. Toute fonction qui accepte et renvoie des tenseurs. Celles-ci ajoutent des opérations TensorFlow au graphique qui transforment les données brutes en données transformées.
  2. L'un des analyseurs fournis par tf.Transform . Les analyseurs acceptent et renvoient également des tenseurs, mais contrairement aux fonctions TensorFlow, ils n'ajoutent pas d' opérations au graphique. Au lieu de cela, les analyseurs tf.Transform à calculer une opération passe-plein en dehors de TensorFlow. Ils utilisent les valeurs de tenseur d'entrée sur l'ensemble de données pour générer un tenseur constant qui est renvoyé en sortie. Par exemple, tft.min calcule le minimum d'un tenseur sur l'ensemble de données. tf.Transform fournit un ensemble fixe d'analyseurs, mais cela sera étendu dans les versions futures.

Exemple de fonction de prétraitement

En combinant des analyseurs et des fonctions TensorFlow standard, les utilisateurs peuvent créer des pipelines flexibles pour transformer les données. La fonction de prétraitement suivante transforme chacune des trois fonctionnalités de différentes manières et combine deux des fonctionnalités:

import tensorflow as tf
import tensorflow_transform as tft
import tensorflow_transform.beam as tft_beam

def preprocessing_fn(inputs):
  x = inputs['x']
  y = inputs['y']
  s = inputs['s']
  x_centered = x - tft.mean(x)
  y_normalized = tft.scale_to_0_1(y)
  s_integerized = tft.compute_and_apply_vocabulary(s)
  x_centered_times_y_normalized = x_centered * y_normalized
  return {
      'x_centered': x_centered,
      'y_normalized': y_normalized,
      'x_centered_times_y_normalized': x_centered_times_y_normalized,
      's_integerized': s_integerized
  }

Ici, x , y et s sont des Tensor s qui représentent des caractéristiques d'entrée. Le premier nouveau tenseur créé, x_centered , est construit en appliquant tft.mean à x et en le soustrayant de x . tft.mean(x) renvoie un tenseur représentant la moyenne du tenseur x . x_centered est le tenseur x avec la moyenne soustraite.

Le deuxième nouveau tenseur, y_normalized , est créé de la même manière mais en utilisant la méthode de commodité tft.scale_to_0_1 . Cette méthode fait quelque chose de similaire au calcul de x_centered , à savoir calculer un maximum et un minimum et les utiliser pour mettre à l'échelle y .

Le tenseur s_integerized montre un exemple de manipulation de chaîne. Dans ce cas, nous prenons une chaîne et la mappons à un entier. Cela utilise la fonction de commoditétft.compute_and_apply_vocabulary . Cette fonction utilise un analyseur pour calculer les valeurs uniques prises par les chaînes d'entrée, puis utilise les opérations TensorFlow pour convertir les chaînes d'entrée en indices dans la table de valeurs uniques.

La dernière colonne montre qu'il est possible d'utiliser les opérations TensorFlow pour créer de nouvelles entités en combinant des tenseurs.

La fonction de prétraitement définit un pipeline d'opérations sur un ensemble de données. Pour appliquer le pipeline, nous nous appuyons sur une implémentation concrète de l'API tf.Transform . L'implémentation Apache Beam fournit PTransform qui applique la fonction de prétraitement d'un utilisateur aux données. Le flux de travail typique d'un utilisateur tf.Transform construira une fonction de prétraitement, puis l'incorporera dans un pipeline Beam plus grand, créant les données pour la formation.

Traitement par lots

Le traitement par lots est une partie importante de TensorFlow. Étant donné que l'un des objectifs de tf.Transform est de fournir un graphe TensorFlow pour le prétraitement qui peut être incorporé dans le graphe de tf.Transform (et, éventuellement, le graphe d'apprentissage), le traitement par lots est également un concept important dans tf.Transform .

Bien que cela ne soit pas évident dans l'exemple ci-dessus, la fonction de prétraitement définie par l'utilisateur reçoit des tenseurs représentant des lots et non des instances individuelles, comme cela se produit pendant l'entraînement et la diffusion avec TensorFlow. D'autre part, les analyseurs effectuent un calcul sur l'ensemble de l'ensemble de données qui renvoie une valeur unique et non un lot de valeurs. x est un Tensor avec une forme de (batch_size,) , tandis que tft.mean(x) est un Tensor avec une forme de () . La soustraction x - tft.mean(x) diffuse où la valeur de tft.mean(x) est soustraite de chaque élément du lot représenté par x .

Implémentation Apache Beam

Alors que la fonction de prétraitement est conçue comme une description logique d'un pipeline de prétraitement implémenté sur plusieurs frameworks de traitement de données, tf.Transform fournit une implémentation canonique utilisée sur Apache Beam. Cette implémentation démontre la fonctionnalité requise d'une implémentation. Il n'y a pas d'API formelle pour cette fonctionnalité, de sorte que chaque implémentation peut utiliser une API qui est idiomatique pour son cadre de traitement de données particulier.

L'implémentation Apache Beam fournit deux PTransform utilisés pour traiter les données pour une fonction de prétraitement. Ce qui suit montre l'utilisation du composite PTransform AnalyzeAndTransformDataset : PTransform AnalyzeAndTransformDataset shows the usage for the composite PTransform AnalyzeAndTransformDataset :

raw_data = [
    {'x': 1, 'y': 1, 's': 'hello'},
    {'x': 2, 'y': 2, 's': 'world'},
    {'x': 3, 'y': 3, 's': 'hello'}
]

raw_data_metadata = ...
transformed_dataset, transform_fn = (
    (raw_data, raw_data_metadata) | tft_beam.AnalyzeAndTransformDataset(
        preprocessing_fn))
transformed_data, transformed_metadata = transformed_dataset

Le contenu transformed_data est affiché ci-dessous et contient les colonnes transformées dans le même format que les données brutes. En particulier, les valeurs de s_integerized sont [0, 1, 0] ces valeurs dépendent de la façon dont les mots hello et world ont été mappés en nombres entiers, ce qui est déterministe. Pour la colonne x_centered , nous avons soustrait la moyenne de sorte que les valeurs de la colonne x , qui étaient [1.0, 2.0, 3.0] , sont devenues [-1.0, 0.0, 1.0] . De même, le reste des colonnes correspond à leurs valeurs attendues.

[{u's_integerized': 0,
  u'x_centered': -1.0,
  u'x_centered_times_y_normalized': -0.0,
  u'y_normalized': 0.0},
 {u's_integerized': 1,
  u'x_centered': 0.0,
  u'x_centered_times_y_normalized': 0.0,
  u'y_normalized': 0.5},
 {u's_integerized': 0,
  u'x_centered': 1.0,
  u'x_centered_times_y_normalized': 1.0,
  u'y_normalized': 1.0}]

Les données raw_data et transformed_data sont des ensembles de données. Les deux sections suivantes montrent comment l'implémentation Beam représente les ensembles de données et comment lire et écrire des données sur le disque. L'autre valeur de retour, transform_fn , représente la transformation appliquée aux données, décrite en détail ci-dessous.

AnalyzeAndTransformDataset est la composition des deux transformations fondamentales fournies par l'implémentation AnalyzeDataset et TransformDataset . Les deux extraits de code suivants sont donc équivalents:

transformed_data, transform_fn = (
    my_data | tft_beam.AnalyzeAndTransformDataset(preprocessing_fn))
transform_fn = my_data | tft_beam.AnalyzeDataset(preprocessing_fn)
transformed_data = (my_data, transform_fn) | tft_beam.TransformDataset()

transform_fn est une fonction pure qui représente une opération appliquée à chaque ligne de l'ensemble de données. En particulier, les valeurs de l'analyseur sont déjà calculées et traitées comme des constantes. Dans l'exemple, le transform_fn contient comme constantes la moyenne de la colonne x , les min et max de la colonne y et le vocabulaire utilisé pour mapper les chaînes sur des entiers.

Une caractéristique importante de tf.Transform est que transform_fn représente une carte sur des lignes - c'est une fonction pure appliquée à chaque ligne séparément. Tous les calculs d'agrégation des lignes sont effectués dans AnalyzeDataset . De plus, le transform_fn est représenté sous la forme d'un Graph TensorFlow qui peut être intégré dans le graphe de diffusion.

AnalyzeAndTransformDataset est fourni pour les optimisations dans ce cas particulier. Il s'agit du même modèle utilisé dans scikit-learn , fournissant les méthodes fit , transform et fit_transform .

Formats et schéma de données

L'implémentation TFT Beam accepte deux formats de données d'entrée différents. Le format "instance dict" (comme vu dans l'exemple ci-dessus et dans simple_example.py ) est un format intuitif et convient aux petits ensembles de données tandis que le format TFXIO ( Apache Arrow ) offre des performances améliorées et convient aux grands ensembles de données.

L'implémentation Beam indique dans quel format la PCollection d'entrée se trouverait par les "métadonnées" accompagnant la PCollection:

(raw_data, raw_data_metadata) | tft.AnalyzeDataset(...)
  • Si raw_data_metadata est un dataset_metadata.DatasetMetadata (voir ci-dessous, "Le format 'instance dict'"), alors raw_data devrait être au format "instance dict".
  • Si raw_data_metadata est un tfxio.TensorAdapterConfig (voir ci-dessous, section "Le format TFXIO"), alors raw_data devrait être au format TFXIO.

Le format "instance dict"

Dans les exemples de code précédents, le code définissant raw_data_metadata est omis. Les métadonnées contiennent le schéma qui définit la disposition des données afin qu'elles soient lues et écrites dans divers formats. Même le format en mémoire montré dans la dernière section n'est pas auto-descriptif et nécessite le schéma pour être interprété comme des tenseurs.

Voici la définition du schéma pour les données d'exemple:

from tensorflow_transform.tf_metadata import dataset_metadata
from tensorflow_transform.tf_metadata import schema_utils

raw_data_metadata = dataset_metadata.DatasetMetadata(
      schema_utils.schema_from_feature_spec({
        's': tf.io.FixedLenFeature([], tf.string),
        'y': tf.io.FixedLenFeature([], tf.float32),
        'x': tf.io.FixedLenFeature([], tf.float32),
    }))

Le proto de Schema contient les informations nécessaires pour analyser les données de leur format sur disque ou en mémoire, en tenseurs. Il est généralement construit en appelant schema_utils.schema_from_feature_spec avec des clés de fonction de mappage dict aux tf.io.FixedLenFeature , tf.io.VarLenFeature et tf.io.SparseFeature . Consultez la documentation de tf.parse_example pour plus de détails.

Ci-dessus, nous utilisons tf.io.FixedLenFeature pour indiquer que chaque fonctionnalité contient un nombre fixe de valeurs, dans ce cas une seule valeur scalaire. Étant tf.Transform que tf.Transform instances, le Tensor réel représentant l' tf.Transform aura la forme (None,) où la dimension inconnue est la dimension du lot.

Le format TFXIO

Avec ce format, les données devraient être contenues dans un pyarrow.RecordBatch . Pour les données tabulaires, notre implémentation Apache Beam accepte les Arrow RecordBatch qui se composent de colonnes des types suivants:

  • pa.list_(<primitive>) , où <primitive> est pa.int64() , pa.float32() pa.binary() ou pa.large_binary() .

  • pa.large_list(<primitive>)

Le jeu de données d'entrée de jouet que nous avons utilisé ci-dessus, lorsqu'il est représenté sous la forme d'un RecordBatch , ressemble à ce qui suit:

raw_data = [
    pa.record_batch([
        pa.array([[1], [2], [3]], pa.list_(pa.float32())),
        pa.array([[1], [2], [3]], pa.list_(pa.float32())),
        pa.array([['hello'], ['world'], ['hello']], pa.list_(pa.binary())),
    ], ['x', 'y', 's'])
]

Semblable à DatasetMetadata étant nécessaire pour accompagner le format "instance dict", un tfxio.TensorAdapterConfig est nécessaire pour accompagner les RecordBatch es. Il se compose du schéma Arrow des RecordBatch es et des TensorRepresentations pour déterminer de manière unique comment les colonnes de RecordBatch es peuvent être interprétées comme des TensorFlow Tensors (y compris mais sans s'y limiter tf.Tensor, tf.SparseTensor).

TensorRepresentations est un Dict[Text, TensorRepresentation] qui établit la relation entre un Tensor que preprocessing_fn accepte et les colonnes des RecordBatch es. Par example:

tensor_representation = {
    'x': text_format.Parse(
        """dense_tensor { column_name: "col1" shape { dim { size: 2 } } }"""
        schema_pb2.TensorRepresentation())
}

Signifie que les inputs['x'] dans preprocessing_fn doivent être un tf.Tensor dense, dont les valeurs proviennent d'une colonne de nom 'col1' dans l'entrée RecordBatch es, et sa forme (par lots) doit être [batch_size, 2] .

TensorRepresentation est un Protobuf défini dans les métadonnées TensorFlow .

Compatibilité avec TensorFlow

tf.Transform permet d'exporter le transform_fn ci-dessus en tant que TF 1.x ou TF 2.x SavedModel. Le comportement par défaut avant la version 0.30 exportait un SavedModel TF 1.x. À partir de la version 0.30 , le comportement par défaut est d'exporter un SavedModel TF 2.x à moins que les comportements TF 2.x ne soient explicitement désactivés (en appelant tf.compat.v1.disable_v2_behavior() par exemple).

Si vous utilisez des concepts TF 1.x tels que les Estimators et les Sessions , vous pouvez conserver le comportement précédent en transmettant force_tf_compat_v1=True à tft_beam.Context si vous utilisez tf.Transform comme bibliothèque autonome ou au composant Transform dans TFX.

Lors de l' exportation du transform_fn comme TF 2.x SavedModel, le preprocessing_fn devrait être traçable à l' aide tf.function . De plus, si vous exécutez votre pipeline à distance (par exemple avec DataflowRunner ), assurez-vous que preprocessing_fn et toutes les dépendances sont correctement empaquetées, comme décrit ici .

Les problèmes connus liés à l'utilisation de tf.Transform pour exporter un SavedModel TF 2.x sont documentés ici .

Entrée et sortie avec Apache Beam

Jusqu'à présent, nous avons vu des données d'entrée et de sortie dans des listes python (de RecordBatch es ou de dictionnaires d'instances). Il s'agit d'une simplification qui repose sur la capacité d'Apache Beam à travailler avec des listes ainsi que sur sa représentation principale des données, la PCollection .

Une PCollection est une représentation de données qui fait partie d'un pipeline Beam. Un pipeline Beam est formé en appliquant divers PTransform , y compris AnalyzeDataset et TransformDataset , et en exécutant le pipeline. Une PCollection n'est pas créée dans la mémoire du binaire principal, mais est plutôt distribuée entre les workers (bien que cette section utilise le mode d'exécution en mémoire).

Pré-conserves PCollection Sources ( TFXIO )

Le format RecordBatch que notre implémentation accepte est un format commun que d'autres bibliothèques TFX acceptent. Par conséquent, TFX propose des «sources» pratiques (alias TFXIO ) qui lisent des fichiers de différents formats sur le disque et produisent des RecordBatch et peuvent également donner TensorAdapterConfig , y compris des TensorRepresentations inférées.

Ces TFXIO se trouvent dans le package tfx_bsl ( tfx_bsl.public.tfxio ).

Exemple: ensemble de données «Revenu du recensement»

L'exemple suivant nécessite à la fois la lecture et l'écriture de données sur le disque et la représentation des données sous forme de PCollection (pas de liste), voir: census_example.py . Ci-dessous, nous montrons comment télécharger les données et exécuter cet exemple. L'ensemble de données «Census Income» est fourni par l' UCI Machine Learning Repository . Cet ensemble de données contient à la fois des données catégoriques et numériques.

Les données sont au format CSV, voici les deux premières lignes:

39, State-gov, 77516, Bachelors, 13, Never-married, Adm-clerical, Not-in-family, White, Male, 2174, 0, 40, United-States, <=50K
50, Self-emp-not-inc, 83311, Bachelors, 13, Married-civ-spouse, Exec-managerial, Husband, White, Male, 0, 0, 13, United-States, <=50K

Les colonnes de l'ensemble de données sont catégoriques ou numériques. Cet ensemble de données décrit un problème de classification: prédire la dernière colonne où l'individu gagne plus ou moins de 50K par an. Cependant, du point de vue de tf.Transform , cette étiquette n'est qu'une autre colonne catégorielle.

Nous utilisons un TFXIO pré- TFXIO , BeamRecordCsvTFXIO pour traduire les lignes CSV en RecordBatches . TFXIO nécessite deux informations importantes:

  • un schéma de métadonnées TensorFlow qui contient des informations de type et de forme sur chaque colonne CSV. TensorRepresentation sont une partie facultative du schéma; s'ils ne sont pas fournis (ce qui est le cas dans cet exemple), ils seront déduits des informations de type et de forme. On peut obtenir le schéma soit en utilisant une fonction d'assistance que nous fournissons pour traduire à partir des spécifications d'analyse TF (illustrées dans cet exemple), soit en exécutant la validation des données TensorFlow .

  • une liste de noms de colonnes, dans leur ordre d'apparition dans le fichier CSV. Notez que ces noms doivent correspondre aux noms de fonctionnalités dans le schéma.

Dans cet exemple, nous education-num l'absence de la fonction education-num . Cela signifie qu'il est représenté comme un tf.io.VarLenFeature dans le feature_spec, et en tant que tf.SparseTensor dans le preprocessing_fn . D'autres fonctionnalités deviendront des tf.Tensor du même nom dans le preprocessing_fn .

csv_tfxio = tfxio.BeamRecordCsvTFXIO(
    physical_format='text', column_names=ordered_columns, schema=SCHEMA)

record_batches = (
    p
    | 'ReadTrainData' >> textio.ReadFromText(train_data_file)
    | ...  # fix up csv lines
    | 'ToRecordBatches' >> csv_tfxio.BeamSource())

tensor_adapter_config = csv_tfxio.TensorAdapterConfig()

Notez que nous avons dû faire quelques corrections supplémentaires après la lecture des lignes CSV. Sinon, nous pourrions compter sur CsvTFXIO pour gérer à la fois la lecture des fichiers et la traduction en RecordBatch es:

csv_tfxio = tfxio.CsvTFXIO(train_data_file, column_name=ordered_columns,
                           schema=SCHEMA)
record_batches = p | 'TFXIORead' >> csv_tfxio.BeamSource()
tensor_adapter_config = csv_tfxio.TensorAdapterConfig()

Le prétraitement est similaire à l'exemple précédent, sauf que la fonction de prétraitement est générée par programme au lieu de spécifier manuellement chaque colonne. Dans la fonction de prétraitement ci-dessous, NUMERICAL_COLUMNS et CATEGORICAL_COLUMNS sont des listes contenant les noms des colonnes numériques et catégorielles:

def preprocessing_fn(inputs):
  """Preprocess input columns into transformed columns."""
  # Since we are modifying some features and leaving others unchanged, we
  # start by setting `outputs` to a copy of `inputs.
  outputs = inputs.copy()

  # Scale numeric columns to have range [0, 1].
  for key in NUMERIC_FEATURE_KEYS:
    outputs[key] = tft.scale_to_0_1(outputs[key])

  for key in OPTIONAL_NUMERIC_FEATURE_KEYS:
    # This is a SparseTensor because it is optional. Here we fill in a default
    # value when it is missing.
      sparse = tf.sparse.SparseTensor(outputs[key].indices, outputs[key].values,
                                      [outputs[key].dense_shape[0], 1])
      dense = tf.sparse.to_dense(sp_input=sparse, default_value=0.)
    # Reshaping from a batch of vectors of size 1 to a batch to scalars.
    dense = tf.squeeze(dense, axis=1)
    outputs[key] = tft.scale_to_0_1(dense)

  # For all categorical columns except the label column, we generate a
  # vocabulary but do not modify the feature.  This vocabulary is instead
  # used in the trainer, by means of a feature column, to convert the feature
  # from a string to an integer id.
  for key in CATEGORICAL_FEATURE_KEYS:
    tft.vocabulary(inputs[key], vocab_filename=key)

  # For the label column we provide the mapping from string to index.
  initializer = tf.lookup.KeyValueTensorInitializer(
      keys=['>50K', '<=50K'],
      values=tf.cast(tf.range(2), tf.int64),
      key_dtype=tf.string,
      value_dtype=tf.int64)
  table = tf.lookup.StaticHashTable(initializer, default_value=-1)

  outputs[LABEL_KEY] = table.lookup(outputs[LABEL_KEY])

  return outputs

Une différence par rapport à l'exemple précédent est que la colonne d'étiquette spécifie manuellement le mappage de la chaîne à un index. Ainsi, '>50' est mappé à 0 et '<=50K' est mappé à 1 car il est utile de savoir quel index dans le modèle entraîné correspond à quelle étiquette.

La variable record_batches représente une PCollection de pyarrow.RecordBatch es. Le tensor_adapter_config est donné par csv_tfxio , qui est déduit de SCHEMA (et finalement, dans cet exemple, des spécifications d'analyse TF).

La dernière étape consiste à écrire les données transformées sur le disque et a une forme similaire à la lecture des données brutes. Le schéma utilisé pour ce faire fait partie de la sortie de AnalyzeAndTransformDataset qui déduit un schéma pour les données de sortie. Le code à écrire sur le disque est indiqué ci-dessous. Le schéma fait partie des métadonnées mais utilise les deux de manière interchangeable dans l'API tf.Transform (c.-à-d. Passer les métadonnées à ExampleProtoCoder ). Sachez que cela écrit dans un format différent. Au lieu de textio.WriteToText , utilisez la prise en charge intégrée de Beam pour le format TFRecord et utilisez un codeur pour encoder les données comme Example protos. Il s'agit d'un meilleur format à utiliser pour la formation, comme indiqué dans la section suivante. transformed_eval_data_base fournit le nom de fichier de base pour les fragments individuels qui sont écrits.

transformed_data | "WriteTrainData" >> tfrecordio.WriteToTFRecord(
    transformed_eval_data_base,
    coder=tft.coders.ExampleProtoCoder(transformed_metadata))

En plus des données d'entraînement, transform_fn est également écrit avec les métadonnées:

_ = (
    transform_fn
    | 'WriteTransformFn' >> tft_beam.WriteTransformFn(working_dir))
transformed_metadata | 'WriteMetadata' >> tft_beam.WriteMetadata(
    transformed_metadata_file, pipeline=p)

Exécutez tout le pipeline Beam avec p.run().wait_until_finish() . Jusqu'à ce point, le pipeline Beam représente un calcul différé et distribué. Il fournit des instructions sur ce qui sera fait, mais les instructions n'ont pas été exécutées. Cet appel final exécute le pipeline spécifié.

Téléchargez le jeu de données du recensement

Téléchargez l'ensemble de données du recensement à l'aide des commandes shell suivantes:

  wget https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data
  wget https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test

Lors de l'exécution du script census_example.py , transmettez le répertoire contenant ces données comme premier argument. Le script crée un sous-répertoire temporaire pour ajouter les données prétraitées.

Intégration avec la formation TensorFlow

La dernière section de census_example.py montre comment les données prétraitées sont utilisées pour entraîner un modèle. Consultez la documentation Estimators pour plus de détails. La première étape consiste à construire un Estimator qui nécessite une description des colonnes prétraitées. Chaque colonne numérique est décrite comme une real_valued_column qui est un wrapper autour d'un vecteur dense avec une taille fixe ( 1 dans cet exemple). Chaque colonne catégorielle est mappée d'une chaîne à des entiers, puis est transmise à l' indicator_column . tft.TFTransformOutput est utilisé pour trouver le chemin du fichier de vocabulaire pour chaque fonctionnalité catégorielle.

real_valued_columns = [feature_column.real_valued_column(key)
                       for key in NUMERIC_FEATURE_KEYS]

one_hot_columns = [
    tf.feature_column.indicator_column(
        tf.feature_column.categorical_column_with_vocabulary_file(
            key=key,
            vocabulary_file=tf_transform_output.vocabulary_file_by_name(
                vocab_filename=key)))
    for key in CATEGORICAL_FEATURE_KEYS]

estimator = tf.estimator.LinearClassifier(real_valued_columns + one_hot_columns)

L'étape suivante consiste à créer un générateur pour générer la fonction d'entrée pour la formation et l'évaluation. Le diffère de la formation utilisée par tf.Learn car une spécification de fonctionnalité n'est pas requise pour analyser les données transformées. Utilisez plutôt les métadonnées des données transformées pour générer une spécification d'entité.

def _make_training_input_fn(tf_transform_output, transformed_examples,
                            batch_size):
  ...
  def input_fn():
    """Input function for training and eval."""
    dataset = tf.data.experimental.make_batched_features_dataset(
        ..., tf_transform_output.transformed_feature_spec(), ...)

    transformed_features = tf.compat.v1.data.make_one_shot_iterator(
        dataset).get_next()
    ...

  return input_fn

Le code restant est identique à l'utilisation de la classe Estimator . L'exemple contient également du code pour exporter le modèle au format SavedModel . Le modèle exporté peut être utilisé par Tensorflow Serving ou Cloud ML Engine .