SIG TFX-Addonsコミュニティに参加して、TFXをさらに改善するのを手伝ってください!

TFXにおけるグラフベースのニューラル構造化学習

このチュートリアルでは、 Neural Structured Learningフレームワークからのグラフの正則化について説明し、TFXパイプラインでの感情分類のエンドツーエンドのワークフローを示します。

概要概要

このノートブックは、レビューのテキストを使用して、映画のレビューをポジティブまたはネガティブに分類します。これは、重要で広く適用可能な種類の機械学習問題である二項分類の例です。

与えられた入力からグラフを作成することにより、このノートブックでのグラフ正則化の使用法を示します。入力に明示的なグラフが含まれていない場合に、Neural Structured Learning(NSL)フレームワークを使用してグラフ正則化モデルを構築するための一般的なレシピは、次のとおりです。

  1. 入力の各テキストサンプルの埋め込みを作成します。これは、 word2vecSwivelBERTなどの事前トレーニング済みモデルを使用して実行できます。
  2. 「L2」距離、「コサイン」距離などの類似性メトリックを使用して、これらの埋め込みに基づいてグラフを作成します。グラフのノードはサンプルに対応し、グラフのエッジはサンプルのペア間の類似性に対応します。
  3. 上記の合成グラフとサンプル特徴からトレーニングデータを生成します。結果のトレーニングデータには、元のノード機能に加えて隣接機能が含まれます。
  4. Estimatorを使用して、ベースモデルとしてニューラルネットワークを作成します。
  5. NSLフレームワークによって提供されるadd_graph_regularizationラッパー関数でベースモデルをラップして、新しいグラフEstimatorモデルを作成します。この新しいモデルには、トレーニング目標の正則化項としてグラフ正則化損失が含まれます。
  6. グラフ推定量モデルをトレーニングして評価します。

このチュートリアルでは、いくつかのカスタムTFXコンポーネントとカスタムグラフ正則化トレーナーコンポーネントを使用して、上記のワークフローをTFXパイプラインに統合します。

以下は、TFXパイプラインの概略図です。オレンジ色のボックスは既製のTFXコンポーネントを表し、ピンク色のボックスはカスタムTFXコンポーネントを表します。

TFXパイプライン

アップグレードピップ

ローカルで実行しているときにシステムでPipをアップグレードしないようにするには、Colabで実行していることを確認してください。もちろん、ローカルシステムは個別にアップグレードできます。

try:
  import colab
  !pip install --upgrade pip
except:
  pass

必要なパッケージをインストールする

!pip install -q -U \
  tfx==0.30.0 \
  neural-structured-learning \
  tensorflow-hub \
  tensorflow-datasets

ランタイムを再起動しましたか?

上記のセルを初めて実行するときにGoogleColabを使用している場合は、ランタイムを再起動する必要があります([ランタイム]> [ランタイムの再起動...])。これは、Colabがパッケージをロードする方法が原因です。

依存関係とインポート

import apache_beam as beam
import gzip as gzip_lib
import numpy as np
import os
import pprint
import shutil
import tempfile
import urllib
import uuid
pp = pprint.PrettyPrinter()

import tensorflow as tf
import neural_structured_learning as nsl

import tfx
from tfx.components.evaluator.component import Evaluator
from tfx.components.example_gen.import_example_gen.component import ImportExampleGen
from tfx.components.example_validator.component import ExampleValidator
from tfx.components.model_validator.component import ModelValidator
from tfx.components.pusher.component import Pusher
from tfx.components.schema_gen.component import SchemaGen
from tfx.components.statistics_gen.component import StatisticsGen
from tfx.components.trainer import executor as trainer_executor
from tfx.components.trainer.component import Trainer
from tfx.components.transform.component import Transform
from tfx.dsl.components.base import executor_spec
from tfx.orchestration.experimental.interactive.interactive_context import InteractiveContext
from tfx.proto import evaluator_pb2
from tfx.proto import example_gen_pb2
from tfx.proto import pusher_pb2
from tfx.proto import trainer_pb2

from tfx.types import artifact
from tfx.types import artifact_utils
from tfx.types import channel
from tfx.types import standard_artifacts
from tfx.types.standard_artifacts import Examples

from tfx.dsl.component.experimental.annotations import InputArtifact
from tfx.dsl.component.experimental.annotations import OutputArtifact
from tfx.dsl.component.experimental.annotations import Parameter
from tfx.dsl.component.experimental.decorators import component

from tensorflow_metadata.proto.v0 import anomalies_pb2
from tensorflow_metadata.proto.v0 import schema_pb2
from tensorflow_metadata.proto.v0 import statistics_pb2

import tensorflow_data_validation as tfdv
import tensorflow_transform as tft
import tensorflow_model_analysis as tfma
import tensorflow_hub as hub
import tensorflow_datasets as tfds

print("TF Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print(
    "GPU is",
    "available" if tf.config.list_physical_devices("GPU") else "NOT AVAILABLE")
print("NSL Version: ", nsl.__version__)
print("TFX Version: ", tfx.__version__)
print("TFDV version: ", tfdv.__version__)
print("TFT version: ", tft.__version__)
print("TFMA version: ", tfma.__version__)
print("Hub version: ", hub.__version__)
print("Beam version: ", beam.__version__)
TF Version:  2.4.1
Eager mode:  True
GPU is NOT AVAILABLE
NSL Version:  1.3.1
TFX Version:  0.30.0
TFDV version:  0.30.0
TFT version:  0.30.0
TFMA version:  0.30.0
Hub version:  0.9.0
Beam version:  2.29.0

IMDBデータセット

IMDBデータセットには、インターネット映画データベースからの50,000本の映画レビューのテキストが含まれています。これらは、トレーニング用の25,000件のレビューとテスト用の25,000件のレビューに分けられます。トレーニングセットとテストセットはバランス取れています。つまり、肯定的なレビューと否定的なレビューが同数含まれています。さらに、50,000件のラベルなしの映画レビューが追加されています。

前処理されたIMDBデータセットをダウンロードする

次のコードは、TFDSを使用してIMDBデータセットをダウンロードします(または、既にダウンロードされている場合はキャッシュされたコピーを使用します)。このノートブックを高速化するために、トレーニングには10,000件のラベル付きレビューと10,000件のラベルなしレビューのみを使用し、評価には10,000件のテストレビューを使用します。

train_set, eval_set = tfds.load(
    "imdb_reviews:1.0.0",
    split=["train[:10000]+unsupervised[:10000]", "test[:10000]"],
    shuffle_files=False)

トレーニングセットからのいくつかのレビューを見てみましょう:

for tfrecord in train_set.take(4):
  print("Review: {}".format(tfrecord["text"].numpy().decode("utf-8")[:300]))
  print("Label: {}\n".format(tfrecord["label"].numpy()))
Review: This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda pi
Label: 0

Review: I have been known to fall asleep during films, but this is usually due to a combination of things including, really tired, being warm and comfortable on the sette and having just eaten a lot. However on this occasion I fell asleep because the film was rubbish. The plot development was constant. Cons
Label: 0

Review: Mann photographs the Alberta Rocky Mountains in a superb fashion, and Jimmy Stewart and Walter Brennan give enjoyable performances as they always seem to do. <br /><br />But come on Hollywood - a Mountie telling the people of Dawson City, Yukon to elect themselves a marshal (yes a marshal!) and to e
Label: 0

Review: This is the kind of film for a snowy Sunday afternoon when the rest of the world can go ahead with its own business as you descend into a big arm-chair and mellow for a couple of hours. Wonderful performances from Cher and Nicolas Cage (as always) gently row the plot along. There are no rapids to cr
Label: 1
def _dict_to_example(instance):
  """Decoded CSV to tf example."""
  feature = {}
  for key, value in instance.items():
    if value is None:
      feature[key] = tf.train.Feature()
    elif value.dtype == np.integer:
      feature[key] = tf.train.Feature(
          int64_list=tf.train.Int64List(value=value.tolist()))
    elif value.dtype == np.float32:
      feature[key] = tf.train.Feature(
          float_list=tf.train.FloatList(value=value.tolist()))
    else:
      feature[key] = tf.train.Feature(
          bytes_list=tf.train.BytesList(value=value.tolist()))
  return tf.train.Example(features=tf.train.Features(feature=feature))


examples_path = tempfile.mkdtemp(prefix="tfx-data")
train_path = os.path.join(examples_path, "train.tfrecord")
eval_path = os.path.join(examples_path, "eval.tfrecord")

for path, dataset in [(train_path, train_set), (eval_path, eval_set)]:
  with tf.io.TFRecordWriter(path) as writer:
    for example in dataset:
      writer.write(
          _dict_to_example({
              "label": np.array([example["label"].numpy()]),
              "text": np.array([example["text"].numpy()]),
          }).SerializeToString())
/home/kbuilder/.local/lib/python3.7/site-packages/ipykernel_launcher.py:7: DeprecationWarning: Converting `np.integer` or `np.signedinteger` to a dtype is deprecated. The current result is `np.dtype(np.int_)` which is not strictly correct. Note that the result depends on the system. To ensure stable results use may want to use `np.int64` or `np.int32`.
  import sys

TFXコンポーネントをインタラクティブに実行する

次のセルで、TFXコンポーネントを作成し、InteractiveContext内で各コンポーネントをインタラクティブにExecutionResultして、 ExecutionResultオブジェクトを取得します。これは、各コンポーネントの依存関係がいつ満たされるかに基づいて、TFXDAGでコンポーネントを実行するオーケストレーターのプロセスを反映しています。

context = InteractiveContext()
WARNING:absl:InteractiveContext pipeline_root argument not provided: using temporary directory /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo as root for pipeline outputs.
WARNING:absl:InteractiveContext metadata_connection_config not provided: using SQLite ML Metadata database at /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/metadata.sqlite.

ExampleGenコンポーネント

ML開発プロセスでは、コード開発を開始するときの最初のステップは、トレーニングデータセットとテストデータセットを取り込むことです。 ExampleGenコンポーネントは、データをTFXパイプラインにExampleGenます。

ExampleGenコンポーネントを作成し、実行します。

input_config = example_gen_pb2.Input(splits=[
    example_gen_pb2.Input.Split(name='train', pattern='train.tfrecord'),
    example_gen_pb2.Input.Split(name='eval', pattern='eval.tfrecord')
])

example_gen = ImportExampleGen(input_base=examples_path, input_config=input_config)

context.run(example_gen, enable_cache=True)
WARNING:apache_beam.runners.interactive.interactive_environment:Dependencies required for Interactive Beam PCollection visualization are not available, please use: `pip install apache-beam[interactive]` to install necessary dependencies to enable all data visualization features.
WARNING:apache_beam.io.tfrecordio:Couldn't find python-snappy so the implementation of _TFRecordUtil._masked_crc32c is not as fast as it could be.
for artifact in example_gen.outputs['examples'].get():
  print(artifact)

print('\nexample_gen.outputs is a {}'.format(type(example_gen.outputs)))
print(example_gen.outputs)

print(example_gen.outputs['examples'].get()[0].split_names)
Artifact(artifact: id: 1
type_id: 5
uri: "/tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/ImportExampleGen/examples/1"
properties {
  key: "split_names"
  value {
    string_value: "[\"train\", \"eval\"]"
  }
}
custom_properties {
  key: "input_fingerprint"
  value {
    string_value: "split:train,num_files:1,total_bytes:27706811,xor_checksum:1621934361,sum_checksum:1621934361\nsplit:eval,num_files:1,total_bytes:13374744,xor_checksum:1621934364,sum_checksum:1621934364"
  }
}
custom_properties {
  key: "payload_format"
  value {
    string_value: "FORMAT_TF_EXAMPLE"
  }
}
custom_properties {
  key: "span"
  value {
    int_value: 0
  }
}
custom_properties {
  key: "state"
  value {
    string_value: "published"
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "0.30.0"
  }
}
state: LIVE
, artifact_type: id: 5
name: "Examples"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
properties {
  key: "version"
  value: INT
}
)

example_gen.outputs is a <class 'tfx.types.node_common._PropertyDictWrapper'>
{'examples': Channel(
    type_name: Examples
    artifacts: [Artifact(artifact: id: 1
type_id: 5
uri: "/tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/ImportExampleGen/examples/1"
properties {
  key: "split_names"
  value {
    string_value: "[\"train\", \"eval\"]"
  }
}
custom_properties {
  key: "input_fingerprint"
  value {
    string_value: "split:train,num_files:1,total_bytes:27706811,xor_checksum:1621934361,sum_checksum:1621934361\nsplit:eval,num_files:1,total_bytes:13374744,xor_checksum:1621934364,sum_checksum:1621934364"
  }
}
custom_properties {
  key: "payload_format"
  value {
    string_value: "FORMAT_TF_EXAMPLE"
  }
}
custom_properties {
  key: "span"
  value {
    int_value: 0
  }
}
custom_properties {
  key: "state"
  value {
    string_value: "published"
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "0.30.0"
  }
}
state: LIVE
, artifact_type: id: 5
name: "Examples"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
properties {
  key: "version"
  value: INT
}
)]
    additional_properties: {}
    additional_custom_properties: {}
)}
["train", "eval"]

コンポーネントの出力には、次の2つのアーティファクトが含まれます。

  • トレーニングの例(ラベル付きレビュー10,000件+ラベルなしレビュー10,000件)
  • 評価例(10,000件のラベル付きレビュー)

IdentificationExamplesカスタムコンポーネント

NSLを使用するには、各インスタンスに一意のIDが必要です。このような一意のIDをすべての分割のすべてのインスタンスに追加するカスタムコンポーネントを作成します。 Apache Beamを活用して、必要に応じて大規模なデータセットに簡単にスケーリングできるようにします。

def make_example_with_unique_id(example, id_feature_name):
  """Adds a unique ID to the given `tf.train.Example` proto.

  This function uses Python's 'uuid' module to generate a universally unique
  identifier for each example.

  Args:
    example: An instance of a `tf.train.Example` proto.
    id_feature_name: The name of the feature in the resulting `tf.train.Example`
      that will contain the unique identifier.

  Returns:
    A new `tf.train.Example` proto that includes a unique identifier as an
    additional feature.
  """
  result = tf.train.Example()
  result.CopyFrom(example)
  unique_id = uuid.uuid4()
  result.features.feature.get_or_create(
      id_feature_name).bytes_list.MergeFrom(
          tf.train.BytesList(value=[str(unique_id).encode('utf-8')]))
  return result


@component
def IdentifyExamples(orig_examples: InputArtifact[Examples],
                     identified_examples: OutputArtifact[Examples],
                     id_feature_name: Parameter[str],
                     component_name: Parameter[str]) -> None:

  # Get a list of the splits in input_data
  splits_list = artifact_utils.decode_split_names(
      split_names=orig_examples.split_names)
  # For completeness, encode the splits names and payload_format.
  # We could also just use input_data.split_names.
  identified_examples.split_names = artifact_utils.encode_split_names(
      splits=splits_list)
  # TODO(b/168616829): Remove populating payload_format after tfx 0.25.0.
  identified_examples.set_string_custom_property(
      "payload_format",
      orig_examples.get_string_custom_property("payload_format"))


  for split in splits_list:
    input_dir = artifact_utils.get_split_uri([orig_examples], split)
    output_dir = artifact_utils.get_split_uri([identified_examples], split)
    os.mkdir(output_dir)
    with beam.Pipeline() as pipeline:
      (pipeline
       | 'ReadExamples' >> beam.io.ReadFromTFRecord(
           os.path.join(input_dir, '*'),
           coder=beam.coders.coders.ProtoCoder(tf.train.Example))
       | 'AddUniqueId' >> beam.Map(make_example_with_unique_id, id_feature_name)
       | 'WriteIdentifiedExamples' >> beam.io.WriteToTFRecord(
           file_path_prefix=os.path.join(output_dir, 'data_tfrecord'),
           coder=beam.coders.coders.ProtoCoder(tf.train.Example),
           file_name_suffix='.gz'))

  return
identify_examples = IdentifyExamples(
    orig_examples=example_gen.outputs['examples'],
    component_name=u'IdentifyExamples',
    id_feature_name=u'id')
