ヘルプKaggleにTensorFlowグレートバリアリーフを保護チャレンジに参加

移行の例:定型推定量

TensorFlow.orgで表示GoogleColabで実行GitHubでソースを表示 ノートブックをダウンロード

缶詰(または既製)の推定器は、さまざまな典型的なユースケースのモデルをトレーニングするための迅速で簡単な方法として、TensorFlow1で伝統的に使用されてきました。 TensorFlow 2は、Kerasモデルを使用して、それらの多くの簡単な近似代替を提供します。 TensorFlow 2の代替が組み込まれていない定型推定器の場合でも、独自の代替をかなり簡単に構築できます。

このガイドでは、直接同等物とカスタム置換のいくつかの例を紹介し、TensorFlow1のtf.estimatorから派生したモデルをKerasを使用してTF2に移行する方法を示します。

つまり、このガイドには、移行の例が含まれています。

モデルのトレーニングの一般的な前兆は、特徴の前処理です。これは、tf.feature_columnを使用してtf.feature_columnモデルに対して実行されます。 TensorFlow 2での特徴の前処理の詳細については、特徴列の移行に関するこのガイドを参照してください。

設定

いくつかの必要なTensorFlowインポートから始めます。

pip install tensorflow_decision_forests
import keras
import pandas as pd
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import tensorflow_decision_forests as tfdf
WARNING:root:Failure to load the custom c++ tensorflow ops. This error is likely caused the version of TensorFlow and TensorFlow Decision Forests are not compatible.
WARNING:root:TF Parameter Server distributed training not available.

標準のTitanicデータセットからデモンストレーション用の簡単なデータを準備します。

x_train = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/train.csv')
x_eval = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/eval.csv')
x_train['sex'].replace(('male', 'female'), (0, 1), inplace=True)
x_eval['sex'].replace(('male', 'female'), (0, 1), inplace=True)

x_train['alone'].replace(('n', 'y'), (0, 1), inplace=True)
x_eval['alone'].replace(('n', 'y'), (0, 1), inplace=True)

x_train['class'].replace(('First', 'Second', 'Third'), (1, 2, 3), inplace=True)
x_eval['class'].replace(('First', 'Second', 'Third'), (1, 2, 3), inplace=True)

x_train.drop(['embark_town', 'deck'], axis=1, inplace=True)
x_eval.drop(['embark_town', 'deck'], axis=1, inplace=True)

y_train = x_train.pop('survived')
y_eval = x_eval.pop('survived')
# Data setup for TensorFlow 1 with `tf.estimator`
def _input_fn():
  return tf1.data.Dataset.from_tensor_slices((dict(x_train), y_train)).batch(32)


def _eval_input_fn():
  return tf1.data.Dataset.from_tensor_slices((dict(x_eval), y_eval)).batch(32)


FEATURE_NAMES = [
    'age', 'fare', 'sex', 'n_siblings_spouses', 'parch', 'class', 'alone'
]

feature_columns = []
for fn in FEATURE_NAMES:
  feat_col = tf1.feature_column.numeric_column(fn, dtype=tf.float32)
  feature_columns.append(feat_col)

さまざまなTensorFlow1EstimatorおよびTensorFlow2Kerasモデルで使用する単純なサンプルオプティマイザーをインスタンス化するメソッドを作成します。

def create_sample_optimizer(tf_version):
  if tf_version == 'tf1':
    optimizer = lambda: tf.keras.optimizers.Ftrl(
        l1_regularization_strength=0.001,
        learning_rate=tf1.train.exponential_decay(
            learning_rate=0.1,
            global_step=tf1.train.get_global_step(),
            decay_steps=10000,
            decay_rate=0.9))
  elif tf_version == 'tf2':
    optimizer = tf.keras.optimizers.Ftrl(
        l1_regularization_strength=0.001,
        learning_rate=tf.keras.optimizers.schedules.ExponentialDecay(
            initial_learning_rate=0.1, decay_steps=10000, decay_rate=0.9))
  return optimizer

例1:LinearEstimatorからの移行

TF1:LinearEstimatorの使用

TensorFlow 1では、 tf.estimator.LinearEstimatorを使用して、回帰および分類の問題のベースライン線形モデルを作成できます。

linear_estimator = tf.estimator.LinearEstimator(
    head=tf.estimator.BinaryClassHead(),
    feature_columns=feature_columns,
    optimizer=create_sample_optimizer('tf1'))
INFO:tensorflow:Using default config.
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp06pccumj
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp06pccumj
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp06pccumj', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp06pccumj', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
linear_estimator.train(input_fn=_input_fn, steps=100)
linear_estimator.evaluate(input_fn=_eval_input_fn, steps=10)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:401: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:401: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/canned/linear.py:1478: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use `layer.add_weight` method instead.
  getter=tf.compat.v1.get_variable)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/ftrl.py:149: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/ftrl.py:149: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp06pccumj/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp06pccumj/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.6931472, step = 0
