事前作成された Estimator

TensorFlow.org で表示 Google Colab で実行 GitHub でソースを表示 ノートブックをダウンロード

警告: 新しいコードには Estimators は推奨されません。Estimators は v1.Session スタイルのコードを実行しますが、これは正しく記述するのはより難しく、特に TF 2 コードと組み合わせると予期しない動作をする可能性があります。Estimators は、互換性保証の対象となりますが、セキュリティの脆弱性以外の修正は行われません。詳細については、移行ガイドを参照してください。

このチュートリアルでは、Estimator を使用して、TensorFlow でアヤメの分類問題を解決する方法を示します。Estimator は、レガシー TensorFlow における完全なモデルの高レベルの表現です。詳細については、Estimatorをご覧ください。

注意: TensorFlow 2.0 では、Keras API でも同じタスクを実行でき、より学習しやすい API とされています。はじめて学習する場合は、Keras から着手することをお勧めします。

まず最初に

始めるには、最初に TensorFlow と必要となる多数のライブラリをインポートします。

import tensorflow as tf

import pandas as pd
2022-12-15 02:40:34.596324: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-15 02:40:34.596441: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-15 02:40:34.596452: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

データセット

このドキュメントのサンプルプログラムは、アヤメの花を、萼片花弁のサイズに基づいて、3 つの品種に分類するモデルを構築してテストします。

モデルのトレーニングには、Iris データセットを使用します。Iris データセットには 4 つの特徴量と 1 つのラベルが含まれます。4 つの特徴量は、次に示す各アヤメの植物学的特性を識別します。

  • 萼片の長さ
  • 萼片の幅
  • 花弁の長さ
  • 花弁の幅

この情報に基づき、データを解析する上で役立ついくつかの定数を定義できます。

CSV_COLUMN_NAMES = ['SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth', 'Species']
SPECIES = ['Setosa', 'Versicolor', 'Virginica']

次に、Keras と Pandas を使用して、Iris データセットをダウンロードして解析します。トレーニング用とテスト用に別々のデータセットを維持することに注意してください。

train_path = tf.keras.utils.get_file(
    "iris_training.csv", "https://storage.googleapis.com/download.tensorflow.org/data/iris_training.csv")
test_path = tf.keras.utils.get_file(
    "iris_test.csv", "https://storage.googleapis.com/download.tensorflow.org/data/iris_test.csv")

train = pd.read_csv(train_path, names=CSV_COLUMN_NAMES, header=0)
test = pd.read_csv(test_path, names=CSV_COLUMN_NAMES, header=0)
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/iris_training.csv
2194/2194 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/iris_test.csv
573/573 [==============================] - 0s 0us/step

データを検査し、4 つの浮動小数型の特徴量カラムと 1 つの int32 ラベルがあることを確認します。

train.head()

各データセットに対し、モデルが予測するようにトレーニングされるラベルを分割します。

train_y = train.pop('Species')
test_y = test.pop('Species')

# The label column has now been removed from the features.
train.head()

Estimator を使ったプログラミングの概要

データのセットアップが完了したので、TensorFlow Estimator を使ってモデルを定義できます。Estimator は、tf.estimator.Estimator から派生したクラスです。TensorFlow は、一群の tf.estimatorLinearRegressor など)を提供しており、一般的な ML アルゴリズムを実装することができます。このほか、独自のカスタム Estimator を作成することもできますが、使用し始めには、事前作成済みの Estimator を使用することをお勧めします。

事前作成済みの Estimator に基づいて TensorFlow プログラムを記述するには、次のタスクを実行する必要があります。

  • 1 つ以上の入力関数を作成する。
  • モデルの特徴量カラムを定義する。
  • Estimator をインスタンス化する。特徴量カラムとさまざまなハイパーパラメータを指定します。
  • Estimator オブジェクトに 1 つ以上のメソッドを呼び出す。データのソースとして適切な入力関数を渡します。

では、アヤメの分類において、これらのタスクをどのように実装するのか見てみましょう。

入力関数を作成する

トレーニング、評価、および予測を行うためのデータを提供する入力関数を作成する必要があります。

入力関数とは、次の要素タプルを出力する tf.data.Dataset オブジェクトを返す関数です。

  • features - 次のような Python ディクショナリ。
    • 各キーが特徴量の名前である。
    • 各値が、特徴量の値のすべてを含む配列である。
  • label - 各サンプルの label の値を含む配列。

入力関数の書式を示すために、単純な実装を次に示します。

def input_evaluation_set():
    features = {'SepalLength': np.array([6.4, 5.0]),
                'SepalWidth':  np.array([2.8, 2.3]),
                'PetalLength': np.array([5.6, 3.3]),
                'PetalWidth':  np.array([2.2, 1.0])}
    labels = np.array([2, 1])
    return features, labels