context.run(identify_examples, enable_cache=False)

StatisticsGenコンポーネント

StatisticsGenコンポーネントは、データセットの記述統計を計算します。生成された統計は、レビューのために視覚化でき、検証やスキーマの推測などに使用されます。

StatisticsGenコンポーネントを作成し、実行します。

# Computes statistics over data for visualization and example validation.
statistics_gen = StatisticsGen(
    examples=identify_examples.outputs["identified_examples"])
context.run(statistics_gen, enable_cache=True)

SchemaGenコンポーネント

SchemaGenコンポーネントは、StatisticsGenからの統計に基づいてデータのスキーマを生成します。各機能のデータ型と、カテゴリ機能の有効な値の範囲を推測しようとします。

SchemaGenコンポーネントを作成し、実行します。

# Generates schema based on statistics files.
schema_gen = SchemaGen(
    statistics=statistics_gen.outputs['statistics'], infer_feature_shape=False)
context.run(schema_gen, enable_cache=True)

生成されたアーティファクトは、 schema.pbtxtのテキスト表現を含む単なるschema_pb2.Schemaです。

train_uri = schema_gen.outputs['schema'].get()[0].uri
schema_filename = os.path.join(train_uri, 'schema.pbtxt')
schema = tfx.utils.io_utils.parse_pbtxt_file(
    file_name=schema_filename, message=schema_pb2.Schema())

これは、 tfdv.display_schema()を使用して視覚化できますtfdv.display_schema()これについては、後続のラボで詳しく説明します)。

tfdv.display_schema(schema)

ExampleValidatorコンポーネント

ExampleValidatorは、StatisticsGenの統計とSchemaGenのスキーマに基づいて異常検出を実行します。欠落した値、間違ったタイプの値、または許容値のドメイン外のカテゴリ値などの問題を探します。

ExampleValidatorコンポーネントを作成し、実行します。

# Performs anomaly detection based on statistics and data schema.
validate_stats = ExampleValidator(
    statistics=statistics_gen.outputs['statistics'],
    schema=schema_gen.outputs['schema'])
context.run(validate_stats, enable_cache=False)

SynthesizeGraphコンポーネント

グラフの作成には、テキストサンプルの埋め込みを作成し、類似性関数を使用して埋め込みを比較することが含まれます。

事前にトレーニングされたSwivel埋め込みを使用して、入力の各サンプルのtf.train.Example形式で埋め込みを作成します。結果の埋め込みは、サンプルのIDとともにTFRecord形式で保存されます。これは重要であり、後でサンプルの埋め込みをグラフ内の対応するノードと一致させることができます。

サンプルの埋め込みができたら、それらを使用して類似性グラフを作成します。つまり、このグラフのノードはサンプルに対応し、このグラフのエッジはノードのペア間の類似性に対応します。

ニューラル構造化学習は、サンプルの埋め込みに基づいてグラフを作成するためのグラフ作成ライブラリを提供します。埋め込みを比較し、それらの間にエッジを構築するための類似性の尺度としてコサイン類似性を使用します。また、類似性のしきい値を指定することもできます。これを使用して、最終的なグラフから異なるエッジを破棄できます。次の例では、類似性のしきい値として0.99を使用すると、111,066の双方向エッジを持つグラフになります。

swivel_url = 'https://tfhub.dev/google/tf2-preview/gnews-swivel-20dim/1'
hub_layer = hub.KerasLayer(swivel_url, input_shape=[], dtype=tf.string)


def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))


def _float_feature(value):
  """Returns a float_list from a float / double."""
  return tf.train.Feature(float_list=tf.train.FloatList(value=value))


def create_embedding_example(example):
  """Create tf.Example containing the sample's embedding and its ID."""
  sentence_embedding = hub_layer(tf.sparse.to_dense(example['text']))

  # Flatten the sentence embedding back to 1-D.
  sentence_embedding = tf.reshape(sentence_embedding, shape=[-1])

  feature_dict = {
      'id': _bytes_feature(tf.sparse.to_dense(example['id']).numpy()),
      'embedding': _float_feature(sentence_embedding.numpy().tolist())
  }

  return tf.train.Example(features=tf.train.Features(feature=feature_dict))


def create_dataset(uri):
  tfrecord_filenames = [os.path.join(uri, name) for name in os.listdir(uri)]
  return tf.data.TFRecordDataset(tfrecord_filenames, compression_type='GZIP')


def create_embeddings(train_path, output_path):
  dataset = create_dataset(train_path)
  embeddings_path = os.path.join(output_path, 'embeddings.tfr')

  feature_map = {
      'label': tf.io.FixedLenFeature([], tf.int64),
      'id': tf.io.VarLenFeature(tf.string),
      'text': tf.io.VarLenFeature(tf.string)
  }

  with tf.io.TFRecordWriter(embeddings_path) as writer:
    for tfrecord in dataset:
      tensor_dict = tf.io.parse_single_example(tfrecord, feature_map)
      embedding_example = create_embedding_example(tensor_dict)
      writer.write(embedding_example.SerializeToString())


def build_graph(output_path, similarity_threshold):
  embeddings_path = os.path.join(output_path, 'embeddings.tfr')
  graph_path = os.path.join(output_path, 'graph.tsv')
  graph_builder_config = nsl.configs.GraphBuilderConfig(
      similarity_threshold=similarity_threshold,
      lsh_splits=32,
      lsh_rounds=15,
      random_seed=12345)
  nsl.tools.build_graph_from_config([embeddings_path], graph_path,
                                    graph_builder_config)
"""Custom Artifact type"""


class SynthesizedGraph(tfx.types.artifact.Artifact):
  """Output artifact of the SynthesizeGraph component"""
  TYPE_NAME = 'SynthesizedGraphPath'
  PROPERTIES = {
      'span': standard_artifacts.SPAN_PROPERTY,
      'split_names': standard_artifacts.SPLIT_NAMES_PROPERTY,
  }


@component
def SynthesizeGraph(identified_examples: InputArtifact[Examples],
                    synthesized_graph: OutputArtifact[SynthesizedGraph],
                    similarity_threshold: Parameter[float],
                    component_name: Parameter[str]) -> None:

  # Get a list of the splits in input_data
  splits_list = artifact_utils.decode_split_names(
      split_names=identified_examples.split_names)

  # We build a graph only based on the 'Split-train' split which includes both
  # labeled and unlabeled examples.
  train_input_examples_uri = os.path.join(identified_examples.uri,
                                          'Split-train')
  output_graph_uri = os.path.join(synthesized_graph.uri, 'Split-train')
  os.mkdir(output_graph_uri)

  print('Creating embeddings...')
  create_embeddings(train_input_examples_uri, output_graph_uri)

  print('Synthesizing graph...')
  build_graph(output_graph_uri, similarity_threshold)

  synthesized_graph.split_names = artifact_utils.encode_split_names(
      splits=['Split-train'])

  return
synthesize_graph = SynthesizeGraph(
    identified_examples=identify_examples.outputs['identified_examples'],
    component_name=u'SynthesizeGraph',
    similarity_threshold=0.99)
context.run(synthesize_graph, enable_cache=False)
Creating embeddings...
Synthesizing graph...
train_uri = synthesize_graph.outputs["synthesized_graph"].get()[0].uri
os.listdir(train_uri)
['Split-train']
graph_path = os.path.join(train_uri, "Split-train", "graph.tsv")
print("node 1\t\t\t\t\tnode 2\t\t\t\t\tsimilarity")
!head {graph_path}
print("...")
!tail {graph_path}
node 1                  node 2                  similarity
05438126-5e0e-4671-8e0b-86457b8867b7    99bb5cd3-7f66-482d-892a-a6654f5ee7a8    0.990020
99bb5cd3-7f66-482d-892a-a6654f5ee7a8    05438126-5e0e-4671-8e0b-86457b8867b7    0.990020
24cac202-6405-4efe-8125-88e2ddeab0b5    531eb828-b955-44ab-ba4d-0eddad3e25b6    0.992586
531eb828-b955-44ab-ba4d-0eddad3e25b6    24cac202-6405-4efe-8125-88e2ddeab0b5    0.992586
24cac202-6405-4efe-8125-88e2ddeab0b5    abf3e373-8357-4ce0-8135-4b014a802fcf    0.990616
abf3e373-8357-4ce0-8135-4b014a802fcf    24cac202-6405-4efe-8125-88e2ddeab0b5    0.990616
65eb9eeb-387a-43c5-b60e-43c5db06f539    531eb828-b955-44ab-ba4d-0eddad3e25b6    0.992505
531eb828-b955-44ab-ba4d-0eddad3e25b6    65eb9eeb-387a-43c5-b60e-43c5db06f539    0.992505
1aa3c2e1-297d-4b22-9135-d77a3c33772b    adf74b60-53a6-4fc5-a362-c98b68d292c0    0.992471
adf74b60-53a6-4fc5-a362-c98b68d292c0    1aa3c2e1-297d-4b22-9135-d77a3c33772b    0.992471
...
a06d3600-2b1b-43dc-81a9-649228fd0345    0ba7f695-3c60-495d-9ea5-1ec5c0df6190    0.990002
0ba7f695-3c60-495d-9ea5-1ec5c0df6190    a06d3600-2b1b-43dc-81a9-649228fd0345    0.990002
fb5318b2-8642-438b-8e78-8ec73612d900    6c492c67-5e1c-4923-aad6-6ed669fceaf5    0.991046
6c492c67-5e1c-4923-aad6-6ed669fceaf5    fb5318b2-8642-438b-8e78-8ec73612d900    0.991046
b7cfb62d-fe05-4c87-a8a4-ceae6b45ca32    3559d799-7647-48ee-a04c-bd5b4d7cf0d8    0.991198
3559d799-7647-48ee-a04c-bd5b4d7cf0d8    b7cfb62d-fe05-4c87-a8a4-ceae6b45ca32    0.991198
0c8dd107-92df-449d-a30b-170cc9c97071    c8dbe83c-3545-4eca-9588-087c58b54b3e    0.990260
c8dbe83c-3545-4eca-9588-087c58b54b3e    0c8dd107-92df-449d-a30b-170cc9c97071    0.990260
f7bf919b-175b-48b0-8dab-f0b1c5f5e13b    a09eccae-9461-405b-b61a-381123881216    0.991317
a09eccae-9461-405b-b61a-381123881216    f7bf919b-175b-48b0-8dab-f0b1c5f5e13b    0.991317
wc -l {graph_path}
222132 /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/SynthesizeGraph/synthesized_graph/6/Split-train/graph.tsv

変換コンポーネント

Transformコンポーネントは、データ変換と機能エンジニアリングを実行します。結果には、トレーニング中とトレーニングまたは推論の前にデータを前処理するために使用される入力TensorFlowグラフが含まれます。このグラフは、モデルトレーニングの結果であるSavedModelの一部になります。トレーニングとサービングの両方に同じ入力グラフが使用されるため、前処理は常に同じであり、1回だけ書き込む必要があります。

変換コンポーネントは、使用しているデータやモデルに必要な機能エンジニアリングの任意の複雑さのために、他の多くのコンポーネントよりも多くのコードを必要とします。必要な処理を定義するコードファイルが利用可能である必要があります。

各サンプルには、次の3つの機能が含まれます。

  1. id :サンプルのノードID。
  2. text_xf :単語IDを含むint64リスト。
  3. label_xf :レビューのターゲットクラスを識別するシングルトンint64:0 =負、1 =正。

Transformコンポーネントに渡すpreprocessing_fn()関数を含むモジュールを定義しましょう。

_transform_module_file = 'imdb_transform.py'
%%writefile {_transform_module_file}

import tensorflow as tf

import tensorflow_transform as tft

SEQUENCE_LENGTH = 100
VOCAB_SIZE = 10000
OOV_SIZE = 100

def tokenize_reviews(reviews, sequence_length=SEQUENCE_LENGTH):
  reviews = tf.strings.lower(reviews)
  reviews = tf.strings.regex_replace(reviews, r" '| '|^'|'$", " ")
  reviews = tf.strings.regex_replace(reviews, "[^a-z' ]", " ")
  tokens = tf.strings.split(reviews)[:, :sequence_length]
  start_tokens = tf.fill([tf.shape(reviews)[0], 1], "<START>")
  end_tokens = tf.fill([tf.shape(reviews)[0], 1], "<END>")
  tokens = tf.concat([start_tokens, tokens, end_tokens], axis=1)
  tokens = tokens[:, :sequence_length]
  tokens = tokens.to_tensor(default_value="<PAD>")
  pad = sequence_length - tf.shape(tokens)[1]
  tokens = tf.pad(tokens, [[0, 0], [0, pad]], constant_values="<PAD>")
  return tf.reshape(tokens, [-1, sequence_length])

def preprocessing_fn(inputs):
  """tf.transform's callback function for preprocessing inputs.

  Args:
    inputs: map from feature keys to raw not-yet-transformed features.

  Returns:
    Map from string feature key to transformed feature operations.
  """
  outputs = {}
  outputs["id"] = inputs["id"]
  tokens = tokenize_reviews(_fill_in_missing(inputs["text"], ''))
  outputs["text_xf"] = tft.compute_and_apply_vocabulary(
      tokens,
      top_k=VOCAB_SIZE,
      num_oov_buckets=OOV_SIZE)
  outputs["label_xf"] = _fill_in_missing(inputs["label"], -1)
  return outputs

def _fill_in_missing(x, default_value):
  """Replace missing values in a SparseTensor.

  Fills in missing values of `x` with the default_value.

  Args:
    x: A `SparseTensor` of rank 2.  Its dense shape should have size at most 1
      in the second dimension.
    default_value: the value with which to replace the missing values.

  Returns:
    A rank 1 tensor where missing values of `x` have been filled in.
  """
  if not isinstance(x, tf.sparse.SparseTensor):
    return x
  return tf.squeeze(
      tf.sparse.to_dense(
          tf.SparseTensor(x.indices, x.values, [x.dense_shape[0], 1]),
          default_value),
      axis=1)
Writing imdb_transform.py

上で作成したファイルを参照して、 Transformコンポーネントを作成して実行します。

# Performs transformations and feature engineering in training and serving.
transform = Transform(
    examples=identify_examples.outputs['identified_examples'],
    schema=schema_gen.outputs['schema'],
    module_file=_transform_module_file)