INFO:tensorflow:loss = 0.6931472, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 20...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 20...
INFO:tensorflow:Saving checkpoints for 20 into /tmp/tmp06pccumj/model.ckpt.
INFO:tensorflow:Saving checkpoints for 20 into /tmp/tmp06pccumj/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 20...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 20...
INFO:tensorflow:Loss for final step: 0.55268794.
INFO:tensorflow:Loss for final step: 0.55268794.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-11-09T02:24:46
INFO:tensorflow:Starting evaluation at 2021-11-09T02:24:46
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmp06pccumj/model.ckpt-20
INFO:tensorflow:Restoring parameters from /tmp/tmp06pccumj/model.ckpt-20
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Inference Time : 0.58634s
INFO:tensorflow:Inference Time : 0.58634s
INFO:tensorflow:Finished evaluation at 2021-11-09-02:24:47
INFO:tensorflow:Finished evaluation at 2021-11-09-02:24:47
INFO:tensorflow:Saving dict for global step 20: accuracy = 0.70075756, accuracy_baseline = 0.625, auc = 0.75472915, auc_precision_recall = 0.65362054, average_loss = 0.5759378, global_step = 20, label/mean = 0.375, loss = 0.5704812, precision = 0.6388889, prediction/mean = 0.41331062, recall = 0.46464646
INFO:tensorflow:Saving dict for global step 20: accuracy = 0.70075756, accuracy_baseline = 0.625, auc = 0.75472915, auc_precision_recall = 0.65362054, average_loss = 0.5759378, global_step = 20, label/mean = 0.375, loss = 0.5704812, precision = 0.6388889, prediction/mean = 0.41331062, recall = 0.46464646
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 20: /tmp/tmp06pccumj/model.ckpt-20
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 20: /tmp/tmp06pccumj/model.ckpt-20
{'accuracy': 0.70075756,
 'accuracy_baseline': 0.625,
 'auc': 0.75472915,
 'auc_precision_recall': 0.65362054,
 'average_loss': 0.5759378,
 'label/mean': 0.375,
 'loss': 0.5704812,
 'precision': 0.6388889,
 'prediction/mean': 0.41331062,
 'recall': 0.46464646,
 'global_step': 20}

TF2:KerasLinearModelの使用

TensorFlow 2では、 tf.compat.v1.keras.models.LinearModelの代わりとなるtf.estimator.LinearEstimatorのインスタンスを作成できます。 tf.compat.v1.kerasパスは、互換性のために事前に作成されたモデルが存在することを示すために使用されます。

linear_model = tf.compat.v1.keras.experimental.LinearModel()
linear_model.compile(loss='mse', optimizer=create_sample_optimizer('tf2'), metrics=['accuracy'])
linear_model.fit(x_train, y_train, epochs=10)
linear_model.evaluate(x_eval, y_eval, return_dict=True)
Epoch 1/10
20/20 [==============================] - 0s 2ms/step - loss: 3.3817 - accuracy: 0.6252
Epoch 2/10
20/20 [==============================] - 0s 2ms/step - loss: 0.4947 - accuracy: 0.6571
Epoch 3/10
20/20 [==============================] - 0s 2ms/step - loss: 0.2136 - accuracy: 0.6794
Epoch 4/10
20/20 [==============================] - 0s 2ms/step - loss: 0.2121 - accuracy: 0.7002
Epoch 5/10
20/20 [==============================] - 0s 2ms/step - loss: 0.2043 - accuracy: 0.7129
Epoch 6/10
20/20 [==============================] - 0s 2ms/step - loss: 0.2099 - accuracy: 0.7592
Epoch 7/10
20/20 [==============================] - 0s 2ms/step - loss: 0.1638 - accuracy: 0.8006
Epoch 8/10
20/20 [==============================] - 0s 2ms/step - loss: 0.1764 - accuracy: 0.7943
Epoch 9/10
20/20 [==============================] - 0s 2ms/step - loss: 0.1759 - accuracy: 0.7783
Epoch 10/10
20/20 [==============================] - 0s 2ms/step - loss: 0.1635 - accuracy: 0.7974
9/9 [==============================] - 0s 2ms/step - loss: 0.1814 - accuracy: 0.7386
{'loss': 0.1814115047454834, 'accuracy': 0.7386363744735718}

例2:DNNEstimatorからの移行

TF1:DNNEstimatorの使用

TensorFlow 1では、 tf.estimator.DNNEstimatorを使用して、回帰および分類の問題のベースラインDNNモデルを作成できます。

dnn_estimator = tf.estimator.DNNEstimator(
    head=tf.estimator.BinaryClassHead(),
    feature_columns=feature_columns,
    hidden_units=[128],
    activation_fn=tf.nn.relu,
    optimizer=create_sample_optimizer('tf1'))
