TensorFlowTransformを使用したデータの前処理

TensorFlow Extended(TFX)の特徴エンジニアリングコンポーネント

この例コラボノートがどのの幾分より高度な例を提供TensorFlowが変換tf.Transform )モデルを訓練し、生産に推論をサービングの両方のための正確に同じコードを使用して、前処理データを用いることができます。

TensorFlow Transformは、トレーニングデータセットのフルパスを必要とする機能の作成など、TensorFlowの入力データを前処理するためのライブラリです。たとえば、TensorFlow Transformを使用すると、次のことができます。

  • 平均と標準偏差を使用して入力値を正規化します
  • すべての入力値に対して語彙を生成することにより、文字列を整数に変換します
  • 観測されたデータ分布に基づいて、フロートをバケットに割り当てることにより、フロートを整数に変換します

TensorFlowには、単一の例または例のバッチに対する操作のサポートが組み込まれています。 tf.Transform全体のトレーニングデータセットを完全にパスをサポートするために、これらの機能を拡張します。

出力tf.Transformあなたがトレーニングとサービス提供の両方に使用できるTensorFlowグラフとしてエクスポートされます。トレーニングとサービングの両方に同じグラフを使用すると、両方のステージで同じ変換が適用されるため、スキューを防ぐことができます。

この例で行っていること

この例では、処理されます国勢調査のデータを含んで広く使われているデータセットを、分類を行うためのモデルを訓練します。道に沿って私たちは使用してデータを変換することがありますtf.Transform

アップグレードピップ

ローカルで実行しているときにシステムでPipをアップグレードしないようにするには、Colabで実行していることを確認してください。もちろん、ローカルシステムは個別にアップグレードできます。

try:
  import colab
  !pip install --upgrade pip
except:
  pass

TensorFlowTransformをインストールします

pip install tensorflow-transform

Pythonのチェック、インポート、およびグローバル

まず、Python 3を使用していることを確認してから、必要なものをインストールしてインポートします。

import sys

# Confirm that we're using Python 3
assert sys.version_info.major == 3, 'Oops, not running Python 3. Use Runtime > Change runtime type'
import math
import os
import pprint

import tensorflow as tf
print('TF: {}'.format(tf.__version__))

import apache_beam as beam
print('Beam: {}'.format(beam.__version__))

import tensorflow_transform as tft
import tensorflow_transform.beam as tft_beam
print('Transform: {}'.format(tft.__version__))

from tfx_bsl.public import tfxio
from tfx_bsl.coders.example_coder import RecordBatchToExamples

!wget https://storage.googleapis.com/artifacts.tfx-oss-public.appspot.com/datasets/census/adult.data
!wget https://storage.googleapis.com/artifacts.tfx-oss-public.appspot.com/datasets/census/adult.test

train = './adult.data'
test = './adult.test'
TF: 2.4.4
Beam: 2.34.0
Transform: 0.29.0
--2021-12-04 10:43:05--  https://storage.googleapis.com/artifacts.tfx-oss-public.appspot.com/datasets/census/adult.data
Resolving storage.googleapis.com (storage.googleapis.com)... 142.251.8.128, 74.125.204.128, 64.233.189.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|142.251.8.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3974305 (3.8M) [application/octet-stream]
Saving to: ‘adult.data’

adult.data          100%[===================>]   3.79M  --.-KB/s    in 0.03s   

2021-12-04 10:43:05 (135 MB/s) - ‘adult.data’ saved [3974305/3974305]

--2021-12-04 10:43:05--  https://storage.googleapis.com/artifacts.tfx-oss-public.appspot.com/datasets/census/adult.test
Resolving storage.googleapis.com (storage.googleapis.com)... 142.250.157.128, 108.177.125.128, 64.233.189.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|142.250.157.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2003153 (1.9M) [application/octet-stream]
Saving to: ‘adult.test’

adult.test          100%[===================>]   1.91M  --.-KB/s    in 0.01s   

2021-12-04 10:43:05 (177 MB/s) - ‘adult.test’ saved [2003153/2003153]

列に名前を付ける

データセットの列を参照するための便利なリストをいくつか作成します。

CATEGORICAL_FEATURE_KEYS = [
    'workclass',
    'education',
    'marital-status',
    'occupation',
    'relationship',
    'race',
    'sex',
    'native-country',
]
NUMERIC_FEATURE_KEYS = [
    'age',
    'capital-gain',
    'capital-loss',
    'hours-per-week',
]
OPTIONAL_NUMERIC_FEATURE_KEYS = [
    'education-num',
]
ORDERED_CSV_COLUMNS = [
    'age', 'workclass', 'fnlwgt', 'education', 'education-num',
    'marital-status', 'occupation', 'relationship', 'race', 'sex',
    'capital-gain', 'capital-loss', 'hours-per-week', 'native-country', 'label'
]
LABEL_KEY = 'label'

機能とスキーマを定義する

入力の列のタイプに基づいてスキーマを定義しましょう。とりわけ、これはそれらを正しくインポートするのに役立ちます。

RAW_DATA_FEATURE_SPEC = dict(
    [(name, tf.io.FixedLenFeature([], tf.string))
     for name in CATEGORICAL_FEATURE_KEYS] +
    [(name, tf.io.FixedLenFeature([], tf.float32))
     for name in NUMERIC_FEATURE_KEYS] +
    [(name, tf.io.VarLenFeature(tf.float32))
     for name in OPTIONAL_NUMERIC_FEATURE_KEYS] +
    [(LABEL_KEY, tf.io.FixedLenFeature([], tf.string))]
)

SCHEMA = tft.tf_metadata.dataset_metadata.DatasetMetadata(
    tft.tf_metadata.schema_utils.schema_from_feature_spec(RAW_DATA_FEATURE_SPEC)).schema

ハイパーパラメータの設定と基本的なハウスキーピング

トレーニングに使用される定数とハイパーパラメータ。バケットサイズには、データセットの説明にリストされているすべてのカテゴリと、「?」の1つの追加カテゴリが含まれます。これは不明を表します。

testing = os.getenv("WEB_TEST_BROWSER", False)
NUM_OOV_BUCKETS = 1
if testing:
  TRAIN_NUM_EPOCHS = 1
  NUM_TRAIN_INSTANCES = 1
  TRAIN_BATCH_SIZE = 1
  NUM_TEST_INSTANCES = 1
else:
  TRAIN_NUM_EPOCHS = 16
  NUM_TRAIN_INSTANCES = 32561
  TRAIN_BATCH_SIZE = 128
  NUM_TEST_INSTANCES = 16281

# Names of temp files
TRANSFORMED_TRAIN_DATA_FILEBASE = 'train_transformed'
TRANSFORMED_TEST_DATA_FILEBASE = 'test_transformed'
EXPORTED_MODEL_DIR = 'exported_model_dir'

で前処理tf.Transform

作成tf.Transform preprocessing_fnを

前処理機能はtf.Transformの最も重要な概念です。前処理関数は、データセットの変換が実際に行われる場所です。それは受け入れ、テンソルは意味テンソルの辞書、返しTensorまたはSparseTensor 。通常、前処理関数の中心を形成するAPI呼び出しには2つの主要なグループがあります。

  1. TensorFlowオプス:テンソルを受け入れ、返す任意の関数で、通常TensorFlowオプスを意味しています。これらは、生データを一度に1つの特徴ベクトルで変換されたデータに変換するグラフにTensorFlow操作を追加します。これらは、トレーニングとサービングの両方で、すべての例で実行されます。
  2. アナライザを変換TensorFlow:tf.Transformが提供する分析器のいずれか。アナライザーもテンソルを受け入れて返しますが、TensorFlow opsとは異なり、トレーニング中に1回だけ実行され、通常はトレーニングデータセット全体を完全に通過します。彼らは、作成テンソル定数あなたのグラフに追加されています、。例えば、 tft.minトレーニングデータセットを超えるテンソルの最小値を計算します。 tf.Transformは固定されたアナライザーのセットを提供しますが、これは将来のバージョンで拡張される予定です。