context.run(transform, enable_cache=True)
ERROR:absl:udf_utils.get_fn {'module_file': None, 'module_path': 'imdb_transform@/tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/_wheels/tfx_user_code_Transform-0.0+074f608d1f54105225e2fee77ebe4b6159a009eca01b5a0791099840a2185d50-py3-none-any.whl', 'preprocessing_fn': None} 'preprocessing_fn'
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_transform/tf_utils.py:266: Tensor.experimental_ref (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use ref() instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_transform/tf_utils.py:266: Tensor.experimental_ref (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use ref() instead.
WARNING:tensorflow:Tables initialized inside a tf.function will be re-initialized on every invocation of the function. This re-initialization can have significant impact on performance. Consider lifting them out of the graph context using `tf.init_scope`.
WARNING:tensorflow:Tables initialized inside a tf.function will be re-initialized on every invocation of the function. This re-initialization can have significant impact on performance. Consider lifting them out of the graph context using `tf.init_scope`.
WARNING:tensorflow:Tables initialized inside a tf.function will be re-initialized on every invocation of the function. This re-initialization can have significant impact on performance. Consider lifting them out of the graph context using `tf.init_scope`.
WARNING:tensorflow:Tables initialized inside a tf.function will be re-initialized on every invocation of the function. This re-initialization can have significant impact on performance. Consider lifting them out of the graph context using `tf.init_scope`.
WARNING:tensorflow:Tables initialized inside a tf.function will be re-initialized on every invocation of the function. This re-initialization can have significant impact on performance. Consider lifting them out of the graph context using `tf.init_scope`.
WARNING:tensorflow:Tables initialized inside a tf.function will be re-initialized on every invocation of the function. This re-initialization can have significant impact on performance. Consider lifting them out of the graph context using `tf.init_scope`.
WARNING:root:This output type hint will be ignored and not used for type-checking purposes. Typically, output type hints for a PTransform are single (or nested) types wrapped by a PCollection, PDone, or None. Got: Tuple[Dict[str, Union[NoneType, _Dataset]], Union[Dict[str, Dict[str, PCollection]], NoneType]] instead.
WARNING:tensorflow:Tables initialized inside a tf.function will be re-initialized on every invocation of the function. This re-initialization can have significant impact on performance. Consider lifting them out of the graph context using `tf.init_scope`.
WARNING:tensorflow:Tables initialized inside a tf.function will be re-initialized on every invocation of the function. This re-initialization can have significant impact on performance. Consider lifting them out of the graph context using `tf.init_scope`.
WARNING:tensorflow:Tables initialized inside a tf.function will be re-initialized on every invocation of the function. This re-initialization can have significant impact on performance. Consider lifting them out of the graph context using `tf.init_scope`.
WARNING:tensorflow:Tables initialized inside a tf.function will be re-initialized on every invocation of the function. This re-initialization can have significant impact on performance. Consider lifting them out of the graph context using `tf.init_scope`.
WARNING:root:This output type hint will be ignored and not used for type-checking purposes. Typically, output type hints for a PTransform are single (or nested) types wrapped by a PCollection, PDone, or None. Got: Tuple[Dict[str, Union[NoneType, _Dataset]], Union[Dict[str, Dict[str, PCollection]], NoneType]] instead.
WARNING:tensorflow:Tensorflow version (2.4.1) found. Note that Tensorflow Transform support for TF 2.0 is currently in beta, and features such as tf.function may not work as intended.
WARNING:tensorflow:Tensorflow version (2.4.1) found. Note that Tensorflow Transform support for TF 2.0 is currently in beta, and features such as tf.function may not work as intended.
WARNING:tensorflow:Tensorflow version (2.4.1) found. Note that Tensorflow Transform support for TF 2.0 is currently in beta, and features such as tf.function may not work as intended.
WARNING:tensorflow:Tensorflow version (2.4.1) found. Note that Tensorflow Transform support for TF 2.0 is currently in beta, and features such as tf.function may not work as intended.
WARNING:tensorflow:Tensorflow version (2.4.1) found. Note that Tensorflow Transform support for TF 2.0 is currently in beta, and features such as tf.function may not work as intended.
WARNING:tensorflow:Tensorflow version (2.4.1) found. Note that Tensorflow Transform support for TF 2.0 is currently in beta, and features such as tf.function may not work as intended. 
WARNING:apache_beam.typehints.typehints:Ignoring send_type hint: <class 'NoneType'>
WARNING:apache_beam.typehints.typehints:Ignoring return_type hint: <class 'NoneType'>
WARNING:apache_beam.typehints.typehints:Ignoring send_type hint: <class 'NoneType'>
WARNING:apache_beam.typehints.typehints:Ignoring return_type hint: <class 'NoneType'>
WARNING:apache_beam.typehints.typehints:Ignoring send_type hint: <class 'NoneType'>
WARNING:apache_beam.typehints.typehints:Ignoring return_type hint: <class 'NoneType'>
WARNING:apache_beam.typehints.typehints:Ignoring send_type hint: <class 'NoneType'>
WARNING:apache_beam.typehints.typehints:Ignoring return_type hint: <class 'NoneType'>
WARNING:apache_beam.typehints.typehints:Ignoring send_type hint: <class 'NoneType'>
WARNING:apache_beam.typehints.typehints:Ignoring return_type hint: <class 'NoneType'>
WARNING:apache_beam.typehints.typehints:Ignoring send_type hint: <class 'NoneType'>
WARNING:apache_beam.typehints.typehints:Ignoring return_type hint: <class 'NoneType'>
INFO:tensorflow:Assets written to: /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Transform/transform_graph/7/.temp_path/tftransform_tmp/c374e97e97564f87aa83fcd06dd226cf/assets
INFO:tensorflow:Assets written to: /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Transform/transform_graph/7/.temp_path/tftransform_tmp/c374e97e97564f87aa83fcd06dd226cf/assets
INFO:tensorflow:Assets written to: /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Transform/transform_graph/7/.temp_path/tftransform_tmp/eebacfe42c6f48248a40db4b138bde22/assets
INFO:tensorflow:Assets written to: /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Transform/transform_graph/7/.temp_path/tftransform_tmp/eebacfe42c6f48248a40db4b138bde22/assets
WARNING:tensorflow:5 out of the last 20004 calls to <function recreate_function.<locals>.restored_function_body at 0x7f6153b17dd0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:5 out of the last 20004 calls to <function recreate_function.<locals>.restored_function_body at 0x7f6153b17dd0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:6 out of the last 20005 calls to <function recreate_function.<locals>.restored_function_body at 0x7f6153b17050> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:6 out of the last 20005 calls to <function recreate_function.<locals>.restored_function_body at 0x7f6153b17050> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:7 out of the last 20006 calls to <function recreate_function.<locals>.restored_function_body at 0x7f6153b364d0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:7 out of the last 20006 calls to <function recreate_function.<locals>.restored_function_body at 0x7f6153b364d0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:8 out of the last 20007 calls to <function recreate_function.<locals>.restored_function_body at 0x7f614b726170> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:8 out of the last 20007 calls to <function recreate_function.<locals>.restored_function_body at 0x7f614b726170> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:9 out of the last 20008 calls to <function recreate_function.<locals>.restored_function_body at 0x7f614b7243b0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:9 out of the last 20008 calls to <function recreate_function.<locals>.restored_function_body at 0x7f614b7243b0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:10 out of the last 20009 calls to <function recreate_function.<locals>.restored_function_body at 0x7f614b726830> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:10 out of the last 20009 calls to <function recreate_function.<locals>.restored_function_body at 0x7f614b726830> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.

Transformコンポーネントには、次の2種類の出力があります。

  • transform_graphは、前処理操作を実行できるグラフです(このグラフは、サービングモデルと評価モデルに含まれます)。
  • transformed_examplesは、前処理されたトレーニングおよび評価データを表します。
transform.outputs
{'transform_graph': Channel(
    type_name: TransformGraph
    artifacts: [Artifact(artifact: id: 7
type_id: 16
uri: "/tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Transform/transform_graph/7"
custom_properties {
  key: "name"
  value {
    string_value: "transform_graph"
  }
}
custom_properties {
  key: "producer_component"
  value {
    string_value: "Transform"
  }
}
custom_properties {
  key: "state"
  value {
    string_value: "published"
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "0.30.0"
  }
}
state: LIVE
, artifact_type: id: 16
name: "TransformGraph"
)]
    additional_properties: {}
    additional_custom_properties: {}
), 'transformed_examples': Channel(
    type_name: Examples
    artifacts: [Artifact(artifact: id: 8
type_id: 5
uri: "/tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Transform/transformed_examples/7"
properties {
  key: "split_names"
  value {
    string_value: "[\"train\", \"eval\"]"
  }
}
custom_properties {
  key: "name"
  value {
    string_value: "transformed_examples"
  }
}
custom_properties {
  key: "producer_component"
  value {
    string_value: "Transform"
  }
}
custom_properties {
  key: "state"
  value {
    string_value: "published"
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "0.30.0"
  }
}
state: LIVE
, artifact_type: id: 5
name: "Examples"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
properties {
  key: "version"
  value: INT
}
)]
    additional_properties: {}
    additional_custom_properties: {}
), 'updated_analyzer_cache': Channel(
    type_name: TransformCache
    artifacts: [Artifact(artifact: id: 9
type_id: 17
uri: "/tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Transform/updated_analyzer_cache/7"
custom_properties {
  key: "name"
  value {
    string_value: "updated_analyzer_cache"
  }
}
custom_properties {
  key: "producer_component"
  value {
    string_value: "Transform"
  }
}
custom_properties {
  key: "state"
  value {
    string_value: "published"
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "0.30.0"
  }
}
state: LIVE
, artifact_type: id: 17
name: "TransformCache"
)]
    additional_properties: {}
    additional_custom_properties: {}
)}

transform_graphアーティファクトをのぞいてみましょう。3つのサブディレクトリを含むディレクトリを指しています。

train_uri = transform.outputs['transform_graph'].get()[0].uri
os.listdir(train_uri)
['transform_fn', 'transformed_metadata', 'metadata']

transform_fnサブディレクトリには、実際の前処理グラフが含まれています。 metadataサブディレクトリには、元のデータのスキーマが含まれています。 transformed_metadataサブディレクトリには、前処理されたデータのスキーマが含まれています。

変換された例のいくつかを見て、それらが実際に意図したとおりに処理されていることを確認してください。

def pprint_examples(artifact, n_examples=3):
  print("artifact:", artifact)
  uri = os.path.join(artifact.uri, "Split-train")
  print("uri:", uri)
  tfrecord_filenames = [os.path.join(uri, name) for name in os.listdir(uri)]
  print("tfrecord_filenames:", tfrecord_filenames)
  dataset = tf.data.TFRecordDataset(tfrecord_filenames, compression_type="GZIP")
  for tfrecord in dataset.take(n_examples):
    serialized_example = tfrecord.numpy()
    example = tf.train.Example.FromString(serialized_example)
    pp.pprint(example)
pprint_examples(transform.outputs['transformed_examples'].get()[0])
artifact: Artifact(artifact: id: 8
type_id: 5
uri: "/tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Transform/transformed_examples/7"
properties {
  key: "split_names"
  value {
    string_value: "[\"train\", \"eval\"]"
  }
}
custom_properties {
  key: "name"
  value {
    string_value: "transformed_examples"
  }
}
custom_properties {
  key: "producer_component"
  value {
    string_value: "Transform"
  }
}
custom_properties {
  key: "state"
  value {
    string_value: "published"
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "0.30.0"
  }
}
state: LIVE
, artifact_type: id: 5
name: "Examples"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
properties {
  key: "version"
  value: INT
}
)
uri: /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Transform/transformed_examples/7/Split-train
tfrecord_filenames: ['/tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Transform/transformed_examples/7/Split-train/transformed_examples-00000-of-00001.gz']
features {
  feature {
    key: "id"
    value {
      bytes_list {
        value: "a795f9a2-a87a-44e3-80cf-0d477229af71"
      }
    }
  }
  feature {
    key: "label_xf"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "text_xf"
    value {
      int64_list {
        value: 13
        value: 8
        value: 14
        value: 32
        value: 338
        value: 310
        value: 15
        value: 95
        value: 27
        value: 10001
        value: 9
        value: 31
        value: 1173
        value: 3153
        value: 43
        value: 495
        value: 10060
        value: 214
        value: 26
        value: 71
        value: 142
        value: 19
        value: 8
        value: 204
        value: 339
        value: 27
        value: 74
        value: 181
        value: 238
        value: 9
        value: 440
        value: 67
        value: 74
        value: 71
        value: 94
        value: 100
        value: 22
        value: 5442
        value: 8
        value: 1573
        value: 607
        value: 530
        value: 8
        value: 15
        value: 6
        value: 32
        value: 378
        value: 6292
        value: 207
        value: 2276
        value: 388
        value: 0
        value: 84
        value: 1023
        value: 154
        value: 65
        value: 155
        value: 52
        value: 0
        value: 10080
        value: 7871
        value: 65
        value: 250
        value: 74
        value: 3202
        value: 20
        value: 10000
        value: 3720
        value: 10020
        value: 10008
        value: 1282
        value: 3862
        value: 3
        value: 53
        value: 3952
        value: 110
        value: 1879
        value: 17
        value: 3153
        value: 14
        value: 166
        value: 19
        value: 2
        value: 1023
        value: 1007
        value: 9405
        value: 9
        value: 2
        value: 15
        value: 12
        value: 14
        value: 4504
        value: 4
        value: 109
        value: 158
        value: 1202
        value: 7
        value: 174
        value: 505
        value: 12
      }
    }
  }
}

features {
  feature {
    key: "id"
    value {
      bytes_list {
        value: "c626e175-8a31-4919-bbcb-f8d745e74d6c"
      }
    }
  }
  feature {
    key: "label_xf"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "text_xf"
    value {
      int64_list {
        value: 13
        value: 7
        value: 23
        value: 75
        value: 494
        value: 5
        value: 748
        value: 2155
        value: 307
        value: 91
        value: 19
        value: 8
        value: 6
        value: 499
        value: 763
        value: 5
        value: 2
        value: 1690
        value: 4
        value: 200
        value: 593
        value: 57
        value: 1244
        value: 120
        value: 2364
        value: 3
        value: 4407
        value: 21
        value: 0
        value: 10081
        value: 3
        value: 263
        value: 42
        value: 6947
        value: 2
        value: 169
        value: 185
        value: 21
        value: 8
        value: 5143
        value: 7
        value: 1339
        value: 2155
        value: 81
        value: 0
        value: 18
        value: 14
        value: 1468
        value: 0
        value: 86
        value: 986
        value: 14
        value: 2259
        value: 1790
        value: 562
        value: 3
        value: 284
        value: 200
        value: 401
        value: 5
        value: 668
        value: 19
        value: 17
        value: 58
        value: 1934
        value: 4
        value: 45
        value: 14
        value: 4212
        value: 113
        value: 43
        value: 135
        value: 7
        value: 753
        value: 7
        value: 224
        value: 23
        value: 1155
        value: 179
        value: 4
        value: 0
        value: 18
        value: 19
        value: 7
        value: 191
        value: 0
        value: 2047
        value: 4
        value: 10
        value: 3
        value: 283
        value: 42
        value: 401
        value: 5
        value: 668
        value: 4
        value: 90
        value: 234
        value: 10023
        value: 227
      }
    }
  }
}

features {
  feature {
    key: "id"
    value {
      bytes_list {
        value: "d3c2198b-8adb-4f72-9184-a29f87276a80"
      }
    }
  }
  feature {
    key: "label_xf"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "text_xf"
    value {
      int64_list {
        value: 13
        value: 4577
        value: 7158
        value: 0
        value: 10047
        value: 3778
        value: 3346
        value: 9
        value: 2
        value: 758
        value: 1915
        value: 3
        value: 2280
        value: 1511
        value: 3
        value: 2003
        value: 10020
        value: 225
        value: 786
        value: 382
        value: 16
        value: 39
        value: 203
        value: 361
        value: 5
        value: 93
        value: 11
        value: 11
        value: 19
        value: 220
        value: 21
        value: 341
        value: 2
        value: 10000
        value: 966
        value: 0
        value: 77
        value: 4
        value: 6677
        value: 464
        value: 10071
        value: 5
        value: 10042
        value: 630
        value: 2
        value: 10044
        value: 404
        value: 2
        value: 10044
        value: 3
        value: 5
        value: 10008
        value: 0
        value: 1259
        value: 630
        value: 106
        value: 10042
        value: 6721
        value: 10
        value: 49
        value: 21
        value: 0
        value: 2071
        value: 20
        value: 1292
        value: 4
        value: 0
        value: 431
        value: 11
        value: 11
        value: 166
        value: 67
        value: 2342
        value: 5815
        value: 12
        value: 575
        value: 21
        value: 0
        value: 1691
        value: 537
        value: 4
        value: 0
        value: 3605
        value: 307
        value: 0
        value: 10054
        value: 1563
        value: 3115
        value: 467
        value: 4577
        value: 3
        value: 1069
        value: 1158
        value: 5
        value: 23
        value: 4279
        value: 6677
        value: 464
        value: 20
        value: 10004
      }
    }
  }
}

GraphAugmentationコンポーネント

サンプルの特徴と合成されたグラフがあるので、ニューラル構造化学習用の拡張トレーニングデータを生成できます。 NSLフレームワークは、グラフとサンプル機能を組み合わせて、グラフの正則化のための最終的なトレーニングデータを生成するためのライブラリを提供します。結果のトレーニングデータには、元のサンプルの特徴とそれに対応する隣人の特徴が含まれます。

このチュートリアルでは、無向エッジを考慮し、サンプルごとに最大3つのネイバーを使用して、グラフネイバーでトレーニングデータを拡張します。

def split_train_and_unsup(input_uri):
  'Separate the labeled and unlabeled instances.'

  tmp_dir = tempfile.mkdtemp(prefix='tfx-data')
  tfrecord_filenames = [
      os.path.join(input_uri, filename) for filename in os.listdir(input_uri)
  ]
  train_path = os.path.join(tmp_dir, 'train.tfrecord')
  unsup_path = os.path.join(tmp_dir, 'unsup.tfrecord')
  with tf.io.TFRecordWriter(train_path) as train_writer, \
       tf.io.TFRecordWriter(unsup_path) as unsup_writer:
    for tfrecord in tf.data.TFRecordDataset(
        tfrecord_filenames, compression_type='GZIP'):
      example = tf.train.Example()
      example.ParseFromString(tfrecord.numpy())
      if ('label_xf' not in example.features.feature or
          example.features.feature['label_xf'].int64_list.value[0] == -1):
        writer = unsup_writer
      else:
        writer = train_writer
      writer.write(tfrecord.numpy())
  return train_path, unsup_path


def gzip(filepath):
  with open(filepath, 'rb') as f_in:
    with gzip_lib.open(filepath + '.gz', 'wb') as f_out:
      shutil.copyfileobj(f_in, f_out)
  os.remove(filepath)


def copy_tfrecords(input_uri, output_uri):
  for filename in os.listdir(input_uri):
    input_filename = os.path.join(input_uri, filename)
    output_filename = os.path.join(output_uri, filename)
    shutil.copyfile(input_filename, output_filename)


@component
def GraphAugmentation(identified_examples: InputArtifact[Examples],
                      synthesized_graph: InputArtifact[SynthesizedGraph],
                      augmented_examples: OutputArtifact[Examples],
                      num_neighbors: Parameter[int],
                      component_name: Parameter[str]) -> None:

  # Get a list of the splits in input_data
  splits_list = artifact_utils.decode_split_names(
      split_names=identified_examples.split_names)

  train_input_uri = os.path.join(identified_examples.uri, 'Split-train')
  eval_input_uri = os.path.join(identified_examples.uri, 'Split-eval')
  train_graph_uri = os.path.join(synthesized_graph.uri, 'Split-train')
  train_output_uri = os.path.join(augmented_examples.uri, 'Split-train')
  eval_output_uri = os.path.join(augmented_examples.uri, 'Split-eval')

  os.mkdir(train_output_uri)
  os.mkdir(eval_output_uri)

  # Separate the labeled and unlabeled examples from the 'Split-train' split.
  train_path, unsup_path = split_train_and_unsup(train_input_uri)

  output_path = os.path.join(train_output_uri, 'nsl_train_data.tfr')
  pack_nbrs_args = dict(
      labeled_examples_path=train_path,
      unlabeled_examples_path=unsup_path,
      graph_path=os.path.join(train_graph_uri, 'graph.tsv'),
      output_training_data_path=output_path,
      add_undirected_edges=True,
      max_nbrs=num_neighbors)
  print('nsl.tools.pack_nbrs arguments:', pack_nbrs_args)
  nsl.tools.pack_nbrs(**pack_nbrs_args)

  # Downstream components expect gzip'ed TFRecords.
  gzip(output_path)

  # The test examples are left untouched and are simply copied over.
  copy_tfrecords(eval_input_uri, eval_output_uri)

  augmented_examples.split_names = identified_examples.split_names

  return
# Augments training data with graph neighbors.
graph_augmentation = GraphAugmentation(
    identified_examples=transform.outputs['transformed_examples'],
    synthesized_graph=synthesize_graph.outputs['synthesized_graph'],
    component_name=u'GraphAugmentation',
    num_neighbors=3)
context.run(graph_augmentation, enable_cache=False)
nsl.tools.pack_nbrs arguments: {'labeled_examples_path': '/tmp/tfx-data0n009kzy/train.tfrecord', 'unlabeled_examples_path': '/tmp/tfx-data0n009kzy/unsup.tfrecord', 'graph_path': '/tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/SynthesizeGraph/synthesized_graph/6/Split-train/graph.tsv', 'output_training_data_path': '/tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/GraphAugmentation/augmented_examples/8/Split-train/nsl_train_data.tfr', 'add_undirected_edges': True, 'max_nbrs': 3}
pprint_examples(graph_augmentation.outputs['augmented_examples'].get()[0], 6)
artifact: Artifact(artifact: id: 10
type_id: 5
uri: "/tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/GraphAugmentation/augmented_examples/8"
properties {
  key: "split_names"
  value {
    string_value: "[\"train\", \"eval\"]"
  }
}
custom_properties {
  key: "name"
  value {
    string_value: "augmented_examples"
  }
}
custom_properties {
  key: "producer_component"
  value {
    string_value: "GraphAugmentation"
  }
}
custom_properties {
  key: "state"
  value {
    string_value: "published"
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "0.30.0"
  }
}
state: LIVE
, artifact_type: id: 5
name: "Examples"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
properties {
  key: "version"
  value: INT
}
)
uri: /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/GraphAugmentation/augmented_examples/8/Split-train
tfrecord_filenames: ['/tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/GraphAugmentation/augmented_examples/8/Split-train/nsl_train_data.tfr.gz']
features {
  feature {
    key: "NL_num_nbrs"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "id"
    value {
      bytes_list {
        value: "a795f9a2-a87a-44e3-80cf-0d477229af71"
      }
    }
  }
  feature {
    key: "label_xf"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "text_xf"
    value {
      int64_list {
        value: 13
        value: 8
        value: 14
        value: 32
        value: 338
        value: 310
        value: 15
        value: 95
        value: 27
        value: 10001
        value: 9
        value: 31
        value: 1173
        value: 3153
        value: 43
        value: 495
        value: 10060
        value: 214
        value: 26
        value: 71
        value: 142
        value: 19
        value: 8
        value: 204
        value: 339
        value: 27
        value: 74
        value: 181
        value: 238
        value: 9
        value: 440
        value: 67
        value: 74
        value: 71
        value: 94
        value: 100
        value: 22
        value: 5442
        value: 8
        value: 1573
        value: 607
        value: 530
        value: 8
        value: 15
        value: 6
        value: 32
        value: 378
        value: 6292
        value: 207
        value: 2276
        value: 388
        value: 0
        value: 84
        value: 1023
        value: 154
        value: 65
        value: 155
        value: 52
        value: 0
        value: 10080
        value: 7871
        value: 65
        value: 250
        value: 74
        value: 3202
        value: 20
        value: 10000
        value: 3720
        value: 10020
        value: 10008
        value: 1282
        value: 3862
        value: 3
        value: 53
        value: 3952
        value: 110
        value: 1879
        value: 17
        value: 3153
        value: 14
        value: 166
        value: 19
        value: 2
        value: 1023
        value: 1007
        value: 9405
        value: 9
        value: 2
        value: 15
        value: 12
        value: 14
        value: 4504
        value: 4
        value: 109
        value: 158
        value: 1202
        value: 7
        value: 174
        value: 505
        value: 12
      }
    }
  }
}

features {
  feature {
    key: "NL_num_nbrs"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "id"
    value {
      bytes_list {
        value: "c626e175-8a31-4919-bbcb-f8d745e74d6c"
      }
    }
  }
  feature {
    key: "label_xf"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "text_xf"
    value {
      int64_list {
        value: 13
        value: 7
        value: 23
        value: 75
        value: 494
        value: 5
        value: 748
        value: 2155
        value: 307
        value: 91
        value: 19
        value: 8
        value: 6
        value: 499
        value: 763
        value: 5
        value: 2
        value: 1690
        value: 4
        value: 200
        value: 593
        value: 57
        value: 1244
        value: 120
        value: 2364
        value: 3
        value: 4407
        value: 21
        value: 0
        value: 10081
        value: 3
        value: 263
        value: 42
        value: 6947
        value: 2
        value: 169
        value: 185
        value: 21
        value: 8
        value: 5143
        value: 7
        value: 1339
        value: 2155
        value: 81
        value: 0
        value: 18
        value: 14
        value: 1468
        value: 0
        value: 86
        value: 986
        value: 14
        value: 2259
        value: 1790
        value: 562
        value: 3
        value: 284
        value: 200
        value: 401
        value: 5
        value: 668
        value: 19
        value: 17
        value: 58
        value: 1934
        value: 4
        value: 45
        value: 14
        value: 4212
        value: 113
        value: 43
        value: 135
        value: 7
        value: 753
        value: 7
        value: 224
        value: 23
        value: 1155
        value: 179
        value: 4
        value: 0
        value: 18
        value: 19
        value: 7
        value: 191
        value: 0
        value: 2047
        value: 4
        value: 10
        value: 3
        value: 283
        value: 42
        value: 401
        value: 5
        value: 668
        value: 4
        value: 90
        value: 234
        value: 10023
        value: 227
      }
    }
  }
}

features {
  feature {
    key: "NL_num_nbrs"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "id"
    value {
      bytes_list {
        value: "d3c2198b-8adb-4f72-9184-a29f87276a80"
      }
    }
  }
  feature {
    key: "label_xf"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "text_xf"
    value {
      int64_list {
        value: 13
        value: 4577
        value: 7158
        value: 0
        value: 10047
        value: 3778
        value: 3346
        value: 9
        value: 2
        value: 758
        value: 1915
        value: 3
        value: 2280
        value: 1511
        value: 3
        value: 2003
        value: 10020
        value: 225
        value: 786
        value: 382
        value: 16
        value: 39
        value: 203
        value: 361
        value: 5
        value: 93
        value: 11
        value: 11
        value: 19
        value: 220
        value: 21
        value: 341
        value: 2
        value: 10000
        value: 966
        value: 0
        value: 77
        value: 4
        value: 6677
        value: 464
        value: 10071
        value: 5
        value: 10042
        value: 630
        value: 2
        value: 10044
        value: 404
        value: 2
        value: 10044
        value: 3
        value: 5
        value: 10008
        value: 0
        value: 1259
        value: 630
        value: 106
        value: 10042
        value: 6721
        value: 10
        value: 49
        value: 21
        value: 0
        value: 2071
        value: 20
        value: 1292
        value: 4
        value: 0
        value: 431
        value: 11
        value: 11
        value: 166
        value: 67
        value: 2342
        value: 5815
        value: 12
        value: 575
        value: 21
        value: 0
        value: 1691
        value: 537
        value: 4
        value: 0
        value: 3605
        value: 307
        value: 0
        value: 10054
        value: 1563
        value: 3115
        value: 467
        value: 4577
        value: 3
        value: 1069
        value: 1158
        value: 5
        value: 23
        value: 4279
        value: 6677
        value: 464
        value: 20
        value: 10004
      }
    }
  }
}

features {
  feature {
    key: "NL_num_nbrs"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "id"
    value {
      bytes_list {
        value: "8a055682-fcbf-4654-ab15-887aec98d76b"
      }
    }
  }
  feature {
    key: "label_xf"
    value {
      int64_list {
        value: 1
      }
    }
  }
  feature {
    key: "text_xf"
    value {
      int64_list {
        value: 13
        value: 8
        value: 6
        value: 0
        value: 251
        value: 4
        value: 18
        value: 20
        value: 2
        value: 6783
        value: 2295
        value: 2338
        value: 52
        value: 0
        value: 468
        value: 4
        value: 0
        value: 189
        value: 73
        value: 153
        value: 1294
        value: 17
        value: 90
        value: 234
        value: 935
        value: 16
        value: 25
        value: 10024
        value: 92
        value: 2
        value: 192
        value: 4218
        value: 3317
        value: 3
        value: 10098
        value: 20
        value: 2
        value: 356
        value: 4
        value: 565
        value: 334
        value: 382
        value: 36
        value: 6989
        value: 3
        value: 6065
        value: 2510
        value: 16
        value: 203
        value: 7264
        value: 2849
        value: 0
        value: 86
        value: 346
        value: 50
        value: 26
        value: 58
        value: 10020
        value: 5
        value: 1464
        value: 58
        value: 2081
        value: 2969
        value: 42
        value: 2
        value: 2364
        value: 3
        value: 1402
        value: 10062
        value: 138
        value: 147
        value: 614
        value: 115
        value: 29
        value: 90
        value: 105
        value: 2
        value: 223
        value: 18
        value: 9
        value: 160
        value: 324
        value: 3
        value: 24
        value: 12
        value: 1252
        value: 0
        value: 2142
        value: 10
        value: 1832
        value: 111
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
      }
    }
  }
}

features {
  feature {
    key: "NL_num_nbrs"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "id"
    value {
      bytes_list {
        value: "cde97cf0-2454-40e0-86e9-246f2a560a7e"
      }
    }
  }
  feature {
    key: "label_xf"
    value {
      int64_list {
        value: 1
      }
    }
  }
  feature {
    key: "text_xf"
    value {
      int64_list {
        value: 13
        value: 16
        value: 423
        value: 23
        value: 1367
        value: 30
        value: 0
        value: 363
        value: 12
        value: 153
        value: 3174
        value: 9
        value: 8
        value: 18
        value: 26
        value: 667
        value: 338
        value: 1372
        value: 0
        value: 86
        value: 46
        value: 9200
        value: 282
        value: 0
        value: 10091
        value: 4
        value: 0
        value: 694
        value: 10028
        value: 52
        value: 362
        value: 26
        value: 202
        value: 39
        value: 216
        value: 5
        value: 27
        value: 5822
        value: 19
        value: 52
        value: 58
        value: 362
        value: 26
        value: 202
        value: 39
        value: 474
        value: 0
        value: 10029
        value: 4
        value: 2
        value: 243
        value: 143
        value: 386
        value: 3
        value: 0
        value: 386
        value: 579
        value: 2
        value: 132
        value: 57
        value: 725
        value: 88
        value: 140
        value: 30
        value: 27
        value: 33
        value: 1359
        value: 29
        value: 8
        value: 567
        value: 35
        value: 106
        value: 230
        value: 60
        value: 0
        value: 3041
        value: 5
        value: 7879
        value: 28
        value: 281
        value: 110
        value: 111
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
      }
    }
  }
}

features {
  feature {
    key: "NL_num_nbrs"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "id"
    value {
      bytes_list {
        value: "ac0cb1cf-bcb3-4eea-917c-96acc3d8f362"
      }
    }
  }
  feature {
    key: "label_xf"
    value {
      int64_list {
        value: 1
      }
    }
  }
  feature {
    key: "text_xf"
    value {
      int64_list {
        value: 13
        value: 8
        value: 6
        value: 2
        value: 18
        value: 69
        value: 140
        value: 27
        value: 83
        value: 31
        value: 1877
        value: 905
        value: 9
        value: 10057
        value: 31
        value: 43
        value: 2115
        value: 36
        value: 32
        value: 2057
        value: 6133
        value: 10
        value: 6
        value: 32
        value: 2474
        value: 1614
        value: 3
        value: 2707
        value: 990
        value: 4
        value: 10067
        value: 9
        value: 2
        value: 1532
        value: 242
        value: 90
        value: 3757
        value: 3
        value: 90
        value: 10026
        value: 0
        value: 242
        value: 6
        value: 260
        value: 31
        value: 24
        value: 4
        value: 0
        value: 84
        value: 497
        value: 177
        value: 1151
        value: 777
        value: 9
        value: 397
        value: 552
        value: 7726
        value: 10051
        value: 34
        value: 14
        value: 379
        value: 33
        value: 1829
        value: 9
        value: 123
        value: 0
        value: 916
        value: 10028
        value: 7
        value: 64
        value: 571
        value: 12
        value: 8
        value: 18
        value: 27
        value: 687
        value: 9
        value: 30
        value: 5609
        value: 16
        value: 25
        value: 99
        value: 117
        value: 66
        value: 2
        value: 130
        value: 21
        value: 8
        value: 842
        value: 7726
        value: 10051
        value: 6
        value: 338
        value: 1107
        value: 3
        value: 24
        value: 10020
        value: 29
        value: 53
        value: 1476
      }
    }
  }
}

トレーナーコンポーネント

Trainerコンポーネントは、TensorFlowを使用してモデルをトレーニングします。

推定量を返す必要があるtrainer_fn関数を含むPythonモジュールを作成します。 Kerasモデルを作成したい場合は、作成してから、 keras.model_to_estimator()を使用して推定器に変換できます。

# Setup paths.
_trainer_module_file = 'imdb_trainer.py'
%%writefile {_trainer_module_file}

import neural_structured_learning as nsl

import tensorflow as tf

import tensorflow_model_analysis as tfma
import tensorflow_transform as tft
from tensorflow_transform.tf_metadata import schema_utils


NBR_FEATURE_PREFIX = 'NL_nbr_'
NBR_WEIGHT_SUFFIX = '_weight'
LABEL_KEY = 'label'
ID_FEATURE_KEY = 'id'

def _transformed_name(key):
  return key + '_xf'


def _transformed_names(keys):
  return [_transformed_name(key) for key in keys]


# Hyperparameters:
#
# We will use an instance of `HParams` to inclue various hyperparameters and
# constants used for training and evaluation. We briefly describe each of them
# below:
#
# -   max_seq_length: This is the maximum number of words considered from each
#                     movie review in this example.
# -   vocab_size: This is the size of the vocabulary considered for this
#                 example.
# -   oov_size: This is the out-of-vocabulary size considered for this example.
# -   distance_type: This is the distance metric used to regularize the sample
#                    with its neighbors.
# -   graph_regularization_multiplier: This controls the relative weight of the
#                                      graph regularization term in the overall
#                                      loss function.
# -   num_neighbors: The number of neighbors used for graph regularization. This
#                    value has to be less than or equal to the `num_neighbors`
#                    argument used above in the GraphAugmentation component when
#                    invoking `nsl.tools.pack_nbrs`.
# -   num_fc_units: The number of units in the fully connected layer of the
#                   neural network.
class HParams(object):
  """Hyperparameters used for training."""
  def __init__(self):
    ### dataset parameters
    # The following 3 values should match those defined in the Transform
    # Component.
    self.max_seq_length = 100
    self.vocab_size = 10000
    self.oov_size = 100
    ### Neural Graph Learning parameters
    self.distance_type = nsl.configs.DistanceType.L2
    self.graph_regularization_multiplier = 0.1
    # The following value has to be at most the value of 'num_neighbors' used
    # in the GraphAugmentation component.
    self.num_neighbors = 1
    ### Model Architecture
    self.num_embedding_dims = 16
    self.num_fc_units = 64

HPARAMS = HParams()


def optimizer_fn():
  """Returns an instance of `tf.Optimizer`."""
  return tf.compat.v1.train.RMSPropOptimizer(
    learning_rate=0.0001, decay=1e-6)


def build_train_op(loss, global_step):
  """Builds a train op to optimize the given loss using gradient descent."""
  with tf.name_scope('train'):
    optimizer = optimizer_fn()
    train_op = optimizer.minimize(loss=loss, global_step=global_step)
  return train_op


# Building the model:
#
# A neural network is created by stacking layers—this requires two main
# architectural decisions:
# * How many layers to use in the model?
# * How many *hidden units* to use for each layer?
#
# In this example, the input data consists of an array of word-indices. The
# labels to predict are either 0 or 1. We will use a feed-forward neural network
# as our base model in this tutorial.
def feed_forward_model(features, is_training, reuse=tf.compat.v1.AUTO_REUSE):
  """Builds a simple 2 layer feed forward neural network.

  The layers are effectively stacked sequentially to build the classifier. The
  first layer is an Embedding layer, which takes the integer-encoded vocabulary
  and looks up the embedding vector for each word-index. These vectors are
  learned as the model trains. The vectors add a dimension to the output array.
  The resulting dimensions are: (batch, sequence, embedding). Next is a global
  average pooling 1D layer, which reduces the dimensionality of its inputs from
  3D to 2D. This fixed-length output vector is piped through a fully-connected
  (Dense) layer with 16 hidden units. The last layer is densely connected with a
  single output node. Using the sigmoid activation function, this value is a
  float between 0 and 1, representing a probability, or confidence level.

  Args:
    features: A dictionary containing batch features returned from the
      `input_fn`, that include sample features, corresponding neighbor features,
      and neighbor weights.
    is_training: a Python Boolean value or a Boolean scalar Tensor, indicating
      whether to apply dropout.
    reuse: a Python Boolean value for reusing variable scope.

  Returns:
    logits: Tensor of shape [batch_size, 1].
    representations: Tensor of shape [batch_size, _] for graph regularization.
      This is the representation of each example at the graph regularization
      layer.
  """

  with tf.compat.v1.variable_scope('ff', reuse=reuse):
    inputs = features[_transformed_name('text')]
    embeddings = tf.compat.v1.get_variable(
        'embeddings',
        shape=[
            HPARAMS.vocab_size + HPARAMS.oov_size, HPARAMS.num_embedding_dims
        ])
    embedding_layer = tf.nn.embedding_lookup(embeddings, inputs)

    pooling_layer = tf.compat.v1.layers.AveragePooling1D(
        pool_size=HPARAMS.max_seq_length, strides=HPARAMS.max_seq_length)(
            embedding_layer)
    # Shape of pooling_layer is now [batch_size, 1, HPARAMS.num_embedding_dims]
    pooling_layer = tf.reshape(pooling_layer, [-1, HPARAMS.num_embedding_dims])

    dense_layer = tf.compat.v1.layers.Dense(
        16, activation='relu')(
            pooling_layer)

    output_layer = tf.compat.v1.layers.Dense(
        1, activation='sigmoid')(
            dense_layer)

    # Graph regularization will be done on the penultimate (dense) layer
    # because the output layer is a single floating point number.
    return output_layer, dense_layer


# A note on hidden units:
#
# The above model has two intermediate or "hidden" layers, between the input and
# output, and excluding the Embedding layer. The number of outputs (units,
# nodes, or neurons) is the dimension of the representational space for the
# layer. In other words, the amount of freedom the network is allowed when
# learning an internal representation. If a model has more hidden units
# (a higher-dimensional representation space), and/or more layers, then the
# network can learn more complex representations. However, it makes the network
# more computationally expensive and may lead to learning unwanted
# patterns—patterns that improve performance on training data but not on the
# test data. This is called overfitting.


# This function will be used to generate the embeddings for samples and their
# corresponding neighbors, which will then be used for graph regularization.
def embedding_fn(features, mode):
  """Returns the embedding corresponding to the given features.

  Args:
    features: A dictionary containing batch features returned from the
      `input_fn`, that include sample features, corresponding neighbor features,
      and neighbor weights.
    mode: Specifies if this is training, evaluation, or prediction. See
      tf.estimator.ModeKeys.

  Returns:
    The embedding that will be used for graph regularization.
  """
  is_training = (mode == tf.estimator.ModeKeys.TRAIN)
  _, embedding = feed_forward_model(features, is_training)
  return embedding


def feed_forward_model_fn(features, labels, mode, params, config):
  """Implementation of the model_fn for the base feed-forward model.

  Args:
    features: This is the first item returned from the `input_fn` passed to
      `train`, `evaluate`, and `predict`. This should be a single `Tensor` or
      `dict` of same.
    labels: This is the second item returned from the `input_fn` passed to
      `train`, `evaluate`, and `predict`. This should be a single `Tensor` or
      `dict` of same (for multi-head models). If mode is `ModeKeys.PREDICT`,
      `labels=None` will be passed. If the `model_fn`'s signature does not
      accept `mode`, the `model_fn` must still be able to handle `labels=None`.
    mode: Optional. Specifies if this training, evaluation or prediction. See
      `ModeKeys`.
    params: An HParams instance as returned by get_hyper_parameters().
    config: Optional configuration object. Will receive what is passed to
      Estimator in `config` parameter, or the default `config`. Allows updating
      things in your model_fn based on configuration such as `num_ps_replicas`,
      or `model_dir`. Unused currently.

  Returns:
     A `tf.estimator.EstimatorSpec` for the base feed-forward model. This does
     not include graph-based regularization.
  """

  is_training = mode == tf.estimator.ModeKeys.TRAIN

  # Build the computation graph.
  probabilities, _ = feed_forward_model(features, is_training)
  predictions = tf.round(probabilities)

  if mode == tf.estimator.ModeKeys.PREDICT:
    # labels will be None, and no loss to compute.
    cross_entropy_loss = None
    eval_metric_ops = None
  else:
    # Loss is required in train and eval modes.
    # Flatten 'probabilities' to 1-D.
    probabilities = tf.reshape(probabilities, shape=[-1])
    cross_entropy_loss = tf.compat.v1.keras.losses.binary_crossentropy(
        labels, probabilities)
    eval_metric_ops = {
        'accuracy': tf.compat.v1.metrics.accuracy(labels, predictions)
    }

  if is_training:
    global_step = tf.compat.v1.train.get_or_create_global_step()
    train_op = build_train_op(cross_entropy_loss, global_step)
  else:
    train_op = None

  return tf.estimator.EstimatorSpec(
      mode=mode,
      predictions={
          'probabilities': probabilities,
          'predictions': predictions
      },
      loss=cross_entropy_loss,
      train_op=train_op,
      eval_metric_ops=eval_metric_ops)


# Tf.Transform considers these features as "raw"
def _get_raw_feature_spec(schema):
  return schema_utils.schema_as_feature_spec(schema).feature_spec


def _gzip_reader_fn(filenames):
  """Small utility returning a record reader that can read gzip'ed files."""
  return tf.data.TFRecordDataset(
      filenames,
      compression_type='GZIP')


def _example_serving_receiver_fn(tf_transform_output, schema):
  """Build the serving in inputs.

  Args:
    tf_transform_output: A TFTransformOutput.
    schema: the schema of the input data.

  Returns:
    Tensorflow graph which parses examples, applying tf-transform to them.
  """
  raw_feature_spec = _get_raw_feature_spec(schema)
  raw_feature_spec.pop(LABEL_KEY)

  # We don't need the ID feature for serving.
  raw_feature_spec.pop(ID_FEATURE_KEY)

  raw_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
      raw_feature_spec, default_batch_size=None)
  serving_input_receiver = raw_input_fn()

  transformed_features = tf_transform_output.transform_raw_features(
      serving_input_receiver.features)

  # Even though, LABEL_KEY was removed from 'raw_feature_spec', the transform
  # operation would have injected the transformed LABEL_KEY feature with a
  # default value.
  transformed_features.pop(_transformed_name(LABEL_KEY))
  return tf.estimator.export.ServingInputReceiver(
      transformed_features, serving_input_receiver.receiver_tensors)


def _eval_input_receiver_fn(tf_transform_output, schema):
  """Build everything needed for the tf-model-analysis to run the model.

  Args:
    tf_transform_output: A TFTransformOutput.
    schema: the schema of the input data.

  Returns:
    EvalInputReceiver function, which contains:

      - Tensorflow graph which parses raw untransformed features, applies the
        tf-transform preprocessing operators.
      - Set of raw, untransformed features.
      - Label against which predictions will be compared.
  """
  # Notice that the inputs are raw features, not transformed features here.
  raw_feature_spec = _get_raw_feature_spec(schema)

  # We don't need the ID feature for TFMA.
  raw_feature_spec.pop(ID_FEATURE_KEY)

  raw_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
      raw_feature_spec, default_batch_size=None)
  serving_input_receiver = raw_input_fn()

  transformed_features = tf_transform_output.transform_raw_features(
      serving_input_receiver.features)

  labels = transformed_features.pop(_transformed_name(LABEL_KEY))
  return tfma.export.EvalInputReceiver(
      features=transformed_features,
      receiver_tensors=serving_input_receiver.receiver_tensors,
      labels=labels)