INFO:tensorflow:Using default config.
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpqp1qe1uu
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpqp1qe1uu
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpqp1qe1uu', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpqp1qe1uu', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
dnn_estimator.train(input_fn=_input_fn, steps=100)
dnn_estimator.evaluate(input_fn=_eval_input_fn, steps=10)
プレースホルダー15l10n-プレースホルダー
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpqp1qe1uu/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpqp1qe1uu/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.8311484, step = 0
INFO:tensorflow:loss = 0.8311484, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 20...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 20...
INFO:tensorflow:Saving checkpoints for 20 into /tmp/tmpqp1qe1uu/model.ckpt.
INFO:tensorflow:Saving checkpoints for 20 into /tmp/tmpqp1qe1uu/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 20...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 20...
INFO:tensorflow:Loss for final step: 0.58950394.
INFO:tensorflow:Loss for final step: 0.58950394.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-11-09T02:24:50
INFO:tensorflow:Starting evaluation at 2021-11-09T02:24:50
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpqp1qe1uu/model.ckpt-20
INFO:tensorflow:Restoring parameters from /tmp/tmpqp1qe1uu/model.ckpt-20
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Inference Time : 0.50217s
INFO:tensorflow:Inference Time : 0.50217s
INFO:tensorflow:Finished evaluation at 2021-11-09-02:24:51
INFO:tensorflow:Finished evaluation at 2021-11-09-02:24:51
INFO:tensorflow:Saving dict for global step 20: accuracy = 0.70454544, accuracy_baseline = 0.625, auc = 0.69406193, auc_precision_recall = 0.60405815, average_loss = 0.6038741, global_step = 20, label/mean = 0.375, loss = 0.59827024, precision = 0.6363636, prediction/mean = 0.40024805, recall = 0.4949495
INFO:tensorflow:Saving dict for global step 20: accuracy = 0.70454544, accuracy_baseline = 0.625, auc = 0.69406193, auc_precision_recall = 0.60405815, average_loss = 0.6038741, global_step = 20, label/mean = 0.375, loss = 0.59827024, precision = 0.6363636, prediction/mean = 0.40024805, recall = 0.4949495
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 20: /tmp/tmpqp1qe1uu/model.ckpt-20
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 20: /tmp/tmpqp1qe1uu/model.ckpt-20
{'accuracy': 0.70454544,
 'accuracy_baseline': 0.625,
 'auc': 0.69406193,
 'auc_precision_recall': 0.60405815,
 'average_loss': 0.6038741,
 'label/mean': 0.375,
 'loss': 0.59827024,
 'precision': 0.6363636,
 'prediction/mean': 0.40024805,
 'recall': 0.4949495,
 'global_step': 20}

TF2:Kerasを使用してカスタムDNNモデルを作成する

TensorFlow 2では、カスタムDNNモデルを作成して、 tf.estimator.DNNEstimatorによって生成されたモデルの代わりに、同様のレベルのユーザー指定のカスタマイズ(たとえば、前の例のように、選択したモデルオプティマイザーをカスタマイズする機能)を使用できます。 。

同様のワークフローを使用して、 tf.estimator.experimental.RNNEstimatorをKerasRNNモデルに置き換えることができます。 Kerasは、 tf.keras.layers.RNNtf.keras.layers.LSTM 、およびtf.keras.layers.GRUを介して、多数の組み込みのカスタマイズ可能な選択肢を提供します。詳細については、こちらを参照してください。

dnn_model = tf.keras.models.Sequential(
    [tf.keras.layers.Dense(128, activation='relu'),
     tf.keras.layers.Dense(1)])

dnn_model.compile(loss='mse', optimizer=create_sample_optimizer('tf2'), metrics=['accuracy'])
dnn_model.fit(x_train, y_train, epochs=10)
dnn_model.evaluate(x_eval, y_eval, return_dict=True)
Epoch 1/10
20/20 [==============================] - 0s 2ms/step - loss: 1654.4760 - accuracy: 0.5821
Epoch 2/10
20/20 [==============================] - 0s 2ms/step - loss: 0.2467 - accuracy: 0.6683
Epoch 3/10
20/20 [==============================] - 0s 2ms/step - loss: 0.2262 - accuracy: 0.6730
Epoch 4/10
20/20 [==============================] - 0s 2ms/step - loss: 0.1946 - accuracy: 0.7193
Epoch 5/10
20/20 [==============================] - 0s 2ms/step - loss: 0.1925 - accuracy: 0.7544
Epoch 6/10
20/20 [==============================] - 0s 2ms/step - loss: 0.1692 - accuracy: 0.7671
Epoch 7/10
20/20 [==============================] - 0s 2ms/step - loss: 0.1607 - accuracy: 0.7927
Epoch 8/10
20/20 [==============================] - 0s 2ms/step - loss: 0.1571 - accuracy: 0.7927
Epoch 9/10
20/20 [==============================] - 0s 2ms/step - loss: 0.1604 - accuracy: 0.7895
Epoch 10/10
20/20 [==============================] - 0s 2ms/step - loss: 0.1521 - accuracy: 0.7974
9/9 [==============================] - 0s 2ms/step - loss: 0.1851 - accuracy: 0.7197
{'loss': 0.18513108789920807, 'accuracy': 0.7196969985961914}
プレースホルダー19

例3:DNNLinearCombinedEstimatorからの移行

TF1:DNNLinearCombinedEstimatorの使用

TensorFlow 1では、 tf.estimator.DNNLinearCombinedEstimatorを使用して、線形コンポーネントとDNNコンポーネントの両方のカスタマイズ機能を備えた回帰および分類問題のベースライン結合モデルを作成できます。

optimizer = create_sample_optimizer('tf1')