def preprocessing_fn(inputs):
  """Preprocess input columns into transformed columns."""
  # Since we are modifying some features and leaving others unchanged, we
  # start by setting `outputs` to a copy of `inputs.
  outputs = inputs.copy()

  # Scale numeric columns to have range [0, 1].
  for key in NUMERIC_FEATURE_KEYS:
    outputs[key] = tft.scale_to_0_1(inputs[key])

  for key in OPTIONAL_NUMERIC_FEATURE_KEYS:
    # This is a SparseTensor because it is optional. Here we fill in a default
    # value when it is missing.
    sparse = tf.sparse.SparseTensor(inputs[key].indices, inputs[key].values,
                                    [inputs[key].dense_shape[0], 1])
    dense = tf.sparse.to_dense(sp_input=sparse, default_value=0.)
    # Reshaping from a batch of vectors of size 1 to a batch to scalars.
    dense = tf.squeeze(dense, axis=1)
    outputs[key] = tft.scale_to_0_1(dense)

  # For all categorical columns except the label column, we generate a
  # vocabulary but do not modify the feature.  This vocabulary is instead
  # used in the trainer, by means of a feature column, to convert the feature
  # from a string to an integer id.
  for key in CATEGORICAL_FEATURE_KEYS:
    outputs[key] = tft.compute_and_apply_vocabulary(
        tf.strings.strip(inputs[key]),
        num_oov_buckets=NUM_OOV_BUCKETS,
        vocab_filename=key)

  # For the label column we provide the mapping from string to index.
  table_keys = ['>50K', '<=50K']
  with tf.init_scope():
    initializer = tf.lookup.KeyValueTensorInitializer(
        keys=table_keys,
        values=tf.cast(tf.range(len(table_keys)), tf.int64),
        key_dtype=tf.string,
        value_dtype=tf.int64)
    table = tf.lookup.StaticHashTable(initializer, default_value=-1)
  # Remove trailing periods for test data when the data is read with tf.data.
  label_str = tf.strings.regex_replace(inputs[LABEL_KEY], r'\.', '')
  label_str = tf.strings.strip(label_str)
  data_labels = table.lookup(label_str)
  transformed_label = tf.one_hot(
      indices=data_labels, depth=len(table_keys), on_value=1.0, off_value=0.0)
  outputs[LABEL_KEY] = tf.reshape(transformed_label, [-1, len(table_keys)])

  return outputs

データを変換する

これで、ApacheBeamパイプラインでデータの変換を開始する準備が整いました。

  1. CSVリーダーを使用してデータを読み込みます
  2. 各カテゴリの語彙を作成することにより、数値データをスケーリングし、カテゴリデータを文字列からint64値インデックスに変換する前処理パイプラインを使用して変換します
  3. その結果を書き出すTFRecordExample我々は、後でモデルのトレーニングに使用するプロト、
def transform_data(train_data_file, test_data_file, working_dir):
  """Transform the data and write out as a TFRecord of Example protos.

  Read in the data using the CSV reader, and transform it using a
  preprocessing pipeline that scales numeric data and converts categorical data
  from strings to int64 values indices, by creating a vocabulary for each
  category.

  Args:
    train_data_file: File containing training data
    test_data_file: File containing test data
    working_dir: Directory to write transformed data and metadata to
  """

  # The "with" block will create a pipeline, and run that pipeline at the exit
  # of the block.
  with beam.Pipeline() as pipeline:
    with tft_beam.Context(temp_dir=tempfile.mkdtemp()):
      # Create a TFXIO to read the census data with the schema. To do this we
      # need to list all columns in order since the schema doesn't specify the
      # order of columns in the csv.
      # We first read CSV files and use BeamRecordCsvTFXIO whose .BeamSource()
      # accepts a PCollection[bytes] because we need to patch the records first
      # (see "FixCommasTrainData" below). Otherwise, tfxio.CsvTFXIO can be used
      # to both read the CSV files and parse them to TFT inputs:
      # csv_tfxio = tfxio.CsvTFXIO(...)
      # raw_data = (pipeline | 'ToRecordBatches' >> csv_tfxio.BeamSource())
      csv_tfxio = tfxio.BeamRecordCsvTFXIO(
          physical_format='text',
          column_names=ORDERED_CSV_COLUMNS,
          schema=SCHEMA)

      # Read in raw data and convert using CSV TFXIO.  Note that we apply
      # some Beam transformations here, which will not be encoded in the TF
      # graph since we don't do the from within tf.Transform's methods
      # (AnalyzeDataset, TransformDataset etc.).  These transformations are just
      # to get data into a format that the CSV TFXIO can read, in particular
      # removing spaces after commas.
      raw_data = (
          pipeline
          | 'ReadTrainData' >> beam.io.ReadFromText(
              train_data_file, coder=beam.coders.BytesCoder())
          | 'FixCommasTrainData' >> beam.Map(
              lambda line: line.replace(b', ', b','))
          | 'DecodeTrainData' >> csv_tfxio.BeamSource())

      # Combine data and schema into a dataset tuple.  Note that we already used
      # the schema to read the CSV data, but we also need it to interpret
      # raw_data.
      raw_dataset = (raw_data, csv_tfxio.TensorAdapterConfig())

      # The TFXIO output format is chosen for improved performance.
      transformed_dataset, transform_fn = (
          raw_dataset | tft_beam.AnalyzeAndTransformDataset(
              preprocessing_fn, output_record_batches=True))

      # Transformed metadata is not necessary for encoding.
      transformed_data, _ = transformed_dataset

      # Extract transformed RecordBatches, encode and write them to the given
      # directory.
      _ = (
          transformed_data
          | 'EncodeTrainData' >>
          beam.FlatMapTuple(lambda batch, _: RecordBatchToExamples(batch))
          | 'WriteTrainData' >> beam.io.WriteToTFRecord(
              os.path.join(working_dir, TRANSFORMED_TRAIN_DATA_FILEBASE)))

      # Now apply transform function to test data.  In this case we remove the
      # trailing period at the end of each line, and also ignore the header line
      # that is present in the test data file.
      raw_test_data = (
          pipeline
          | 'ReadTestData' >> beam.io.ReadFromText(
              test_data_file, skip_header_lines=1,
              coder=beam.coders.BytesCoder())
          | 'FixCommasTestData' >> beam.Map(
              lambda line: line.replace(b', ', b','))
          | 'RemoveTrailingPeriodsTestData' >> beam.Map(lambda line: line[:-1])
          | 'DecodeTestData' >> csv_tfxio.BeamSource())

      raw_test_dataset = (raw_test_data, csv_tfxio.TensorAdapterConfig())

      # The TFXIO output format is chosen for improved performance.
      transformed_test_dataset = (
          (raw_test_dataset, transform_fn)
          | tft_beam.TransformDataset(output_record_batches=True))

      # Transformed metadata is not necessary for encoding.
      transformed_test_data, _ = transformed_test_dataset

      # Extract transformed RecordBatches, encode and write them to the given
      # directory.
      _ = (
          transformed_test_data
          | 'EncodeTestData' >>
          beam.FlatMapTuple(lambda batch, _: RecordBatchToExamples(batch))
          | 'WriteTestData' >> beam.io.WriteToTFRecord(
              os.path.join(working_dir, TRANSFORMED_TEST_DATA_FILEBASE)))

      # Will write a SavedModel and metadata to working_dir, which can then
      # be read by the tft.TFTransformOutput class.
      _ = (
          transform_fn
          | 'WriteTransformFn' >> tft_beam.WriteTransformFn(working_dir))

前処理されたデータを使用して、tf.kerasを使用してモデルをトレーニングする

どのように表示するにはtf.Transformトレーニングとサービス提供の両方に同じコードを使用することを可能にし、したがってスキュー防止し、我々はモデルをトレーニングするつもりです。モデルをトレーニングし、トレーニングしたモデルを本番用に準備するには、入力関数を作成する必要があります。トレーニング入力関数とサービング入力関数の主な違いは、トレーニングデータにはラベルが含まれ、本番データには含まれないことです。引数と戻り値も多少異なります。

トレーニング用の入力関数を作成する

def _make_training_input_fn(tf_transform_output, transformed_examples,
                            batch_size):
  """An input function reading from transformed data, converting to model input.

  Args:
    tf_transform_output: Wrapper around output of tf.Transform.
    transformed_examples: Base filename of examples.
    batch_size: Batch size.

  Returns:
    The input data for training or eval, in the form of k.
  """
  def input_fn():
    return tf.data.experimental.make_batched_features_dataset(
        file_pattern=transformed_examples,
        batch_size=batch_size,
        features=tf_transform_output.transformed_feature_spec(),
        reader=tf.data.TFRecordDataset,
        label_key=LABEL_KEY,
        shuffle=True).prefetch(tf.data.experimental.AUTOTUNE)

  return input_fn

提供するための入力関数を作成する

本番環境で使用できる入力関数を作成し、トレーニング済みのモデルを提供できるように準備しましょう。