def _augment_feature_spec(feature_spec, num_neighbors):
  """Augments `feature_spec` to include neighbor features.
    Args:
      feature_spec: Dictionary of feature keys mapping to TF feature types.
      num_neighbors: Number of neighbors to use for feature key augmentation.
    Returns:
      An augmented `feature_spec` that includes neighbor feature keys.
  """
  for i in range(num_neighbors):
    feature_spec['{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'id')] = \
        tf.io.VarLenFeature(dtype=tf.string)
    # We don't care about the neighbor features corresponding to
    # _transformed_name(LABEL_KEY) because the LABEL_KEY feature will be
    # removed from the feature spec during training/evaluation.
    feature_spec['{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'text_xf')] = \
        tf.io.FixedLenFeature(shape=[HPARAMS.max_seq_length], dtype=tf.int64,
                              default_value=tf.constant(0, dtype=tf.int64,
                                                        shape=[HPARAMS.max_seq_length]))
    # The 'NL_num_nbrs' features is currently not used.

  # Set the neighbor weight feature keys.
  for i in range(num_neighbors):
    feature_spec['{}{}{}'.format(NBR_FEATURE_PREFIX, i, NBR_WEIGHT_SUFFIX)] = \
        tf.io.FixedLenFeature(shape=[1], dtype=tf.float32, default_value=[0.0])

  return feature_spec


def _input_fn(filenames, tf_transform_output, is_training, batch_size=200):
  """Generates features and labels for training or evaluation.

  Args:
    filenames: [str] list of CSV files to read data from.
    tf_transform_output: A TFTransformOutput.
    is_training: Boolean indicating if we are in training mode.
    batch_size: int First dimension size of the Tensors returned by input_fn

  Returns:
    A (features, indices) tuple where features is a dictionary of
      Tensors, and indices is a single Tensor of label indices.
  """
  transformed_feature_spec = (
      tf_transform_output.transformed_feature_spec().copy())

  # During training, NSL uses augmented training data (which includes features
  # from graph neighbors). So, update the feature spec accordingly. This needs
  # to be done because we are using different schemas for NSL training and eval,
  # but the Trainer Component only accepts a single schema.
  if is_training:
    transformed_feature_spec =_augment_feature_spec(transformed_feature_spec,
                                                    HPARAMS.num_neighbors)

  dataset = tf.data.experimental.make_batched_features_dataset(
      filenames, batch_size, transformed_feature_spec, reader=_gzip_reader_fn)

  transformed_features = tf.compat.v1.data.make_one_shot_iterator(
      dataset).get_next()
  # We pop the label because we do not want to use it as a feature while we're
  # training.
  return transformed_features, transformed_features.pop(
      _transformed_name(LABEL_KEY))


# TFX will call this function
def trainer_fn(hparams, schema):
  """Build the estimator using the high level API.
  Args:
    hparams: Holds hyperparameters used to train the model as name/value pairs.
    schema: Holds the schema of the training examples.
  Returns:
    A dict of the following:

      - estimator: The estimator that will be used for training and eval.
      - train_spec: Spec for training.
      - eval_spec: Spec for eval.
      - eval_input_receiver_fn: Input function for eval.
  """
  train_batch_size = 40
  eval_batch_size = 40

  tf_transform_output = tft.TFTransformOutput(hparams.transform_output)

  train_input_fn = lambda: _input_fn(
      hparams.train_files,
      tf_transform_output,
      is_training=True,
      batch_size=train_batch_size)

  eval_input_fn = lambda: _input_fn(
      hparams.eval_files,
      tf_transform_output,
      is_training=False,
      batch_size=eval_batch_size)

  train_spec = tf.estimator.TrainSpec(
      train_input_fn,
      max_steps=hparams.train_steps)

  serving_receiver_fn = lambda: _example_serving_receiver_fn(
      tf_transform_output, schema)

  exporter = tf.estimator.FinalExporter('imdb', serving_receiver_fn)
  eval_spec = tf.estimator.EvalSpec(
      eval_input_fn,
      steps=hparams.eval_steps,
      exporters=[exporter],
      name='imdb-eval')

  run_config = tf.estimator.RunConfig(
      save_checkpoints_steps=999, keep_checkpoint_max=1)

  run_config = run_config.replace(model_dir=hparams.serving_model_dir)

  estimator = tf.estimator.Estimator(
      model_fn=feed_forward_model_fn, config=run_config, params=HPARAMS)

  # Create a graph regularization config.
  graph_reg_config = nsl.configs.make_graph_reg_config(
      max_neighbors=HPARAMS.num_neighbors,
      multiplier=HPARAMS.graph_regularization_multiplier,
      distance_type=HPARAMS.distance_type,
      sum_over_axis=-1)

  # Invoke the Graph Regularization Estimator wrapper to incorporate
  # graph-based regularization for training.
  graph_nsl_estimator = nsl.estimator.add_graph_regularization(
      estimator,
      embedding_fn,
      optimizer_fn=optimizer_fn,
      graph_reg_config=graph_reg_config)

  # Create an input receiver for TFMA processing
  receiver_fn = lambda: _eval_input_receiver_fn(
      tf_transform_output, schema)

  return {
      'estimator': graph_nsl_estimator,
      'train_spec': train_spec,
      'eval_spec': eval_spec,
      'eval_input_receiver_fn': receiver_fn
  }
Writing imdb_trainer.py

Trainerコンポーネントを作成して実行し、上記で作成したファイルを渡します。

# Uses user-provided Python function that implements a model using TensorFlow's
# Estimators API.
trainer = Trainer(
    module_file=_trainer_module_file,
    custom_executor_spec=executor_spec.ExecutorClassSpec(
        trainer_executor.Executor),
    transformed_examples=graph_augmentation.outputs['augmented_examples'],
    schema=schema_gen.outputs['schema'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))
context.run(trainer)
WARNING:absl:`custom_executor_spec` is going to be deprecated.
WARNING:absl:Examples artifact does not have payload_format custom property. Falling back to FORMAT_TF_EXAMPLE
WARNING:absl:Examples artifact does not have payload_format custom property. Falling back to FORMAT_TF_EXAMPLE
WARNING:absl:Examples artifact does not have payload_format custom property. Falling back to FORMAT_TF_EXAMPLE
ERROR:absl:udf_utils.get_fn {'train_args': '{\n  "num_steps": 10000\n}', 'eval_args': '{\n  "num_steps": 5000\n}', 'module_file': None, 'run_fn': None, 'trainer_fn': None, 'custom_config': 'null', 'module_path': 'imdb_trainer@/tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/_wheels/tfx_user_code_Trainer-0.0+b990a2c6a4f23081880867efa3bd3c38db9d7bd0a87a0c9b277ae63714defc8d-py3-none-any.whl'} 'trainer_fn'
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-Serving', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 999, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 1, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-Serving', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 999, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 1, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps 999 or save_checkpoints_secs None.
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps 999 or save_checkpoints_secs None.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/rmsprop.py:123: calling Ones.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/rmsprop.py:123: calling Ones.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-Serving/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-Serving/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.6928506, step = 0
INFO:tensorflow:loss = 0.6928506, step = 0
INFO:tensorflow:global_step/sec: 371.062
INFO:tensorflow:global_step/sec: 371.062
INFO:tensorflow:loss = 0.69304836, step = 100 (0.271 sec)
INFO:tensorflow:loss = 0.69304836, step = 100 (0.271 sec)
INFO:tensorflow:global_step/sec: 533.677
INFO:tensorflow:global_step/sec: 533.677
INFO:tensorflow:loss = 0.6928582, step = 200 (0.189 sec)
INFO:tensorflow:loss = 0.6928582, step = 200 (0.189 sec)
INFO:tensorflow:global_step/sec: 521.762
INFO:tensorflow:global_step/sec: 521.762
INFO:tensorflow:loss = 0.69261754, step = 300 (0.190 sec)
INFO:tensorflow:loss = 0.69261754, step = 300 (0.190 sec)
INFO:tensorflow:global_step/sec: 525.457
INFO:tensorflow:global_step/sec: 525.457
INFO:tensorflow:loss = 0.6919486, step = 400 (0.190 sec)
INFO:tensorflow:loss = 0.6919486, step = 400 (0.190 sec)
INFO:tensorflow:global_step/sec: 524.356
INFO:tensorflow:global_step/sec: 524.356
INFO:tensorflow:loss = 0.69098777, step = 500 (0.191 sec)
INFO:tensorflow:loss = 0.69098777, step = 500 (0.191 sec)
INFO:tensorflow:global_step/sec: 520.17
INFO:tensorflow:global_step/sec: 520.17
INFO:tensorflow:loss = 0.6909203, step = 600 (0.192 sec)
INFO:tensorflow:loss = 0.6909203, step = 600 (0.192 sec)
INFO:tensorflow:global_step/sec: 523.913
INFO:tensorflow:global_step/sec: 523.913
INFO:tensorflow:loss = 0.68905544, step = 700 (0.191 sec)
INFO:tensorflow:loss = 0.68905544, step = 700 (0.191 sec)
INFO:tensorflow:global_step/sec: 530.831
INFO:tensorflow:global_step/sec: 530.831
INFO:tensorflow:loss = 0.6896232, step = 800 (0.188 sec)
INFO:tensorflow:loss = 0.6896232, step = 800 (0.188 sec)
INFO:tensorflow:global_step/sec: 515.854
INFO:tensorflow:global_step/sec: 515.854
INFO:tensorflow:loss = 0.6868899, step = 900 (0.194 sec)
INFO:tensorflow:loss = 0.6868899, step = 900 (0.194 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 999...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 999...
INFO:tensorflow:Saving checkpoints for 999 into /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-Serving/model.ckpt.
INFO:tensorflow:Saving checkpoints for 999 into /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-Serving/model.ckpt.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/saver.py:970: remove_checkpoint (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to delete files with this prefix.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/saver.py:970: remove_checkpoint (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to delete files with this prefix.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 999...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 999...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-05-25T09:21:26Z
INFO:tensorflow:Starting evaluation at 2021-05-25T09:21:26Z
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-Serving/model.ckpt-999
INFO:tensorflow:Restoring parameters from /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-Serving/model.ckpt-999
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [500/5000]
INFO:tensorflow:Evaluation [500/5000]
INFO:tensorflow:Evaluation [1000/5000]
INFO:tensorflow:Evaluation [1000/5000]
INFO:tensorflow:Evaluation [1500/5000]
INFO:tensorflow:Evaluation [1500/5000]
INFO:tensorflow:Evaluation [2000/5000]
INFO:tensorflow:Evaluation [2000/5000]
INFO:tensorflow:Evaluation [2500/5000]
INFO:tensorflow:Evaluation [2500/5000]
INFO:tensorflow:Evaluation [3000/5000]
INFO:tensorflow:Evaluation [3000/5000]
INFO:tensorflow:Evaluation [3500/5000]
INFO:tensorflow:Evaluation [3500/5000]
INFO:tensorflow:Evaluation [4000/5000]
INFO:tensorflow:Evaluation [4000/5000]
INFO:tensorflow:Evaluation [4500/5000]
INFO:tensorflow:Evaluation [4500/5000]
INFO:tensorflow:Evaluation [5000/5000]
INFO:tensorflow:Evaluation [5000/5000]
INFO:tensorflow:Inference Time : 3.93100s
INFO:tensorflow:Inference Time : 3.93100s
INFO:tensorflow:Finished evaluation at 2021-05-25-09:21:30
INFO:tensorflow:Finished evaluation at 2021-05-25-09:21:30
INFO:tensorflow:Saving dict for global step 999: accuracy = 0.7047, global_step = 999, loss = 0.68683666
INFO:tensorflow:Saving dict for global step 999: accuracy = 0.7047, global_step = 999, loss = 0.68683666
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 999: /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-Serving/model.ckpt-999
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 999: /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-Serving/model.ckpt-999
INFO:tensorflow:global_step/sec: 23.0109
INFO:tensorflow:global_step/sec: 23.0109
INFO:tensorflow:loss = 0.68343806, step = 1000 (4.345 sec)
INFO:tensorflow:loss = 0.68343806, step = 1000 (4.345 sec)
INFO:tensorflow:global_step/sec: 499.123
INFO:tensorflow:global_step/sec: 499.123
INFO:tensorflow:loss = 0.682852, step = 1100 (0.201 sec)
INFO:tensorflow:loss = 0.682852, step = 1100 (0.201 sec)
INFO:tensorflow:global_step/sec: 519.179
INFO:tensorflow:global_step/sec: 519.179
INFO:tensorflow:loss = 0.6818341, step = 1200 (0.193 sec)
INFO:tensorflow:loss = 0.6818341, step = 1200 (0.193 sec)
INFO:tensorflow:global_step/sec: 530.08
INFO:tensorflow:global_step/sec: 530.08
INFO:tensorflow:loss = 0.6773318, step = 1300 (0.188 sec)
INFO:tensorflow:loss = 0.6773318, step = 1300 (0.188 sec)
INFO:tensorflow:global_step/sec: 532.471
INFO:tensorflow:global_step/sec: 532.471
INFO:tensorflow:loss = 0.67901975, step = 1400 (0.188 sec)
INFO:tensorflow:loss = 0.67901975, step = 1400 (0.188 sec)
INFO:tensorflow:global_step/sec: 526.228
INFO:tensorflow:global_step/sec: 526.228
INFO:tensorflow:loss = 0.67296153, step = 1500 (0.190 sec)
INFO:tensorflow:loss = 0.67296153, step = 1500 (0.190 sec)
INFO:tensorflow:global_step/sec: 546.989
INFO:tensorflow:global_step/sec: 546.989
INFO:tensorflow:loss = 0.67858636, step = 1600 (0.183 sec)
INFO:tensorflow:loss = 0.67858636, step = 1600 (0.183 sec)
INFO:tensorflow:global_step/sec: 519.426
INFO:tensorflow:global_step/sec: 519.426
INFO:tensorflow:loss = 0.67118084, step = 1700 (0.193 sec)
INFO:tensorflow:loss = 0.67118084, step = 1700 (0.193 sec)
INFO:tensorflow:global_step/sec: 539.933
INFO:tensorflow:global_step/sec: 539.933
INFO:tensorflow:loss = 0.66567874, step = 1800 (0.185 sec)
INFO:tensorflow:loss = 0.66567874, step = 1800 (0.185 sec)
INFO:tensorflow:global_step/sec: 534.835
INFO:tensorflow:global_step/sec: 534.835
INFO:tensorflow:loss = 0.6710329, step = 1900 (0.187 sec)
INFO:tensorflow:loss = 0.6710329, step = 1900 (0.187 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1998...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1998...
INFO:tensorflow:Saving checkpoints for 1998 into /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-Serving/model.ckpt.
INFO:tensorflow:Saving checkpoints for 1998 into /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-Serving/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1998...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1998...
INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs).
INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs).
INFO:tensorflow:global_step/sec: 413.855
INFO:tensorflow:global_step/sec: 413.855
INFO:tensorflow:loss = 0.65162957, step = 2000 (0.241 sec)
INFO:tensorflow:loss = 0.65162957, step = 2000 (0.241 sec)
INFO:tensorflow:global_step/sec: 496.204
INFO:tensorflow:global_step/sec: 496.204
INFO:tensorflow:loss = 0.65566516, step = 2100 (0.202 sec)
INFO:tensorflow:loss = 0.65566516, step = 2100 (0.202 sec)
INFO:tensorflow:global_step/sec: 516.348
INFO:tensorflow:global_step/sec: 516.348
INFO:tensorflow:loss = 0.6617956, step = 2200 (0.194 sec)
INFO:tensorflow:loss = 0.6617956, step = 2200 (0.194 sec)
INFO:tensorflow:global_step/sec: 540.354
INFO:tensorflow:global_step/sec: 540.354
INFO:tensorflow:loss = 0.66066074, step = 2300 (0.184 sec)
INFO:tensorflow:loss = 0.66066074, step = 2300 (0.184 sec)
INFO:tensorflow:global_step/sec: 534.857
INFO:tensorflow:global_step/sec: 534.857
INFO:tensorflow:loss = 0.64779204, step = 2400 (0.187 sec)
INFO:tensorflow:loss = 0.64779204, step = 2400 (0.187 sec)
INFO:tensorflow:global_step/sec: 536.3
INFO:tensorflow:global_step/sec: 536.3
INFO:tensorflow:loss = 0.6503115, step = 2500 (0.186 sec)
INFO:tensorflow:loss = 0.6503115, step = 2500 (0.186 sec)
INFO:tensorflow:global_step/sec: 531.572
INFO:tensorflow:global_step/sec: 531.572
INFO:tensorflow:loss = 0.6608962, step = 2600 (0.188 sec)
INFO:tensorflow:loss = 0.6608962, step = 2600 (0.188 sec)
INFO:tensorflow:global_step/sec: 534.767
INFO:tensorflow:global_step/sec: 534.767
INFO:tensorflow:loss = 0.6526867, step = 2700 (0.187 sec)
INFO:tensorflow:loss = 0.6526867, step = 2700 (0.187 sec)
INFO:tensorflow:global_step/sec: 540.119
INFO:tensorflow:global_step/sec: 540.119
INFO:tensorflow:loss = 0.6412292, step = 2800 (0.185 sec)
INFO:tensorflow:loss = 0.6412292, step = 2800 (0.185 sec)
INFO:tensorflow:global_step/sec: 525.283
INFO:tensorflow:global_step/sec: 525.283
INFO:tensorflow:loss = 0.60136825, step = 2900 (0.191 sec)
INFO:tensorflow:loss = 0.60136825, step = 2900 (0.191 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 2997...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 2997...
INFO:tensorflow:Saving checkpoints for 2997 into /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-Serving/model.ckpt.
INFO:tensorflow:Saving checkpoints for 2997 into /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-Serving/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 2997...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 2997...
INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs).
INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs).
INFO:tensorflow:global_step/sec: 418.01
INFO:tensorflow:global_step/sec: 418.01
INFO:tensorflow:loss = 0.5930721, step = 3000 (0.239 sec)
INFO:tensorflow:loss = 0.5930721, step = 3000 (0.239 sec)
INFO:tensorflow:global_step/sec: 531.614
INFO:tensorflow:global_step/sec: 531.614
INFO:tensorflow:loss = 0.64433926, step = 3100 (0.188 sec)
INFO:tensorflow:loss = 0.64433926, step = 3100 (0.188 sec)
INFO:tensorflow:global_step/sec: 538.866
INFO:tensorflow:global_step/sec: 538.866
INFO:tensorflow:loss = 0.64364636, step = 3200 (0.186 sec)
INFO:tensorflow:loss = 0.64364636, step = 3200 (0.186 sec)
INFO:tensorflow:global_step/sec: 524.68
INFO:tensorflow:global_step/sec: 524.68
INFO:tensorflow:loss = 0.60354906, step = 3300 (0.191 sec)
INFO:tensorflow:loss = 0.60354906, step = 3300 (0.191 sec)
INFO:tensorflow:global_step/sec: 537.055
INFO:tensorflow:global_step/sec: 537.055
INFO:tensorflow:loss = 0.6163613, step = 3400 (0.186 sec)
INFO:tensorflow:loss = 0.6163613, step = 3400 (0.186 sec)
INFO:tensorflow:global_step/sec: 507.556
INFO:tensorflow:global_step/sec: 507.556
INFO:tensorflow:loss = 0.571828, step = 3500 (0.197 sec)
INFO:tensorflow:loss = 0.571828, step = 3500 (0.197 sec)
INFO:tensorflow:global_step/sec: 529.114
INFO:tensorflow:global_step/sec: 529.114
INFO:tensorflow:loss = 0.6008686, step = 3600 (0.189 sec)
INFO:tensorflow:loss = 0.6008686, step = 3600 (0.189 sec)
INFO:tensorflow:global_step/sec: 541.027
INFO:tensorflow:global_step/sec: 541.027
INFO:tensorflow:loss = 0.61156744, step = 3700 (0.185 sec)
INFO:tensorflow:loss = 0.61156744, step = 3700 (0.185 sec)
INFO:tensorflow:global_step/sec: 535.604
INFO:tensorflow:global_step/sec: 535.604
INFO:tensorflow:loss = 0.63452137, step = 3800 (0.187 sec)
INFO:tensorflow:loss = 0.63452137, step = 3800 (0.187 sec)
INFO:tensorflow:global_step/sec: 529.595
INFO:tensorflow:global_step/sec: 529.595
INFO:tensorflow:loss = 0.5998771, step = 3900 (0.189 sec)
INFO:tensorflow:loss = 0.5998771, step = 3900 (0.189 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 3996...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 3996...
INFO:tensorflow:Saving checkpoints for 3996 into /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-Serving/model.ckpt.
INFO:tensorflow:Saving checkpoints for 3996 into /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-Serving/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 3996...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 3996...
INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs).
INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs).
INFO:tensorflow:global_step/sec: 411.953
INFO:tensorflow:global_step/sec: 411.953
INFO:tensorflow:loss = 0.6018996, step = 4000 (0.242 sec)
INFO:tensorflow:loss = 0.6018996, step = 4000 (0.242 sec)
INFO:tensorflow:global_step/sec: 521.844
INFO:tensorflow:global_step/sec: 521.844
INFO:tensorflow:loss = 0.5382768, step = 4100 (0.192 sec)
INFO:tensorflow:loss = 0.5382768, step = 4100 (0.192 sec)
INFO:tensorflow:global_step/sec: 539.501
INFO:tensorflow:global_step/sec: 539.501
INFO:tensorflow:loss = 0.6149534, step = 4200 (0.185 sec)
INFO:tensorflow:loss = 0.6149534, step = 4200 (0.185 sec)
INFO:tensorflow:global_step/sec: 537.315
INFO:tensorflow:global_step/sec: 537.315
INFO:tensorflow:loss = 0.5436077, step = 4300 (0.186 sec)
INFO:tensorflow:loss = 0.5436077, step = 4300 (0.186 sec)
INFO:tensorflow:global_step/sec: 536.308
INFO:tensorflow:global_step/sec: 536.308
INFO:tensorflow:loss = 0.5910145, step = 4400 (0.186 sec)
INFO:tensorflow:loss = 0.5910145, step = 4400 (0.186 sec)
INFO:tensorflow:global_step/sec: 538.57
INFO:tensorflow:global_step/sec: 538.57
INFO:tensorflow:loss = 0.5471341, step = 4500 (0.186 sec)
INFO:tensorflow:loss = 0.5471341, step = 4500 (0.186 sec)
INFO:tensorflow:global_step/sec: 537.228
INFO:tensorflow:global_step/sec: 537.228
INFO:tensorflow:loss = 0.5353176, step = 4600 (0.186 sec)
INFO:tensorflow:loss = 0.5353176, step = 4600 (0.186 sec)
INFO:tensorflow:global_step/sec: 537.485
INFO:tensorflow:global_step/sec: 537.485
INFO:tensorflow:loss = 0.5134242, step = 4700 (0.186 sec)
INFO:tensorflow:loss = 0.5134242, step = 4700 (0.186 sec)
INFO:tensorflow:global_step/sec: 522.625
INFO:tensorflow:global_step/sec: 522.625
INFO:tensorflow:loss = 0.54486126, step = 4800 (0.191 sec)
INFO:tensorflow:loss = 0.54486126, step = 4800 (0.191 sec)
INFO:tensorflow:global_step/sec: 534.207
INFO:tensorflow:global_step/sec: 534.207
INFO:tensorflow:loss = 0.55618936, step = 4900 (0.188 sec)
INFO:tensorflow:loss = 0.55618936, step = 4900 (0.188 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 4995...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 4995...
INFO:tensorflow:Saving checkpoints for 4995 into /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-Serving/model.ckpt.
INFO:tensorflow:Saving checkpoints for 4995 into /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-Serving/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 4995...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 4995...
INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs).
INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs).
INFO:tensorflow:global_step/sec: 413.154
INFO:tensorflow:global_step/sec: 413.154
INFO:tensorflow:loss = 0.53982574, step = 5000 (0.241 sec)
INFO:tensorflow:loss = 0.53982574, step = 5000 (0.241 sec)
INFO:tensorflow:global_step/sec: 531.524
INFO:tensorflow:global_step/sec: 531.524
INFO:tensorflow:loss = 0.51829904, step = 5100 (0.189 sec)
INFO:tensorflow:loss = 0.51829904, step = 5100 (0.189 sec)
INFO:tensorflow:global_step/sec: 533.169
INFO:tensorflow:global_step/sec: 533.169
INFO:tensorflow:loss = 0.5065013, step = 5200 (0.188 sec)
INFO:tensorflow:loss = 0.5065013, step = 5200 (0.188 sec)
INFO:tensorflow:global_step/sec: 535.152
INFO:tensorflow:global_step/sec: 535.152
INFO:tensorflow:loss = 0.5239268, step = 5300 (0.187 sec)
INFO:tensorflow:loss = 0.5239268, step = 5300 (0.187 sec)
INFO:tensorflow:global_step/sec: 542.196
INFO:tensorflow:global_step/sec: 542.196
INFO:tensorflow:loss = 0.6251954, step = 5400 (0.184 sec)
INFO:tensorflow:loss = 0.6251954, step = 5400 (0.184 sec)
INFO:tensorflow:global_step/sec: 538.087
INFO:tensorflow:global_step/sec: 538.087
INFO:tensorflow:loss = 0.50621057, step = 5500 (0.186 sec)
INFO:tensorflow:loss = 0.50621057, step = 5500 (0.186 sec)
INFO:tensorflow:global_step/sec: 529.196
INFO:tensorflow:global_step/sec: 529.196
INFO:tensorflow:loss = 0.44370502, step = 5600 (0.190 sec)
INFO:tensorflow:loss = 0.44370502, step = 5600 (0.190 sec)
INFO:tensorflow:global_step/sec: 538.522
INFO:tensorflow:global_step/sec: 538.522
INFO:tensorflow:loss = 0.48304006, step = 5700 (0.184 sec)
INFO:tensorflow:loss = 0.48304006, step = 5700 (0.184 sec)
INFO:tensorflow:global_step/sec: 530.869
INFO:tensorflow:global_step/sec: 530.869
INFO:tensorflow:loss = 0.56802607, step = 5800 (0.188 sec)
INFO:tensorflow:loss = 0.56802607, step = 5800 (0.188 sec)
INFO:tensorflow:global_step/sec: 540.313
INFO:tensorflow:global_step/sec: 540.313
INFO:tensorflow:loss = 0.50465614, step = 5900 (0.185 sec)
INFO:tensorflow:loss = 0.50465614, step = 5900 (0.185 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5994...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5994...
INFO:tensorflow:Saving checkpoints for 5994 into /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-Serving/model.ckpt.
INFO:tensorflow:Saving checkpoints for 5994 into /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-Serving/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5994...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5994...
INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs).
INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs).
INFO:tensorflow:global_step/sec: 414.475
INFO:tensorflow:global_step/sec: 414.475
INFO:tensorflow:loss = 0.5280973, step = 6000 (0.241 sec)
INFO:tensorflow:loss = 0.5280973, step = 6000 (0.241 sec)
INFO:tensorflow:global_step/sec: 530.259
INFO:tensorflow:global_step/sec: 530.259
INFO:tensorflow:loss = 0.47920552, step = 6100 (0.189 sec)
INFO:tensorflow:loss = 0.47920552, step = 6100 (0.189 sec)
INFO:tensorflow:global_step/sec: 540.715
INFO:tensorflow:global_step/sec: 540.715
INFO:tensorflow:loss = 0.61536604, step = 6200 (0.185 sec)
INFO:tensorflow:loss = 0.61536604, step = 6200 (0.185 sec)
INFO:tensorflow:global_step/sec: 529.91
INFO:tensorflow:global_step/sec: 529.91
INFO:tensorflow:loss = 0.49844682, step = 6300 (0.189 sec)
INFO:tensorflow:loss = 0.49844682, step = 6300 (0.189 sec)
INFO:tensorflow:global_step/sec: 540.25
INFO:tensorflow:global_step/sec: 540.25
INFO:tensorflow:loss = 0.48115006, step = 6400 (0.185 sec)
INFO:tensorflow:loss = 0.48115006, step = 6400 (0.185 sec)
INFO:tensorflow:global_step/sec: 546.463
INFO:tensorflow:global_step/sec: 546.463
INFO:tensorflow:loss = 0.4472222, step = 6500 (0.183 sec)
INFO:tensorflow:loss = 0.4472222, step = 6500 (0.183 sec)
INFO:tensorflow:global_step/sec: 534.192
INFO:tensorflow:global_step/sec: 534.192
INFO:tensorflow:loss = 0.38770705, step = 6600 (0.188 sec)
INFO:tensorflow:loss = 0.38770705, step = 6600 (0.188 sec)
INFO:tensorflow:global_step/sec: 543.924
INFO:tensorflow:global_step/sec: 543.924
INFO:tensorflow:loss = 0.4829008, step = 6700 (0.184 sec)
INFO:tensorflow:loss = 0.4829008, step = 6700 (0.184 sec)
INFO:tensorflow:global_step/sec: 537.401
INFO:tensorflow:global_step/sec: 537.401
INFO:tensorflow:loss = 0.45408913, step = 6800 (0.186 sec)
INFO:tensorflow:loss = 0.45408913, step = 6800 (0.186 sec)
INFO:tensorflow:global_step/sec: 523.727
INFO:tensorflow:global_step/sec: 523.727
INFO:tensorflow:loss = 0.51433593, step = 6900 (0.190 sec)
INFO:tensorflow:loss = 0.51433593, step = 6900 (0.190 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6993...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6993...
INFO:tensorflow:Saving checkpoints for 6993 into /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-Serving/model.ckpt.
INFO:tensorflow:Saving checkpoints for 6993 into /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-Serving/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6993...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6993...
INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs).
INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs).
INFO:tensorflow:global_step/sec: 414.601
INFO:tensorflow:global_step/sec: 414.601
INFO:tensorflow:loss = 0.42705825, step = 7000 (0.241 sec)
INFO:tensorflow:loss = 0.42705825, step = 7000 (0.241 sec)
INFO:tensorflow:global_step/sec: 534.449
INFO:tensorflow:global_step/sec: 534.449
INFO:tensorflow:loss = 0.46791628, step = 7100 (0.187 sec)
INFO:tensorflow:loss = 0.46791628, step = 7100 (0.187 sec)
INFO:tensorflow:global_step/sec: 538.796
INFO:tensorflow:global_step/sec: 538.796
INFO:tensorflow:loss = 0.42690745, step = 7200 (0.186 sec)
INFO:tensorflow:loss = 0.42690745, step = 7200 (0.186 sec)
INFO:tensorflow:global_step/sec: 540.411
INFO:tensorflow:global_step/sec: 540.411
INFO:tensorflow:loss = 0.4788864, step = 7300 (0.185 sec)
INFO:tensorflow:loss = 0.4788864, step = 7300 (0.185 sec)
INFO:tensorflow:global_step/sec: 522.263
INFO:tensorflow:global_step/sec: 522.263
INFO:tensorflow:loss = 0.46483028, step = 7400 (0.192 sec)
INFO:tensorflow:loss = 0.46483028, step = 7400 (0.192 sec)
INFO:tensorflow:global_step/sec: 538.673
INFO:tensorflow:global_step/sec: 538.673
INFO:tensorflow:loss = 0.50668573, step = 7500 (0.185 sec)
INFO:tensorflow:loss = 0.50668573, step = 7500 (0.185 sec)
INFO:tensorflow:global_step/sec: 541.609
INFO:tensorflow:global_step/sec: 541.609
INFO:tensorflow:loss = 0.5154858, step = 7600 (0.185 sec)
INFO:tensorflow:loss = 0.5154858, step = 7600 (0.185 sec)
INFO:tensorflow:global_step/sec: 533.354
INFO:tensorflow:global_step/sec: 533.354
INFO:tensorflow:loss = 0.5880207, step = 7700 (0.187 sec)
INFO:tensorflow:loss = 0.5880207, step = 7700 (0.187 sec)
INFO:tensorflow:global_step/sec: 518.6
INFO:tensorflow:global_step/sec: 518.6
INFO:tensorflow:loss = 0.39300022, step = 7800 (0.193 sec)
INFO:tensorflow:loss = 0.39300022, step = 7800 (0.193 sec)
INFO:tensorflow:global_step/sec: 540.13
INFO:tensorflow:global_step/sec: 540.13
INFO:tensorflow:loss = 0.37214977, step = 7900 (0.185 sec)
INFO:tensorflow:loss = 0.37214977, step = 7900 (0.185 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 7992...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 7992...
INFO:tensorflow:Saving checkpoints for 7992 into /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-Serving/model.ckpt.
INFO:tensorflow:Saving checkpoints for 7992 into /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-Serving/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 7992...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 7992...
INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs).
INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs).
INFO:tensorflow:global_step/sec: 417.028
INFO:tensorflow:global_step/sec: 417.028
INFO:tensorflow:loss = 0.42874864, step = 8000 (0.240 sec)
INFO:tensorflow:loss = 0.42874864, step = 8000 (0.240 sec)
INFO:tensorflow:global_step/sec: 528.743
INFO:tensorflow:global_step/sec: 528.743
INFO:tensorflow:loss = 0.41675007, step = 8100 (0.189 sec)
INFO:tensorflow:loss = 0.41675007, step = 8100 (0.189 sec)
INFO:tensorflow:global_step/sec: 537.473
INFO:tensorflow:global_step/sec: 537.473
INFO:tensorflow:loss = 0.44912827, step = 8200 (0.186 sec)
INFO:tensorflow:loss = 0.44912827, step = 8200 (0.186 sec)
INFO:tensorflow:global_step/sec: 540.336
INFO:tensorflow:global_step/sec: 540.336
INFO:tensorflow:loss = 0.3802678, step = 8300 (0.185 sec)
INFO:tensorflow:loss = 0.3802678, step = 8300 (0.185 sec)
INFO:tensorflow:global_step/sec: 528.773
INFO:tensorflow:global_step/sec: 528.773
INFO:tensorflow:loss = 0.3766276, step = 8400 (0.189 sec)
INFO:tensorflow:loss = 0.3766276, step = 8400 (0.189 sec)
INFO:tensorflow:global_step/sec: 529.9
INFO:tensorflow:global_step/sec: 529.9
INFO:tensorflow:loss = 0.4270187, step = 8500 (0.189 sec)
INFO:tensorflow:loss = 0.4270187, step = 8500 (0.189 sec)
INFO:tensorflow:global_step/sec: 549.284
INFO:tensorflow:global_step/sec: 549.284
INFO:tensorflow:loss = 0.42446175, step = 8600 (0.183 sec)
INFO:tensorflow:loss = 0.42446175, step = 8600 (0.183 sec)
INFO:tensorflow:global_step/sec: 528.564
INFO:tensorflow:global_step/sec: 528.564
INFO:tensorflow:loss = 0.40740764, step = 8700 (0.189 sec)
INFO:tensorflow:loss = 0.40740764, step = 8700 (0.189 sec)
INFO:tensorflow:global_step/sec: 538.921
INFO:tensorflow:global_step/sec: 538.921
INFO:tensorflow:loss = 0.39799905, step = 8800 (0.185 sec)
INFO:tensorflow:loss = 0.39799905, step = 8800 (0.185 sec)
INFO:tensorflow:global_step/sec: 542.929
INFO:tensorflow:global_step/sec: 542.929
INFO:tensorflow:loss = 0.38929462, step = 8900 (0.184 sec)
INFO:tensorflow:loss = 0.38929462, step = 8900 (0.184 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 8991...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 8991...
INFO:tensorflow:Saving checkpoints for 8991 into /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-Serving/model.ckpt.
INFO:tensorflow:Saving checkpoints for 8991 into /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-Serving/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 8991...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 8991...
INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs).
INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs).
INFO:tensorflow:global_step/sec: 413.639
INFO:tensorflow:global_step/sec: 413.639
INFO:tensorflow:loss = 0.4797022, step = 9000 (0.242 sec)
INFO:tensorflow:loss = 0.4797022, step = 9000 (0.242 sec)
INFO:tensorflow:global_step/sec: 534.34
INFO:tensorflow:global_step/sec: 534.34
INFO:tensorflow:loss = 0.45945722, step = 9100 (0.187 sec)
INFO:tensorflow:loss = 0.45945722, step = 9100 (0.187 sec)
INFO:tensorflow:global_step/sec: 545.106
INFO:tensorflow:global_step/sec: 545.106
INFO:tensorflow:loss = 0.43185732, step = 9200 (0.183 sec)
INFO:tensorflow:loss = 0.43185732, step = 9200 (0.183 sec)
INFO:tensorflow:global_step/sec: 546.473
INFO:tensorflow:global_step/sec: 546.473
INFO:tensorflow:loss = 0.41618124, step = 9300 (0.183 sec)
INFO:tensorflow:loss = 0.41618124, step = 9300 (0.183 sec)
INFO:tensorflow:global_step/sec: 497.628
INFO:tensorflow:global_step/sec: 497.628
INFO:tensorflow:loss = 0.46784154, step = 9400 (0.201 sec)
INFO:tensorflow:loss = 0.46784154, step = 9400 (0.201 sec)
INFO:tensorflow:global_step/sec: 493.506
INFO:tensorflow:global_step/sec: 493.506
INFO:tensorflow:loss = 0.36964092, step = 9500 (0.202 sec)
INFO:tensorflow:loss = 0.36964092, step = 9500 (0.202 sec)
INFO:tensorflow:global_step/sec: 534.347
INFO:tensorflow:global_step/sec: 534.347
INFO:tensorflow:loss = 0.3630741, step = 9600 (0.187 sec)
INFO:tensorflow:loss = 0.3630741, step = 9600 (0.187 sec)
INFO:tensorflow:global_step/sec: 504.866
INFO:tensorflow:global_step/sec: 504.866
INFO:tensorflow:loss = 0.39720613, step = 9700 (0.199 sec)
INFO:tensorflow:loss = 0.39720613, step = 9700 (0.199 sec)
INFO:tensorflow:global_step/sec: 513.867
INFO:tensorflow:global_step/sec: 513.867
INFO:tensorflow:loss = 0.39687905, step = 9800 (0.194 sec)
INFO:tensorflow:loss = 0.39687905, step = 9800 (0.194 sec)
INFO:tensorflow:global_step/sec: 520.659
INFO:tensorflow:global_step/sec: 520.659
INFO:tensorflow:loss = 0.43222266, step = 9900 (0.192 sec)
INFO:tensorflow:loss = 0.43222266, step = 9900 (0.192 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 9990...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 9990...
INFO:tensorflow:Saving checkpoints for 9990 into /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-Serving/model.ckpt.
INFO:tensorflow:Saving checkpoints for 9990 into /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-Serving/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 9990...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 9990...
INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs).
INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs).
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10000...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10000...
INFO:tensorflow:Saving checkpoints for 10000 into /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-Serving/model.ckpt.
INFO:tensorflow:Saving checkpoints for 10000 into /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-Serving/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10000...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10000...
INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs).
INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs).
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-05-25T09:21:48Z
INFO:tensorflow:Starting evaluation at 2021-05-25T09:21:48Z
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-Serving/model.ckpt-10000
INFO:tensorflow:Restoring parameters from /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-Serving/model.ckpt-10000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [500/5000]
INFO:tensorflow:Evaluation [500/5000]
INFO:tensorflow:Evaluation [1000/5000]
INFO:tensorflow:Evaluation [1000/5000]
INFO:tensorflow:Evaluation [1500/5000]
INFO:tensorflow:Evaluation [1500/5000]
INFO:tensorflow:Evaluation [2000/5000]
INFO:tensorflow:Evaluation [2000/5000]
INFO:tensorflow:Evaluation [2500/5000]
INFO:tensorflow:Evaluation [2500/5000]
INFO:tensorflow:Evaluation [3000/5000]
INFO:tensorflow:Evaluation [3000/5000]
INFO:tensorflow:Evaluation [3500/5000]
INFO:tensorflow:Evaluation [3500/5000]
INFO:tensorflow:Evaluation [4000/5000]
INFO:tensorflow:Evaluation [4000/5000]
INFO:tensorflow:Evaluation [4500/5000]
INFO:tensorflow:Evaluation [4500/5000]
INFO:tensorflow:Evaluation [5000/5000]
INFO:tensorflow:Evaluation [5000/5000]
INFO:tensorflow:Inference Time : 3.87687s
INFO:tensorflow:Inference Time : 3.87687s
INFO:tensorflow:Finished evaluation at 2021-05-25-09:21:52
INFO:tensorflow:Finished evaluation at 2021-05-25-09:21:52
INFO:tensorflow:Saving dict for global step 10000: accuracy = 0.8008, global_step = 10000, loss = 0.4371457
INFO:tensorflow:Saving dict for global step 10000: accuracy = 0.8008, global_step = 10000, loss = 0.4371457
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10000: /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-Serving/model.ckpt-10000
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10000: /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-Serving/model.ckpt-10000
INFO:tensorflow:Performing the final export in the end of training.
INFO:tensorflow:Performing the final export in the end of training.
WARNING:tensorflow:Unused features are always dropped in the TF 2.x implementation. Ignoring value of drop_unused_features.
WARNING:tensorflow:Unused features are always dropped in the TF 2.x implementation. Ignoring value of drop_unused_features.
WARNING:tensorflow:Loading a TF2 SavedModel but eager mode seems disabled.
WARNING:tensorflow:Loading a TF2 SavedModel but eager mode seems disabled.
WARNING:tensorflow:11 out of the last 20010 calls to <function recreate_function.<locals>.restored_function_body at 0x7f6153e01a70> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:11 out of the last 20010 calls to <function recreate_function.<locals>.restored_function_body at 0x7f6153e01a70> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:11 out of the last 11 calls to <function recreate_function.<locals>.restored_function_body at 0x7f6153e01950> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:11 out of the last 11 calls to <function recreate_function.<locals>.restored_function_body at 0x7f6153e01950> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
Exception ignored in: <function CapturableResourceDeleter.__del__ at 0x7f60f3cace60>
Traceback (most recent call last):
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/tracking/tracking.py", line 208, in __del__
    self._destroy_resource()
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 828, in __call__
    result = self._call(*args, **kwds)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 871, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 726, in _initialize
    *args, **kwds))
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 2969, in _get_concrete_function_internal_garbage_collected
    graph_function, _ = self._maybe_define_function(args, kwargs)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 3361, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 3206, in _create_graph_function
    capture_by_value=self._capture_by_value),
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py", line 990, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 634, in wrapped_fn
    out = weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/function_deserialization.py", line 253, in restored_function_body
    return _call_concrete_function(function, inputs)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/function_deserialization.py", line 75, in _call_concrete_function
    result = function._call_flat(tensor_inputs, function._captured_inputs)  # pylint: disable=protected-access
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/load.py", line 116, in _call_flat
    cancellation_manager)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1932, in _call_flat
    flat_outputs = forward_function.call(ctx, args_with_tangents)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 589, in call
    executor_type=executor_type)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/functional_ops.py", line 1206, in partitioned_call
    f.add_to_graph(graph)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 505, in add_to_graph
    g._add_function(self)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 3396, in _add_function
    gradient)