combined_estimator = tf.estimator.DNNLinearCombinedEstimator(
    head=tf.estimator.BinaryClassHead(),
    # Wide settings
    linear_feature_columns=feature_columns,
    linear_optimizer=optimizer,
    # Deep settings
    dnn_feature_columns=feature_columns,
    dnn_hidden_units=[128],
    dnn_optimizer=optimizer)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp89s7nmxt
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp89s7nmxt
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp89s7nmxt', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp89s7nmxt', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
combined_estimator.train(input_fn=_input_fn, steps=100)
combined_estimator.evaluate(input_fn=_eval_input_fn, steps=10)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/canned/linear.py:1478: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use `layer.add_weight` method instead.
  getter=tf.compat.v1.get_variable)
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp89s7nmxt/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp89s7nmxt/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 1.2762358, step = 0
INFO:tensorflow:loss = 1.2762358, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 20...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 20...
INFO:tensorflow:Saving checkpoints for 20 into /tmp/tmp89s7nmxt/model.ckpt.
INFO:tensorflow:Saving checkpoints for 20 into /tmp/tmp89s7nmxt/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 20...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 20...
INFO:tensorflow:Loss for final step: 0.54353213.
INFO:tensorflow:Loss for final step: 0.54353213.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-11-09T02:24:55
INFO:tensorflow:Starting evaluation at 2021-11-09T02:24:55
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmp89s7nmxt/model.ckpt-20
INFO:tensorflow:Restoring parameters from /tmp/tmp89s7nmxt/model.ckpt-20
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Inference Time : 0.63745s
INFO:tensorflow:Inference Time : 0.63745s
INFO:tensorflow:Finished evaluation at 2021-11-09-02:24:56
INFO:tensorflow:Finished evaluation at 2021-11-09-02:24:56
INFO:tensorflow:Saving dict for global step 20: accuracy = 0.71590906, accuracy_baseline = 0.625, auc = 0.75469846, auc_precision_recall = 0.6428577, average_loss = 0.5944004, global_step = 20, label/mean = 0.375, loss = 0.582566, precision = 0.65789473, prediction/mean = 0.40653682, recall = 0.5050505
INFO:tensorflow:Saving dict for global step 20: accuracy = 0.71590906, accuracy_baseline = 0.625, auc = 0.75469846, auc_precision_recall = 0.6428577, average_loss = 0.5944004, global_step = 20, label/mean = 0.375, loss = 0.582566, precision = 0.65789473, prediction/mean = 0.40653682, recall = 0.5050505
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 20: /tmp/tmp89s7nmxt/model.ckpt-20
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 20: /tmp/tmp89s7nmxt/model.ckpt-20
{'accuracy': 0.71590906,
 'accuracy_baseline': 0.625,
 'auc': 0.75469846,
 'auc_precision_recall': 0.6428577,
 'average_loss': 0.5944004,
 'label/mean': 0.375,
 'loss': 0.582566,
 'precision': 0.65789473,
 'prediction/mean': 0.40653682,
 'recall': 0.5050505,
 'global_step': 20}

TF2:KerasWideDeepModelの使用

TensorFlow 2では、Keras tf.compat.v1.keras.models.WideDeepModelのインスタンスを作成して、 tf.estimator.DNNLinearCombinedEstimatorによって生成されたインスタンスの代わりに、同様のレベルのユーザー指定のカスタマイズを行うことができます(たとえば、前の例では、選択したモデルオプティマイザーをカスタマイズする機能)。

このWideDeepModelは、構成要素であるLinearModelとカスタムDNNモデルに基づいて構築されます。これらは両方とも、前の2つの例で説明されています。必要に応じて、組み込みのLinearModelの代わりにカスタム線形モデルを使用することもできます。

缶詰の推定量の代わりに独自のモデルを構築したい場合はkeras.Sequentialモデルを構築する方法を確認してください。カスタムトレーニングとオプティマイザーの詳細については、このガイドを確認することもできます。

# Create LinearModel and DNN Model as in Examples 1 and 2
optimizer = create_sample_optimizer('tf2')

linear_model = tf.compat.v1.keras.experimental.LinearModel()
linear_model.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])
linear_model.fit(x_train, y_train, epochs=10, verbose=0)

dnn_model = tf.keras.models.Sequential(
    [tf.keras.layers.Dense(128, activation='relu'),
     tf.keras.layers.Dense(1)])
dnn_model.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])
combined_model = tf.compat.v1.keras.experimental.WideDeepModel(linear_model,
                                                               dnn_model)
combined_model.compile(
    optimizer=[optimizer, optimizer], loss='mse', metrics=['accuracy'])