def _make_serving_input_fn(tf_transform_output, raw_examples, batch_size):
  """An input function reading from raw data, converting to model input.

  Args:
    tf_transform_output: Wrapper around output of tf.Transform.
    raw_examples: Base filename of examples.
    batch_size: Batch size.

  Returns:
    The input data for training or eval, in the form of k.
  """

  def get_ordered_raw_data_dtypes():
    result = []
    for col in ORDERED_CSV_COLUMNS:
      if col not in RAW_DATA_FEATURE_SPEC:
        result.append(0.0)
        continue
      spec = RAW_DATA_FEATURE_SPEC[col]
      if isinstance(spec, tf.io.FixedLenFeature):
        result.append(spec.dtype)
      else:
        result.append(0.0)
    return result

  def input_fn():
    dataset = tf.data.experimental.make_csv_dataset(
        file_pattern=raw_examples,
        batch_size=batch_size,
        column_names=ORDERED_CSV_COLUMNS,
        column_defaults=get_ordered_raw_data_dtypes(),
        prefetch_buffer_size=0,
        ignore_errors=True)

    tft_layer = tf_transform_output.transform_features_layer()

    def transform_dataset(data):
      raw_features = {}
      for key, val in data.items():
        if key not in RAW_DATA_FEATURE_SPEC:
          continue
        if isinstance(RAW_DATA_FEATURE_SPEC[key], tf.io.VarLenFeature):
          raw_features[key] = tf.RaggedTensor.from_tensor(
              tf.expand_dims(val, -1)).to_sparse()
          continue
        raw_features[key] = val
      transformed_features = tft_layer(raw_features)
      data_labels = transformed_features.pop(LABEL_KEY)
      return (transformed_features, data_labels)

    return dataset.map(
        transform_dataset,
        num_parallel_calls=tf.data.experimental.AUTOTUNE).prefetch(
            tf.data.experimental.AUTOTUNE)

  return input_fn

モデルのトレーニング、評価、エクスポート

def export_serving_model(tf_transform_output, model, output_dir):
  """Exports a keras model for serving.

  Args:
    tf_transform_output: Wrapper around output of tf.Transform.
    model: A keras model to export for serving.
    output_dir: A directory where the model will be exported to.
  """
  # The layer has to be saved to the model for keras tracking purpases.
  model.tft_layer = tf_transform_output.transform_features_layer()

  @tf.function
  def serve_tf_examples_fn(serialized_tf_examples):
    """Serving tf.function model wrapper."""
    feature_spec = RAW_DATA_FEATURE_SPEC.copy()
    feature_spec.pop(LABEL_KEY)
    parsed_features = tf.io.parse_example(serialized_tf_examples, feature_spec)
    transformed_features = model.tft_layer(parsed_features)
    outputs = model(transformed_features)
    classes_names = tf.constant([['0', '1']])
    classes = tf.tile(classes_names, [tf.shape(outputs)[0], 1])
    return {'classes': classes, 'scores': outputs}

  concrete_serving_fn = serve_tf_examples_fn.get_concrete_function(
      tf.TensorSpec(shape=[None], dtype=tf.string, name='inputs'))
  signatures = {'serving_default': concrete_serving_fn}

  # This is required in order to make this model servable with model_server.
  versioned_output_dir = os.path.join(output_dir, '1')
  model.save(versioned_output_dir, save_format='tf', signatures=signatures)
def train_and_evaluate(working_dir,
                       num_train_instances=NUM_TRAIN_INSTANCES,
                       num_test_instances=NUM_TEST_INSTANCES):
  """Train the model on training data and evaluate on test data.

  Args:
    working_dir: The location of the Transform output.
    num_train_instances: Number of instances in train set
    num_test_instances: Number of instances in test set

  Returns:
    The results from the estimator's 'evaluate' method
  """
  train_data_path_pattern = os.path.join(working_dir,
                                 TRANSFORMED_TRAIN_DATA_FILEBASE + '*')
  eval_data_path_pattern = os.path.join(working_dir,
                            TRANSFORMED_TEST_DATA_FILEBASE + '*')
  tf_transform_output = tft.TFTransformOutput(working_dir)

  train_input_fn = _make_training_input_fn(
      tf_transform_output, train_data_path_pattern, batch_size=TRAIN_BATCH_SIZE)
  train_dataset = train_input_fn()

  # Evaluate model on test dataset.
  eval_input_fn = _make_training_input_fn(
      tf_transform_output, eval_data_path_pattern, batch_size=TRAIN_BATCH_SIZE)
  validation_dataset = eval_input_fn()

  feature_spec = tf_transform_output.transformed_feature_spec().copy()
  feature_spec.pop(LABEL_KEY)

  inputs = {}
  for key, spec in feature_spec.items():
    if isinstance(spec, tf.io.VarLenFeature):
      inputs[key] = tf.keras.layers.Input(
          shape=[None], name=key, dtype=spec.dtype, sparse=True)
    elif isinstance(spec, tf.io.FixedLenFeature):
      inputs[key] = tf.keras.layers.Input(
          shape=spec.shape, name=key, dtype=spec.dtype)
    else:
      raise ValueError('Spec type is not supported: ', key, spec)

  encoded_inputs = {}
  for key in inputs:
    feature = tf.expand_dims(inputs[key], -1)
    if key in CATEGORICAL_FEATURE_KEYS:
      num_buckets = tf_transform_output.num_buckets_for_transformed_feature(key)
      encoding_layer = (
          tf.keras.layers.experimental.preprocessing.CategoryEncoding(
              max_tokens=num_buckets, output_mode='binary', sparse=False))
      encoded_inputs[key] = encoding_layer(feature)
    else:
      encoded_inputs[key] = feature

  stacked_inputs = tf.concat(tf.nest.flatten(encoded_inputs), axis=1)
  output = tf.keras.layers.Dense(100, activation='relu')(stacked_inputs)
  output = tf.keras.layers.Dense(70, activation='relu')(output)
  output = tf.keras.layers.Dense(50, activation='relu')(output)
  output = tf.keras.layers.Dense(20, activation='relu')(output)
  output = tf.keras.layers.Dense(2, activation='sigmoid')(output)
  model = tf.keras.Model(inputs=inputs, outputs=output)

  model.compile(optimizer='adam',
                loss='binary_crossentropy',
                metrics=['accuracy'])
  pprint.pprint(model.summary())

  model.fit(train_dataset, validation_data=validation_dataset,
            epochs=TRAIN_NUM_EPOCHS,
            steps_per_epoch=math.ceil(num_train_instances / TRAIN_BATCH_SIZE),
            validation_steps=math.ceil(num_test_instances / TRAIN_BATCH_SIZE))

  # Export the model.
  exported_model_dir = os.path.join(working_dir, EXPORTED_MODEL_DIR)
  export_serving_model(tf_transform_output, model, exported_model_dir)

  metrics_values = model.evaluate(validation_dataset, steps=num_test_instances)
  metrics_labels = model.metrics_names
  return {l: v for l, v in zip(metrics_labels, metrics_values)}

すべてまとめる

国勢調査データを前処理し、モデルをトレーニングして、提供する準備をするために必要なすべてのものを作成しました。これまでのところ、準備は整っています。走り始める時が来ました!

import tempfile
temp = os.path.join(tempfile.gettempdir(), 'keras')