tensorflow.python.framework.errors_impl.InvalidArgumentError: 'func' argument to TF_GraphCopyFunction cannot be null
Exception ignored in: <function CapturableResourceDeleter.__del__ at 0x7f60f3cace60>
Traceback (most recent call last):
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/tracking/tracking.py", line 208, in __del__
    self._destroy_resource()
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 828, in __call__
    result = self._call(*args, **kwds)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 871, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 726, in _initialize
    *args, **kwds))
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 2969, in _get_concrete_function_internal_garbage_collected
    graph_function, _ = self._maybe_define_function(args, kwargs)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 3361, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 3206, in _create_graph_function
    capture_by_value=self._capture_by_value),
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py", line 990, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 634, in wrapped_fn
    out = weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/function_deserialization.py", line 253, in restored_function_body
    return _call_concrete_function(function, inputs)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/function_deserialization.py", line 75, in _call_concrete_function
    result = function._call_flat(tensor_inputs, function._captured_inputs)  # pylint: disable=protected-access
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/load.py", line 116, in _call_flat
    cancellation_manager)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1932, in _call_flat
    flat_outputs = forward_function.call(ctx, args_with_tangents)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 589, in call
    executor_type=executor_type)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/functional_ops.py", line 1206, in partitioned_call
    f.add_to_graph(graph)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 505, in add_to_graph
    g._add_function(self)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 3396, in _add_function
    gradient)