入力関数を自分で作成すれば、features ディクショナリと label リストを好みに合わせて生成できるようにすることができますが、あらゆる種類のデータを解析できる TensorFlow の Dataset API を使用することをお勧めします。

Dataset API は、多数の一般的な事例を処理することができます。たとえば、Dataset API を使用すると、大量のファイルのレコードを並列して読み取り、単一のストリームに結合することが簡単に行えます。

この例では事を単純にするために、pandas でデータを読み込み、このメモリ内のデータから入力パイプラインを構築します。

def input_fn(features, labels, training=True, batch_size=256):
    """An input function for training or evaluating"""
    # Convert the inputs to a Dataset.
    dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))

    # Shuffle and repeat if you are in training mode.
    if training:
        dataset = dataset.shuffle(1000).repeat()

    return dataset.batch(batch_size)

特徴量カラムを定義する

特徴量カラムは、特徴量ディクショナリの生の入力データを、モデルがどのように使用すべきかを説明するオブジェクトです。Estimator モデルを作成する際に、モデルが使用する各特徴量を説明する特徴量カラムをモデルに渡します。tf.feature_column モジュールには、モデルに対してデータを表現するためのオプションが多数含まれています。

Iris については、4 つの生の特徴量は数値であるため、Estimator に対して、これら 4 つの各特徴量を 32 ビットの浮動小数点数型の値として表現するように命令する特徴量カラムを構築します。したがって、特徴カラムを作成するためのコードは、次のようになります。

# Feature columns describe how to use the input.
my_feature_columns = []
for key in train.keys():
    my_feature_columns.append(tf.feature_column.numeric_column(key=key))

特徴量カラムは、ここに示すものよりもはるかに高度なものに構築することができます。特徴量カラムの詳細については、こちらのガイドをご覧ください。

モデルが生の特徴量をどのように表現するかに関する記述を準備できたので、Estimator を構築することができます。

Estimator をインスタンス化する

アヤメの問題はよく知られた分類問題です。幸いにも、TensorFlow は、次のような事前作成済みの分類子 Estimator を複数用意しています。

アヤメの問題に関しては、tf.estimator.DNNClassifier が最適な選択肢と言えます。この Estimator をインスタンス化する方法を次に示します。

# Build a DNN with 2 hidden layers with 30 and 10 hidden nodes each.
classifier = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns,
    # Two hidden layers of 30 and 10 nodes respectively.
    hidden_units=[30, 10],
    # The model must choose between 3 classes.
    n_classes=3)
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpv8gz9mg7
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpv8gz9mg7', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

トレーニングして評価して予測する

Estimator オブジェクトを準備したので、次の項目を行うメソッド呼び出すことができます。

  • モデルをトレーニングする。
  • トレーニングされたモデルを評価する。
  • トレーニングされたモデルを使用して、予測を立てる。

モデルをトレーニングする

次のように、Estimator の train メソッドを呼び出して、モデルをトレーニングします。

# Train the Model.
classifier.train(
    input_fn=lambda: input_fn(train, train_y, training=True),
    steps=5000)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_v2/adagrad.py:93: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
2022-12-15 02:40:40.279374: W tensorflow/core/common_runtime/type_inference.cc:339] Type inference failed. This indicates an invalid graph that escaped type checking. Error message: INVALID_ARGUMENT: expected compatible input types, but input 1:
type_id: TFT_OPTIONAL
args {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_TENSOR
    args {
      type_id: TFT_INT64
    }
  }
}
 is neither a subtype nor a supertype of the combined inputs preceding it:
type_id: TFT_OPTIONAL
args {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_TENSOR
    args {
      type_id: TFT_INT32
    }
  }
}

    while inferring type of node 'dnn/zero_fraction/cond/output/_18'
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpv8gz9mg7/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 1.3653078, step = 0
INFO:tensorflow:global_step/sec: 424.758
INFO:tensorflow:loss = 1.0535867, step = 100 (0.237 sec)
INFO:tensorflow:global_step/sec: 551.284
INFO:tensorflow:loss = 0.9845939, step = 200 (0.182 sec)
INFO:tensorflow:global_step/sec: 564.881
INFO:tensorflow:loss = 0.93356884, step = 300 (0.177 sec)
INFO:tensorflow:global_step/sec: 566.209
INFO:tensorflow:loss = 0.8931673, step = 400 (0.177 sec)
INFO:tensorflow:global_step/sec: 566.856
INFO:tensorflow:loss = 0.863848, step = 500 (0.176 sec)
INFO:tensorflow:global_step/sec: 545.981
INFO:tensorflow:loss = 0.83251595, step = 600 (0.183 sec)
INFO:tensorflow:global_step/sec: 556.842
INFO:tensorflow:loss = 0.79153323, step = 700 (0.180 sec)
INFO:tensorflow:global_step/sec: 555.019
INFO:tensorflow:loss = 0.7832278, step = 800 (0.180 sec)
INFO:tensorflow:global_step/sec: 545.374
INFO:tensorflow:loss = 0.7552776, step = 900 (0.183 sec)
INFO:tensorflow:global_step/sec: 532.68
INFO:tensorflow:loss = 0.7393273, step = 1000 (0.188 sec)
INFO:tensorflow:global_step/sec: 548.912
INFO:tensorflow:loss = 0.72828543, step = 1100 (0.182 sec)
INFO:tensorflow:global_step/sec: 540.638
INFO:tensorflow:loss = 0.7053199, step = 1200 (0.185 sec)
INFO:tensorflow:global_step/sec: 538.089
INFO:tensorflow:loss = 0.6969624, step = 1300 (0.186 sec)
INFO:tensorflow:global_step/sec: 543.99
INFO:tensorflow:loss = 0.6601064, step = 1400 (0.184 sec)
INFO:tensorflow:global_step/sec: 545.081
INFO:tensorflow:loss = 0.6538592, step = 1500 (0.183 sec)
INFO:tensorflow:global_step/sec: 583.538
INFO:tensorflow:loss = 0.6431242, step = 1600 (0.172 sec)
INFO:tensorflow:global_step/sec: 582.859
INFO:tensorflow:loss = 0.6337292, step = 1700 (0.171 sec)
INFO:tensorflow:global_step/sec: 581.883
INFO:tensorflow:loss = 0.6249995, step = 1800 (0.172 sec)
INFO:tensorflow:global_step/sec: 599.331
INFO:tensorflow:loss = 0.6074962, step = 1900 (0.167 sec)
INFO:tensorflow:global_step/sec: 585.038
INFO:tensorflow:loss = 0.5954494, step = 2000 (0.171 sec)
INFO:tensorflow:global_step/sec: 582.089
INFO:tensorflow:loss = 0.59253395, step = 2100 (0.172 sec)
INFO:tensorflow:global_step/sec: 602.62
INFO:tensorflow:loss = 0.5689405, step = 2200 (0.166 sec)
INFO:tensorflow:global_step/sec: 588.479
INFO:tensorflow:loss = 0.5602833, step = 2300 (0.170 sec)
INFO:tensorflow:global_step/sec: 588.211
INFO:tensorflow:loss = 0.55631864, step = 2400 (0.170 sec)
INFO:tensorflow:global_step/sec: 592.785
INFO:tensorflow:loss = 0.54900175, step = 2500 (0.169 sec)
INFO:tensorflow:global_step/sec: 580.758
INFO:tensorflow:loss = 0.54516006, step = 2600 (0.172 sec)
INFO:tensorflow:global_step/sec: 578.628
INFO:tensorflow:loss = 0.52530795, step = 2700 (0.173 sec)
INFO:tensorflow:global_step/sec: 587.328
INFO:tensorflow:loss = 0.5299491, step = 2800 (0.170 sec)
INFO:tensorflow:global_step/sec: 579.237
INFO:tensorflow:loss = 0.5184585, step = 2900 (0.173 sec)
INFO:tensorflow:global_step/sec: 579.646
INFO:tensorflow:loss = 0.5071718, step = 3000 (0.173 sec)
INFO:tensorflow:global_step/sec: 588.019
INFO:tensorflow:loss = 0.4960404, step = 3100 (0.170 sec)
INFO:tensorflow:global_step/sec: 577.387
INFO:tensorflow:loss = 0.47985545, step = 3200 (0.173 sec)
INFO:tensorflow:global_step/sec: 595.986
INFO:tensorflow:loss = 0.48654804, step = 3300 (0.168 sec)
INFO:tensorflow:global_step/sec: 573.446
INFO:tensorflow:loss = 0.48582077, step = 3400 (0.174 sec)
INFO:tensorflow:global_step/sec: 578.275
INFO:tensorflow:loss = 0.46541944, step = 3500 (0.173 sec)
INFO:tensorflow:global_step/sec: 566.381
INFO:tensorflow:loss = 0.4748811, step = 3600 (0.177 sec)
INFO:tensorflow:global_step/sec: 571.906
INFO:tensorflow:loss = 0.47083074, step = 3700 (0.175 sec)
INFO:tensorflow:global_step/sec: 574.657
INFO:tensorflow:loss = 0.44040596, step = 3800 (0.174 sec)
INFO:tensorflow:global_step/sec: 574.819
INFO:tensorflow:loss = 0.45463592, step = 3900 (0.174 sec)
INFO:tensorflow:global_step/sec: 588.25
INFO:tensorflow:loss = 0.44358343, step = 4000 (0.170 sec)
INFO:tensorflow:global_step/sec: 585.435
INFO:tensorflow:loss = 0.4394082, step = 4100 (0.171 sec)
INFO:tensorflow:global_step/sec: 570.47
INFO:tensorflow:loss = 0.44057947, step = 4200 (0.175 sec)
INFO:tensorflow:global_step/sec: 570.751
INFO:tensorflow:loss = 0.43266475, step = 4300 (0.175 sec)
INFO:tensorflow:global_step/sec: 573.692
INFO:tensorflow:loss = 0.4180813, step = 4400 (0.174 sec)
INFO:tensorflow:global_step/sec: 557.131
INFO:tensorflow:loss = 0.42264342, step = 4500 (0.180 sec)
INFO:tensorflow:global_step/sec: 573.329
INFO:tensorflow:loss = 0.42919323, step = 4600 (0.174 sec)
INFO:tensorflow:global_step/sec: 588.864
INFO:tensorflow:loss = 0.41696268, step = 4700 (0.170 sec)
INFO:tensorflow:global_step/sec: 579.002
INFO:tensorflow:loss = 0.40943825, step = 4800 (0.173 sec)
INFO:tensorflow:global_step/sec: 566.232
INFO:tensorflow:loss = 0.410751, step = 4900 (0.177 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5000...
INFO:tensorflow:Saving checkpoints for 5000 into /tmpfs/tmp/tmpv8gz9mg7/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5000...
INFO:tensorflow:Loss for final step: 0.4027862.
<tensorflow_estimator.python.estimator.canned.dnn.DNNClassifierV2 at 0x7f003e6a3dc0>

Estimator が期待するとおり、引数を取らない入力関数を指定しながら、input_fn 呼び出しを lambda にラッピングして引数をキャプチャするところに注意してください。steps 引数はメソッドに対して、あるトレーニングステップ数を完了した後にトレーニングを停止するように指定しています。

トレーニングされたモデルを評価する

モデルのトレーニングが完了したので、そのパフォーマンスに関する統計を得ることができます。次のコードブロックは、テストデータに対してトレーニングされたモデルの精度を評価します。

eval_result = classifier.evaluate(
    input_fn=lambda: input_fn(test, test_y, training=False))

print('\nTest set accuracy: {accuracy:0.3f}\n'.format(**eval_result))
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-12-15T02:40:50
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpv8gz9mg7/model.ckpt-5000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Inference Time : 0.52079s
INFO:tensorflow:Finished evaluation at 2022-12-15-02:40:50
INFO:tensorflow:Saving dict for global step 5000: accuracy = 0.9, average_loss = 0.4802318, global_step = 5000, loss = 0.4802318
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 5000: /tmpfs/tmp/tmpv8gz9mg7/model.ckpt-5000

Test set accuracy: 0.900

train メソッドへの呼び出しとは異なり、評価するsteps 引数を渡していません。eval の input_fn データの単一のエポックのみを返します。

eval_result ディクショナリには、average_loss(サンプル当たりの平均損失)、loss(ミニバッチ当たりの平均損失)、および Estimator の global_step の値(実行したトレーニングイテレーションの回数)も含まれます。

トレーニングされたモデルから予測(推論)を立てる

良質の評価結果を生み出すトレーニング済みのモデルを準備できました。これから、このトレーニング済みのモデルを使用し、ラベル付けできない測定に基づいてアヤメの品種を予測します。トレーニングと評価と同様に、単一の関数呼び出して予測を行います。

# Generate predictions from the model
expected = ['Setosa', 'Versicolor', 'Virginica']
predict_x = {
    'SepalLength': [5.1, 5.9, 6.9],
    'SepalWidth': [3.3, 3.0, 3.1],
    'PetalLength': [1.7, 4.2, 5.4],
    'PetalWidth': [0.5, 1.5, 2.1],
}

def input_fn(features, batch_size=256):
    """An input function for prediction."""
    # Convert the inputs to a Dataset without labels.
    return tf.data.Dataset.from_tensor_slices(dict(features)).batch(batch_size)

predictions = classifier.predict(
    input_fn=lambda: input_fn(predict_x))

predict メソッドは Python イテラブルを返し、各サンプルの予測結果のディクショナリを生成します。次のコードを使って、予測とその確率を出力します。

for pred_dict, expec in zip(predictions, expected):
    class_id = pred_dict['class_ids'][0]
    probability = pred_dict['probabilities'][class_id]

    print('Prediction is "{}" ({:.1f}%), expected "{}"'.format(
        SPECIES[class_id], 100 * probability, expec))
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpv8gz9mg7/model.ckpt-5000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
Prediction is "Setosa" (85.5%), expected "Setosa"
Prediction is "Versicolor" (47.2%), expected "Versicolor"
Prediction is "Virginica" (60.3%), expected "Virginica"