事前作成された Estimator

TensorFlow.org で表示 Google Colab で実行 GitHub でソースを表示 ノートブックをダウンロード

警告: 新しいコードには Estimators は推奨されません。Estimators は v1.Session スタイルのコードを実行しますが、これは正しく記述するのはより難しく、特に TF 2 コードと組み合わせると予期しない動作をする可能性があります。Estimators は、互換性保証の対象となりますが、セキュリティの脆弱性以外の修正は行われません。詳細については、移行ガイドを参照してください。

このチュートリアルでは、Estimator を使用して、TensorFlow でアヤメの分類問題を解決する方法を示します。Estimator は、レガシー TensorFlow における完全なモデルの高レベルの表現です。詳細については、Estimatorをご覧ください。

注意: TensorFlow 2.0 では、Keras API でも同じタスクを実行でき、より学習しやすい API とされています。はじめて学習する場合は、Keras から着手することをお勧めします。

まず最初に

始めるには、最初に TensorFlow と必要となる多数のライブラリをインポートします。

import tensorflow as tf

import pandas as pd
2022-08-08 21:25:33.844889: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-08-08 21:25:34.634832: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory
2022-08-08 21:25:34.635106: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory
2022-08-08 21:25:34.635120: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

データセット

このドキュメントのサンプルプログラムは、アヤメの花を、萼片花弁のサイズに基づいて、3 つの品種に分類するモデルを構築してテストします。

モデルのトレーニングには、Iris データセットを使用します。Iris データセットには 4 つの特徴量と 1 つのラベルが含まれます。4 つの特徴量は、次に示す各アヤメの植物学的特性を識別します。

  • 萼片の長さ
  • 萼片の幅
  • 花弁の長さ
  • 花弁の幅

この情報に基づき、データを解析する上で役立ついくつかの定数を定義できます。

CSV_COLUMN_NAMES = ['SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth', 'Species']
SPECIES = ['Setosa', 'Versicolor', 'Virginica']

次に、Keras と Pandas を使用して、Iris データセットをダウンロードして解析します。トレーニング用とテスト用に別々のデータセットを維持することに注意してください。

train_path = tf.keras.utils.get_file(
    "iris_training.csv", "https://storage.googleapis.com/download.tensorflow.org/data/iris_training.csv")
test_path = tf.keras.utils.get_file(
    "iris_test.csv", "https://storage.googleapis.com/download.tensorflow.org/data/iris_test.csv")

train = pd.read_csv(train_path, names=CSV_COLUMN_NAMES, header=0)
test = pd.read_csv(test_path, names=CSV_COLUMN_NAMES, header=0)
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/iris_training.csv
2194/2194 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/iris_test.csv
573/573 [==============================] - 0s 0us/step

データを検査し、4 つの浮動小数型の特徴量カラムと 1 つの int32 ラベルがあることを確認します。

train.head()

各データセットに対し、モデルが予測するようにトレーニングされるラベルを分割します。

train_y = train.pop('Species')
test_y = test.pop('Species')

# The label column has now been removed from the features.
train.head()

Estimator を使ったプログラミングの概要

データのセットアップが完了したので、TensorFlow Estimator を使ってモデルを定義できます。Estimator は、tf.estimator.Estimator から派生したクラスです。TensorFlow は、一群の tf.estimatorLinearRegressor など)を提供しており、一般的な ML アルゴリズムを実装することができます。このほか、独自のカスタム Estimator を作成することもできますが、使用し始めには、事前作成済みの Estimator を使用することをお勧めします。

事前作成済みの Estimator に基づいて TensorFlow プログラムを記述するには、次のタスクを実行する必要があります。

  • 1 つ以上の入力関数を作成する。
  • モデルの特徴量カラムを定義する。
  • Estimator をインスタンス化する。特徴量カラムとさまざまなハイパーパラメータを指定します。
  • Estimator オブジェクトに 1 つ以上のメソッドを呼び出す。データのソースとして適切な入力関数を渡します。

では、アヤメの分類において、これらのタスクをどのように実装するのか見てみましょう。

入力関数を作成する

トレーニング、評価、および予測を行うためのデータを提供する入力関数を作成する必要があります。

入力関数とは、次の要素タプルを出力する tf.data.Dataset オブジェクトを返す関数です。

  • features - 次のような Python ディクショナリ。
    • 各キーが特徴量の名前である。
    • 各値が、特徴量の値のすべてを含む配列である。
  • label - 各サンプルの label の値を含む配列。

入力関数の書式を示すために、単純な実装を次に示します。

def input_evaluation_set():
    features = {'SepalLength': np.array([6.4, 5.0]),
                'SepalWidth':  np.array([2.8, 2.3]),
                'PetalLength': np.array([5.6, 3.3]),
                'PetalWidth':  np.array([2.2, 1.0])}
    labels = np.array([2, 1])
    return features, labels