combined_model.fit([x_train, x_train], y_train, epochs=10)
combined_model.evaluate(x_eval, y_eval, return_dict=True)
Epoch 1/10
20/20 [==============================] - 0s 3ms/step - loss: 838.9148 - accuracy: 0.6667
Epoch 2/10
20/20 [==============================] - 0s 3ms/step - loss: 0.2411 - accuracy: 0.7544
Epoch 3/10
20/20 [==============================] - 0s 3ms/step - loss: 0.1947 - accuracy: 0.7911
Epoch 4/10
20/20 [==============================] - 0s 3ms/step - loss: 0.1755 - accuracy: 0.7974
Epoch 5/10
20/20 [==============================] - 0s 3ms/step - loss: 0.1742 - accuracy: 0.7767
Epoch 6/10
20/20 [==============================] - 0s 3ms/step - loss: 0.1595 - accuracy: 0.8070
Epoch 7/10
20/20 [==============================] - 0s 3ms/step - loss: 0.1558 - accuracy: 0.8086
Epoch 8/10
20/20 [==============================] - 0s 3ms/step - loss: 0.1581 - accuracy: 0.8022
Epoch 9/10
20/20 [==============================] - 0s 3ms/step - loss: 0.1509 - accuracy: 0.8006
Epoch 10/10
20/20 [==============================] - 0s 3ms/step - loss: 0.1527 - accuracy: 0.8070
9/9 [==============================] - 0s 2ms/step - loss: 0.1783 - accuracy: 0.7538
{'loss': 0.17828215658664703, 'accuracy': 0.7537878751754761}

例4:BoostedTreesEstimatorからの移行

TF1:BoostedTreesEstimatorの使用

TensorFlow 1では、 tf.estimator.BoostedTreesEstimatorを使用してベースラインを作成し、回帰および分類問題の決定木のアンサンブルを使用してベースライン勾配ブースティングモデルを作成できます。

bt_estimator = tf1.estimator.BoostedTreesEstimator(
    head=tf.estimator.BinaryClassHead(),
    n_batches_per_layer=1,
    max_depth=10,
    n_trees=1000,
    feature_columns=feature_columns)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp7uv3o9pn
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp7uv3o9pn
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp7uv3o9pn', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp7uv3o9pn', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
bt_estimator.train(input_fn=_input_fn, steps=1000)
bt_estimator.evaluate(input_fn=_eval_input_fn, steps=100)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
WARNING:tensorflow:Issue encountered when serializing resources.
Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'_Resource' object has no attribute 'name'
WARNING:tensorflow:Issue encountered when serializing resources.
Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'_Resource' object has no attribute 'name'
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
WARNING:tensorflow:Issue encountered when serializing resources.
Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'_Resource' object has no attribute 'name'
WARNING:tensorflow:Issue encountered when serializing resources.
Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'_Resource' object has no attribute 'name'
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp7uv3o9pn/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp7uv3o9pn/model.ckpt.
WARNING:tensorflow:Issue encountered when serializing resources.
Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'_Resource' object has no attribute 'name'
WARNING:tensorflow:Issue encountered when serializing resources.
Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'_Resource' object has no attribute 'name'
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.6931472, step = 0
INFO:tensorflow:loss = 0.6931472, step = 0
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 0 vs previous value: 0. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 0 vs previous value: 0. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 19...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 19...
INFO:tensorflow:Saving checkpoints for 19 into /tmp/tmp7uv3o9pn/model.ckpt.
INFO:tensorflow:Saving checkpoints for 19 into /tmp/tmp7uv3o9pn/model.ckpt.
WARNING:tensorflow:Issue encountered when serializing resources.
Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'_Resource' object has no attribute 'name'
WARNING:tensorflow:Issue encountered when serializing resources.
Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'_Resource' object has no attribute 'name'
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 19...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 19...
INFO:tensorflow:Loss for final step: 0.3596191.
INFO:tensorflow:Loss for final step: 0.3596191.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-11-09T02:25:01
INFO:tensorflow:Starting evaluation at 2021-11-09T02:25:01
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmp7uv3o9pn/model.ckpt-19
INFO:tensorflow:Restoring parameters from /tmp/tmp7uv3o9pn/model.ckpt-19
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Inference Time : 0.54183s
INFO:tensorflow:Inference Time : 0.54183s
INFO:tensorflow:Finished evaluation at 2021-11-09-02:25:01
INFO:tensorflow:Finished evaluation at 2021-11-09-02:25:01
INFO:tensorflow:Saving dict for global step 19: accuracy = 0.77272725, accuracy_baseline = 0.625, auc = 0.80756044, auc_precision_recall = 0.7711308, average_loss = 0.5074444, global_step = 19, label/mean = 0.375, loss = 0.49429518, precision = 0.71910113, prediction/mean = 0.4125744, recall = 0.64646465
INFO:tensorflow:Saving dict for global step 19: accuracy = 0.77272725, accuracy_baseline = 0.625, auc = 0.80756044, auc_precision_recall = 0.7711308, average_loss = 0.5074444, global_step = 19, label/mean = 0.375, loss = 0.49429518, precision = 0.71910113, prediction/mean = 0.4125744, recall = 0.64646465
WARNING:tensorflow:Issue encountered when serializing resources.
Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'_Resource' object has no attribute 'name'
WARNING:tensorflow:Issue encountered when serializing resources.
Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'_Resource' object has no attribute 'name'
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 19: /tmp/tmp7uv3o9pn/model.ckpt-19
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 19: /tmp/tmp7uv3o9pn/model.ckpt-19
{'accuracy': 0.77272725,
 'accuracy_baseline': 0.625,
 'auc': 0.80756044,
 'auc_precision_recall': 0.7711308,
 'average_loss': 0.5074444,
 'label/mean': 0.375,
 'loss': 0.49429518,
 'precision': 0.71910113,
 'prediction/mean': 0.4125744,
 'recall': 0.64646465,
 'global_step': 19}