tensorflow.python.framework.errors_impl.InvalidArgumentError: 'func' argument to TF_GraphCopyFunction cannot be null
WARNING:tensorflow:11 out of the last 11 calls to <function recreate_function.<locals>.restored_function_body at 0x7f614b726ef0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:11 out of the last 11 calls to <function recreate_function.<locals>.restored_function_body at 0x7f614b726ef0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
Exception ignored in: <function CapturableResourceDeleter.__del__ at 0x7f60f3cace60>
Traceback (most recent call last):
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/tracking/tracking.py", line 208, in __del__
    self._destroy_resource()
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 828, in __call__
    result = self._call(*args, **kwds)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 871, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 726, in _initialize
    *args, **kwds))
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 2969, in _get_concrete_function_internal_garbage_collected
    graph_function, _ = self._maybe_define_function(args, kwargs)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 3361, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 3206, in _create_graph_function
    capture_by_value=self._capture_by_value),
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py", line 990, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 634, in wrapped_fn
    out = weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/function_deserialization.py", line 253, in restored_function_body
    return _call_concrete_function(function, inputs)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/function_deserialization.py", line 75, in _call_concrete_function
    result = function._call_flat(tensor_inputs, function._captured_inputs)  # pylint: disable=protected-access
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/load.py", line 116, in _call_flat
    cancellation_manager)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1932, in _call_flat
    flat_outputs = forward_function.call(ctx, args_with_tangents)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 589, in call
    executor_type=executor_type)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/functional_ops.py", line 1206, in partitioned_call
    f.add_to_graph(graph)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 505, in add_to_graph
    g._add_function(self)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 3396, in _add_function
    gradient)