入力関数を自分で作成すれば、features ディクショナリと label リストを好みに合わせて生成できるようにすることができますが、あらゆる種類のデータを解析できる TensorFlow の Dataset API を使用することをお勧めします。

Dataset API は、多数の一般的な事例を処理することができます。たとえば、Dataset API を使用すると、大量のファイルのレコードを並列して読み取り、単一のストリームに結合することが簡単に行えます。

この例では事を単純にするために、pandas でデータを読み込み、このメモリ内のデータから入力パイプラインを構築します。

def input_fn(features, labels, training=True, batch_size=256):
    """An input function for training or evaluating"""
    # Convert the inputs to a Dataset.
    dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))

    # Shuffle and repeat if you are in training mode.
    if training:
        dataset = dataset.shuffle(1000).repeat()

    return dataset.batch(batch_size)

特徴量カラムを定義する

特徴量カラムは、特徴量ディクショナリの生の入力データを、モデルがどのように使用すべきかを説明するオブジェクトです。Estimator モデルを作成する際に、モデルが使用する各特徴量を説明する特徴量カラムをモデルに渡します。tf.feature_column モジュールには、モデルに対してデータを表現するためのオプションが多数含まれています。

Iris については、4 つの生の特徴量は数値であるため、Estimator に対して、これら 4 つの各特徴量を 32 ビットの浮動小数点数型の値として表現するように命令する特徴量カラムを構築します。したがって、特徴カラムを作成するためのコードは、次のようになります。

# Feature columns describe how to use the input.
my_feature_columns = []
for key in train.keys():
    my_feature_columns.append(tf.feature_column.numeric_column(key=key))

特徴量カラムは、ここに示すものよりもはるかに高度なものに構築することができます。特徴量カラムの詳細については、こちらのガイドをご覧ください。

モデルが生の特徴量をどのように表現するかに関する記述を準備できたので、Estimator を構築することができます。

Estimator をインスタンス化する

アヤメの問題はよく知られた分類問題です。幸いにも、TensorFlow は、次のような事前作成済みの分類子 Estimator を複数用意しています。

アヤメの問題に関しては、tf.estimator.DNNClassifier が最適な選択肢と言えます。この Estimator をインスタンス化する方法を次に示します。

# Build a DNN with 2 hidden layers with 30 and 10 hidden nodes each.
classifier = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns,
    # Two hidden layers of 30 and 10 nodes respectively.
    hidden_units=[30, 10],
    # The model must choose between 3 classes.
    n_classes=3)
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpprwk6v82
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpprwk6v82', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

トレーニングして評価して予測する

Estimator オブジェクトを準備したので、次の項目を行うメソッド呼び出すことができます。

  • モデルをトレーニングする。
  • トレーニングされたモデルを評価する。
  • トレーニングされたモデルを使用して、予測を立てる。

モデルをトレーニングする

次のように、Estimator の train メソッドを呼び出して、モデルをトレーニングします。

# Train the Model.
classifier.train(
    input_fn=lambda: input_fn(train, train_y, training=True),
    steps=5000)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_v2/adagrad.py:86: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
2022-08-08 21:25:40.211869: W tensorflow/core/common_runtime/forward_type_inference.cc:332] Type inference failed. This indicates an invalid graph that escaped type checking. Error message: INVALID_ARGUMENT: expected compatible input types, but input 1:
type_id: TFT_OPTIONAL
args {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_TENSOR
    args {
      type_id: TFT_INT64
    }
  }
}
 is neither a subtype nor a supertype of the combined inputs preceding it:
type_id: TFT_OPTIONAL
args {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_TENSOR
    args {
      type_id: TFT_INT32
    }
  }
}

    while inferring type of node 'dnn/zero_fraction/cond/output/_18'
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpprwk6v82/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 1.3833963, step = 0
INFO:tensorflow:global_step/sec: 270.145
INFO:tensorflow:loss = 1.093653, step = 100 (0.371 sec)
INFO:tensorflow:global_step/sec: 331.015
INFO:tensorflow:loss = 1.0144349, step = 200 (0.302 sec)
INFO:tensorflow:global_step/sec: 338.782
INFO:tensorflow:loss = 0.9342374, step = 300 (0.295 sec)
INFO:tensorflow:global_step/sec: 338.423
INFO:tensorflow:loss = 0.8742392, step = 400 (0.296 sec)
INFO:tensorflow:global_step/sec: 337.761
INFO:tensorflow:loss = 0.8391523, step = 500 (0.296 sec)
INFO:tensorflow:global_step/sec: 339.752
INFO:tensorflow:loss = 0.79797184, step = 600 (0.294 sec)
INFO:tensorflow:global_step/sec: 334.24
INFO:tensorflow:loss = 0.765428, step = 700 (0.299 sec)
INFO:tensorflow:global_step/sec: 340.393
INFO:tensorflow:loss = 0.7378235, step = 800 (0.294 sec)
INFO:tensorflow:global_step/sec: 336.25
INFO:tensorflow:loss = 0.712427, step = 900 (0.298 sec)
INFO:tensorflow:global_step/sec: 344.938
INFO:tensorflow:loss = 0.693172, step = 1000 (0.290 sec)
INFO:tensorflow:global_step/sec: 346.002
INFO:tensorflow:loss = 0.66773224, step = 1100 (0.289 sec)
INFO:tensorflow:global_step/sec: 345.179
INFO:tensorflow:loss = 0.65159553, step = 1200 (0.290 sec)
INFO:tensorflow:global_step/sec: 345.04
INFO:tensorflow:loss = 0.6365468, step = 1300 (0.290 sec)
INFO:tensorflow:global_step/sec: 349.557
INFO:tensorflow:loss = 0.61549705, step = 1400 (0.286 sec)
INFO:tensorflow:global_step/sec: 348.741
INFO:tensorflow:loss = 0.60658413, step = 1500 (0.287 sec)
INFO:tensorflow:global_step/sec: 346.716
INFO:tensorflow:loss = 0.5915766, step = 1600 (0.288 sec)
INFO:tensorflow:global_step/sec: 347.163
INFO:tensorflow:loss = 0.57304, step = 1700 (0.288 sec)
INFO:tensorflow:global_step/sec: 343.946
INFO:tensorflow:loss = 0.56998014, step = 1800 (0.291 sec)
INFO:tensorflow:global_step/sec: 341.354
INFO:tensorflow:loss = 0.56178665, step = 1900 (0.293 sec)
INFO:tensorflow:global_step/sec: 340.802
INFO:tensorflow:loss = 0.55901486, step = 2000 (0.293 sec)
INFO:tensorflow:global_step/sec: 333.447
INFO:tensorflow:loss = 0.5424686, step = 2100 (0.300 sec)
INFO:tensorflow:global_step/sec: 333.049
INFO:tensorflow:loss = 0.5326629, step = 2200 (0.300 sec)
INFO:tensorflow:global_step/sec: 333.174
INFO:tensorflow:loss = 0.5287124, step = 2300 (0.300 sec)
INFO:tensorflow:global_step/sec: 337.146
INFO:tensorflow:loss = 0.51603675, step = 2400 (0.296 sec)
INFO:tensorflow:global_step/sec: 337.034
INFO:tensorflow:loss = 0.5209577, step = 2500 (0.297 sec)
INFO:tensorflow:global_step/sec: 334.268
INFO:tensorflow:loss = 0.5081441, step = 2600 (0.299 sec)
INFO:tensorflow:global_step/sec: 331.569
INFO:tensorflow:loss = 0.5103288, step = 2700 (0.302 sec)
INFO:tensorflow:global_step/sec: 331.344
INFO:tensorflow:loss = 0.4944778, step = 2800 (0.302 sec)
INFO:tensorflow:global_step/sec: 336.303
INFO:tensorflow:loss = 0.4908868, step = 2900 (0.297 sec)
INFO:tensorflow:global_step/sec: 322.001
INFO:tensorflow:loss = 0.48573777, step = 3000 (0.311 sec)
INFO:tensorflow:global_step/sec: 330.917
INFO:tensorflow:loss = 0.47572097, step = 3100 (0.302 sec)
INFO:tensorflow:global_step/sec: 326.728
INFO:tensorflow:loss = 0.4840495, step = 3200 (0.306 sec)
INFO:tensorflow:global_step/sec: 326.019
INFO:tensorflow:loss = 0.47777748, step = 3300 (0.307 sec)
INFO:tensorflow:global_step/sec: 330.633
INFO:tensorflow:loss = 0.4695906, step = 3400 (0.302 sec)
INFO:tensorflow:global_step/sec: 329.266
INFO:tensorflow:loss = 0.4644341, step = 3500 (0.304 sec)
INFO:tensorflow:global_step/sec: 322.862
INFO:tensorflow:loss = 0.4584458, step = 3600 (0.310 sec)
INFO:tensorflow:global_step/sec: 326.762
INFO:tensorflow:loss = 0.4631693, step = 3700 (0.306 sec)
INFO:tensorflow:global_step/sec: 332.526
INFO:tensorflow:loss = 0.44423047, step = 3800 (0.301 sec)
INFO:tensorflow:global_step/sec: 324.73
INFO:tensorflow:loss = 0.4475951, step = 3900 (0.308 sec)
INFO:tensorflow:global_step/sec: 338.916
INFO:tensorflow:loss = 0.4578572, step = 4000 (0.295 sec)
INFO:tensorflow:global_step/sec: 348.121
INFO:tensorflow:loss = 0.45477766, step = 4100 (0.287 sec)
INFO:tensorflow:global_step/sec: 346.834
INFO:tensorflow:loss = 0.43525168, step = 4200 (0.288 sec)
INFO:tensorflow:global_step/sec: 337.227
INFO:tensorflow:loss = 0.45133963, step = 4300 (0.297 sec)
INFO:tensorflow:global_step/sec: 334.052
INFO:tensorflow:loss = 0.42795596, step = 4400 (0.299 sec)
INFO:tensorflow:global_step/sec: 350.413
INFO:tensorflow:loss = 0.4303703, step = 4500 (0.285 sec)
INFO:tensorflow:global_step/sec: 349.309
INFO:tensorflow:loss = 0.43121222, step = 4600 (0.286 sec)
INFO:tensorflow:global_step/sec: 352.032
INFO:tensorflow:loss = 0.4288513, step = 4700 (0.284 sec)
INFO:tensorflow:global_step/sec: 334.175
INFO:tensorflow:loss = 0.4216747, step = 4800 (0.299 sec)
INFO:tensorflow:global_step/sec: 338.229
INFO:tensorflow:loss = 0.4167729, step = 4900 (0.296 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5000...
INFO:tensorflow:Saving checkpoints for 5000 into /tmpfs/tmp/tmpprwk6v82/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5000...
INFO:tensorflow:Loss for final step: 0.41855305.
<tensorflow_estimator.python.estimator.canned.dnn.DNNClassifierV2 at 0x7f3d9025d250>

Estimator が期待するとおり、引数を取らない入力関数を指定しながら、input_fn 呼び出しを lambda にラッピングして引数をキャプチャするところに注意してください。steps 引数はメソッドに対して、あるトレーニングステップ数を完了した後にトレーニングを停止するように指定しています。

トレーニングされたモデルを評価する

モデルのトレーニングが完了したので、そのパフォーマンスに関する統計を得ることができます。次のコードブロックは、テストデータに対してトレーニングされたモデルの精度を評価します。

eval_result = classifier.evaluate(
    input_fn=lambda: input_fn(test, test_y, training=False))

print('\nTest set accuracy: {accuracy:0.3f}\n'.format(**eval_result))
WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.checkpoint_management has been moved to tensorflow.python.checkpoint.checkpoint_management. The old module will be deleted in version 2.9.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-08-08T21:25:56
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpprwk6v82/model.ckpt-5000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Inference Time : 0.49679s
INFO:tensorflow:Finished evaluation at 2022-08-08-21:25:56
INFO:tensorflow:Saving dict for global step 5000: accuracy = 0.76666665, average_loss = 0.49484098, global_step = 5000, loss = 0.49484098
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 5000: /tmpfs/tmp/tmpprwk6v82/model.ckpt-5000

Test set accuracy: 0.767

train メソッドへの呼び出しとは異なり、評価するsteps 引数を渡していません。eval の input_fn データの単一のエポックのみを返します。

eval_result ディクショナリには、average_loss(サンプル当たりの平均損失)、loss(ミニバッチ当たりの平均損失)、および Estimator の global_step の値(実行したトレーニングイテレーションの回数)も含まれます。

トレーニングされたモデルから予測(推論)を立てる

良質の評価結果を生み出すトレーニング済みのモデルを準備できました。これから、このトレーニング済みのモデルを使用し、ラベル付けできない測定に基づいてアヤメの品種を予測します。トレーニングと評価と同様に、単一の関数呼び出して予測を行います。

# Generate predictions from the model
expected = ['Setosa', 'Versicolor', 'Virginica']
predict_x = {
    'SepalLength': [5.1, 5.9, 6.9],
    'SepalWidth': [3.3, 3.0, 3.1],
    'PetalLength': [1.7, 4.2, 5.4],
    'PetalWidth': [0.5, 1.5, 2.1],
}

def input_fn(features, batch_size=256):
    """An input function for prediction."""
    # Convert the inputs to a Dataset without labels.
    return tf.data.Dataset.from_tensor_slices(dict(features)).batch(batch_size)

predictions = classifier.predict(
    input_fn=lambda: input_fn(predict_x))

predict メソッドは Python イテラブルを返し、各サンプルの予測結果のディクショナリを生成します。次のコードを使って、予測とその確率を出力します。

for pred_dict, expec in zip(predictions, expected):
    class_id = pred_dict['class_ids'][0]
    probability = pred_dict['probabilities'][class_id]

    print('Prediction is "{}" ({:.1f}%), expected "{}"'.format(
        SPECIES[class_id], 100 * probability, expec))
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpprwk6v82/model.ckpt-5000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
Prediction is "Setosa" (86.4%), expected "Setosa"
Prediction is "Versicolor" (47.9%), expected "Versicolor"
Prediction is "Virginica" (62.9%), expected "Virginica"