Ta strona została przetłumaczona przez Cloud Translation API.
Switch to English

Uczenie neuronowe oparte na grafach w TFX

W tym samouczku opisano regularyzację wykresów w ramach Neural Structured Learning i przedstawiono kompleksowy przepływ pracy dla klasyfikacji sentymentów w potoku TFX.

Przegląd

Ten notatnik klasyfikuje recenzje filmów jako pozytywne lub negatywne na podstawie tekstu recenzji. To jest przykład klasyfikacji binarnej , ważnego i szeroko stosowanego problemu uczenia maszynowego.

W tym notatniku zademonstrujemy użycie regularyzacji wykresów, budując wykres na podstawie podanych danych wejściowych. Ogólna recepta na zbudowanie modelu regularyzowanego grafem przy użyciu struktury uczenia się neuronowego (NSL), gdy dane wejściowe nie zawierają jawnego wykresu, jest następująca:

 1. Utwórz osadzenia dla każdego próbki tekstu w danych wejściowych. Można to zrobić za pomocą wstępnie wytrenowanych modeli, takich jak word2vec , Swivel , BERT itp.
 2. Zbuduj wykres w oparciu o te osadzenia, używając metryki podobieństwa, takiej jak odległość „L2”, odległość „cosinus” itp. Węzły na wykresie odpowiadają próbkom, a krawędzie na wykresie odpowiadają podobieństwu między parami próbek.
 3. Wygeneruj dane treningowe z powyższego zsyntetyzowanego wykresu i przykładowych funkcji. Wynikowe dane uczące będą zawierały sąsiadujące elementy oprócz oryginalnych funkcji węzła.
 4. Utwórz sieć neuronową jako model bazowy za pomocą estymatorów.
 5. Zawiń model podstawowy add_graph_regularization opakowania add_graph_regularization , która jest udostępniana przez strukturę NSL, aby utworzyć nowy model estymatora wykresów. Ten nowy model będzie zawierał wykres utraty regularyzacji jako termin regularyzacji w celu szkoleniowym.
 6. Trenuj i oceniaj model estymatora wykresów.

W tym samouczku zintegrujemy powyższy przepływ pracy w potoku TFX przy użyciu kilku niestandardowych składników TFX, a także niestandardowego składnika trenera regulowanego wykresem.

Poniżej znajduje się schemat naszego potoku TFX. Pomarańczowe pudełka reprezentują gotowe komponenty TFX, a różowe pudełka reprezentują niestandardowe komponenty TFX.

Rurociąg TFX

Upgrade Pip

Aby uniknąć aktualizacji Pipa w systemie podczas uruchamiania lokalnego, sprawdź, czy działamy w Colab. Systemy lokalne można oczywiście aktualizować oddzielnie.

try:
 import colab
 !pip install --upgrade pip
except:
 pass

Zainstaluj wymagane pakiety

!pip install -q -U \
 tfx==0.23.0 \
 neural-structured-learning \
 tensorflow-hub \
 tensorflow-datasets
ERROR: After October 2020 you may experience errors when installing or updating packages. This is because pip will change the way that it resolves dependency conflicts.

We recommend you use --use-feature=2020-resolver to test your packages with the new resolver before it becomes the default.

tensorflow-metadata 0.24.0 requires absl-py<0.11,>=0.9, but you'll have absl-py 0.8.1 which is incompatible.
apache-beam 2.24.0 requires dill<0.3.2,>=0.3.1.1, but you'll have dill 0.3.2 which is incompatible.
google-api-python-client 1.12.3 requires httplib2<1dev,>=0.15.0, but you'll have httplib2 0.9.2 which is incompatible.
tfx-bsl 0.23.0 requires tensorflow-metadata<0.24,>=0.23, but you'll have tensorflow-metadata 0.24.0 which is incompatible.
tensorflow-transform 0.23.0 requires tensorflow-metadata<0.24,>=0.23, but you'll have tensorflow-metadata 0.24.0 which is incompatible.
tensorflow-model-analysis 0.23.0 requires tensorflow-metadata<0.24,>=0.23, but you'll have tensorflow-metadata 0.24.0 which is incompatible.
tensorflow-data-validation 0.23.1 requires tensorflow-metadata<0.24,>=0.23, but you'll have tensorflow-metadata 0.24.0 which is incompatible.

Czy zrestartowałeś środowisko wykonawcze?

Jeśli używasz Google Colab, przy pierwszym uruchomieniu powyższej komórki musisz ponownie uruchomić środowisko wykonawcze (Środowisko wykonawcze> Uruchom ponownie środowisko wykonawcze ...). Wynika to ze sposobu, w jaki Colab ładuje paczki.

Zależności i import

import apache_beam as beam
import gzip as gzip_lib
import numpy as np
import os
import pprint
import shutil
import tempfile
import urllib
import uuid
pp = pprint.PrettyPrinter()

import tensorflow as tf
import neural_structured_learning as nsl

import tfx
from tfx.components.evaluator.component import Evaluator
from tfx.components.example_gen.import_example_gen.component import ImportExampleGen
from tfx.components.example_validator.component import ExampleValidator
from tfx.components.model_validator.component import ModelValidator
from tfx.components.pusher.component import Pusher
from tfx.components.schema_gen.component import SchemaGen
from tfx.components.statistics_gen.component import StatisticsGen
from tfx.components.trainer.component import Trainer
from tfx.components.transform.component import Transform
from tfx.orchestration.experimental.interactive.interactive_context import InteractiveContext
from tfx.proto import evaluator_pb2
from tfx.proto import example_gen_pb2
from tfx.proto import pusher_pb2
from tfx.proto import trainer_pb2
from tfx.utils.dsl_utils import external_input

from tfx.types import artifact
from tfx.types import artifact_utils
from tfx.types import channel
from tfx.types import standard_artifacts
from tfx.types.standard_artifacts import Examples

from tfx.dsl.component.experimental.annotations import InputArtifact
from tfx.dsl.component.experimental.annotations import OutputArtifact
from tfx.dsl.component.experimental.annotations import Parameter
from tfx.dsl.component.experimental.decorators import component

from tensorflow_metadata.proto.v0 import anomalies_pb2
from tensorflow_metadata.proto.v0 import schema_pb2
from tensorflow_metadata.proto.v0 import statistics_pb2

import tensorflow_data_validation as tfdv
import tensorflow_transform as tft
import tensorflow_model_analysis as tfma
import tensorflow_hub as hub
import tensorflow_datasets as tfds

print("TF Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print(
  "GPU is",
  "available" if tf.config.list_physical_devices("GPU") else "NOT AVAILABLE")
print("NSL Version: ", nsl.__version__)
print("TFX Version: ", tfx.__version__)
print("TFDV version: ", tfdv.__version__)
print("TFT version: ", tft.__version__)
print("TFMA version: ", tfma.__version__)
print("Hub version: ", hub.__version__)
print("Beam version: ", beam.__version__)
TF Version: 2.3.1
Eager mode: True
GPU is available
NSL Version: 1.3.1
TFX Version: 0.23.0
TFDV version: 0.23.1
TFT version: 0.23.0
TFMA version: 0.23.0
Hub version: 0.9.0
Beam version: 2.24.0

Zbiór danych IMDB

Zbiór danych IMDB zawiera tekst 50 000 recenzji filmów z internetowej bazy danych filmów . Są one podzielone na 25 000 recenzji do szkoleń i 25 000 recenzji do testów. Zestawy treningowe i testowe są zbilansowane , co oznacza, że ​​zawierają równą liczbę pozytywnych i negatywnych recenzji. Ponadto istnieje 50 000 dodatkowych nieoznaczonych recenzji filmów.

Pobierz wstępnie przetworzony zbiór danych IMDB

Poniższy kod pobiera zestaw danych IMDB (lub używa kopii z pamięci podręcznej, jeśli została już pobrana) przy użyciu TFDS. Aby przyspieszyć działanie tego notatnika, użyjemy tylko 10 000 oznaczonych recenzji i 10 000 nieoznakowanych recenzji na potrzeby szkoleń oraz 10 000 testów do oceny.

train_set, eval_set = tfds.load(
  "imdb_reviews:1.0.0",
  split=["train[:10000]+unsupervised[:10000]", "test[:10000]"],
  shuffle_files=False)