transform_data(train, test, temp)
results = train_and_evaluate(temp)
pprint.pprint(results)
WARNING:apache_beam.runners.interactive.interactive_environment:Dependencies required for Interactive Beam PCollection visualization are not available, please use: `pip install apache-beam[interactive]` to install necessary dependencies to enable all data visualization features.
WARNING:tensorflow:Tensorflow version (2.4.4) found. Note that Tensorflow Transform support for TF 2.0 is currently in beta, and features such as tf.function may not work as intended.
WARNING:tensorflow:Tensorflow version (2.4.4) found. Note that Tensorflow Transform support for TF 2.0 is currently in beta, and features such as tf.function may not work as intended.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_transform/tf_utils.py:266: Tensor.experimental_ref (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use ref() instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_transform/tf_utils.py:266: Tensor.experimental_ref (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use ref() instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:201: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.
2021-12-04 10:43:07.088016: W tensorflow/stream_executor/platform/default/dso_loader.cc:60] Could not load dynamic library 'libcusolver.so.10'; dlerror: libcusolver.so.10: cannot open shared object file: No such file or directory
2021-12-04 10:43:07.089022: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1757] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:201: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:No assets to write.
INFO:tensorflow:No assets to write.
WARNING:tensorflow:Issue encountered when serializing tft_mapper_use.
Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'Counter' object has no attribute 'name'
WARNING:tensorflow:Issue encountered when serializing tft_mapper_use.
Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'Counter' object has no attribute 'name'
INFO:tensorflow:SavedModel written to: /tmp/tmpwtmrrrxa/tftransform_tmp/3dfb612abc894c0ab0ae6895d85b5084/saved_model.pb
INFO:tensorflow:SavedModel written to: /tmp/tmpwtmrrrxa/tftransform_tmp/3dfb612abc894c0ab0ae6895d85b5084/saved_model.pb
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:No assets to write.
INFO:tensorflow:No assets to write.
WARNING:tensorflow:Issue encountered when serializing tft_mapper_use.
Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'Counter' object has no attribute 'name'
WARNING:tensorflow:Issue encountered when serializing tft_mapper_use.
Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'Counter' object has no attribute 'name'
INFO:tensorflow:SavedModel written to: /tmp/tmpwtmrrrxa/tftransform_tmp/c76371e6c4104068b035f1ba7ac0c160/saved_model.pb
INFO:tensorflow:SavedModel written to: /tmp/tmpwtmrrrxa/tftransform_tmp/c76371e6c4104068b035f1ba7ac0c160/saved_model.pb
WARNING:tensorflow:Tensorflow version (2.4.4) found. Note that Tensorflow Transform support for TF 2.0 is currently in beta, and features such as tf.function may not work as intended.
WARNING:tensorflow:Tensorflow version (2.4.4) found. Note that Tensorflow Transform support for TF 2.0 is currently in beta, and features such as tf.function may not work as intended.
WARNING:tensorflow:Tensorflow version (2.4.4) found. Note that Tensorflow Transform support for TF 2.0 is currently in beta, and features such as tf.function may not work as intended.
WARNING:tensorflow:Tensorflow version (2.4.4) found. Note that Tensorflow Transform support for TF 2.0 is currently in beta, and features such as tf.function may not work as intended. 
WARNING:root:Make sure that locally built Python SDK docker image has Python 3.7 interpreter.
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
2021-12-04 10:43:12.129285: W tensorflow/stream_executor/platform/default/dso_loader.cc:60] Could not load dynamic library 'libcusolver.so.10'; dlerror: libcusolver.so.10: cannot open shared object file: No such file or directory
2021-12-04 10:43:12.129350: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1757] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:Assets written to: /tmp/tmpwtmrrrxa/tftransform_tmp/a447c39aff834eaa8b3df63abd6a0d29/assets
INFO:tensorflow:Assets written to: /tmp/tmpwtmrrrxa/tftransform_tmp/a447c39aff834eaa8b3df63abd6a0d29/assets
INFO:tensorflow:SavedModel written to: /tmp/tmpwtmrrrxa/tftransform_tmp/a447c39aff834eaa8b3df63abd6a0d29/saved_model.pb
INFO:tensorflow:SavedModel written to: /tmp/tmpwtmrrrxa/tftransform_tmp/a447c39aff834eaa8b3df63abd6a0d29/saved_model.pb
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_3:0\022\tworkclass"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_3:0\022\tworkclass"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_5:0\022\teducation"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_5:0\022\teducation"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_7:0\022\016marital-status"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_7:0\022\016marital-status"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_9:0\022\noccupation"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_9:0\022\noccupation"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_11:0\022\014relationship"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_11:0\022\014relationship"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_13:0\022\004race"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_13:0\022\004race"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_15:0\022\003sex"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_15:0\022\003sex"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_17:0\022\016native-country"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_17:0\022\016native-country"
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
2021-12-04 10:43:17.368791: W tensorflow/stream_executor/platform/default/dso_loader.cc:60] Could not load dynamic library 'libcusolver.so.10'; dlerror: libcusolver.so.10: cannot open shared object file: No such file or directory
2021-12-04 10:43:17.368851: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1757] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_3:0\022\tworkclass"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_3:0\022\tworkclass"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_5:0\022\teducation"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_5:0\022\teducation"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_7:0\022\016marital-status"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_7:0\022\016marital-status"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_9:0\022\noccupation"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_9:0\022\noccupation"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_11:0\022\014relationship"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_11:0\022\014relationship"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_13:0\022\004race"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_13:0\022\004race"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_15:0\022\003sex"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_15:0\022\003sex"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_17:0\022\016native-country"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_17:0\022\016native-country"
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
WARNING:apache_beam.io.tfrecordio:Couldn't find python-snappy so the implementation of _TFRecordUtil._masked_crc32c is not as fast as it could be.
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_3:0\022\tworkclass"
2021-12-04 10:43:18.716754: W tensorflow/stream_executor/platform/default/dso_loader.cc:60] Could not load dynamic library 'libcusolver.so.10'; dlerror: libcusolver.so.10: cannot open shared object file: No such file or directory
2021-12-04 10:43:18.716809: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1757] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_3:0\022\tworkclass"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_5:0\022\teducation"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_5:0\022\teducation"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_7:0\022\016marital-status"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_7:0\022\016marital-status"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_9:0\022\noccupation"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_9:0\022\noccupation"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_11:0\022\014relationship"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_11:0\022\014relationship"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_13:0\022\004race"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_13:0\022\004race"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_15:0\022\003sex"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_15:0\022\003sex"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_17:0\022\016native-country"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_17:0\022\016native-country"
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
education (InputLayer)          [(None,)]            0                                            
__________________________________________________________________________________________________
marital-status (InputLayer)     [(None,)]            0                                            
__________________________________________________________________________________________________
native-country (InputLayer)     [(None,)]            0                                            
__________________________________________________________________________________________________
occupation (InputLayer)         [(None,)]            0                                            
__________________________________________________________________________________________________
race (InputLayer)               [(None,)]            0                                            
__________________________________________________________________________________________________
relationship (InputLayer)       [(None,)]            0                                            
__________________________________________________________________________________________________
sex (InputLayer)                [(None,)]            0                                            
__________________________________________________________________________________________________
workclass (InputLayer)          [(None,)]            0                                            
__________________________________________________________________________________________________
age (InputLayer)                [(None,)]            0                                            
__________________________________________________________________________________________________
capital-gain (InputLayer)       [(None,)]            0                                            
__________________________________________________________________________________________________
capital-loss (InputLayer)       [(None,)]            0                                            
__________________________________________________________________________________________________
tf.expand_dims_3 (TFOpLambda)   (None, 1)            0           education[0][0]                  
__________________________________________________________________________________________________
education-num (InputLayer)      [(None,)]            0                                            
__________________________________________________________________________________________________
hours-per-week (InputLayer)     [(None,)]            0                                            
__________________________________________________________________________________________________
tf.expand_dims_6 (TFOpLambda)   (None, 1)            0           marital-status[0][0]             
__________________________________________________________________________________________________
tf.expand_dims_7 (TFOpLambda)   (None, 1)            0           native-country[0][0]             
__________________________________________________________________________________________________
tf.expand_dims_8 (TFOpLambda)   (None, 1)            0           occupation[0][0]                 
__________________________________________________________________________________________________
tf.expand_dims_9 (TFOpLambda)   (None, 1)            0           race[0][0]                       
__________________________________________________________________________________________________
tf.expand_dims_10 (TFOpLambda)  (None, 1)            0           relationship[0][0]               
__________________________________________________________________________________________________
tf.expand_dims_11 (TFOpLambda)  (None, 1)            0           sex[0][0]                        
__________________________________________________________________________________________________
tf.expand_dims_12 (TFOpLambda)  (None, 1)            0           workclass[0][0]                  
__________________________________________________________________________________________________
tf.expand_dims (TFOpLambda)     (None, 1)            0           age[0][0]                        
__________________________________________________________________________________________________
tf.expand_dims_1 (TFOpLambda)   (None, 1)            0           capital-gain[0][0]               
__________________________________________________________________________________________________
tf.expand_dims_2 (TFOpLambda)   (None, 1)            0           capital-loss[0][0]               
__________________________________________________________________________________________________
category_encoding (CategoryEnco (None, 17)           0           tf.expand_dims_3[0][0]           
__________________________________________________________________________________________________
tf.expand_dims_4 (TFOpLambda)   (None, 1)            0           education-num[0][0]              
__________________________________________________________________________________________________
tf.expand_dims_5 (TFOpLambda)   (None, 1)            0           hours-per-week[0][0]             
__________________________________________________________________________________________________
category_encoding_1 (CategoryEn (None, 8)            0           tf.expand_dims_6[0][0]           
__________________________________________________________________________________________________
category_encoding_2 (CategoryEn (None, 43)           0           tf.expand_dims_7[0][0]           
__________________________________________________________________________________________________
category_encoding_3 (CategoryEn (None, 16)           0           tf.expand_dims_8[0][0]           
__________________________________________________________________________________________________
category_encoding_4 (CategoryEn (None, 6)            0           tf.expand_dims_9[0][0]           
__________________________________________________________________________________________________
category_encoding_5 (CategoryEn (None, 7)            0           tf.expand_dims_10[0][0]          
__________________________________________________________________________________________________
category_encoding_6 (CategoryEn (None, 3)            0           tf.expand_dims_11[0][0]          
__________________________________________________________________________________________________
category_encoding_7 (CategoryEn (None, 10)           0           tf.expand_dims_12[0][0]          
__________________________________________________________________________________________________
tf.concat (TFOpLambda)          (None, 115)          0           tf.expand_dims[0][0]             
                                                                 tf.expand_dims_1[0][0]           
                                                                 tf.expand_dims_2[0][0]           
                                                                 category_encoding[0][0]          
                                                                 tf.expand_dims_4[0][0]           
                                                                 tf.expand_dims_5[0][0]           
                                                                 category_encoding_1[0][0]        
                                                                 category_encoding_2[0][0]        
                                                                 category_encoding_3[0][0]        
                                                                 category_encoding_4[0][0]        
                                                                 category_encoding_5[0][0]        
                                                                 category_encoding_6[0][0]        
                                                                 category_encoding_7[0][0]        
__________________________________________________________________________________________________
dense (Dense)                   (None, 100)          11600       tf.concat[0][0]                  
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 70)           7070        dense[0][0]                      
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 50)           3550        dense_1[0][0]                    
__________________________________________________________________________________________________
dense_3 (Dense)                 (None, 20)           1020        dense_2[0][0]                    
__________________________________________________________________________________________________
dense_4 (Dense)                 (None, 2)            42          dense_3[0][0]                    
==================================================================================================
Total params: 23,282
Trainable params: 23,282
Non-trainable params: 0
__________________________________________________________________________________________________
None
Epoch 1/16
255/255 [==============================] - 2s 5ms/step - loss: 0.4575 - accuracy: 0.7892 - val_loss: 0.3393 - val_accuracy: 0.8425
Epoch 2/16
255/255 [==============================] - 1s 3ms/step - loss: 0.3390 - accuracy: 0.8420 - val_loss: 0.3367 - val_accuracy: 0.8442
Epoch 3/16
255/255 [==============================] - 1s 3ms/step - loss: 0.3278 - accuracy: 0.8478 - val_loss: 0.3256 - val_accuracy: 0.8490
Epoch 4/16
255/255 [==============================] - 1s 3ms/step - loss: 0.3182 - accuracy: 0.8494 - val_loss: 0.3246 - val_accuracy: 0.8481
Epoch 5/16
255/255 [==============================] - 1s 3ms/step - loss: 0.3133 - accuracy: 0.8527 - val_loss: 0.3204 - val_accuracy: 0.8484
Epoch 6/16
255/255 [==============================] - 1s 3ms/step - loss: 0.3054 - accuracy: 0.8566 - val_loss: 0.3232 - val_accuracy: 0.8480
Epoch 7/16
255/255 [==============================] - 1s 4ms/step - loss: 0.3024 - accuracy: 0.8568 - val_loss: 0.3248 - val_accuracy: 0.8488
Epoch 8/16
255/255 [==============================] - 1s 3ms/step - loss: 0.2970 - accuracy: 0.8595 - val_loss: 0.3310 - val_accuracy: 0.8470
Epoch 9/16
255/255 [==============================] - 1s 3ms/step - loss: 0.2932 - accuracy: 0.8619 - val_loss: 0.3277 - val_accuracy: 0.8465
Epoch 10/16
255/255 [==============================] - 1s 3ms/step - loss: 0.2946 - accuracy: 0.8617 - val_loss: 0.3292 - val_accuracy: 0.8495
Epoch 11/16
255/255 [==============================] - 1s 3ms/step - loss: 0.2914 - accuracy: 0.8606 - val_loss: 0.3334 - val_accuracy: 0.8511
Epoch 12/16
255/255 [==============================] - 1s 3ms/step - loss: 0.2864 - accuracy: 0.8631 - val_loss: 0.3328 - val_accuracy: 0.8490
Epoch 13/16
255/255 [==============================] - 1s 3ms/step - loss: 0.2811 - accuracy: 0.8671 - val_loss: 0.3386 - val_accuracy: 0.8503
Epoch 14/16
255/255 [==============================] - 1s 3ms/step - loss: 0.2738 - accuracy: 0.8720 - val_loss: 0.3397 - val_accuracy: 0.8483
Epoch 15/16
255/255 [==============================] - 1s 3ms/step - loss: 0.2709 - accuracy: 0.8745 - val_loss: 0.3429 - val_accuracy: 0.8491
Epoch 16/16
255/255 [==============================] - 1s 3ms/step - loss: 0.2705 - accuracy: 0.8724 - val_loss: 0.3467 - val_accuracy: 0.8491
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
2021-12-04 10:43:37.584301: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
INFO:tensorflow:Assets written to: /tmp/keras/exported_model_dir/1/assets
INFO:tensorflow:Assets written to: /tmp/keras/exported_model_dir/1/assets
16281/16281 [==============================] - 21s 1ms/step - loss: 0.3470 - accuracy: 0.8491
{'accuracy': 0.8490878939628601, 'loss': 0.34699547290802}

(オプション)前処理されたデータを使用して、tf.estimatorを使用してモデルをトレーニングします

Kerasモデルの代わりにEstimatorモデルを使用したい場合は、このセクションのコードにその方法を示します。

トレーニング用の入力関数を作成する

def _make_training_input_fn(tf_transform_output, transformed_examples,
                            batch_size):
  """Creates an input function reading from transformed data.

  Args:
    tf_transform_output: Wrapper around output of tf.Transform.
    transformed_examples: Base filename of examples.
    batch_size: Batch size.

  Returns:
    The input function for training or eval.
  """
  def input_fn():
    """Input function for training and eval."""
    dataset = tf.data.experimental.make_batched_features_dataset(
        file_pattern=transformed_examples,
        batch_size=batch_size,
        features=tf_transform_output.transformed_feature_spec(),
        reader=tf.data.TFRecordDataset,
        shuffle=True)

    transformed_features = tf.compat.v1.data.make_one_shot_iterator(
        dataset).get_next()

    # Extract features and label from the transformed tensors.
    transformed_labels = tf.where(
        tf.equal(transformed_features.pop(LABEL_KEY), 1))

    return transformed_features, transformed_labels[:,1]

  return input_fn

提供するための入力関数を作成する

本番環境で使用できる入力関数を作成し、トレーニング済みのモデルを提供できるように準備しましょう。

def _make_serving_input_fn(tf_transform_output):
  """Creates an input function reading from raw data.

  Args:
    tf_transform_output: Wrapper around output of tf.Transform.

  Returns:
    The serving input function.
  """
  raw_feature_spec = RAW_DATA_FEATURE_SPEC.copy()
  # Remove label since it is not available during serving.
  raw_feature_spec.pop(LABEL_KEY)

  def serving_input_fn():
    """Input function for serving."""
    # Get raw features by generating the basic serving input_fn and calling it.
    # Here we generate an input_fn that expects a parsed Example proto to be fed
    # to the model at serving time.  See also
    # tf.estimator.export.build_raw_serving_input_receiver_fn.
    raw_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
        raw_feature_spec, default_batch_size=None)
    serving_input_receiver = raw_input_fn()

    # Apply the transform function that was used to generate the materialized
    # data.
    raw_features = serving_input_receiver.features
    transformed_features = tf_transform_output.transform_raw_features(
        raw_features)

    return tf.estimator.export.ServingInputReceiver(
        transformed_features, serving_input_receiver.receiver_tensors)

  return serving_input_fn

入力データをFeatureColumnsでラップします

このモデルでは、TensorFlowFeatureColumnsのデータが必要です。

def get_feature_columns(tf_transform_output):
  """Returns the FeatureColumns for the model.

  Args:
    tf_transform_output: A `TFTransformOutput` object.

  Returns:
    A list of FeatureColumns.
  """
  # Wrap scalars as real valued columns.
  real_valued_columns = [tf.feature_column.numeric_column(key, shape=())
                         for key in NUMERIC_FEATURE_KEYS]

  # Wrap categorical columns.
  one_hot_columns = [
      tf.feature_column.indicator_column(
          tf.feature_column.categorical_column_with_identity(
              key=key,
              num_buckets=(NUM_OOV_BUCKETS +
                  tf_transform_output.vocabulary_size_by_name(
                      vocab_filename=key))))
      for key in CATEGORICAL_FEATURE_KEYS]

  return real_valued_columns + one_hot_columns

モデルのトレーニング、評価、エクスポート

def train_and_evaluate(working_dir, num_train_instances=NUM_TRAIN_INSTANCES,
                       num_test_instances=NUM_TEST_INSTANCES):
  """Train the model on training data and evaluate on test data.

  Args:
    working_dir: Directory to read transformed data and metadata from and to
        write exported model to.
    num_train_instances: Number of instances in train set
    num_test_instances: Number of instances in test set

  Returns:
    The results from the estimator's 'evaluate' method
  """
  tf_transform_output = tft.TFTransformOutput(working_dir)

  run_config = tf.estimator.RunConfig()

  estimator = tf.estimator.LinearClassifier(
      feature_columns=get_feature_columns(tf_transform_output),
      config=run_config,
      loss_reduction=tf.losses.Reduction.SUM)

  # Fit the model using the default optimizer.
  train_input_fn = _make_training_input_fn(
      tf_transform_output,
      os.path.join(working_dir, TRANSFORMED_TRAIN_DATA_FILEBASE + '*'),
      batch_size=TRAIN_BATCH_SIZE)
  estimator.train(
      input_fn=train_input_fn,
      max_steps=TRAIN_NUM_EPOCHS * num_train_instances / TRAIN_BATCH_SIZE)

  # Evaluate model on test dataset.
  eval_input_fn = _make_training_input_fn(
      tf_transform_output,
      os.path.join(working_dir, TRANSFORMED_TEST_DATA_FILEBASE + '*'),
      batch_size=1)

  # Export the model.
  serving_input_fn = _make_serving_input_fn(tf_transform_output)
  exported_model_dir = os.path.join(working_dir, EXPORTED_MODEL_DIR)
  estimator.export_saved_model(exported_model_dir, serving_input_fn)

  return estimator.evaluate(input_fn=eval_input_fn, steps=num_test_instances)

すべてまとめる

国勢調査データを前処理し、モデルをトレーニングして、提供する準備をするために必要なすべてのものを作成しました。これまでのところ、準備は整っています。走り始める時が来ました!

import tempfile
temp = os.path.join(tempfile.gettempdir(), 'estimator')

transform_data(train, test, temp)
results = train_and_evaluate(temp)
pprint.pprint(results)
WARNING:tensorflow:Tensorflow version (2.4.4) found. Note that Tensorflow Transform support for TF 2.0 is currently in beta, and features such as tf.function may not work as intended.
WARNING:tensorflow:Tensorflow version (2.4.4) found. Note that Tensorflow Transform support for TF 2.0 is currently in beta, and features such as tf.function may not work as intended.
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:No assets to write.
INFO:tensorflow:No assets to write.
WARNING:tensorflow:Issue encountered when serializing tft_mapper_use.
Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'Counter' object has no attribute 'name'
WARNING:tensorflow:Issue encountered when serializing tft_mapper_use.
Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'Counter' object has no attribute 'name'
INFO:tensorflow:SavedModel written to: /tmp/tmpi7o66bl8/tftransform_tmp/a7f3726df5bf498ca24bd528eebca9e9/saved_model.pb
INFO:tensorflow:SavedModel written to: /tmp/tmpi7o66bl8/tftransform_tmp/a7f3726df5bf498ca24bd528eebca9e9/saved_model.pb
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:No assets to write.
INFO:tensorflow:No assets to write.
WARNING:tensorflow:Issue encountered when serializing tft_mapper_use.
Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'Counter' object has no attribute 'name'
WARNING:tensorflow:Issue encountered when serializing tft_mapper_use.
Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'Counter' object has no attribute 'name'
INFO:tensorflow:SavedModel written to: /tmp/tmpi7o66bl8/tftransform_tmp/3466a3517ec243a39102fa6ad6e5fec2/saved_model.pb
INFO:tensorflow:SavedModel written to: /tmp/tmpi7o66bl8/tftransform_tmp/3466a3517ec243a39102fa6ad6e5fec2/saved_model.pb
WARNING:tensorflow:Tensorflow version (2.4.4) found. Note that Tensorflow Transform support for TF 2.0 is currently in beta, and features such as tf.function may not work as intended.
WARNING:tensorflow:Tensorflow version (2.4.4) found. Note that Tensorflow Transform support for TF 2.0 is currently in beta, and features such as tf.function may not work as intended.
WARNING:tensorflow:Tensorflow version (2.4.4) found. Note that Tensorflow Transform support for TF 2.0 is currently in beta, and features such as tf.function may not work as intended.
WARNING:tensorflow:Tensorflow version (2.4.4) found. Note that Tensorflow Transform support for TF 2.0 is currently in beta, and features such as tf.function may not work as intended. 
WARNING:root:Make sure that locally built Python SDK docker image has Python 3.7 interpreter.
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
2021-12-04 10:44:05.733070: W tensorflow/stream_executor/platform/default/dso_loader.cc:60] Could not load dynamic library 'libcusolver.so.10'; dlerror: libcusolver.so.10: cannot open shared object file: No such file or directory
2021-12-04 10:44:05.733123: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1757] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:Assets written to: /tmp/tmpi7o66bl8/tftransform_tmp/96186aa415404f0884cb3766b270b9b2/assets
INFO:tensorflow:Assets written to: /tmp/tmpi7o66bl8/tftransform_tmp/96186aa415404f0884cb3766b270b9b2/assets
INFO:tensorflow:SavedModel written to: /tmp/tmpi7o66bl8/tftransform_tmp/96186aa415404f0884cb3766b270b9b2/saved_model.pb
INFO:tensorflow:SavedModel written to: /tmp/tmpi7o66bl8/tftransform_tmp/96186aa415404f0884cb3766b270b9b2/saved_model.pb
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_3:0\022\tworkclass"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_3:0\022\tworkclass"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_5:0\022\teducation"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_5:0\022\teducation"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_7:0\022\016marital-status"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_7:0\022\016marital-status"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_9:0\022\noccupation"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_9:0\022\noccupation"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_11:0\022\014relationship"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_11:0\022\014relationship"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_13:0\022\004race"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_13:0\022\004race"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_15:0\022\003sex"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_15:0\022\003sex"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_17:0\022\016native-country"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_17:0\022\016native-country"
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_3:0\022\tworkclass"
2021-12-04 10:44:10.983401: W tensorflow/stream_executor/platform/default/dso_loader.cc:60] Could not load dynamic library 'libcusolver.so.10'; dlerror: libcusolver.so.10: cannot open shared object file: No such file or directory
2021-12-04 10:44:10.983461: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1757] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_3:0\022\tworkclass"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_5:0\022\teducation"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_5:0\022\teducation"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_7:0\022\016marital-status"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_7:0\022\016marital-status"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_9:0\022\noccupation"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_9:0\022\noccupation"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_11:0\022\014relationship"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_11:0\022\014relationship"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_13:0\022\004race"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_13:0\022\004race"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_15:0\022\003sex"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_15:0\022\003sex"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_17:0\022\016native-country"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_17:0\022\016native-country"
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_3:0\022\tworkclass"
2021-12-04 10:44:12.469671: W tensorflow/stream_executor/platform/default/dso_loader.cc:60] Could not load dynamic library 'libcusolver.so.10'; dlerror: libcusolver.so.10: cannot open shared object file: No such file or directory
2021-12-04 10:44:12.469756: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1757] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_3:0\022\tworkclass"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_5:0\022\teducation"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_5:0\022\teducation"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_7:0\022\016marital-status"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_7:0\022\016marital-status"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_9:0\022\noccupation"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_9:0\022\noccupation"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_11:0\022\014relationship"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_11:0\022\014relationship"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_13:0\022\004race"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_13:0\022\004race"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_15:0\022\003sex"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_15:0\022\003sex"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_17:0\022\016native-country"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_17:0\022\016native-country"
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpwufx88ji
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpwufx88ji
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpwufx88ji', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpwufx88ji', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer_v1.py:1727: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use `layer.add_weight` method instead.
  warnings.warn('`layer.add_variable` is deprecated and '
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/keras/optimizer_v2/ftrl.py:134: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/keras/optimizer_v2/ftrl.py:134: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
2021-12-04 10:44:15.191355: W tensorflow/stream_executor/platform/default/dso_loader.cc:60] Could not load dynamic library 'libcusolver.so.10'; dlerror: libcusolver.so.10: cannot open shared object file: No such file or directory
2021-12-04 10:44:15.191419: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1757] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpwufx88ji/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpwufx88ji/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 88.72284, step = 0
INFO:tensorflow:loss = 88.72284, step = 0
INFO:tensorflow:global_step/sec: 432.87
INFO:tensorflow:global_step/sec: 432.87
INFO:tensorflow:loss = 33.484627, step = 100 (0.233 sec)
INFO:tensorflow:loss = 33.484627, step = 100 (0.233 sec)
INFO:tensorflow:global_step/sec: 764.774
INFO:tensorflow:global_step/sec: 764.774
INFO:tensorflow:loss = 42.72283, step = 200 (0.130 sec)
INFO:tensorflow:loss = 42.72283, step = 200 (0.130 sec)
INFO:tensorflow:global_step/sec: 763.549
INFO:tensorflow:global_step/sec: 763.549
INFO:tensorflow:loss = 55.91174, step = 300 (0.131 sec)
INFO:tensorflow:loss = 55.91174, step = 300 (0.131 sec)
INFO:tensorflow:global_step/sec: 755.175
INFO:tensorflow:global_step/sec: 755.175
INFO:tensorflow:loss = 39.204643, step = 400 (0.133 sec)
INFO:tensorflow:loss = 39.204643, step = 400 (0.133 sec)
INFO:tensorflow:global_step/sec: 792.262
INFO:tensorflow:global_step/sec: 792.262
INFO:tensorflow:loss = 41.268295, step = 500 (0.126 sec)
INFO:tensorflow:loss = 41.268295, step = 500 (0.126 sec)
INFO:tensorflow:global_step/sec: 743.725
INFO:tensorflow:global_step/sec: 743.725
INFO:tensorflow:loss = 51.267006, step = 600 (0.135 sec)
INFO:tensorflow:loss = 51.267006, step = 600 (0.135 sec)
INFO:tensorflow:global_step/sec: 806.716
INFO:tensorflow:global_step/sec: 806.716
INFO:tensorflow:loss = 42.03744, step = 700 (0.124 sec)
INFO:tensorflow:loss = 42.03744, step = 700 (0.124 sec)
INFO:tensorflow:global_step/sec: 763.135
INFO:tensorflow:global_step/sec: 763.135
INFO:tensorflow:loss = 42.66994, step = 800 (0.131 sec)
INFO:tensorflow:loss = 42.66994, step = 800 (0.131 sec)
INFO:tensorflow:global_step/sec: 779.496
INFO:tensorflow:global_step/sec: 779.496
INFO:tensorflow:loss = 48.643982, step = 900 (0.129 sec)
INFO:tensorflow:loss = 48.643982, step = 900 (0.129 sec)
INFO:tensorflow:global_step/sec: 787.431
INFO:tensorflow:global_step/sec: 787.431
INFO:tensorflow:loss = 41.668102, step = 1000 (0.127 sec)
INFO:tensorflow:loss = 41.668102, step = 1000 (0.127 sec)
INFO:tensorflow:global_step/sec: 737.697
INFO:tensorflow:global_step/sec: 737.697
INFO:tensorflow:loss = 40.340927, step = 1100 (0.135 sec)
INFO:tensorflow:loss = 40.340927, step = 1100 (0.135 sec)
INFO:tensorflow:global_step/sec: 755.647
INFO:tensorflow:global_step/sec: 755.647
INFO:tensorflow:loss = 31.146494, step = 1200 (0.133 sec)
INFO:tensorflow:loss = 31.146494, step = 1200 (0.133 sec)
INFO:tensorflow:global_step/sec: 785.653
INFO:tensorflow:global_step/sec: 785.653
INFO:tensorflow:loss = 30.96864, step = 1300 (0.127 sec)
INFO:tensorflow:loss = 30.96864, step = 1300 (0.127 sec)
INFO:tensorflow:global_step/sec: 759.461
INFO:tensorflow:global_step/sec: 759.461
INFO:tensorflow:loss = 38.621964, step = 1400 (0.132 sec)
INFO:tensorflow:loss = 38.621964, step = 1400 (0.132 sec)
INFO:tensorflow:global_step/sec: 777.328
INFO:tensorflow:global_step/sec: 777.328
INFO:tensorflow:loss = 44.518555, step = 1500 (0.129 sec)
INFO:tensorflow:loss = 44.518555, step = 1500 (0.129 sec)
INFO:tensorflow:global_step/sec: 741.005
INFO:tensorflow:global_step/sec: 741.005
INFO:tensorflow:loss = 45.997204, step = 1600 (0.135 sec)
INFO:tensorflow:loss = 45.997204, step = 1600 (0.135 sec)
INFO:tensorflow:global_step/sec: 734.846
INFO:tensorflow:global_step/sec: 734.846
INFO:tensorflow:loss = 50.39132, step = 1700 (0.136 sec)
INFO:tensorflow:loss = 50.39132, step = 1700 (0.136 sec)
INFO:tensorflow:global_step/sec: 752.826
INFO:tensorflow:global_step/sec: 752.826
INFO:tensorflow:loss = 45.41472, step = 1800 (0.133 sec)
INFO:tensorflow:loss = 45.41472, step = 1800 (0.133 sec)
INFO:tensorflow:global_step/sec: 757.018
INFO:tensorflow:global_step/sec: 757.018
INFO:tensorflow:loss = 46.133186, step = 1900 (0.132 sec)
INFO:tensorflow:loss = 46.133186, step = 1900 (0.132 sec)
INFO:tensorflow:global_step/sec: 700.757
INFO:tensorflow:global_step/sec: 700.757
INFO:tensorflow:loss = 34.684982, step = 2000 (0.143 sec)
INFO:tensorflow:loss = 34.684982, step = 2000 (0.143 sec)
INFO:tensorflow:global_step/sec: 741.709
INFO:tensorflow:global_step/sec: 741.709
INFO:tensorflow:loss = 39.637863, step = 2100 (0.135 sec)
INFO:tensorflow:loss = 39.637863, step = 2100 (0.135 sec)
INFO:tensorflow:global_step/sec: 772.066
INFO:tensorflow:global_step/sec: 772.066
INFO:tensorflow:loss = 45.70813, step = 2200 (0.129 sec)
INFO:tensorflow:loss = 45.70813, step = 2200 (0.129 sec)
INFO:tensorflow:global_step/sec: 776.263
INFO:tensorflow:global_step/sec: 776.263
INFO:tensorflow:loss = 39.104668, step = 2300 (0.129 sec)
INFO:tensorflow:loss = 39.104668, step = 2300 (0.129 sec)
INFO:tensorflow:global_step/sec: 768.016
INFO:tensorflow:global_step/sec: 768.016
INFO:tensorflow:loss = 36.262817, step = 2400 (0.130 sec)
INFO:tensorflow:loss = 36.262817, step = 2400 (0.130 sec)
INFO:tensorflow:global_step/sec: 754.04
INFO:tensorflow:global_step/sec: 754.04
INFO:tensorflow:loss = 43.80282, step = 2500 (0.132 sec)
INFO:tensorflow:loss = 43.80282, step = 2500 (0.132 sec)
INFO:tensorflow:global_step/sec: 742.917
INFO:tensorflow:global_step/sec: 742.917
INFO:tensorflow:loss = 48.113125, step = 2600 (0.135 sec)
INFO:tensorflow:loss = 48.113125, step = 2600 (0.135 sec)
INFO:tensorflow:global_step/sec: 753.394
INFO:tensorflow:global_step/sec: 753.394
INFO:tensorflow:loss = 43.442005, step = 2700 (0.133 sec)
INFO:tensorflow:loss = 43.442005, step = 2700 (0.133 sec)
INFO:tensorflow:global_step/sec: 768.985
INFO:tensorflow:global_step/sec: 768.985
INFO:tensorflow:loss = 34.593086, step = 2800 (0.130 sec)
INFO:tensorflow:loss = 34.593086, step = 2800 (0.130 sec)
INFO:tensorflow:global_step/sec: 756.393
INFO:tensorflow:global_step/sec: 756.393
INFO:tensorflow:loss = 38.085594, step = 2900 (0.132 sec)
INFO:tensorflow:loss = 38.085594, step = 2900 (0.132 sec)
INFO:tensorflow:global_step/sec: 792.717
INFO:tensorflow:global_step/sec: 792.717
INFO:tensorflow:loss = 42.41484, step = 3000 (0.126 sec)
INFO:tensorflow:loss = 42.41484, step = 3000 (0.126 sec)
INFO:tensorflow:global_step/sec: 763.25
INFO:tensorflow:global_step/sec: 763.25
INFO:tensorflow:loss = 42.457626, step = 3100 (0.131 sec)
INFO:tensorflow:loss = 42.457626, step = 3100 (0.131 sec)
INFO:tensorflow:global_step/sec: 747.998
INFO:tensorflow:global_step/sec: 747.998
INFO:tensorflow:loss = 52.64791, step = 3200 (0.134 sec)
INFO:tensorflow:loss = 52.64791, step = 3200 (0.134 sec)
INFO:tensorflow:global_step/sec: 733.804
INFO:tensorflow:global_step/sec: 733.804
INFO:tensorflow:loss = 36.78949, step = 3300 (0.136 sec)
INFO:tensorflow:loss = 36.78949, step = 3300 (0.136 sec)
INFO:tensorflow:global_step/sec: 747.473
INFO:tensorflow:global_step/sec: 747.473
INFO:tensorflow:loss = 43.02353, step = 3400 (0.134 sec)
INFO:tensorflow:loss = 43.02353, step = 3400 (0.134 sec)
INFO:tensorflow:global_step/sec: 766.967
INFO:tensorflow:global_step/sec: 766.967
INFO:tensorflow:loss = 42.971584, step = 3500 (0.131 sec)
INFO:tensorflow:loss = 42.971584, step = 3500 (0.131 sec)
INFO:tensorflow:global_step/sec: 759.238
INFO:tensorflow:global_step/sec: 759.238
INFO:tensorflow:loss = 31.898714, step = 3600 (0.133 sec)
INFO:tensorflow:loss = 31.898714, step = 3600 (0.133 sec)
INFO:tensorflow:global_step/sec: 770.209
INFO:tensorflow:global_step/sec: 770.209
INFO:tensorflow:loss = 43.47151, step = 3700 (0.128 sec)
INFO:tensorflow:loss = 43.47151, step = 3700 (0.128 sec)
INFO:tensorflow:global_step/sec: 750.127
INFO:tensorflow:global_step/sec: 750.127
INFO:tensorflow:loss = 40.073875, step = 3800 (0.133 sec)
INFO:tensorflow:loss = 40.073875, step = 3800 (0.133 sec)
INFO:tensorflow:global_step/sec: 731.607
INFO:tensorflow:global_step/sec: 731.607
INFO:tensorflow:loss = 33.494003, step = 3900 (0.137 sec)
INFO:tensorflow:loss = 33.494003, step = 3900 (0.137 sec)
INFO:tensorflow:global_step/sec: 753.01
INFO:tensorflow:global_step/sec: 753.01
INFO:tensorflow:loss = 40.401936, step = 4000 (0.133 sec)
INFO:tensorflow:loss = 40.401936, step = 4000 (0.133 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 4071...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 4071...
INFO:tensorflow:Saving checkpoints for 4071 into /tmp/tmpwufx88ji/model.ckpt.
INFO:tensorflow:Saving checkpoints for 4071 into /tmp/tmpwufx88ji/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 4071...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 4071...
INFO:tensorflow:Loss for final step: 51.911263.
INFO:tensorflow:Loss for final step: 51.911263.
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_3:0\022\tworkclass"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_3:0\022\tworkclass"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_5:0\022\teducation"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_5:0\022\teducation"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_7:0\022\016marital-status"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_7:0\022\016marital-status"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_9:0\022\noccupation"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\013\n\tConst_9:0\022\noccupation"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_11:0\022\014relationship"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_11:0\022\014relationship"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_13:0\022\004race"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_13:0\022\004race"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_15:0\022\003sex"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_15:0\022\003sex"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_17:0\022\016native-country"
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\014\n\nConst_17:0\022\016native-country"
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification']
INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification']
INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression']
INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression']
INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict']
INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict']
INFO:tensorflow:Signatures INCLUDED in export for Train: None
INFO:tensorflow:Signatures INCLUDED in export for Train: None
INFO:tensorflow:Signatures INCLUDED in export for Eval: None
INFO:tensorflow:Signatures INCLUDED in export for Eval: None
INFO:tensorflow:Restoring parameters from /tmp/tmpwufx88ji/model.ckpt-4071
2021-12-04 10:44:22.080737: W tensorflow/stream_executor/platform/default/dso_loader.cc:60] Could not load dynamic library 'libcusolver.so.10'; dlerror: libcusolver.so.10: cannot open shared object file: No such file or directory
2021-12-04 10:44:22.080796: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1757] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
INFO:tensorflow:Restoring parameters from /tmp/tmpwufx88ji/model.ckpt-4071
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:Assets written to: /tmp/estimator/exported_model_dir/temp-1638614661/assets
INFO:tensorflow:Assets written to: /tmp/estimator/exported_model_dir/temp-1638614661/assets
INFO:tensorflow:SavedModel written to: /tmp/estimator/exported_model_dir/temp-1638614661/saved_model.pb
INFO:tensorflow:SavedModel written to: /tmp/estimator/exported_model_dir/temp-1638614661/saved_model.pb
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-12-04T10:44:23Z
INFO:tensorflow:Starting evaluation at 2021-12-04T10:44:23Z
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpwufx88ji/model.ckpt-4071
2021-12-04 10:44:23.300547: W tensorflow/stream_executor/platform/default/dso_loader.cc:60] Could not load dynamic library 'libcusolver.so.10'; dlerror: libcusolver.so.10: cannot open shared object file: No such file or directory
2021-12-04 10:44:23.300668: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1757] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
INFO:tensorflow:Restoring parameters from /tmp/tmpwufx88ji/model.ckpt-4071
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1628/16281]
INFO:tensorflow:Evaluation [1628/16281]
INFO:tensorflow:Evaluation [3256/16281]
INFO:tensorflow:Evaluation [3256/16281]
INFO:tensorflow:Evaluation [4884/16281]
INFO:tensorflow:Evaluation [4884/16281]
INFO:tensorflow:Evaluation [6512/16281]
INFO:tensorflow:Evaluation [6512/16281]
INFO:tensorflow:Evaluation [8140/16281]
INFO:tensorflow:Evaluation [8140/16281]
INFO:tensorflow:Evaluation [9768/16281]
INFO:tensorflow:Evaluation [9768/16281]
INFO:tensorflow:Evaluation [11396/16281]
INFO:tensorflow:Evaluation [11396/16281]
INFO:tensorflow:Evaluation [13024/16281]
INFO:tensorflow:Evaluation [13024/16281]
INFO:tensorflow:Evaluation [14652/16281]
INFO:tensorflow:Evaluation [14652/16281]
INFO:tensorflow:Evaluation [16280/16281]
INFO:tensorflow:Evaluation [16280/16281]
INFO:tensorflow:Evaluation [16281/16281]
INFO:tensorflow:Evaluation [16281/16281]
INFO:tensorflow:Inference Time : 12.76048s
INFO:tensorflow:Inference Time : 12.76048s
INFO:tensorflow:Finished evaluation at 2021-12-04-10:44:35
INFO:tensorflow:Finished evaluation at 2021-12-04-10:44:35
INFO:tensorflow:Saving dict for global step 4071: accuracy = 0.85123765, accuracy_baseline = 0.76377374, auc = 0.9019859, auc_precision_recall = 0.9672531, average_loss = 0.32398567, global_step = 4071, label/mean = 0.76377374, loss = 0.32398567, precision = 0.8828477, prediction/mean = 0.75662553, recall = 0.9284278
INFO:tensorflow:Saving dict for global step 4071: accuracy = 0.85123765, accuracy_baseline = 0.76377374, auc = 0.9019859, auc_precision_recall = 0.9672531, average_loss = 0.32398567, global_step = 4071, label/mean = 0.76377374, loss = 0.32398567, precision = 0.8828477, prediction/mean = 0.75662553, recall = 0.9284278
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 4071: /tmp/tmpwufx88ji/model.ckpt-4071
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 4071: /tmp/tmpwufx88ji/model.ckpt-4071
{'accuracy': 0.85123765,
 'accuracy_baseline': 0.76377374,
 'auc': 0.9019859,
 'auc_precision_recall': 0.9672531,
 'average_loss': 0.32398567,
 'global_step': 4071,
 'label/mean': 0.76377374,
 'loss': 0.32398567,
 'precision': 0.8828477,
 'prediction/mean': 0.75662553,
 'recall': 0.9284278}

我々のしたこと

この例では、使用tf.Transform国勢調査データのデータセットを前処理すると、洗浄され、変換されたデータとモデルを訓練します。また、トレーニング済みモデルを本番環境にデプロイして推論を実行するときに使用できる入力関数も作成しました。トレーニングと推論の両方に同じコードを使用することで、データの偏りに関する問題を回避します。その過程で、データのクリーンアップに必要な変換を実行するためのApacheBeam変換の作成について学習しました。また、いずれかを使用してモデルを訓練するために、この変換されたデータを使用する方法を説明しましたtf.kerasまたはtf.estimator 。これは、TensorFlowTransformでできることのほんの一部です。我々はに飛び込むことをお勧めしtf.Transform 、それはあなたのために何ができるかを発見します。