TF2:TensorFlowデシジョンフォレストの使用

TensorFlow 2では、 tfdf.keras.GradientBoostedTreesModelによって生成されたモデルの最も近い事前にパッケージ化された代替物は、 tf.estimator.BoostedTreesEstimatorを使用して作成されたものであり、それぞれがエラーから「学習」するように設計された、浅い決定木のシーケンスを順次トレーニングします。シーケンスの前任者によって作成されました。

GradientBoostedTreesModelは、カスタマイズのためのより多くのオプションを提供し、基本的な深度制約から早期停止条件まですべての指定を可能にします。 GradientBoostedTreesModel属性の詳細については、ここを参照してください。

gbt_model = tfdf.keras.GradientBoostedTreesModel(
    task=tfdf.keras.Task.CLASSIFICATION)
gbt_model.compile(metrics=['mse', 'accuracy'])
train_df, eval_df = x_train.copy(), x_eval.copy()
train_df['survived'], eval_df['survived'] = y_train, y_eval

train_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(train_df, label='survived')
eval_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(eval_df, label='survived')

gbt_model.fit(train_dataset)
gbt_model.evaluate(eval_dataset, return_dict=True)
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_decision_forests/keras/core.py:1612: FutureWarning: In a future version of pandas all arguments of DataFrame.drop except for the argument 'labels' will be keyword-only
  features_dataframe = dataframe.drop(label, 1)
10/10 [==============================] - 4s 15ms/step
[INFO kernel.cc:736] Start Yggdrasil model training
[INFO kernel.cc:737] Collect training examples
[INFO kernel.cc:392] Number of batches: 10
[INFO kernel.cc:393] Number of examples: 627
[INFO kernel.cc:759] Dataset:
Number of records: 627
Number of columns: 8

Number of columns by type:
    NUMERICAL: 7 (87.5%)
    CATEGORICAL: 1 (12.5%)

Columns:

NUMERICAL: 7 (87.5%)
    0: "age" NUMERICAL mean:29.6313 min:0.75 max:80 sd:12.5018
    1: "alone" NUMERICAL mean:0.593301 min:0 max:1 sd:0.491218
    2: "class" NUMERICAL mean:2.29027 min:1 max:3 sd:0.844506
    3: "fare" NUMERICAL mean:34.3854 min:0 max:512.329 sd:54.5542
    4: "n_siblings_spouses" NUMERICAL mean:0.545455 min:0 max:8 sd:1.15017
    5: "parch" NUMERICAL mean:0.379585 min:0 max:5 sd:0.792367
    6: "sex" NUMERICAL mean:0.346093 min:0 max:1 sd:0.475723

CATEGORICAL: 1 (12.5%)
    7: "__LABEL" CATEGORICAL integerized vocab-size:3 no-ood-item

Terminology:
    nas: Number of non-available (i.e. missing) values.
    ood: Out of dictionary.
    manually-defined: Attribute which type is manually defined by the user i.e. the type was not automatically inferred.
    tokenized: The attribute value is obtained through tokenization.
    has-dict: The attribute is attached to a string dictionary e.g. a categorical attribute stored as a string.
    vocab-size: Number of unique values.

[INFO kernel.cc:762] Configure learner
[WARNING gradient_boosted_trees.cc:1643] Subsample hyperparameter given but sampling method does not match.
[WARNING gradient_boosted_trees.cc:1656] GOSS alpha hyperparameter given but GOSS is disabled.
[WARNING gradient_boosted_trees.cc:1665] GOSS beta hyperparameter given but GOSS is disabled.
[WARNING gradient_boosted_trees.cc:1677] SelGB ratio hyperparameter given but SelGB is disabled.
[INFO kernel.cc:787] Training config:
learner: "GRADIENT_BOOSTED_TREES"
features: "age"
features: "alone"
features: "class"
features: "fare"
features: "n_siblings_spouses"
features: "parch"
features: "sex"
label: "__LABEL"
task: CLASSIFICATION
[yggdrasil_decision_forests.model.gradient_boosted_trees.proto.gradient_boosted_trees_config] {
  num_trees: 300
  decision_tree {
    max_depth: 6
    min_examples: 5
    in_split_min_examples_check: true
    missing_value_policy: GLOBAL_IMPUTATION
    allow_na_conditions: false
    categorical_set_greedy_forward {
      sampling: 0.1
      max_num_items: -1
      min_item_frequency: 1
    }
    growing_strategy_local {
    }
    categorical {
      cart {
      }
    }
    num_candidate_attributes_ratio: -1
    axis_aligned_split {
    }
    internal {
      sorting_strategy: PRESORTED
    }
  }
  shrinkage: 0.1
  validation_set_ratio: 0.1
  early_stopping: VALIDATION_LOSS_INCREASE
  early_stopping_num_trees_look_ahead: 30
  l2_regularization: 0
  lambda_loss: 1
  mart {
  }
  adapt_subsample_for_maximum_training_duration: false
  l1_regularization: 0
  use_hessian_gain: false
  l2_regularization_categorical: 1
  apply_link_function: true
  compute_permutation_variable_importance: false
}

[INFO kernel.cc:790] Deployment config:
num_threads: 6