tensorflow.python.framework.errors_impl.InvalidArgumentError: 'func' argument to TF_GraphCopyFunction cannot be null
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:201: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:201: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.
INFO:tensorflow:Signatures INCLUDED in export for Classify: None
INFO:tensorflow:Signatures INCLUDED in export for Classify: None
INFO:tensorflow:Signatures INCLUDED in export for Regress: None
INFO:tensorflow:Signatures INCLUDED in export for Regress: None
INFO:tensorflow:Signatures INCLUDED in export for Predict: ['serving_default']
INFO:tensorflow:Signatures INCLUDED in export for Predict: ['serving_default']
INFO:tensorflow:Signatures INCLUDED in export for Train: None
INFO:tensorflow:Signatures INCLUDED in export for Train: None
INFO:tensorflow:Signatures INCLUDED in export for Eval: None
INFO:tensorflow:Signatures INCLUDED in export for Eval: None
INFO:tensorflow:Restoring parameters from /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-Serving/model.ckpt-10000
INFO:tensorflow:Restoring parameters from /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-Serving/model.ckpt-10000
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:Assets written to: /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-Serving/export/imdb/temp-1621934512/assets
INFO:tensorflow:Assets written to: /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-Serving/export/imdb/temp-1621934512/assets
INFO:tensorflow:SavedModel written to: /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-Serving/export/imdb/temp-1621934512/saved_model.pb
INFO:tensorflow:SavedModel written to: /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-Serving/export/imdb/temp-1621934512/saved_model.pb
INFO:tensorflow:Loss for final step: 0.5107416.
INFO:tensorflow:Loss for final step: 0.5107416.
WARNING:tensorflow:Unused features are always dropped in the TF 2.x implementation. Ignoring value of drop_unused_features.
WARNING:tensorflow:Unused features are always dropped in the TF 2.x implementation. Ignoring value of drop_unused_features.
WARNING:tensorflow:Loading a TF2 SavedModel but eager mode seems disabled.
WARNING:tensorflow:Loading a TF2 SavedModel but eager mode seems disabled.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Signatures INCLUDED in export for Classify: None
INFO:tensorflow:Signatures INCLUDED in export for Classify: None
INFO:tensorflow:Signatures INCLUDED in export for Regress: None
INFO:tensorflow:Signatures INCLUDED in export for Regress: None
INFO:tensorflow:Signatures INCLUDED in export for Predict: None
INFO:tensorflow:Signatures INCLUDED in export for Predict: None
INFO:tensorflow:Signatures INCLUDED in export for Train: None
INFO:tensorflow:Signatures INCLUDED in export for Train: None
INFO:tensorflow:Signatures INCLUDED in export for Eval: ['eval']
INFO:tensorflow:Signatures INCLUDED in export for Eval: ['eval']
WARNING:tensorflow:Export includes no default signature!
WARNING:tensorflow:Export includes no default signature!
INFO:tensorflow:Restoring parameters from /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-Serving/model.ckpt-10000
INFO:tensorflow:Restoring parameters from /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-Serving/model.ckpt-10000
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:Assets written to: /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-TFMA/temp-1621934514/assets
INFO:tensorflow:Assets written to: /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-TFMA/temp-1621934514/assets
INFO:tensorflow:SavedModel written to: /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-TFMA/temp-1621934514/saved_model.pb
INFO:tensorflow:SavedModel written to: /tmp/tfx-interactive-2021-05-25T09_19_24.660931-lt26ehlo/Trainer/model_run/9/Format-TFMA/temp-1621934514/saved_model.pb
WARNING:absl:Support for estimator-based executor and model export will be deprecated soon. Please use export structure <ModelExportPath>/serving_model_dir/saved_model.pb"
WARNING:absl:Support for estimator-based executor and model export will be deprecated soon. Please use export structure <ModelExportPath>/eval_model_dir/saved_model.pb"