Downloading and preparing dataset imdb_reviews/plain_text/1.0.0 (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /home/kbuilder/tensorflow_datasets/imdb_reviews/plain_text/1.0.0...
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteIPYLOW/imdb_reviews-train.tfrecord
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteIPYLOW/imdb_reviews-test.tfrecord
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteIPYLOW/imdb_reviews-unsupervised.tfrecord

Warning:absl:Dataset is using deprecated text encoder API which will be removed soon. Please use the plain_text version of the dataset and migrate to `tensorflow_text`.

Dataset imdb_reviews downloaded and prepared to /home/kbuilder/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data.

Spójrzmy na kilka recenzji z zestawu treningowego:

for tfrecord in train_set.take(4):
 print("Review: {}".format(tfrecord["text"].numpy().decode("utf-8")[:300]))
 print("Label: {}\n".format(tfrecord["label"].numpy()))
Review: This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda pi
Label: 0

Review: I have been known to fall asleep during films, but this is usually due to a combination of things including, really tired, being warm and comfortable on the sette and having just eaten a lot. However on this occasion I fell asleep because the film was rubbish. The plot development was constant. Cons
Label: 0

Review: Mann photographs the Alberta Rocky Mountains in a superb fashion, and Jimmy Stewart and Walter Brennan give enjoyable performances as they always seem to do. <br /><br />But come on Hollywood - a Mountie telling the people of Dawson City, Yukon to elect themselves a marshal (yes a marshal!) and to e
Label: 0

Review: This is the kind of film for a snowy Sunday afternoon when the rest of the world can go ahead with its own business as you descend into a big arm-chair and mellow for a couple of hours. Wonderful performances from Cher and Nicolas Cage (as always) gently row the plot along. There are no rapids to cr
Label: 1


def _dict_to_example(instance):
 """Decoded CSV to tf example."""
 feature = {}
 for key, value in instance.items():
  if value is None:
   feature[key] = tf.train.Feature()
  elif value.dtype == np.integer:
   feature[key] = tf.train.Feature(
     int64_list=tf.train.Int64List(value=value.tolist()))
  elif value.dtype == np.float32:
   feature[key] = tf.train.Feature(
     float_list=tf.train.FloatList(value=value.tolist()))
  else:
   feature[key] = tf.train.Feature(
     bytes_list=tf.train.BytesList(value=value.tolist()))
 return tf.train.Example(features=tf.train.Features(feature=feature))


examples_path = tempfile.mkdtemp(prefix="tfx-data")
train_path = os.path.join(examples_path, "train.tfrecord")
eval_path = os.path.join(examples_path, "eval.tfrecord")

for path, dataset in [(train_path, train_set), (eval_path, eval_set)]:
 with tf.io.TFRecordWriter(path) as writer:
  for example in dataset:
   writer.write(
     _dict_to_example({
       "label": np.array([example["label"].numpy()]),
       "text": np.array([example["text"].numpy()]),
     }).SerializeToString())

Interaktywne uruchamianie składników TFX

W następnych komórkach utworzysz komponenty TFX i uruchomisz każdy z nich interaktywnie w ramach InteractiveContext, aby uzyskać obiekty ExecutionResult . Odzwierciedla to proces programu Orchestrator, który uruchamia komponenty w TFX DAG na podstawie tego, kiedy spełnione są zależności każdego składnika.

context = InteractiveContext()
WARNING:absl:InteractiveContext pipeline_root argument not provided: using temporary directory /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac as root for pipeline outputs.
WARNING:absl:InteractiveContext metadata_connection_config not provided: using SQLite ML Metadata database at /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/metadata.sqlite.

Składnik ExampleGen

W każdym procesie programowania ML pierwszym krokiem podczas rozpoczynania tworzenia kodu jest pozyskanie zestawów danych szkoleniowych i testowych. Składnik ExampleGen przenosi dane do potoku TFX.

Utwórz składnik ExampleGen i uruchom go.

input_data = external_input(examples_path)

input_config = example_gen_pb2.Input(splits=[
  example_gen_pb2.Input.Split(name='train', pattern='train.tfrecord'),
  example_gen_pb2.Input.Split(name='eval', pattern='eval.tfrecord')
])

example_gen = ImportExampleGen(input=input_data, input_config=input_config)

context.run(example_gen, enable_cache=True)
WARNING:tensorflow:From <ipython-input-1-6617f383c251>:1: external_input (from tfx.utils.dsl_utils) is deprecated and will be removed in a future version.
Instructions for updating:
external_input is deprecated, directly pass the uri to ExampleGen.

Warning:absl:The "input" argument to the ImportExampleGen component has been deprecated by "input_base". Please update your usage as support for this argument will be removed soon.
WARNING:apache_beam.runners.interactive.interactive_environment:Dependencies required for Interactive Beam PCollection visualization are not available, please use: `pip install apache-beam[interactive]` to install necessary dependencies to enable all data visualization features.

Warning:apache_beam.io.tfrecordio:Couldn't find python-snappy so the implementation of _TFRecordUtil._masked_crc32c is not as fast as it could be.

for artifact in example_gen.outputs['examples'].get():
 print(artifact)

print('\nexample_gen.outputs is a {}'.format(type(example_gen.outputs)))
print(example_gen.outputs)

print(example_gen.outputs['examples'].get()[0].split_names)
Artifact(artifact: id: 1
type_id: 5
uri: "/tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/ImportExampleGen/examples/1"
properties {
 key: "split_names"
 value {
  string_value: "[\"train\", \"eval\"]"
 }
}
custom_properties {
 key: "input_fingerprint"
 value {
  string_value: "split:train,num_files:1,total_bytes:27706811,xor_checksum:1602753958,sum_checksum:1602753958\nsplit:eval,num_files:1,total_bytes:13374744,xor_checksum:1602753960,sum_checksum:1602753960"
 }
}
custom_properties {
 key: "name"
 value {
  string_value: "examples"
 }
}
custom_properties {
 key: "payload_format"
 value {
  string_value: "FORMAT_TF_EXAMPLE"
 }
}
custom_properties {
 key: "pipeline_name"
 value {
  string_value: "interactive-2020-10-15T09_26_00.686186"
 }
}
custom_properties {
 key: "producer_component"
 value {
  string_value: "ImportExampleGen"
 }
}
custom_properties {
 key: "span"
 value {
  string_value: "0"
 }
}
custom_properties {
 key: "state"
 value {
  string_value: "published"
 }
}
, artifact_type: id: 5
name: "Examples"
properties {
 key: "span"
 value: INT
}
properties {
 key: "split_names"
 value: STRING
}
properties {
 key: "version"
 value: INT
}
)

example_gen.outputs is a <class 'tfx.types.node_common._PropertyDictWrapper'>
{'examples': Channel(
  type_name: Examples
  artifacts: [Artifact(artifact: id: 1
type_id: 5
uri: "/tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/ImportExampleGen/examples/1"
properties {
 key: "split_names"
 value {
  string_value: "[\"train\", \"eval\"]"
 }
}
custom_properties {
 key: "input_fingerprint"
 value {
  string_value: "split:train,num_files:1,total_bytes:27706811,xor_checksum:1602753958,sum_checksum:1602753958\nsplit:eval,num_files:1,total_bytes:13374744,xor_checksum:1602753960,sum_checksum:1602753960"
 }
}
custom_properties {
 key: "name"
 value {
  string_value: "examples"
 }
}
custom_properties {
 key: "payload_format"
 value {
  string_value: "FORMAT_TF_EXAMPLE"
 }
}
custom_properties {
 key: "pipeline_name"
 value {
  string_value: "interactive-2020-10-15T09_26_00.686186"
 }
}
custom_properties {
 key: "producer_component"
 value {
  string_value: "ImportExampleGen"
 }
}
custom_properties {
 key: "span"
 value {
  string_value: "0"
 }
}
custom_properties {
 key: "state"
 value {
  string_value: "published"
 }
}
, artifact_type: id: 5
name: "Examples"
properties {
 key: "span"
 value: INT
}
properties {
 key: "split_names"
 value: STRING
}
properties {
 key: "version"
 value: INT
}
)]
)}
["train", "eval"]

Dane wyjściowe składnika obejmują 2 artefakty:

 • przykłady szkoleń (10000 oznaczonych recenzji + 10000 nieoznaczonych recenzji)
 • ewaluacyjne przykłady (10000 oznaczonych recenzji)

Składnik niestandardowy IdentifyExamples

Aby korzystać z nakazu NSL, każda instancja musi mieć unikalny identyfikator. Tworzymy niestandardowy komponent, który dodaje taki unikalny identyfikator do wszystkich instancji we wszystkich podziałach. Korzystamy z Apache Beam, aby w razie potrzeby móc łatwo skalować do dużych zbiorów danych.

def make_example_with_unique_id(example, id_feature_name):
 """Adds a unique ID to the given `tf.train.Example` proto.

 This function uses Python's 'uuid' module to generate a universally unique
 identifier for each example.

 Args:
  example: An instance of a `tf.train.Example` proto.
  id_feature_name: The name of the feature in the resulting `tf.train.Example`
   that will contain the unique identifier.

 Returns:
  A new `tf.train.Example` proto that includes a unique identifier as an
  additional feature.
 """
 result = tf.train.Example()
 result.CopyFrom(example)
 unique_id = uuid.uuid4()
 result.features.feature.get_or_create(
   id_feature_name).bytes_list.MergeFrom(
     tf.train.BytesList(value=[str(unique_id).encode('utf-8')]))
 return result


@component
def IdentifyExamples(orig_examples: InputArtifact[Examples],
           identified_examples: OutputArtifact[Examples],
           id_feature_name: Parameter[str],
           component_name: Parameter[str]) -> None:

 # Get a list of the splits in input_data
 splits_list = artifact_utils.decode_split_names(
   split_names=orig_examples.split_names)

 for split in splits_list:
  input_dir = os.path.join(orig_examples.uri, split)
  output_dir = os.path.join(identified_examples.uri, split)
  os.mkdir(output_dir)
  with beam.Pipeline() as pipeline:
   (pipeline
    | 'ReadExamples' >> beam.io.ReadFromTFRecord(
      os.path.join(input_dir, '*'),
      coder=beam.coders.coders.ProtoCoder(tf.train.Example))
    | 'AddUniqueId' >> beam.Map(make_example_with_unique_id, id_feature_name)
    | 'WriteIdentifiedExamples' >> beam.io.WriteToTFRecord(
      file_path_prefix=os.path.join(output_dir, 'data_tfrecord'),
      coder=beam.coders.coders.ProtoCoder(tf.train.Example),
      file_name_suffix='.gz'))

 # For completeness, encode the splits names and payload_format.
 # We could also just use input_data.split_names.
 identified_examples.split_names = artifact_utils.encode_split_names(
   splits=splits_list)
 # TODO(b/168616829): Remove populating payload_format after tfx 0.25.0.
 identified_examples.set_string_custom_property(
   "payload_format",
   orig_examples.get_string_custom_property("payload_format"))

 return
identify_examples = IdentifyExamples(
  orig_examples=example_gen.outputs['examples'],
  component_name=u'IdentifyExamples',
  id_feature_name=u'id')
context.run(identify_examples, enable_cache=False)

Składnik StatisticsGen

Składnik StatisticsGen oblicza opisowe statystyki dla zestawu danych. Statystyki, które generuje, można wizualizować w celu przeglądu i służą na przykład do walidacji i wnioskowania o schemacie.

Utwórz składnik StatisticsGen i uruchom go.

# Computes statistics over data for visualization and example validation.
statistics_gen = StatisticsGen(
  examples=identify_examples.outputs["identified_examples"])
context.run(statistics_gen, enable_cache=True)

Składnik SchemaGen

Komponent SchemaGen generuje schemat danych na podstawie statystyk z StatisticsGen. Próbuje wywnioskować typy danych każdej funkcji i zakresy wartości prawnych dla cech kategorialnych.

Utwórz składnik SchemaGen i uruchom go.

# Generates schema based on statistics files.
schema_gen = SchemaGen(statistics=statistics_gen.outputs['statistics'])
context.run(schema_gen, enable_cache=True)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_data_validation/utils/stats_util.py:229: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.
Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`

Wygenerowany artefakt to po prostu schema.pbtxt zawierający tekstową reprezentację schema_pb2.Schema :

train_uri = schema_gen.outputs['schema'].get()[0].uri
schema_filename = os.path.join(train_uri, 'schema.pbtxt')
schema = tfx.utils.io_utils.parse_pbtxt_file(
  file_name=schema_filename, message=schema_pb2.Schema())

Można to zwizualizować za pomocą tfdv.display_schema() (przyjrzymy się temu bardziej szczegółowo w kolejnym laboratorium):

tfdv.display_schema(schema)

Składnik ExampleValidator

ExampleValidator wykonuje wykrywanie anomalii na podstawie statystyk z StatisticsGen i schematu z SchemaGen. Poszukuje problemów, takich jak brakujące wartości, wartości niewłaściwego typu lub wartości kategorialne poza domeną dopuszczalnych wartości.

Utwórz składnik ExampleValidator i uruchom go.

# Performs anomaly detection based on statistics and data schema.
validate_stats = ExampleValidator(
  statistics=statistics_gen.outputs['statistics'],
  schema=schema_gen.outputs['schema'])
context.run(validate_stats, enable_cache=False)

Składnik SynthesizeGraph

Konstrukcja wykresu polega na utworzeniu osadzeń dla próbek tekstu, a następnie użyciu funkcji podobieństwa w celu porównania osadzeń.

Użyjemy wstępnie wytrenowanych osadzeń Swivel do tworzenia tf.train.Example formacie tf.train.Example dla każdej próbki w wejściu. Będziemy przechowywać wynikowe osadzenia w formacie TFRecord wraz z identyfikatorem próbki. Jest to ważne i pozwoli nam później dopasować przykładowe osadzenia z odpowiednimi węzłami na wykresie.

Kiedy już będziemy mieli przykładowe osadzenia, użyjemy ich do zbudowania wykresu podobieństwa, tj. Węzły na tym wykresie będą odpowiadały próbkom, a krawędzie na tym wykresie będą odpowiadały podobieństwu między parami węzłów.

Neural Structured Learning zapewnia bibliotekę tworzenia wykresów do tworzenia wykresów w oparciu o przykładowe osadzenia. Używa podobieństwa cosinusowego jako miary podobieństwa do porównywania osadzeń i tworzenia krawędzi między nimi. Pozwala również określić próg podobieństwa, który można wykorzystać do odrzucenia niepodobnych krawędzi z końcowego wykresu. W poniższym przykładzie, używając 0,99 jako progu podobieństwa, otrzymujemy wykres, który ma 115 368 dwukierunkowych krawędzi.

swivel_url = 'https://tfhub.dev/google/tf2-preview/gnews-swivel-20dim/1'
hub_layer = hub.KerasLayer(swivel_url, input_shape=[], dtype=tf.string)


def _bytes_feature(value):
 """Returns a bytes_list from a string / byte."""
 return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))


def _float_feature(value):
 """Returns a float_list from a float / double."""
 return tf.train.Feature(float_list=tf.train.FloatList(value=value))


def create_embedding_example(example):
 """Create tf.Example containing the sample's embedding and its ID."""
 sentence_embedding = hub_layer(tf.sparse.to_dense(example['text']))

 # Flatten the sentence embedding back to 1-D.
 sentence_embedding = tf.reshape(sentence_embedding, shape=[-1])

 feature_dict = {
   'id': _bytes_feature(tf.sparse.to_dense(example['id']).numpy()),
   'embedding': _float_feature(sentence_embedding.numpy().tolist())
 }

 return tf.train.Example(features=tf.train.Features(feature=feature_dict))


def create_dataset(uri):
 tfrecord_filenames = [os.path.join(uri, name) for name in os.listdir(uri)]
 return tf.data.TFRecordDataset(tfrecord_filenames, compression_type='GZIP')


def create_embeddings(train_path, output_path):
 dataset = create_dataset(train_path)
 embeddings_path = os.path.join(output_path, 'embeddings.tfr')

 feature_map = {
   'label': tf.io.FixedLenFeature([], tf.int64),
   'id': tf.io.VarLenFeature(tf.string),
   'text': tf.io.VarLenFeature(tf.string)
 }

 with tf.io.TFRecordWriter(embeddings_path) as writer:
  for tfrecord in dataset:
   tensor_dict = tf.io.parse_single_example(tfrecord, feature_map)
   embedding_example = create_embedding_example(tensor_dict)
   writer.write(embedding_example.SerializeToString())


def build_graph(output_path, similarity_threshold):
 embeddings_path = os.path.join(output_path, 'embeddings.tfr')
 graph_path = os.path.join(output_path, 'graph.tfv')
 nsl.tools.build_graph([embeddings_path], graph_path, similarity_threshold)
"""Custom Artifact type"""


class SynthesizedGraph(tfx.types.artifact.Artifact):
 """Output artifact of the SynthesizeGraph component"""
 TYPE_NAME = 'SynthesizedGraphPath'
 PROPERTIES = {
   'span': standard_artifacts.SPAN_PROPERTY,
   'split_names': standard_artifacts.SPLIT_NAMES_PROPERTY,
 }


@component
def SynthesizeGraph(identified_examples: InputArtifact[Examples],
          synthesized_graph: OutputArtifact[SynthesizedGraph],
          similarity_threshold: Parameter[float],
          component_name: Parameter[str]) -> None:

 # Get a list of the splits in input_data
 splits_list = artifact_utils.decode_split_names(
   split_names=identified_examples.split_names)

 # We build a graph only based on the 'train' split which includes both
 # labeled and unlabeled examples.
 train_input_examples_uri = os.path.join(identified_examples.uri, 'train')
 output_graph_uri = os.path.join(synthesized_graph.uri, 'train')
 os.mkdir(output_graph_uri)

 print('Creating embeddings...')
 create_embeddings(train_input_examples_uri, output_graph_uri)

 print('Synthesizing graph...')
 build_graph(output_graph_uri, similarity_threshold)

 synthesized_graph.split_names = artifact_utils.encode_split_names(
   splits=['train'])

 return
synthesize_graph = SynthesizeGraph(
  identified_examples=identify_examples.outputs['identified_examples'],
  component_name=u'SynthesizeGraph',
  similarity_threshold=0.99)
context.run(synthesize_graph, enable_cache=False)
Creating embeddings...
Synthesizing graph...

train_uri = synthesize_graph.outputs["synthesized_graph"].get()[0].uri
os.listdir(train_uri)
['train']
graph_path = os.path.join(train_uri, "train", "graph.tfv")
print("node 1\t\t\t\t\tnode 2\t\t\t\t\tsimilarity")
!head {graph_path}
print("...")
!tail {graph_path}
node 1         node 2         similarity
c54d7b6d-5522-4c7f-80e8-63aefb40518d  48dc5b8a-2941-4de3-a92c-9a6829821632  0.991918
48dc5b8a-2941-4de3-a92c-9a6829821632  c54d7b6d-5522-4c7f-80e8-63aefb40518d  0.991918
4be77993-5b51-40fc-9ebd-ea4185243e0f  352566d1-7ecc-4299-8226-7ce88160661d  0.991171
352566d1-7ecc-4299-8226-7ce88160661d  4be77993-5b51-40fc-9ebd-ea4185243e0f  0.991171
4be77993-5b51-40fc-9ebd-ea4185243e0f  f57a5e51-2960-493e-980d-395826c35ee0  0.992568
f57a5e51-2960-493e-980d-395826c35ee0  4be77993-5b51-40fc-9ebd-ea4185243e0f  0.992568
3630bfa5-2c97-47c4-acfd-bec08a96bc4a  00dc9419-28f2-4852-8ed1-604384254f8c  0.993089
00dc9419-28f2-4852-8ed1-604384254f8c  3630bfa5-2c97-47c4-acfd-bec08a96bc4a  0.993089
3630bfa5-2c97-47c4-acfd-bec08a96bc4a  21e41556-c9ad-4c5e-a580-e5f2772c1ba4  0.991987
21e41556-c9ad-4c5e-a580-e5f2772c1ba4  3630bfa5-2c97-47c4-acfd-bec08a96bc4a  0.991987
...
2f63416d-12e9-40d1-970b-ab978a6d1e93  c4d07e3b-b991-42ba-9848-57a489080bab  0.993670
c4d07e3b-b991-42ba-9848-57a489080bab  2f63416d-12e9-40d1-970b-ab978a6d1e93  0.993670
829c875d-66ef-43d2-ab35-51fc7448b61d  f8fb2876-afc4-4e4b-af80-2c432377191b  0.990820
f8fb2876-afc4-4e4b-af80-2c432377191b  829c875d-66ef-43d2-ab35-51fc7448b61d  0.990820
bf272722-d225-4640-9edd-ec7674fc5734  97b231c3-7952-4826-bb01-a2935991a4e7  0.991107
97b231c3-7952-4826-bb01-a2935991a4e7  bf272722-d225-4640-9edd-ec7674fc5734  0.991107
f8fb2876-afc4-4e4b-af80-2c432377191b  42c5d827-4c77-4519-b128-48543104770f  0.990005
42c5d827-4c77-4519-b128-48543104770f  f8fb2876-afc4-4e4b-af80-2c432377191b  0.990005
9bbbebb5-eb68-42d5-a0a3-700a6ac33a78  e74d91e7-170a-46a9-abf6-d176ef810e54  0.993868
e74d91e7-170a-46a9-abf6-d176ef810e54  9bbbebb5-eb68-42d5-a0a3-700a6ac33a78  0.993868

wc -l {graph_path}
230736 /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/SynthesizeGraph/synthesized_graph/6/train/graph.tfv

Składnik transformacji

Składnik Transform wykonuje transformacje danych i inżynierię funkcji. Wyniki obejmują wykres wejściowy TensorFlow, który jest używany zarówno podczas uczenia, jak i służący do wstępnego przetwarzania danych przed uczeniem lub wnioskiem. Ten wykres staje się częścią SavedModel, który jest wynikiem uczenia modelu. Ponieważ ten sam wykres wejściowy jest używany zarówno do uczenia, jak i udostępniania, przetwarzanie wstępne będzie zawsze takie samo i wystarczy je zapisać tylko raz.

Komponent Transform wymaga więcej kodu niż wiele innych komponentów ze względu na dowolną złożoność inżynierii funkcji, której możesz potrzebować dla danych i / lub modelu, z którym pracujesz. Wymaga dostępności plików kodu, które definiują wymagane przetwarzanie.

Każda próbka będzie zawierać następujące trzy funkcje:

 1. id : identyfikator węzła próbki.
 2. text_xf : lista int64 zawierająca identyfikatory słów.
 3. label_xf : singleton int64 określający klasę docelową recenzji: 0 = negatywna, 1 = pozytywna.

Zdefiniujmy moduł zawierający funkcję preprocessing_fn() , którą przekażemy do komponentu Transform :

_transform_module_file = 'imdb_transform.py'
%%writefile {_transform_module_file}

import tensorflow as tf

import tensorflow_transform as tft

SEQUENCE_LENGTH = 100
VOCAB_SIZE = 10000
OOV_SIZE = 100

def tokenize_reviews(reviews, sequence_length=SEQUENCE_LENGTH):
 reviews = tf.strings.lower(reviews)
 reviews = tf.strings.regex_replace(reviews, r" '| '|^'|'$", " ")
 reviews = tf.strings.regex_replace(reviews, "[^a-z' ]", " ")
 tokens = tf.strings.split(reviews)[:, :sequence_length]
 start_tokens = tf.fill([tf.shape(reviews)[0], 1], "<START>")
 end_tokens = tf.fill([tf.shape(reviews)[0], 1], "<END>")
 tokens = tf.concat([start_tokens, tokens, end_tokens], axis=1)
 tokens = tokens[:, :sequence_length]
 tokens = tokens.to_tensor(default_value="<PAD>")
 pad = sequence_length - tf.shape(tokens)[1]
 tokens = tf.pad(tokens, [[0, 0], [0, pad]], constant_values="<PAD>")
 return tf.reshape(tokens, [-1, sequence_length])

def preprocessing_fn(inputs):
 """tf.transform's callback function for preprocessing inputs.

 Args:
  inputs: map from feature keys to raw not-yet-transformed features.

 Returns:
  Map from string feature key to transformed feature operations.
 """
 outputs = {}
 outputs["id"] = inputs["id"]
 tokens = tokenize_reviews(_fill_in_missing(inputs["text"], ''))
 outputs["text_xf"] = tft.compute_and_apply_vocabulary(
   tokens,
   top_k=VOCAB_SIZE,
   num_oov_buckets=OOV_SIZE)
 outputs["label_xf"] = _fill_in_missing(inputs["label"], -1)
 return outputs

def _fill_in_missing(x, default_value):
 """Replace missing values in a SparseTensor.

 Fills in missing values of `x` with the default_value.

 Args:
  x: A `SparseTensor` of rank 2. Its dense shape should have size at most 1
   in the second dimension.
  default_value: the value with which to replace the missing values.

 Returns:
  A rank 1 tensor where missing values of `x` have been filled in.
 """
 return tf.squeeze(
   tf.sparse.to_dense(
     tf.SparseTensor(x.indices, x.values, [x.dense_shape[0], 1]),
     default_value),
   axis=1)
Writing imdb_transform.py

Utwórz i uruchom komponent Transform , odwołując się do plików, które zostały utworzone powyżej.

# Performs transformations and feature engineering in training and serving.
transform = Transform(
  examples=identify_examples.outputs['identified_examples'],
  schema=schema_gen.outputs['schema'],
  # TODO(b/169218106): Remove transformed_examples kwargs after bugfix is released.
  transformed_examples=channel.Channel(
    type=standard_artifacts.Examples,
    artifacts=[standard_artifacts.Examples()]),
  module_file=_transform_module_file)
context.run(transform)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tfx/components/transform/executor.py:485: Schema (from tensorflow_transform.tf_metadata.dataset_schema) is deprecated and will be removed in a future version.
Instructions for updating:
Schema is a deprecated, use schema_utils.schema_from_feature_spec to create a `Schema`
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_transform/tf_utils.py:218: Tensor.experimental_ref (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use ref() instead.

Warning:root:This output type hint will be ignored and not used for type-checking purposes. Typically, output type hints for a PTransform are single (or nested) types wrapped by a PCollection, PDone, or None. Got: Tuple[Dict[str, Union[NoneType, _Dataset]], Union[Dict[str, Dict[str, PCollection]], NoneType]] instead.
WARNING:root:This output type hint will be ignored and not used for type-checking purposes. Typically, output type hints for a PTransform are single (or nested) types wrapped by a PCollection, PDone, or None. Got: Tuple[Dict[str, Union[NoneType, _Dataset]], Union[Dict[str, Dict[str, PCollection]], NoneType]] instead.

Warning:tensorflow:Tensorflow version (2.3.1) found. Note that Tensorflow Transform support for TF 2.0 is currently in beta, and features such as tf.function may not work as intended. 
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:201: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:No assets to write.
WARNING:tensorflow:Issue encountered when serializing tft_mapper_use.
Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'Counter' object has no attribute 'name'
INFO:tensorflow:SavedModel written to: /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Transform/transform_graph/7/.temp_path/tftransform_tmp/11b6d4f9f3844b359227a3c768c5608d/saved_model.pb
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:No assets to write.
WARNING:tensorflow:Issue encountered when serializing tft_mapper_use.
Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'Counter' object has no attribute 'name'
INFO:tensorflow:SavedModel written to: /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Transform/transform_graph/7/.temp_path/tftransform_tmp/8f11b64eb9504bd2bd71067216fee1db/saved_model.pb
WARNING:tensorflow:Tensorflow version (2.3.1) found. Note that Tensorflow Transform support for TF 2.0 is currently in beta, and features such as tf.function may not work as intended. 

Warning:apache_beam.typehints.typehints:Ignoring send_type hint: <class 'NoneType'>
WARNING:apache_beam.typehints.typehints:Ignoring return_type hint: <class 'NoneType'>
WARNING:apache_beam.typehints.typehints:Ignoring send_type hint: <class 'NoneType'>
WARNING:apache_beam.typehints.typehints:Ignoring return_type hint: <class 'NoneType'>
WARNING:apache_beam.typehints.typehints:Ignoring send_type hint: <class 'NoneType'>
WARNING:apache_beam.typehints.typehints:Ignoring return_type hint: <class 'NoneType'>

Warning:tensorflow:Tensorflow version (2.3.1) found. Note that Tensorflow Transform support for TF 2.0 is currently in beta, and features such as tf.function may not work as intended. 

Warning:apache_beam.typehints.typehints:Ignoring send_type hint: <class 'NoneType'>
WARNING:apache_beam.typehints.typehints:Ignoring return_type hint: <class 'NoneType'>
WARNING:apache_beam.typehints.typehints:Ignoring send_type hint: <class 'NoneType'>
WARNING:apache_beam.typehints.typehints:Ignoring return_type hint: <class 'NoneType'>
WARNING:apache_beam.typehints.typehints:Ignoring send_type hint: <class 'NoneType'>
WARNING:apache_beam.typehints.typehints:Ignoring return_type hint: <class 'NoneType'>

INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:Assets written to: /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Transform/transform_graph/7/.temp_path/tftransform_tmp/3e8a5a5dc9af40df94c4c20167ed200f/assets
INFO:tensorflow:SavedModel written to: /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Transform/transform_graph/7/.temp_path/tftransform_tmp/3e8a5a5dc9af40df94c4c20167ed200f/saved_model.pb
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_1:0\022-vocab_compute_and_apply_vocabulary_vocabulary"

INFO:tensorflow:Saver not created because there are no variables in the graph to restore
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_1:0\022-vocab_compute_and_apply_vocabulary_vocabulary"

INFO:tensorflow:Saver not created because there are no variables in the graph to restore
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_1:0\022-vocab_compute_and_apply_vocabulary_vocabulary"

INFO:tensorflow:Saver not created because there are no variables in the graph to restore

Składnik Transform ma 2 typy wyjść:

 • transform_graph to wykres, który może wykonywać operacje przetwarzania wstępnego (ten wykres zostanie uwzględniony w modelach udostępniania i oceny).
 • transformed_examples reprezentuje wstępnie przetworzone dane szkoleniowe i oceny.
transform.outputs
{'transform_graph': Channel(
  type_name: TransformGraph
  artifacts: [Artifact(artifact: id: 7
type_id: 16
uri: "/tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Transform/transform_graph/7"
custom_properties {
 key: "name"
 value {
  string_value: "transform_graph"
 }
}
custom_properties {
 key: "pipeline_name"
 value {
  string_value: "interactive-2020-10-15T09_26_00.686186"
 }
}
custom_properties {
 key: "producer_component"
 value {
  string_value: "Transform"
 }
}
custom_properties {
 key: "state"
 value {
  string_value: "published"
 }
}
, artifact_type: id: 16
name: "TransformGraph"
)]
), 'transformed_examples': Channel(
  type_name: Examples
  artifacts: [Artifact(artifact: id: 8
type_id: 5
uri: "/tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Transform/transformed_examples/7"
properties {
 key: "split_names"
 value {
  string_value: "[\"train\", \"eval\"]"
 }
}
custom_properties {
 key: "name"
 value {
  string_value: "transformed_examples"
 }
}
custom_properties {
 key: "pipeline_name"
 value {
  string_value: "interactive-2020-10-15T09_26_00.686186"
 }
}
custom_properties {
 key: "producer_component"
 value {
  string_value: "Transform"
 }
}
custom_properties {
 key: "state"
 value {
  string_value: "published"
 }
}
, artifact_type: id: 5
name: "Examples"
properties {
 key: "span"
 value: INT
}
properties {
 key: "split_names"
 value: STRING
}
properties {
 key: "version"
 value: INT
}
)]
)}

Rzuć okiem na artefakt transform_graph : wskazuje on katalog zawierający 3 podkatalogi:

train_uri = transform.outputs['transform_graph'].get()[0].uri
os.listdir(train_uri)
['transform_fn', 'transformed_metadata', 'metadata']

Podkatalog transform_fn zawiera rzeczywisty wykres przetwarzania wstępnego. Podkatalog metadata zawiera schemat oryginalnych danych. Podkatalog transformed_metadata zawiera schemat wstępnie przetworzonych danych.

Przyjrzyj się niektórym przekształconym przykładom i sprawdź, czy rzeczywiście zostały przetworzone zgodnie z przeznaczeniem.

def pprint_examples(artifact, n_examples=3):
 print("artifact:", artifact)
 uri = os.path.join(artifact.uri, "train")
 print("uri:", uri)
 tfrecord_filenames = [os.path.join(uri, name) for name in os.listdir(uri)]
 print("tfrecord_filenames:", tfrecord_filenames)
 dataset = tf.data.TFRecordDataset(tfrecord_filenames, compression_type="GZIP")
 for tfrecord in dataset.take(n_examples):
  serialized_example = tfrecord.numpy()
  example = tf.train.Example.FromString(serialized_example)
  pp.pprint(example)
pprint_examples(transform.outputs['transformed_examples'].get()[0])
artifact: Artifact(artifact: id: 8
type_id: 5
uri: "/tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Transform/transformed_examples/7"
properties {
 key: "split_names"
 value {
  string_value: "[\"train\", \"eval\"]"
 }
}
custom_properties {
 key: "name"
 value {
  string_value: "transformed_examples"
 }
}
custom_properties {
 key: "pipeline_name"
 value {
  string_value: "interactive-2020-10-15T09_26_00.686186"
 }
}
custom_properties {
 key: "producer_component"
 value {
  string_value: "Transform"
 }
}
custom_properties {
 key: "state"
 value {
  string_value: "published"
 }
}
, artifact_type: id: 5
name: "Examples"
properties {
 key: "span"
 value: INT
}
properties {
 key: "split_names"
 value: STRING
}
properties {
 key: "version"
 value: INT
}
)
uri: /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Transform/transformed_examples/7/train
tfrecord_filenames: ['/tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Transform/transformed_examples/7/train/transformed_examples-00000-of-00001.gz']
features {
 feature {
  key: "id"
  value {
   bytes_list {
    value: "08903146-1233-49d7-ac8e-ac126c0a8b14"
   }
  }
 }
 feature {
  key: "label_xf"
  value {
   int64_list {
    value: 0
   }
  }
 }
 feature {
  key: "text_xf"
  value {
   int64_list {
    value: 13
    value: 8
    value: 14
    value: 32
    value: 338
    value: 310
    value: 15
    value: 95
    value: 27
    value: 10001
    value: 9
    value: 31
    value: 1173
    value: 3153
    value: 43
    value: 495
    value: 10060
    value: 214
    value: 26
    value: 71
    value: 142
    value: 19
    value: 8
    value: 204
    value: 339
    value: 27
    value: 74
    value: 181
    value: 238
    value: 9
    value: 440
    value: 67
    value: 74
    value: 71
    value: 94
    value: 100
    value: 22
    value: 5442
    value: 8
    value: 1573
    value: 607
    value: 530
    value: 8
    value: 15
    value: 6
    value: 32
    value: 378
    value: 6292
    value: 207
    value: 2276
    value: 388
    value: 0
    value: 84
    value: 1023
    value: 154
    value: 65
    value: 155
    value: 52
    value: 0
    value: 10080
    value: 7871
    value: 65
    value: 250
    value: 74
    value: 3202
    value: 20
    value: 10000
    value: 3720
    value: 10020
    value: 10008
    value: 1282
    value: 3862
    value: 3
    value: 53
    value: 3952
    value: 110
    value: 1879
    value: 17
    value: 3153
    value: 14
    value: 166
    value: 19
    value: 2
    value: 1023
    value: 1007
    value: 9405
    value: 9
    value: 2
    value: 15
    value: 12
    value: 14
    value: 4504
    value: 4
    value: 109
    value: 158
    value: 1202
    value: 7
    value: 174
    value: 505
    value: 12
   }
  }
 }
}

features {
 feature {
  key: "id"
  value {
   bytes_list {
    value: "71e3f765-3bfd-4754-92fb-c258c43f78dc"
   }
  }
 }
 feature {
  key: "label_xf"
  value {
   int64_list {
    value: 0
   }
  }
 }
 feature {
  key: "text_xf"
  value {
   int64_list {
    value: 13
    value: 7
    value: 23
    value: 75
    value: 494
    value: 5
    value: 748
    value: 2155
    value: 307
    value: 91
    value: 19
    value: 8
    value: 6
    value: 499
    value: 763
    value: 5
    value: 2
    value: 1690
    value: 4
    value: 200
    value: 593
    value: 57
    value: 1244
    value: 120
    value: 2364
    value: 3
    value: 4407
    value: 21
    value: 0
    value: 10081
    value: 3
    value: 263
    value: 42
    value: 6947
    value: 2
    value: 169
    value: 185
    value: 21
    value: 8
    value: 5143
    value: 7
    value: 1339
    value: 2155
    value: 81
    value: 0
    value: 18
    value: 14
    value: 1468
    value: 0
    value: 86
    value: 986
    value: 14
    value: 2259
    value: 1790
    value: 562
    value: 3
    value: 284
    value: 200
    value: 401
    value: 5
    value: 668
    value: 19
    value: 17
    value: 58
    value: 1934
    value: 4
    value: 45
    value: 14
    value: 4212
    value: 113
    value: 43
    value: 135
    value: 7
    value: 753
    value: 7
    value: 224
    value: 23
    value: 1155
    value: 179
    value: 4
    value: 0
    value: 18
    value: 19
    value: 7
    value: 191
    value: 0
    value: 2047
    value: 4
    value: 10
    value: 3
    value: 283
    value: 42
    value: 401
    value: 5
    value: 668
    value: 4
    value: 90
    value: 234
    value: 10023
    value: 227
   }
  }
 }
}

features {
 feature {
  key: "id"
  value {
   bytes_list {
    value: "eaad5638-befe-4556-8ef8-1b5061aaab34"
   }
  }
 }
 feature {
  key: "label_xf"
  value {
   int64_list {
    value: 0
   }
  }
 }
 feature {
  key: "text_xf"
  value {
   int64_list {
    value: 13
    value: 4577
    value: 7158
    value: 0
    value: 10047
    value: 3778
    value: 3346
    value: 9
    value: 2
    value: 758
    value: 1915
    value: 3
    value: 2280
    value: 1511
    value: 3
    value: 2003
    value: 10020
    value: 225
    value: 786
    value: 382
    value: 16
    value: 39
    value: 203
    value: 361
    value: 5
    value: 93
    value: 11
    value: 11
    value: 19
    value: 220
    value: 21
    value: 341
    value: 2
    value: 10000
    value: 966
    value: 0
    value: 77
    value: 4
    value: 6677
    value: 464
    value: 10071
    value: 5
    value: 10042
    value: 630
    value: 2
    value: 10044
    value: 404
    value: 2
    value: 10044
    value: 3
    value: 5
    value: 10008
    value: 0
    value: 1259
    value: 630
    value: 106
    value: 10042
    value: 6721
    value: 10
    value: 49
    value: 21
    value: 0
    value: 2071
    value: 20
    value: 1292
    value: 4
    value: 0
    value: 431
    value: 11
    value: 11
    value: 166
    value: 67
    value: 2342
    value: 5815
    value: 12
    value: 575
    value: 21
    value: 0
    value: 1691
    value: 537
    value: 4
    value: 0
    value: 3605
    value: 307
    value: 0
    value: 10054
    value: 1563
    value: 3115
    value: 467
    value: 4577
    value: 3
    value: 1069
    value: 1158
    value: 5
    value: 23
    value: 4279
    value: 6677
    value: 464
    value: 20
    value: 10004
   }
  }
 }
}


Składnik GraphAugmentation

Ponieważ mamy przykładowe funkcje i zsyntetyzowany wykres, możemy wygenerować rozszerzone dane szkoleniowe dla uczenia się opartego na neuronach. Struktura NSL zapewnia bibliotekę umożliwiającą połączenie wykresu i przykładowych funkcji w celu uzyskania ostatecznych danych szkoleniowych do regularyzacji wykresów. Wynikowe dane szkoleniowe będą zawierały oryginalne przykładowe cechy, a także cechy odpowiadających im sąsiadów.

W tym samouczku bierzemy pod uwagę niekierowane krawędzie i używamy maksymalnie 3 sąsiadów na próbkę, aby rozszerzyć dane szkoleniowe o sąsiadów wykresu.

def split_train_and_unsup(input_uri):
 'Separate the labeled and unlabeled instances.'

 tmp_dir = tempfile.mkdtemp(prefix='tfx-data')
 tfrecord_filenames = [
   os.path.join(input_uri, filename) for filename in os.listdir(input_uri)
 ]
 train_path = os.path.join(tmp_dir, 'train.tfrecord')
 unsup_path = os.path.join(tmp_dir, 'unsup.tfrecord')
 with tf.io.TFRecordWriter(train_path) as train_writer, \
    tf.io.TFRecordWriter(unsup_path) as unsup_writer:
  for tfrecord in tf.data.TFRecordDataset(
    tfrecord_filenames, compression_type='GZIP'):
   example = tf.train.Example()
   example.ParseFromString(tfrecord.numpy())
   if ('label_xf' not in example.features.feature or
     example.features.feature['label_xf'].int64_list.value[0] == -1):
    writer = unsup_writer
   else:
    writer = train_writer
   writer.write(tfrecord.numpy())
 return train_path, unsup_path


def gzip(filepath):
 with open(filepath, 'rb') as f_in:
  with gzip_lib.open(filepath + '.gz', 'wb') as f_out:
   shutil.copyfileobj(f_in, f_out)
 os.remove(filepath)


def copy_tfrecords(input_uri, output_uri):
 for filename in os.listdir(input_uri):
  input_filename = os.path.join(input_uri, filename)
  output_filename = os.path.join(output_uri, filename)
  shutil.copyfile(input_filename, output_filename)


@component
def GraphAugmentation(identified_examples: InputArtifact[Examples],
           synthesized_graph: InputArtifact[SynthesizedGraph],
           augmented_examples: OutputArtifact[Examples],
           num_neighbors: Parameter[int],
           component_name: Parameter[str]) -> None:

 # Get a list of the splits in input_data
 splits_list = artifact_utils.decode_split_names(
   split_names=identified_examples.split_names)

 train_input_uri = os.path.join(identified_examples.uri, 'train')
 eval_input_uri = os.path.join(identified_examples.uri, 'eval')
 train_graph_uri = os.path.join(synthesized_graph.uri, 'train')
 train_output_uri = os.path.join(augmented_examples.uri, 'train')
 eval_output_uri = os.path.join(augmented_examples.uri, 'eval')

 os.mkdir(train_output_uri)
 os.mkdir(eval_output_uri)

 # Separate out the labeled and unlabeled examples from the 'train' split.
 train_path, unsup_path = split_train_and_unsup(train_input_uri)

 output_path = os.path.join(train_output_uri, 'nsl_train_data.tfr')
 pack_nbrs_args = dict(
   labeled_examples_path=train_path,
   unlabeled_examples_path=unsup_path,
   graph_path=os.path.join(train_graph_uri, 'graph.tfv'),
   output_training_data_path=output_path,
   add_undirected_edges=True,
   max_nbrs=num_neighbors)
 print('nsl.tools.pack_nbrs arguments:', pack_nbrs_args)
 nsl.tools.pack_nbrs(**pack_nbrs_args)

 # Downstream components expect gzip'ed TFRecords.
 gzip(output_path)

 # The test examples are left untouched and are simply copied over.
 copy_tfrecords(eval_input_uri, eval_output_uri)

 augmented_examples.split_names = identified_examples.split_names

 return
# Augments training data with graph neighbors.
graph_augmentation = GraphAugmentation(
  identified_examples=transform.outputs['transformed_examples'],
  synthesized_graph=synthesize_graph.outputs['synthesized_graph'],
  component_name=u'GraphAugmentation',
  num_neighbors=3)
context.run(graph_augmentation, enable_cache=False)
nsl.tools.pack_nbrs arguments: {'labeled_examples_path': '/tmp/tfx-datajre7hdjd/train.tfrecord', 'unlabeled_examples_path': '/tmp/tfx-datajre7hdjd/unsup.tfrecord', 'graph_path': '/tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/SynthesizeGraph/synthesized_graph/6/train/graph.tfv', 'output_training_data_path': '/tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/GraphAugmentation/augmented_examples/8/train/nsl_train_data.tfr', 'add_undirected_edges': True, 'max_nbrs': 3}

pprint_examples(graph_augmentation.outputs['augmented_examples'].get()[0], 6)
artifact: Artifact(artifact: id: 9
type_id: 5
uri: "/tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/GraphAugmentation/augmented_examples/8"
properties {
 key: "split_names"
 value {
  string_value: "[\"train\", \"eval\"]"
 }
}
custom_properties {
 key: "name"
 value {
  string_value: "augmented_examples"
 }
}
custom_properties {
 key: "pipeline_name"
 value {
  string_value: "interactive-2020-10-15T09_26_00.686186"
 }
}
custom_properties {
 key: "producer_component"
 value {
  string_value: "GraphAugmentation"
 }
}
custom_properties {
 key: "state"
 value {
  string_value: "published"
 }
}
, artifact_type: id: 5
name: "Examples"
properties {
 key: "span"
 value: INT
}
properties {
 key: "split_names"
 value: STRING
}
properties {
 key: "version"
 value: INT
}
)
uri: /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/GraphAugmentation/augmented_examples/8/train
tfrecord_filenames: ['/tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/GraphAugmentation/augmented_examples/8/train/nsl_train_data.tfr.gz']
features {
 feature {
  key: "NL_num_nbrs"
  value {
   int64_list {
    value: 0
   }
  }
 }
 feature {
  key: "id"
  value {
   bytes_list {
    value: "08903146-1233-49d7-ac8e-ac126c0a8b14"
   }
  }
 }
 feature {
  key: "label_xf"
  value {
   int64_list {
    value: 0
   }
  }
 }
 feature {
  key: "text_xf"
  value {
   int64_list {
    value: 13
    value: 8
    value: 14
    value: 32
    value: 338
    value: 310
    value: 15
    value: 95
    value: 27
    value: 10001
    value: 9
    value: 31
    value: 1173
    value: 3153
    value: 43
    value: 495
    value: 10060
    value: 214
    value: 26
    value: 71
    value: 142
    value: 19
    value: 8
    value: 204
    value: 339
    value: 27
    value: 74
    value: 181
    value: 238
    value: 9
    value: 440
    value: 67
    value: 74
    value: 71
    value: 94
    value: 100
    value: 22
    value: 5442
    value: 8
    value: 1573
    value: 607
    value: 530
    value: 8
    value: 15
    value: 6
    value: 32
    value: 378
    value: 6292
    value: 207
    value: 2276
    value: 388
    value: 0
    value: 84
    value: 1023
    value: 154
    value: 65
    value: 155
    value: 52
    value: 0
    value: 10080
    value: 7871
    value: 65
    value: 250
    value: 74
    value: 3202
    value: 20
    value: 10000
    value: 3720
    value: 10020
    value: 10008
    value: 1282
    value: 3862
    value: 3
    value: 53
    value: 3952
    value: 110
    value: 1879
    value: 17
    value: 3153
    value: 14
    value: 166
    value: 19
    value: 2
    value: 1023
    value: 1007
    value: 9405
    value: 9
    value: 2
    value: 15
    value: 12
    value: 14
    value: 4504
    value: 4
    value: 109
    value: 158
    value: 1202
    value: 7
    value: 174
    value: 505
    value: 12
   }
  }
 }
}

features {
 feature {
  key: "NL_num_nbrs"
  value {
   int64_list {
    value: 0
   }
  }
 }
 feature {
  key: "id"
  value {
   bytes_list {
    value: "71e3f765-3bfd-4754-92fb-c258c43f78dc"
   }
  }
 }
 feature {
  key: "label_xf"
  value {
   int64_list {
    value: 0
   }
  }
 }
 feature {
  key: "text_xf"
  value {
   int64_list {
    value: 13
    value: 7
    value: 23
    value: 75
    value: 494
    value: 5
    value: 748
    value: 2155
    value: 307
    value: 91
    value: 19
    value: 8
    value: 6
    value: 499
    value: 763
    value: 5
    value: 2
    value: 1690
    value: 4
    value: 200
    value: 593
    value: 57
    value: 1244
    value: 120
    value: 2364
    value: 3
    value: 4407
    value: 21
    value: 0
    value: 10081
    value: 3
    value: 263
    value: 42
    value: 6947
    value: 2
    value: 169
    value: 185
    value: 21
    value: 8
    value: 5143
    value: 7
    value: 1339
    value: 2155
    value: 81
    value: 0
    value: 18
    value: 14
    value: 1468
    value: 0
    value: 86
    value: 986
    value: 14
    value: 2259
    value: 1790
    value: 562
    value: 3
    value: 284
    value: 200
    value: 401
    value: 5
    value: 668
    value: 19
    value: 17
    value: 58
    value: 1934
    value: 4
    value: 45
    value: 14
    value: 4212
    value: 113
    value: 43
    value: 135
    value: 7
    value: 753
    value: 7
    value: 224
    value: 23
    value: 1155
    value: 179
    value: 4
    value: 0
    value: 18
    value: 19
    value: 7
    value: 191
    value: 0
    value: 2047
    value: 4
    value: 10
    value: 3
    value: 283
    value: 42
    value: 401
    value: 5
    value: 668
    value: 4
    value: 90
    value: 234
    value: 10023
    value: 227
   }
  }
 }
}

features {
 feature {
  key: "NL_num_nbrs"
  value {
   int64_list {
    value: 0
   }
  }
 }
 feature {
  key: "id"
  value {
   bytes_list {
    value: "eaad5638-befe-4556-8ef8-1b5061aaab34"
   }
  }
 }
 feature {
  key: "label_xf"
  value {
   int64_list {
    value: 0
   }
  }
 }
 feature {
  key: "text_xf"
  value {
   int64_list {
    value: 13
    value: 4577
    value: 7158
    value: 0
    value: 10047
    value: 3778
    value: 3346
    value: 9
    value: 2
    value: 758
    value: 1915
    value: 3
    value: 2280
    value: 1511
    value: 3
    value: 2003
    value: 10020
    value: 225
    value: 786
    value: 382
    value: 16
    value: 39
    value: 203
    value: 361
    value: 5
    value: 93
    value: 11
    value: 11
    value: 19
    value: 220
    value: 21
    value: 341
    value: 2
    value: 10000
    value: 966
    value: 0
    value: 77
    value: 4
    value: 6677
    value: 464
    value: 10071
    value: 5
    value: 10042
    value: 630
    value: 2
    value: 10044
    value: 404
    value: 2
    value: 10044
    value: 3
    value: 5
    value: 10008
    value: 0
    value: 1259
    value: 630
    value: 106
    value: 10042
    value: 6721
    value: 10
    value: 49
    value: 21
    value: 0
    value: 2071
    value: 20
    value: 1292
    value: 4
    value: 0
    value: 431
    value: 11
    value: 11
    value: 166
    value: 67
    value: 2342
    value: 5815
    value: 12
    value: 575
    value: 21
    value: 0
    value: 1691
    value: 537
    value: 4
    value: 0
    value: 3605
    value: 307
    value: 0
    value: 10054
    value: 1563
    value: 3115
    value: 467
    value: 4577
    value: 3
    value: 1069
    value: 1158
    value: 5
    value: 23
    value: 4279
    value: 6677
    value: 464
    value: 20
    value: 10004
   }
  }
 }
}

features {
 feature {
  key: "NL_num_nbrs"
  value {
   int64_list {
    value: 0
   }
  }
 }
 feature {
  key: "id"
  value {
   bytes_list {
    value: "11ff10a2-1ea4-4b10-ba91-2ba633b8abd4"
   }
  }
 }
 feature {
  key: "label_xf"
  value {
   int64_list {
    value: 1
   }
  }
 }
 feature {
  key: "text_xf"
  value {
   int64_list {
    value: 13
    value: 8
    value: 6
    value: 0
    value: 251
    value: 4
    value: 18
    value: 20
    value: 2
    value: 6783
    value: 2295
    value: 2338
    value: 52
    value: 0
    value: 468
    value: 4
    value: 0
    value: 189
    value: 73
    value: 153
    value: 1294
    value: 17
    value: 90
    value: 234
    value: 935
    value: 16
    value: 25
    value: 10024
    value: 92
    value: 2
    value: 192
    value: 4218
    value: 3317
    value: 3
    value: 10098
    value: 20
    value: 2
    value: 356
    value: 4
    value: 565
    value: 334
    value: 382
    value: 36
    value: 6989
    value: 3
    value: 6065
    value: 2510
    value: 16
    value: 203
    value: 7264
    value: 2849
    value: 0
    value: 86
    value: 346
    value: 50
    value: 26
    value: 58
    value: 10020
    value: 5
    value: 1464
    value: 58
    value: 2081
    value: 2969
    value: 42
    value: 2
    value: 2364
    value: 3
    value: 1402
    value: 10062
    value: 138
    value: 147
    value: 614
    value: 115
    value: 29
    value: 90
    value: 105
    value: 2
    value: 223
    value: 18
    value: 9
    value: 160
    value: 324
    value: 3
    value: 24
    value: 12
    value: 1252
    value: 0
    value: 2142
    value: 10
    value: 1832
    value: 111
    value: 1
    value: 1
    value: 1
    value: 1
    value: 1
    value: 1
    value: 1
    value: 1
    value: 1
   }
  }
 }
}

features {
 feature {
  key: "NL_num_nbrs"
  value {
   int64_list {
    value: 0
   }
  }
 }
 feature {
  key: "id"
  value {
   bytes_list {
    value: "ed3db659-5524-4410-a5d5-d2bbd550a01f"
   }
  }
 }
 feature {
  key: "label_xf"
  value {
   int64_list {
    value: 1
   }
  }
 }
 feature {
  key: "text_xf"
  value {
   int64_list {
    value: 13
    value: 16
    value: 423
    value: 23
    value: 1367
    value: 30
    value: 0
    value: 363
    value: 12
    value: 153
    value: 3174
    value: 9
    value: 8
    value: 18
    value: 26
    value: 667
    value: 338
    value: 1372
    value: 0
    value: 86
    value: 46
    value: 9200
    value: 282
    value: 0
    value: 10091
    value: 4
    value: 0
    value: 694
    value: 10028
    value: 52
    value: 362
    value: 26
    value: 202
    value: 39
    value: 216
    value: 5
    value: 27
    value: 5822
    value: 19
    value: 52
    value: 58
    value: 362
    value: 26
    value: 202
    value: 39
    value: 474
    value: 0
    value: 10029
    value: 4
    value: 2
    value: 243
    value: 143
    value: 386
    value: 3
    value: 0
    value: 386
    value: 579
    value: 2
    value: 132
    value: 57
    value: 725
    value: 88
    value: 140
    value: 30
    value: 27
    value: 33
    value: 1359
    value: 29
    value: 8
    value: 567
    value: 35
    value: 106
    value: 230
    value: 60
    value: 0
    value: 3041
    value: 5
    value: 7879
    value: 28
    value: 281
    value: 110
    value: 111
    value: 1
    value: 1
    value: 1
    value: 1
    value: 1
    value: 1
    value: 1
    value: 1
    value: 1
    value: 1
    value: 1
    value: 1
    value: 1
    value: 1
    value: 1
    value: 1
    value: 1
    value: 1
   }
  }
 }
}

features {
 feature {
  key: "NL_nbr_0_id"
  value {
   bytes_list {
    value: "daf1f061-ef48-4476-a047-b9022c372d4e"
   }
  }
 }
 feature {
  key: "NL_nbr_0_label_xf"
  value {
   int64_list {
    value: -1
   }
  }
 }
 feature {
  key: "NL_nbr_0_text_xf"
  value {
   int64_list {
    value: 13
    value: 7
    value: 174
    value: 2
    value: 1525
    value: 4
    value: 440
    value: 3
    value: 1260
    value: 91
    value: 108
    value: 19
    value: 10095
    value: 10004
    value: 40
    value: 2
    value: 169
    value: 4
    value: 4594
    value: 84
    value: 4
    value: 30
    value: 8
    value: 15
    value: 1063
    value: 9
    value: 54
    value: 966
    value: 31
    value: 926
    value: 757
    value: 104
    value: 3
    value: 757
    value: 86
    value: 986
    value: 0
    value: 68
    value: 4769
    value: 9
    value: 69
    value: 8
    value: 18
    value: 1252
    value: 0
    value: 375
    value: 31
    value: 103
    value: 1558
    value: 9
    value: 9
    value: 640
    value: 876
    value: 3
    value: 2551
    value: 24
    value: 1946
    value: 1097
    value: 8
    value: 15
    value: 5
    value: 2
    value: 2351
    value: 1779
    value: 19
    value: 7
    value: 95
    value: 118
    value: 4
    value: 109
    value: 2351
    value: 9899
    value: 12
    value: 23
    value: 4876
    value: 16
    value: 63
    value: 16
    value: 8
    value: 24
    value: 0
    value: 68
    value: 104
    value: 12
    value: 361
    value: 5
    value: 2257
    value: 9
    value: 2
    value: 1092
    value: 97
    value: 26
    value: 0
    value: 2114
    value: 10044
    value: 10025
    value: 3
    value: 28
    value: 343
    value: 6595
   }
  }
 }
 feature {
  key: "NL_nbr_0_weight"
  value {
   float_list {
    value: 0.9909949898719788
   }
  }
 }
 feature {
  key: "NL_num_nbrs"
  value {
   int64_list {
    value: 1
   }
  }
 }
 feature {
  key: "id"
  value {
   bytes_list {
    value: "c6c89c93-e2e9-4c4a-9f52-1221b4467499"
   }
  }
 }
 feature {
  key: "label_xf"
  value {
   int64_list {
    value: 1
   }
  }
 }
 feature {
  key: "text_xf"
  value {
   int64_list {
    value: 13
    value: 8
    value: 6
    value: 2
    value: 18
    value: 69
    value: 140
    value: 27
    value: 83
    value: 31
    value: 1877
    value: 905
    value: 9
    value: 10057
    value: 31
    value: 43
    value: 2115
    value: 36
    value: 32
    value: 2057
    value: 6133
    value: 10
    value: 6
    value: 32
    value: 2474
    value: 1614
    value: 3
    value: 2707
    value: 990
    value: 4
    value: 10067
    value: 9
    value: 2
    value: 1532
    value: 242
    value: 90
    value: 3757
    value: 3
    value: 90
    value: 10026
    value: 0
    value: 242
    value: 6
    value: 260
    value: 31
    value: 24
    value: 4
    value: 0
    value: 84
    value: 497
    value: 177
    value: 1151
    value: 777
    value: 9
    value: 397
    value: 552
    value: 7726
    value: 10051
    value: 34
    value: 14
    value: 379
    value: 33
    value: 1829
    value: 9
    value: 123
    value: 0
    value: 916
    value: 10028
    value: 7
    value: 64
    value: 571
    value: 12
    value: 8
    value: 18
    value: 27
    value: 687
    value: 9
    value: 30
    value: 5609
    value: 16
    value: 25
    value: 99
    value: 117
    value: 66
    value: 2
    value: 130
    value: 21
    value: 8
    value: 842
    value: 7726
    value: 10051
    value: 6
    value: 338
    value: 1107
    value: 3
    value: 24
    value: 10020
    value: 29
    value: 53
    value: 1476
   }
  }
 }
}


Komponent trenera

Komponent Trainer uczy modeli przy użyciu TensorFlow.

Utwórz moduł Pythona zawierający funkcję trainer_fn , która musi zwracać estymator. Jeśli wolisz stworzyć model Keras, możesz to zrobić, a następnie przekonwertować go na estymator za pomocą keras.model_to_estimator() .

# Setup paths.
_trainer_module_file = 'imdb_trainer.py'
%%writefile {_trainer_module_file}

import neural_structured_learning as nsl

import tensorflow as tf

import tensorflow_model_analysis as tfma
import tensorflow_transform as tft
from tensorflow_transform.tf_metadata import schema_utils


NBR_FEATURE_PREFIX = 'NL_nbr_'
NBR_WEIGHT_SUFFIX = '_weight'
LABEL_KEY = 'label'
ID_FEATURE_KEY = 'id'

def _transformed_name(key):
 return key + '_xf'


def _transformed_names(keys):
 return [_transformed_name(key) for key in keys]


# Hyperparameters:
#
# We will use an instance of `HParams` to inclue various hyperparameters and
# constants used for training and evaluation. We briefly describe each of them
# below:
#
# -  max_seq_length: This is the maximum number of words considered from each
#           movie review in this example.
# -  vocab_size: This is the size of the vocabulary considered for this
#         example.
# -  oov_size: This is the out-of-vocabulary size considered for this example.
# -  distance_type: This is the distance metric used to regularize the sample
#          with its neighbors.
# -  graph_regularization_multiplier: This controls the relative weight of the
#                   graph regularization term in the overall
#                   loss function.
# -  num_neighbors: The number of neighbors used for graph regularization. This
#          value has to be less than or equal to the `num_neighbors`
#          argument used above in the GraphAugmentation component when
#          invoking `nsl.tools.pack_nbrs`.
# -  num_fc_units: The number of units in the fully connected layer of the
#          neural network.
class HParams(object):
 """Hyperparameters used for training."""
 def __init__(self):
  ### dataset parameters
  # The following 3 values should match those defined in the Transform
  # Component.
  self.max_seq_length = 100
  self.vocab_size = 10000
  self.oov_size = 100
  ### Neural Graph Learning parameters
  self.distance_type = nsl.configs.DistanceType.L2
  self.graph_regularization_multiplier = 0.1
  # The following value has to be at most the value of 'num_neighbors' used
  # in the GraphAugmentation component.
  self.num_neighbors = 1
  ### Model Architecture
  self.num_embedding_dims = 16
  self.num_fc_units = 64

HPARAMS = HParams()


def optimizer_fn():
 """Returns an instance of `tf.Optimizer`."""
 return tf.compat.v1.train.RMSPropOptimizer(
  learning_rate=0.0001, decay=1e-6)


def build_train_op(loss, global_step):
 """Builds a train op to optimize the given loss using gradient descent."""
 with tf.name_scope('train'):
  optimizer = optimizer_fn()
  train_op = optimizer.minimize(loss=loss, global_step=global_step)
 return train_op


# Building the model:
#
# A neural network is created by stacking layers—this requires two main
# architectural decisions:
# * How many layers to use in the model?
# * How many *hidden units* to use for each layer?
#
# In this example, the input data consists of an array of word-indices. The
# labels to predict are either 0 or 1. We will use a feed-forward neural network
# as our base model in this tutorial.
def feed_forward_model(features, is_training, reuse=tf.compat.v1.AUTO_REUSE):
 """Builds a simple 2 layer feed forward neural network.

 The layers are effectively stacked sequentially to build the classifier. The
 first layer is an Embedding layer, which takes the integer-encoded vocabulary
 and looks up the embedding vector for each word-index. These vectors are
 learned as the model trains. The vectors add a dimension to the output array.
 The resulting dimensions are: (batch, sequence, embedding). Next is a global
 average pooling 1D layer, which reduces the dimensionality of its inputs from
 3D to 2D. This fixed-length output vector is piped through a fully-connected
 (Dense) layer with 16 hidden units. The last layer is densely connected with a
 single output node. Using the sigmoid activation function, this value is a
 float between 0 and 1, representing a probability, or confidence level.

 Args:
  features: A dictionary containing batch features returned from the
   `input_fn`, that include sample features, corresponding neighbor features,
   and neighbor weights.
  is_training: a Python Boolean value or a Boolean scalar Tensor, indicating
   whether to apply dropout.
  reuse: a Python Boolean value for reusing variable scope.

 Returns:
  logits: Tensor of shape [batch_size, 1].
  representations: Tensor of shape [batch_size, _] for graph regularization.
   This is the representation of each example at the graph regularization
   layer.
 """

 with tf.compat.v1.variable_scope('ff', reuse=reuse):
  inputs = features[_transformed_name('text')]
  embeddings = tf.compat.v1.get_variable(
    'embeddings',
    shape=[
      HPARAMS.vocab_size + HPARAMS.oov_size, HPARAMS.num_embedding_dims
    ])
  embedding_layer = tf.nn.embedding_lookup(embeddings, inputs)

  pooling_layer = tf.compat.v1.layers.AveragePooling1D(
    pool_size=HPARAMS.max_seq_length, strides=HPARAMS.max_seq_length)(
      embedding_layer)
  # Shape of pooling_layer is now [batch_size, 1, HPARAMS.num_embedding_dims]
  pooling_layer = tf.reshape(pooling_layer, [-1, HPARAMS.num_embedding_dims])

  dense_layer = tf.compat.v1.layers.Dense(
    16, activation='relu')(
      pooling_layer)

  output_layer = tf.compat.v1.layers.Dense(
    1, activation='sigmoid')(
      dense_layer)

  # Graph regularization will be done on the penultimate (dense) layer
  # because the output layer is a single floating point number.
  return output_layer, dense_layer


# A note on hidden units:
#
# The above model has two intermediate or "hidden" layers, between the input and
# output, and excluding the Embedding layer. The number of outputs (units,
# nodes, or neurons) is the dimension of the representational space for the
# layer. In other words, the amount of freedom the network is allowed when
# learning an internal representation. If a model has more hidden units
# (a higher-dimensional representation space), and/or more layers, then the
# network can learn more complex representations. However, it makes the network
# more computationally expensive and may lead to learning unwanted
# patterns—patterns that improve performance on training data but not on the
# test data. This is called overfitting.


# This function will be used to generate the embeddings for samples and their
# corresponding neighbors, which will then be used for graph regularization.
def embedding_fn(features, mode):
 """Returns the embedding corresponding to the given features.

 Args:
  features: A dictionary containing batch features returned from the
   `input_fn`, that include sample features, corresponding neighbor features,
   and neighbor weights.
  mode: Specifies if this is training, evaluation, or prediction. See
   tf.estimator.ModeKeys.

 Returns:
  The embedding that will be used for graph regularization.
 """
 is_training = (mode == tf.estimator.ModeKeys.TRAIN)
 _, embedding = feed_forward_model(features, is_training)
 return embedding


def feed_forward_model_fn(features, labels, mode, params, config):
 """Implementation of the model_fn for the base feed-forward model.

 Args:
  features: This is the first item returned from the `input_fn` passed to
   `train`, `evaluate`, and `predict`. This should be a single `Tensor` or
   `dict` of same.
  labels: This is the second item returned from the `input_fn` passed to
   `train`, `evaluate`, and `predict`. This should be a single `Tensor` or
   `dict` of same (for multi-head models). If mode is `ModeKeys.PREDICT`,
   `labels=None` will be passed. If the `model_fn`'s signature does not
   accept `mode`, the `model_fn` must still be able to handle `labels=None`.
  mode: Optional. Specifies if this training, evaluation or prediction. See
   `ModeKeys`.
  params: An HParams instance as returned by get_hyper_parameters().
  config: Optional configuration object. Will receive what is passed to
   Estimator in `config` parameter, or the default `config`. Allows updating
   things in your model_fn based on configuration such as `num_ps_replicas`,
   or `model_dir`. Unused currently.

 Returns:
   A `tf.estimator.EstimatorSpec` for the base feed-forward model. This does
   not include graph-based regularization.
 """

 is_training = mode == tf.estimator.ModeKeys.TRAIN

 # Build the computation graph.
 probabilities, _ = feed_forward_model(features, is_training)
 predictions = tf.round(probabilities)

 if mode == tf.estimator.ModeKeys.PREDICT:
  # labels will be None, and no loss to compute.
  cross_entropy_loss = None
  eval_metric_ops = None
 else:
  # Loss is required in train and eval modes.
  # Flatten 'probabilities' to 1-D.
  probabilities = tf.reshape(probabilities, shape=[-1])
  cross_entropy_loss = tf.compat.v1.keras.losses.binary_crossentropy(
    labels, probabilities)
  eval_metric_ops = {
    'accuracy': tf.compat.v1.metrics.accuracy(labels, predictions)
  }

 if is_training:
  global_step = tf.compat.v1.train.get_or_create_global_step()
  train_op = build_train_op(cross_entropy_loss, global_step)
 else:
  train_op = None

 return tf.estimator.EstimatorSpec(
   mode=mode,
   predictions={
     'probabilities': probabilities,
     'predictions': predictions
   },
   loss=cross_entropy_loss,
   train_op=train_op,
   eval_metric_ops=eval_metric_ops)


# Tf.Transform considers these features as "raw"
def _get_raw_feature_spec(schema):
 return schema_utils.schema_as_feature_spec(schema).feature_spec


def _gzip_reader_fn(filenames):
 """Small utility returning a record reader that can read gzip'ed files."""
 return tf.data.TFRecordDataset(
   filenames,
   compression_type='GZIP')


def _example_serving_receiver_fn(tf_transform_output, schema):
 """Build the serving in inputs.

 Args:
  tf_transform_output: A TFTransformOutput.
  schema: the schema of the input data.

 Returns:
  Tensorflow graph which parses examples, applying tf-transform to them.
 """
 raw_feature_spec = _get_raw_feature_spec(schema)
 raw_feature_spec.pop(LABEL_KEY)

 # We don't need the ID feature for serving.
 raw_feature_spec.pop(ID_FEATURE_KEY)

 raw_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
   raw_feature_spec, default_batch_size=None)
 serving_input_receiver = raw_input_fn()

 transformed_features = tf_transform_output.transform_raw_features(
   serving_input_receiver.features)

 # Even though, LABEL_KEY was removed from 'raw_feature_spec', the transform
 # operation would have injected the transformed LABEL_KEY feature with a
 # default value.
 transformed_features.pop(_transformed_name(LABEL_KEY))
 return tf.estimator.export.ServingInputReceiver(
   transformed_features, serving_input_receiver.receiver_tensors)


def _eval_input_receiver_fn(tf_transform_output, schema):
 """Build everything needed for the tf-model-analysis to run the model.

 Args:
  tf_transform_output: A TFTransformOutput.
  schema: the schema of the input data.

 Returns:
  EvalInputReceiver function, which contains:

   - Tensorflow graph which parses raw untransformed features, applies the
    tf-transform preprocessing operators.
   - Set of raw, untransformed features.
   - Label against which predictions will be compared.
 """
 # Notice that the inputs are raw features, not transformed features here.
 raw_feature_spec = _get_raw_feature_spec(schema)

 # We don't need the ID feature for TFMA.
 raw_feature_spec.pop(ID_FEATURE_KEY)

 raw_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
   raw_feature_spec, default_batch_size=None)
 serving_input_receiver = raw_input_fn()

 transformed_features = tf_transform_output.transform_raw_features(
   serving_input_receiver.features)

 labels = transformed_features.pop(_transformed_name(LABEL_KEY))
 return tfma.export.EvalInputReceiver(
   features=transformed_features,
   receiver_tensors=serving_input_receiver.receiver_tensors,
   labels=labels)


def _augment_feature_spec(feature_spec, num_neighbors):
 """Augments `feature_spec` to include neighbor features.
  Args:
   feature_spec: Dictionary of feature keys mapping to TF feature types.
   num_neighbors: Number of neighbors to use for feature key augmentation.
  Returns:
   An augmented `feature_spec` that includes neighbor feature keys.
 """
 for i in range(num_neighbors):
  feature_spec['{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'id')] = \
    tf.io.VarLenFeature(dtype=tf.string)
  # We don't care about the neighbor features corresponding to
  # _transformed_name(LABEL_KEY) because the LABEL_KEY feature will be
  # removed from the feature spec during training/evaluation.
  feature_spec['{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'text_xf')] = \
    tf.io.FixedLenFeature(shape=[HPARAMS.max_seq_length], dtype=tf.int64,
               default_value=tf.constant(0, dtype=tf.int64,
                            shape=[HPARAMS.max_seq_length]))
  # The 'NL_num_nbrs' features is currently not used.

 # Set the neighbor weight feature keys.
 for i in range(num_neighbors):
  feature_spec['{}{}{}'.format(NBR_FEATURE_PREFIX, i, NBR_WEIGHT_SUFFIX)] = \
    tf.io.FixedLenFeature(shape=[1], dtype=tf.float32, default_value=[0.0])

 return feature_spec


def _input_fn(filenames, tf_transform_output, is_training, batch_size=200):
 """Generates features and labels for training or evaluation.

 Args:
  filenames: [str] list of CSV files to read data from.
  tf_transform_output: A TFTransformOutput.
  is_training: Boolean indicating if we are in training mode.
  batch_size: int First dimension size of the Tensors returned by input_fn

 Returns:
  A (features, indices) tuple where features is a dictionary of
   Tensors, and indices is a single Tensor of label indices.
 """
 transformed_feature_spec = (
   tf_transform_output.transformed_feature_spec().copy())

 # During training, NSL uses augmented training data (which includes features
 # from graph neighbors). So, update the feature spec accordingly. This needs
 # to be done because we are using different schemas for NSL training and eval,
 # but the Trainer Component only accepts a single schema.
 if is_training:
  transformed_feature_spec =_augment_feature_spec(transformed_feature_spec,
                          HPARAMS.num_neighbors)

 dataset = tf.data.experimental.make_batched_features_dataset(
   filenames, batch_size, transformed_feature_spec, reader=_gzip_reader_fn)

 transformed_features = tf.compat.v1.data.make_one_shot_iterator(
   dataset).get_next()
 # We pop the label because we do not want to use it as a feature while we're
 # training.
 return transformed_features, transformed_features.pop(
   _transformed_name(LABEL_KEY))


# TFX will call this function
def trainer_fn(hparams, schema):
 """Build the estimator using the high level API.
 Args:
  hparams: Holds hyperparameters used to train the model as name/value pairs.
  schema: Holds the schema of the training examples.
 Returns:
  A dict of the following:

   - estimator: The estimator that will be used for training and eval.
   - train_spec: Spec for training.
   - eval_spec: Spec for eval.
   - eval_input_receiver_fn: Input function for eval.
 """
 train_batch_size = 40
 eval_batch_size = 40

 tf_transform_output = tft.TFTransformOutput(hparams.transform_output)

 train_input_fn = lambda: _input_fn(
   hparams.train_files,
   tf_transform_output,
   is_training=True,
   batch_size=train_batch_size)

 eval_input_fn = lambda: _input_fn(
   hparams.eval_files,
   tf_transform_output,
   is_training=False,
   batch_size=eval_batch_size)

 train_spec = tf.estimator.TrainSpec(
   train_input_fn,
   max_steps=hparams.train_steps)

 serving_receiver_fn = lambda: _example_serving_receiver_fn(
   tf_transform_output, schema)

 exporter = tf.estimator.FinalExporter('imdb', serving_receiver_fn)
 eval_spec = tf.estimator.EvalSpec(
   eval_input_fn,
   steps=hparams.eval_steps,
   exporters=[exporter],
   name='imdb-eval')

 run_config = tf.estimator.RunConfig(
   save_checkpoints_steps=999, keep_checkpoint_max=1)

 run_config = run_config.replace(model_dir=hparams.serving_model_dir)

 estimator = tf.estimator.Estimator(
   model_fn=feed_forward_model_fn, config=run_config, params=HPARAMS)

 # Create a graph regularization config.
 graph_reg_config = nsl.configs.make_graph_reg_config(
   max_neighbors=HPARAMS.num_neighbors,
   multiplier=HPARAMS.graph_regularization_multiplier,
   distance_type=HPARAMS.distance_type,
   sum_over_axis=-1)

 # Invoke the Graph Regularization Estimator wrapper to incorporate
 # graph-based regularization for training.
 graph_nsl_estimator = nsl.estimator.add_graph_regularization(
   estimator,
   embedding_fn,
   optimizer_fn=optimizer_fn,
   graph_reg_config=graph_reg_config)

 # Create an input receiver for TFMA processing
 receiver_fn = lambda: _eval_input_receiver_fn(
   tf_transform_output, schema)

 return {
   'estimator': graph_nsl_estimator,
   'train_spec': train_spec,
   'eval_spec': eval_spec,
   'eval_input_receiver_fn': receiver_fn
 }
Writing imdb_trainer.py

Utwórz i uruchom komponent Trainer , przekazując mu plik, który utworzyliśmy powyżej.

# Uses user-provided Python function that implements a model using TensorFlow's
# Estimators API.
trainer = Trainer(
  module_file=_trainer_module_file,
  transformed_examples=graph_augmentation.outputs['augmented_examples'],
  schema=schema_gen.outputs['schema'],
  transform_graph=transform.outputs['transform_graph'],
  train_args=trainer_pb2.TrainArgs(num_steps=10000),
  eval_args=trainer_pb2.EvalArgs(num_steps=5000))
context.run(trainer)
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/serving_model_dir', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 999, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
graph_options {
 rewrite_options {
  meta_optimizer_iterations: ONE
 }
}
, '_keep_checkpoint_max': 1, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps 999 or save_checkpoints_secs None.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/rmsprop.py:123: calling Ones.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/serving_model_dir/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.69318736, step = 0
INFO:tensorflow:global_step/sec: 222.546
INFO:tensorflow:loss = 0.6928638, step = 100 (0.450 sec)
INFO:tensorflow:global_step/sec: 291.344
INFO:tensorflow:loss = 0.69281894, step = 200 (0.343 sec)
INFO:tensorflow:global_step/sec: 296.443
INFO:tensorflow:loss = 0.6927313, step = 300 (0.337 sec)
INFO:tensorflow:global_step/sec: 291.965
INFO:tensorflow:loss = 0.6917414, step = 400 (0.342 sec)
INFO:tensorflow:global_step/sec: 298.269
INFO:tensorflow:loss = 0.6905616, step = 500 (0.335 sec)
INFO:tensorflow:global_step/sec: 292.315
INFO:tensorflow:loss = 0.6894297, step = 600 (0.342 sec)
INFO:tensorflow:global_step/sec: 295.769
INFO:tensorflow:loss = 0.6896509, step = 700 (0.338 sec)
INFO:tensorflow:global_step/sec: 296.858
INFO:tensorflow:loss = 0.68861306, step = 800 (0.337 sec)
INFO:tensorflow:global_step/sec: 292.735
INFO:tensorflow:loss = 0.68658316, step = 900 (0.342 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 999...
INFO:tensorflow:Saving checkpoints for 999 into /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/serving_model_dir/model.ckpt.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/saver.py:971: remove_checkpoint (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to delete files with this prefix.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 999...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2020-10-15T09:32:00Z
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/serving_model_dir/model.ckpt-999
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [500/5000]
INFO:tensorflow:Evaluation [1000/5000]
INFO:tensorflow:Evaluation [1500/5000]
INFO:tensorflow:Evaluation [2000/5000]
INFO:tensorflow:Evaluation [2500/5000]
INFO:tensorflow:Evaluation [3000/5000]
INFO:tensorflow:Evaluation [3500/5000]
INFO:tensorflow:Evaluation [4000/5000]
INFO:tensorflow:Evaluation [4500/5000]
INFO:tensorflow:Evaluation [5000/5000]
INFO:tensorflow:Inference Time : 5.29909s
INFO:tensorflow:Finished evaluation at 2020-10-15-09:32:05
INFO:tensorflow:Saving dict for global step 999: accuracy = 0.7035, global_step = 999, loss = 0.68670774
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 999: /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/serving_model_dir/model.ckpt-999
INFO:tensorflow:global_step/sec: 17.0767
INFO:tensorflow:loss = 0.68894106, step = 1000 (5.855 sec)
INFO:tensorflow:global_step/sec: 299.602
INFO:tensorflow:loss = 0.6814944, step = 1100 (0.334 sec)
INFO:tensorflow:global_step/sec: 300.889
INFO:tensorflow:loss = 0.6839364, step = 1200 (0.333 sec)
INFO:tensorflow:global_step/sec: 302.256
INFO:tensorflow:loss = 0.6763433, step = 1300 (0.331 sec)
INFO:tensorflow:global_step/sec: 299.199
INFO:tensorflow:loss = 0.6769841, step = 1400 (0.334 sec)
INFO:tensorflow:global_step/sec: 299.279
INFO:tensorflow:loss = 0.67444175, step = 1500 (0.334 sec)
INFO:tensorflow:global_step/sec: 307.62
INFO:tensorflow:loss = 0.67098206, step = 1600 (0.325 sec)
INFO:tensorflow:global_step/sec: 304.262
INFO:tensorflow:loss = 0.665629, step = 1700 (0.329 sec)
INFO:tensorflow:global_step/sec: 297.873
INFO:tensorflow:loss = 0.6719124, step = 1800 (0.336 sec)
INFO:tensorflow:global_step/sec: 306.605
INFO:tensorflow:loss = 0.65660954, step = 1900 (0.326 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1998...
INFO:tensorflow:Saving checkpoints for 1998 into /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/serving_model_dir/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1998...
INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs).
INFO:tensorflow:global_step/sec: 254.265
INFO:tensorflow:loss = 0.6726355, step = 2000 (0.393 sec)
INFO:tensorflow:global_step/sec: 290.351
INFO:tensorflow:loss = 0.6551316, step = 2100 (0.345 sec)
INFO:tensorflow:global_step/sec: 298.852
INFO:tensorflow:loss = 0.67447, step = 2200 (0.335 sec)
INFO:tensorflow:global_step/sec: 295.696
INFO:tensorflow:loss = 0.64570725, step = 2300 (0.338 sec)
INFO:tensorflow:global_step/sec: 301.494
INFO:tensorflow:loss = 0.6464771, step = 2400 (0.332 sec)
INFO:tensorflow:global_step/sec: 304.472
INFO:tensorflow:loss = 0.6501285, step = 2500 (0.329 sec)
INFO:tensorflow:global_step/sec: 302.118
INFO:tensorflow:loss = 0.6361262, step = 2600 (0.331 sec)
INFO:tensorflow:global_step/sec: 307.043
INFO:tensorflow:loss = 0.64034796, step = 2700 (0.325 sec)
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 2748 vs previous value: 2748. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
INFO:tensorflow:global_step/sec: 298.63
INFO:tensorflow:loss = 0.62189335, step = 2800 (0.335 sec)
INFO:tensorflow:global_step/sec: 293.917
INFO:tensorflow:loss = 0.6147873, step = 2900 (0.340 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 2997...
INFO:tensorflow:Saving checkpoints for 2997 into /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/serving_model_dir/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 2997...
INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs).
INFO:tensorflow:global_step/sec: 254.499
INFO:tensorflow:loss = 0.61259216, step = 3000 (0.393 sec)
INFO:tensorflow:global_step/sec: 298.886
INFO:tensorflow:loss = 0.6229025, step = 3100 (0.335 sec)
INFO:tensorflow:global_step/sec: 305.197
INFO:tensorflow:loss = 0.60436034, step = 3200 (0.328 sec)
INFO:tensorflow:global_step/sec: 299.399
INFO:tensorflow:loss = 0.62933403, step = 3300 (0.334 sec)
INFO:tensorflow:global_step/sec: 301.028
INFO:tensorflow:loss = 0.60902774, step = 3400 (0.332 sec)
INFO:tensorflow:global_step/sec: 300.191
INFO:tensorflow:loss = 0.64181244, step = 3500 (0.333 sec)
INFO:tensorflow:global_step/sec: 290.434
INFO:tensorflow:loss = 0.57052743, step = 3600 (0.344 sec)
INFO:tensorflow:global_step/sec: 299.378
INFO:tensorflow:loss = 0.60267526, step = 3700 (0.334 sec)
INFO:tensorflow:global_step/sec: 307.013
INFO:tensorflow:loss = 0.6107319, step = 3800 (0.326 sec)
INFO:tensorflow:global_step/sec: 304.692
INFO:tensorflow:loss = 0.56591743, step = 3900 (0.328 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 3996...
INFO:tensorflow:Saving checkpoints for 3996 into /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/serving_model_dir/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 3996...
INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs).
INFO:tensorflow:global_step/sec: 255.208
INFO:tensorflow:loss = 0.56774515, step = 4000 (0.392 sec)
INFO:tensorflow:global_step/sec: 309.924
INFO:tensorflow:loss = 0.59160006, step = 4100 (0.323 sec)
INFO:tensorflow:global_step/sec: 306.066
INFO:tensorflow:loss = 0.5484713, step = 4200 (0.327 sec)
INFO:tensorflow:global_step/sec: 301.846
INFO:tensorflow:loss = 0.63335776, step = 4300 (0.332 sec)
INFO:tensorflow:global_step/sec: 299.014
INFO:tensorflow:loss = 0.5656133, step = 4400 (0.334 sec)
INFO:tensorflow:global_step/sec: 306.259
INFO:tensorflow:loss = 0.5533817, step = 4500 (0.326 sec)
INFO:tensorflow:global_step/sec: 300.019
INFO:tensorflow:loss = 0.56391084, step = 4600 (0.333 sec)
INFO:tensorflow:global_step/sec: 304.165
INFO:tensorflow:loss = 0.5910115, step = 4700 (0.329 sec)
INFO:tensorflow:global_step/sec: 295.489
INFO:tensorflow:loss = 0.5945301, step = 4800 (0.338 sec)
INFO:tensorflow:global_step/sec: 297.313
INFO:tensorflow:loss = 0.61218303, step = 4900 (0.336 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 4995...
INFO:tensorflow:Saving checkpoints for 4995 into /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/serving_model_dir/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 4995...
INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs).
INFO:tensorflow:global_step/sec: 260.352
INFO:tensorflow:loss = 0.5332743, step = 5000 (0.385 sec)
INFO:tensorflow:global_step/sec: 304.608
INFO:tensorflow:loss = 0.56679493, step = 5100 (0.328 sec)
INFO:tensorflow:global_step/sec: 311.855
INFO:tensorflow:loss = 0.54229665, step = 5200 (0.321 sec)
INFO:tensorflow:global_step/sec: 305.253
INFO:tensorflow:loss = 0.52315617, step = 5300 (0.328 sec)
INFO:tensorflow:global_step/sec: 299.658
INFO:tensorflow:loss = 0.5793217, step = 5400 (0.334 sec)
INFO:tensorflow:global_step/sec: 304.107
INFO:tensorflow:loss = 0.5486561, step = 5500 (0.329 sec)
INFO:tensorflow:global_step/sec: 308.079
INFO:tensorflow:loss = 0.49263632, step = 5600 (0.325 sec)
INFO:tensorflow:global_step/sec: 313.378
INFO:tensorflow:loss = 0.5385544, step = 5700 (0.319 sec)
INFO:tensorflow:global_step/sec: 302.781
INFO:tensorflow:loss = 0.5010498, step = 5800 (0.330 sec)
INFO:tensorflow:global_step/sec: 296.805
INFO:tensorflow:loss = 0.47667298, step = 5900 (0.337 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5994...
INFO:tensorflow:Saving checkpoints for 5994 into /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/serving_model_dir/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5994...
INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs).
INFO:tensorflow:global_step/sec: 257.488
INFO:tensorflow:loss = 0.55798185, step = 6000 (0.388 sec)
INFO:tensorflow:global_step/sec: 305.947
INFO:tensorflow:loss = 0.43396345, step = 6100 (0.327 sec)
INFO:tensorflow:global_step/sec: 296.299
INFO:tensorflow:loss = 0.43670568, step = 6200 (0.338 sec)
INFO:tensorflow:global_step/sec: 303.445
INFO:tensorflow:loss = 0.46067405, step = 6300 (0.330 sec)
INFO:tensorflow:global_step/sec: 310.182
INFO:tensorflow:loss = 0.5060933, step = 6400 (0.322 sec)
INFO:tensorflow:global_step/sec: 298.273
INFO:tensorflow:loss = 0.4996158, step = 6500 (0.335 sec)
INFO:tensorflow:global_step/sec: 300.567
INFO:tensorflow:loss = 0.396133, step = 6600 (0.333 sec)
INFO:tensorflow:global_step/sec: 297.986
INFO:tensorflow:loss = 0.42002386, step = 6700 (0.336 sec)
INFO:tensorflow:global_step/sec: 304.359
INFO:tensorflow:loss = 0.4611571, step = 6800 (0.328 sec)
INFO:tensorflow:global_step/sec: 302.25
INFO:tensorflow:loss = 0.44177708, step = 6900 (0.331 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6993...
INFO:tensorflow:Saving checkpoints for 6993 into /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/serving_model_dir/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6993...
INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs).
INFO:tensorflow:global_step/sec: 256.018
INFO:tensorflow:loss = 0.46849436, step = 7000 (0.390 sec)
INFO:tensorflow:global_step/sec: 291.076
INFO:tensorflow:loss = 0.41983128, step = 7100 (0.344 sec)
INFO:tensorflow:global_step/sec: 296.444
INFO:tensorflow:loss = 0.35345578, step = 7200 (0.337 sec)
INFO:tensorflow:global_step/sec: 293.356
INFO:tensorflow:loss = 0.41871148, step = 7300 (0.341 sec)
INFO:tensorflow:global_step/sec: 303.596
INFO:tensorflow:loss = 0.47682336, step = 7400 (0.329 sec)
INFO:tensorflow:global_step/sec: 303.782
INFO:tensorflow:loss = 0.55223024, step = 7500 (0.329 sec)
INFO:tensorflow:global_step/sec: 296.762
INFO:tensorflow:loss = 0.42545128, step = 7600 (0.337 sec)
INFO:tensorflow:global_step/sec: 309.21
INFO:tensorflow:loss = 0.43023503, step = 7700 (0.323 sec)
INFO:tensorflow:global_step/sec: 306.462
INFO:tensorflow:loss = 0.5604722, step = 7800 (0.326 sec)
INFO:tensorflow:global_step/sec: 303.329
INFO:tensorflow:loss = 0.5337108, step = 7900 (0.330 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 7992...
INFO:tensorflow:Saving checkpoints for 7992 into /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/serving_model_dir/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 7992...
INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs).
INFO:tensorflow:global_step/sec: 255.136
INFO:tensorflow:loss = 0.4013764, step = 8000 (0.392 sec)
INFO:tensorflow:global_step/sec: 301.602
INFO:tensorflow:loss = 0.4093078, step = 8100 (0.332 sec)
INFO:tensorflow:global_step/sec: 299.313
INFO:tensorflow:loss = 0.41223857, step = 8200 (0.334 sec)
INFO:tensorflow:global_step/sec: 296.211
INFO:tensorflow:loss = 0.4117222, step = 8300 (0.338 sec)
INFO:tensorflow:global_step/sec: 299.752
INFO:tensorflow:loss = 0.39056668, step = 8400 (0.334 sec)
INFO:tensorflow:global_step/sec: 302.187
INFO:tensorflow:loss = 0.391355, step = 8500 (0.331 sec)
INFO:tensorflow:global_step/sec: 295.599
INFO:tensorflow:loss = 0.46732607, step = 8600 (0.338 sec)
INFO:tensorflow:global_step/sec: 297.524
INFO:tensorflow:loss = 0.44837368, step = 8700 (0.336 sec)
INFO:tensorflow:global_step/sec: 298.751
INFO:tensorflow:loss = 0.5095719, step = 8800 (0.335 sec)
INFO:tensorflow:global_step/sec: 300.476
INFO:tensorflow:loss = 0.3573585, step = 8900 (0.333 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 8991...
INFO:tensorflow:Saving checkpoints for 8991 into /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/serving_model_dir/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 8991...
INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs).
INFO:tensorflow:global_step/sec: 259.517
INFO:tensorflow:loss = 0.38418576, step = 9000 (0.385 sec)
INFO:tensorflow:global_step/sec: 300.71
INFO:tensorflow:loss = 0.3826803, step = 9100 (0.333 sec)
INFO:tensorflow:global_step/sec: 301.991
INFO:tensorflow:loss = 0.36049247, step = 9200 (0.331 sec)
INFO:tensorflow:global_step/sec: 298.252
INFO:tensorflow:loss = 0.31363297, step = 9300 (0.335 sec)
INFO:tensorflow:global_step/sec: 297.207
INFO:tensorflow:loss = 0.3982248, step = 9400 (0.337 sec)
INFO:tensorflow:global_step/sec: 301.999
INFO:tensorflow:loss = 0.34949106, step = 9500 (0.331 sec)
INFO:tensorflow:global_step/sec: 301.815
INFO:tensorflow:loss = 0.40354735, step = 9600 (0.331 sec)
INFO:tensorflow:global_step/sec: 300.948
INFO:tensorflow:loss = 0.47522005, step = 9700 (0.333 sec)
INFO:tensorflow:global_step/sec: 299.78
INFO:tensorflow:loss = 0.4353662, step = 9800 (0.333 sec)
INFO:tensorflow:global_step/sec: 300.752
INFO:tensorflow:loss = 0.45311904, step = 9900 (0.333 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 9990...
INFO:tensorflow:Saving checkpoints for 9990 into /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/serving_model_dir/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 9990...
INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs).
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10000...
INFO:tensorflow:Saving checkpoints for 10000 into /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/serving_model_dir/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10000...
INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs).
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2020-10-15T09:32:36Z
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/serving_model_dir/model.ckpt-10000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [500/5000]
INFO:tensorflow:Evaluation [1000/5000]
INFO:tensorflow:Evaluation [1500/5000]
INFO:tensorflow:Evaluation [2000/5000]
INFO:tensorflow:Evaluation [2500/5000]
INFO:tensorflow:Evaluation [3000/5000]
INFO:tensorflow:Evaluation [3500/5000]
INFO:tensorflow:Evaluation [4000/5000]
INFO:tensorflow:Evaluation [4500/5000]
INFO:tensorflow:Evaluation [5000/5000]
INFO:tensorflow:Inference Time : 5.22927s
INFO:tensorflow:Finished evaluation at 2020-10-15-09:32:41
INFO:tensorflow:Saving dict for global step 10000: accuracy = 0.8, global_step = 10000, loss = 0.4427957
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10000: /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/serving_model_dir/model.ckpt-10000
INFO:tensorflow:Performing the final export in the end of training.
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_1:0\022-vocab_compute_and_apply_vocabulary_vocabulary"

INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Signatures INCLUDED in export for Classify: None
INFO:tensorflow:Signatures INCLUDED in export for Regress: None
INFO:tensorflow:Signatures INCLUDED in export for Predict: ['serving_default']
INFO:tensorflow:Signatures INCLUDED in export for Train: None
INFO:tensorflow:Signatures INCLUDED in export for Eval: None
INFO:tensorflow:Restoring parameters from /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/serving_model_dir/model.ckpt-10000
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:Assets written to: /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/serving_model_dir/export/imdb/temp-1602754361/assets
INFO:tensorflow:SavedModel written to: /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/serving_model_dir/export/imdb/temp-1602754361/saved_model.pb
INFO:tensorflow:Loss for final step: 0.4515194.
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_1:0\022-vocab_compute_and_apply_vocabulary_vocabulary"

INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Signatures INCLUDED in export for Classify: None
INFO:tensorflow:Signatures INCLUDED in export for Regress: None
INFO:tensorflow:Signatures INCLUDED in export for Predict: None
INFO:tensorflow:Signatures INCLUDED in export for Train: None
INFO:tensorflow:Signatures INCLUDED in export for Eval: ['eval']
WARNING:tensorflow:Export includes no default signature!
INFO:tensorflow:Restoring parameters from /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/serving_model_dir/model.ckpt-10000
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:Assets written to: /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/eval_model_dir/temp-1602754361/assets
INFO:tensorflow:SavedModel written to: /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/eval_model_dir/temp-1602754361/saved_model.pb

Warning:absl:Support for estimator-based executor and model export will be deprecated soon. Please use export structure <ModelExportPath>/serving_model_dir/saved_model.pb"
WARNING:absl:Support for estimator-based executor and model export will be deprecated soon. Please use export structure <ModelExportPath>/eval_model_dir/saved_model.pb"

Rzuć okiem na wytrenowany model, który został wyeksportowany z Trainer .

train_uri = trainer.outputs['model'].get()[0].uri
serving_model_path = os.path.join(train_uri, 'serving_model_dir')
exported_model = tf.saved_model.load(serving_model_path)
exported_model.graph.get_operations()[:10] + ["..."]
[<tf.Operation 'global_step/Initializer/zeros' type=Const>,
 <tf.Operation 'global_step' type=VarHandleOp>,
 <tf.Operation 'global_step/IsInitialized/VarIsInitializedOp' type=VarIsInitializedOp>,
 <tf.Operation 'global_step/Assign' type=AssignVariableOp>,
 <tf.Operation 'global_step/Read/ReadVariableOp' type=ReadVariableOp>,
 <tf.Operation 'input_example_tensor' type=Placeholder>,
 <tf.Operation 'ParseExample/ParseExampleV2/names' type=Const>,
 <tf.Operation 'ParseExample/ParseExampleV2/sparse_keys' type=Const>,
 <tf.Operation 'ParseExample/ParseExampleV2/dense_keys' type=Const>,
 <tf.Operation 'ParseExample/ParseExampleV2/ragged_keys' type=Const>,
 '...']

Wizualizujmy metryki modelu za pomocą Tensorboard.


# Get the URI of the output artifact representing the training logs,
# which is a directory
model_run_dir = trainer.outputs['model_run'].get()[0].uri

%load_ext tensorboard
%tensorboard --logdir {model_run_dir}

Obsługa modelu

Regularyzacja wykresu wpływa tylko na przepływ pracy szkolenia poprzez dodanie składnika regularyzacji do funkcji straty. W rezultacie model oceny i obsługujące przepływy pracy pozostają niezmienione. Z tego samego powodu pominęliśmy również późniejsze komponenty TFX, które zwykle pojawiają się po komponencie Trainer , takim jak Evaluator , Pusher itp.

Wniosek

Zademonstrowaliśmy użycie regularyzacji wykresów przy użyciu struktury uczenia się neuronowego (NSL) w potoku TFX, nawet jeśli dane wejściowe nie zawierają jawnego wykresu. Rozważaliśmy zadanie klasyfikacji sentymentu recenzji filmów IMDB, dla których zsyntetyzowaliśmy wykres podobieństwa oparty na osadzeniach recenzji. Zachęcamy użytkowników do dalszych eksperymentów, używając różnych osadzeń do tworzenia wykresów, zmieniając hiperparametry, zmieniając zakres nadzoru i definiując różne architektury modeli.