[INFO kernel.cc:817] Train model
[INFO gradient_boosted_trees.cc:404] Default loss set to BINOMIAL_LOG_LIKELIHOOD
[INFO gradient_boosted_trees.cc:1001] Training gradient boosted tree on 627 example(s) and 7 feature(s).
[INFO gradient_boosted_trees.cc:1044] 569 examples used for training and 58 examples used for validation
[INFO gradient_boosted_trees.cc:1426]     num-trees:1 train-loss:1.234815 train-accuracy:0.615114 valid-loss:1.274158 valid-accuracy:0.586207
[INFO gradient_boosted_trees.cc:1428]     num-trees:2 train-loss:1.149731 train-accuracy:0.659051 valid-loss:1.218309 valid-accuracy:0.655172
[INFO gradient_boosted_trees.cc:2740] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.949917
[INFO gradient_boosted_trees.cc:229] Truncates the model to 15 tree(s) i.e. 15  iteration(s).
[INFO gradient_boosted_trees.cc:263] Final model num-trees:15 valid-loss:0.949917 valid-accuracy:0.810345
[INFO kernel.cc:828] Export model in log directory: /tmp/tmpocaf8kqm
[INFO kernel.cc:836] Save model in resources
[INFO kernel.cc:988] Loading model from path
[INFO abstract_model.cc:993] Engine "GradientBoostedTreesQuickScorerExtended" built
[INFO kernel.cc:848] Use fast generic engine
5/5 [==============================] - 0s 3ms/step - loss: 0.0000e+00 - mse: 0.1319 - accuracy: 0.8144
{'loss': 0.0, 'mse': 0.13190586864948273, 'accuracy': 0.814393937587738}

TensorFlow 2には、 tf.estimator.BoostedTreesEstimatorによって生成されたモデルの代わりに使用できる別のTFDFがありますtfdf.keras.RandomForestModelRandomForestModelは、入力トレーニングデータセットのランダムなサブセットでそれぞれトレーニングされた、深い決定木の投票母集団で構成される、堅牢で過剰適合に強い学習者を作成します。

RandomForestModelGradientBoostedTreesModelは、同様に広範なレベルのカスタマイズを提供します。それらのどちらを選択するかは問題固有であり、タスクまたはアプリケーションによって異なります。

RandomForestModelおよびGradientBoostedTreesModel属性の詳細については、APIドキュメントを確認してください。

rf_model = tfdf.keras.RandomForestModel(
    task=tfdf.keras.Task.CLASSIFICATION)
rf_model.compile(metrics=['mse', 'accuracy'])
rf_model.fit(train_dataset)
rf_model.evaluate(eval_dataset, return_dict=True)
10/10 [==============================] - 0s 22ms/step
[INFO kernel.cc:736] Start Yggdrasil model training
[INFO kernel.cc:737] Collect training examples
[INFO kernel.cc:392] Number of batches: 10
[INFO kernel.cc:393] Number of examples: 627
[INFO kernel.cc:759] Dataset:
Number of records: 627
Number of columns: 8

Number of columns by type:
    NUMERICAL: 7 (87.5%)
    CATEGORICAL: 1 (12.5%)

Columns:

NUMERICAL: 7 (87.5%)
    0: "age" NUMERICAL mean:29.6313 min:0.75 max:80 sd:12.5018
    1: "alone" NUMERICAL mean:0.593301 min:0 max:1 sd:0.491218
    2: "class" NUMERICAL mean:2.29027 min:1 max:3 sd:0.844506
    3: "fare" NUMERICAL mean:34.3854 min:0 max:512.329 sd:54.5542
    4: "n_siblings_spouses" NUMERICAL mean:0.545455 min:0 max:8 sd:1.15017
    5: "parch" NUMERICAL mean:0.379585 min:0 max:5 sd:0.792367
    6: "sex" NUMERICAL mean:0.346093 min:0 max:1 sd:0.475723

CATEGORICAL: 1 (12.5%)
    7: "__LABEL" CATEGORICAL integerized vocab-size:3 no-ood-item

Terminology:
    nas: Number of non-available (i.e. missing) values.
    ood: Out of dictionary.
    manually-defined: Attribute which type is manually defined by the user i.e. the type was not automatically inferred.
    tokenized: The attribute value is obtained through tokenization.
    has-dict: The attribute is attached to a string dictionary e.g. a categorical attribute stored as a string.
    vocab-size: Number of unique values.

[INFO kernel.cc:762] Configure learner
[INFO kernel.cc:787] Training config:
learner: "RANDOM_FOREST"
features: "age"
features: "alone"
features: "class"
features: "fare"
features: "n_siblings_spouses"
features: "parch"
features: "sex"
label: "__LABEL"
task: CLASSIFICATION
[yggdrasil_decision_forests.model.random_forest.proto.random_forest_config] {
  num_trees: 300
  decision_tree {
    max_depth: 16
    min_examples: 5
    in_split_min_examples_check: true
    missing_value_policy: GLOBAL_IMPUTATION
    allow_na_conditions: false
    categorical_set_greedy_forward {
      sampling: 0.1
      max_num_items: -1
      min_item_frequency: 1
    }
    growing_strategy_local {
    }
    categorical {
      cart {
      }
    }
    num_candidate_attributes_ratio: -1
    axis_aligned_split {
    }
    internal {
      sorting_strategy: PRESORTED
    }
  }
  winner_take_all_inference: true
  compute_oob_performances: true
  compute_oob_variable_importances: false
  adapt_bootstrap_size_ratio_for_maximum_training_duration: false
}