Trainerからエクスポートされたトレーニング済みモデルをのぞいてみましょう。

train_uri = trainer.outputs['model'].get()[0].uri
serving_model_path = os.path.join(train_uri, 'Format-Serving')
exported_model = tf.saved_model.load(serving_model_path)
exported_model.graph.get_operations()[:10] + ["..."]
[<tf.Operation 'global_step/Initializer/zeros' type=Const>,
 <tf.Operation 'global_step' type=VarHandleOp>,
 <tf.Operation 'global_step/IsInitialized/VarIsInitializedOp' type=VarIsInitializedOp>,
 <tf.Operation 'global_step/Assign' type=AssignVariableOp>,
 <tf.Operation 'global_step/Read/ReadVariableOp' type=ReadVariableOp>,
 <tf.Operation 'input_example_tensor' type=Placeholder>,
 <tf.Operation 'ParseExample/ParseExampleV2/names' type=Const>,
 <tf.Operation 'ParseExample/ParseExampleV2/sparse_keys' type=Const>,
 <tf.Operation 'ParseExample/ParseExampleV2/dense_keys' type=Const>,
 <tf.Operation 'ParseExample/ParseExampleV2/ragged_keys' type=Const>,
 '...']

Tensorboardを使用してモデルのメトリックを視覚化してみましょう。


# Get the URI of the output artifact representing the training logs,
# which is a directory
model_run_dir = trainer.outputs['model_run'].get()[0].uri

%load_ext tensorboard
%tensorboard --logdir {model_run_dir}

モデルサービング

グラフの正則化は、損失関数に正則化項を追加することによってのみトレーニングワークフローに影響します。その結果、モデルの評価と提供のワークフローは変更されません。 EvaluatorPusherなどのTrainerコンポーネントの後に通常続くダウンストリームTFXコンポーネントも省略したのも同じ理由です。

結論

入力に明示的なグラフが含まれていない場合でも、TFXパイプラインでNeural Structured Learning(NSL)フレームワークを使用したグラフ正則化の使用を示しました。レビューの埋め込みに基づいて類似性グラフを合成したIMDB映画レビューの感情分類のタスクを検討しました。グラフ構築にさまざまな埋め込みを使用し、ハイパーパラメータを変更し、監視の量を変更し、さまざまなモデルアーキテクチャを定義することにより、ユーザーがさらに実験することをお勧めします。