[INFO kernel.cc:790] Deployment config:
num_threads: 6

[INFO kernel.cc:817] Train model
[INFO random_forest.cc:315] Training random forest on 627 example(s) and 7 feature(s).
[INFO random_forest.cc:628] Training of tree  1/300 (tree index:2) done accuracy:0.730435 logloss:9.71611
[INFO random_forest.cc:628] Training of tree  11/300 (tree index:10) done accuracy:0.8016 logloss:3.93822
[INFO random_forest.cc:628] Training of tree  21/300 (tree index:18) done accuracy:0.786284 logloss:2.61817
[INFO random_forest.cc:628] Training of tree  31/300 (tree index:29) done accuracy:0.797448 logloss:2.29544
[INFO random_forest.cc:628] Training of tree  41/300 (tree index:39) done accuracy:0.795853 logloss:2.24153
[INFO random_forest.cc:628] Training of tree  51/300 (tree index:51) done accuracy:0.803828 logloss:2.13208
[INFO random_forest.cc:628] Training of tree  61/300 (tree index:62) done accuracy:0.807018 logloss:1.97651
[INFO random_forest.cc:628] Training of tree  71/300 (tree index:69) done accuracy:0.799043 logloss:1.76892
[INFO random_forest.cc:628] Training of tree  81/300 (tree index:79) done accuracy:0.800638 logloss:1.40511
[INFO random_forest.cc:628] Training of tree  91/300 (tree index:90) done accuracy:0.802233 logloss:1.29943
[INFO random_forest.cc:628] Training of tree  101/300 (tree index:100) done accuracy:0.795853 logloss:1.3013
[INFO random_forest.cc:628] Training of tree  111/300 (tree index:108) done accuracy:0.799043 logloss:1.29775
[INFO random_forest.cc:628] Training of tree  121/300 (tree index:119) done accuracy:0.795853 logloss:1.29933
[INFO random_forest.cc:628] Training of tree  131/300 (tree index:132) done accuracy:0.797448 logloss:1.2497
[INFO random_forest.cc:628] Training of tree  141/300 (tree index:141) done accuracy:0.800638 logloss:1.25267
[INFO random_forest.cc:628] Training of tree  151/300 (tree index:149) done accuracy:0.800638 logloss:1.25276
[INFO random_forest.cc:628] Training of tree  161/300 (tree index:162) done accuracy:0.800638 logloss:1.20288
[INFO random_forest.cc:628] Training of tree  171/300 (tree index:171) done accuracy:0.802233 logloss:1.15226
[INFO random_forest.cc:628] Training of tree  181/300 (tree index:182) done accuracy:0.803828 logloss:1.15123
[INFO random_forest.cc:628] Training of tree  192/300 (tree index:191) done accuracy:0.807018 logloss:1.14935
[INFO random_forest.cc:628] Training of tree  202/300 (tree index:200) done accuracy:0.803828 logloss:1.14886
[INFO random_forest.cc:628] Training of tree  212/300 (tree index:213) done accuracy:0.807018 logloss:1.14736
[INFO random_forest.cc:628] Training of tree  222/300 (tree index:222) done accuracy:0.807018 logloss:1.0956
[INFO random_forest.cc:628] Training of tree  232/300 (tree index:230) done accuracy:0.803828 logloss:0.995656
[INFO random_forest.cc:628] Training of tree  242/300 (tree index:240) done accuracy:0.805423 logloss:0.996381
[INFO random_forest.cc:628] Training of tree  252/300 (tree index:253) done accuracy:0.807018 logloss:0.994959
[INFO random_forest.cc:628] Training of tree  262/300 (tree index:263) done accuracy:0.810207 logloss:0.994315
[INFO random_forest.cc:628] Training of tree  272/300 (tree index:269) done accuracy:0.808612 logloss:0.995073
[INFO random_forest.cc:628] Training of tree  282/300 (tree index:280) done accuracy:0.808612 logloss:0.943966
[INFO random_forest.cc:628] Training of tree  292/300 (tree index:293) done accuracy:0.807018 logloss:0.943486
[INFO random_forest.cc:628] Training of tree  300/300 (tree index:298) done accuracy:0.805423 logloss:0.944073
[INFO random_forest.cc:696] Final OOB metrics: accuracy:0.805423 logloss:0.944073
[INFO kernel.cc:828] Export model in log directory: /tmp/tmp8vy0bslr
[INFO kernel.cc:836] Save model in resources
[INFO kernel.cc:988] Loading model from path
[INFO decision_forest.cc:590] Model loaded with 300 root(s), 34374 node(s), and 7 input feature(s).
[INFO kernel.cc:848] Use fast generic engine
5/5 [==============================] - 0s 4ms/step - loss: 0.0000e+00 - mse: 0.1270 - accuracy: 0.8636
{'loss': 0.0, 'mse': 0.12698587775230408, 'accuracy': 0.8636363744735718}