Migration examples: Canned Estimators

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Canned (or Premade) Estimators have traditionally been used in TensorFlow 1 as quick and easy ways to train models for a variety of typical use cases. TensorFlow 2 provides straightforward approximate substitutes for a number of them by way of Keras models. For those canned estimators that do not have built-in TensorFlow 2 substitutes, you can still build your own replacement fairly easily.

This guide will walk you through a few examples of direct equivalents and custom substitutions to demonstrate how TensorFlow 1's tf.estimator-derived models can be migrated to TensorFlow 2 with Keras.

Namely, this guide includes examples for migrating:

  • From tf.estimator's LinearEstimator, Classifier or Regressor in TensorFlow 1 to Keras tf.compat.v1.keras.models.LinearModel in TensorFlow 2
  • From tf.estimator's DNNEstimator, Classifier or Regressor in TensorFlow 1 to a custom Keras DNN ModelKeras in TensorFlow 2
  • From tf.estimator's DNNLinearCombinedEstimator, Classifier or Regressor in TensorFlow 1 to tf.compat.v1.keras.models.WideDeepModel in TensorFlow 2
  • From tf.estimator's BoostedTreesEstimator, Classifier or Regressor in TensorFlow 1 to tfdf.keras.GradientBoostedTreesModel in TensorFlow 2

A common precursor to the training of a model is feature preprocessing, which is done for TensorFlow 1 Estimator models with tf.feature_column. For more information on feature preprocessing in TensorFlow 2, see this guide on migrating from feature columns to the Keras preprocessing layers API.

Setup

Start with a couple of necessary TensorFlow imports,

pip install tensorflow_decision_forests
import pandas as pd
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import tensorflow_decision_forests as tfdf
from tensorflow import keras

prepare some simple data for demonstration from the standard Titanic dataset,

x_train = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/train.csv')
x_eval = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/eval.csv')
x_train['sex'].replace(('male', 'female'), (0, 1), inplace=True)
x_eval['sex'].replace(('male', 'female'), (0, 1), inplace=True)

x_train['alone'].replace(('n', 'y'), (0, 1), inplace=True)
x_eval['alone'].replace(('n', 'y'), (0, 1), inplace=True)

x_train['class'].replace(('First', 'Second', 'Third'), (1, 2, 3), inplace=True)
x_eval['class'].replace(('First', 'Second', 'Third'), (1, 2, 3), inplace=True)

x_train.drop(['embark_town', 'deck'], axis=1, inplace=True)
x_eval.drop(['embark_town', 'deck'], axis=1, inplace=True)

y_train = x_train.pop('survived')
y_eval = x_eval.pop('survived')
# Data setup for TensorFlow 1 with `tf.estimator`
def _input_fn():
  return tf1.data.Dataset.from_tensor_slices((dict(x_train), y_train)).batch(32)


def _eval_input_fn():
  return tf1.data.Dataset.from_tensor_slices((dict(x_eval), y_eval)).batch(32)


FEATURE_NAMES = [
    'age', 'fare', 'sex', 'n_siblings_spouses', 'parch', 'class', 'alone'
]

feature_columns = []
for fn in FEATURE_NAMES:
  feat_col = tf1.feature_column.numeric_column(fn, dtype=tf.float32)
  feature_columns.append(feat_col)
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_10080/2801132002.py:16: numeric_column (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.
Instructions for updating:
Use Keras preprocessing layers instead, either directly or via the `tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has a functional equivalent in `tf.keras.layers` for feature preprocessing when training a Keras model.

and create a method to instantiate a simplistic sample optimizer to use with various TensorFlow 1 Estimator and TensorFlow 2 Keras models.

def create_sample_optimizer(tf_version):
  if tf_version == 'tf1':
    optimizer = lambda: tf.keras.optimizers.legacy.Ftrl(
        l1_regularization_strength=0.001,
        learning_rate=tf1.train.exponential_decay(
            learning_rate=0.1,
            global_step=tf1.train.get_global_step(),
            decay_steps=10000,
            decay_rate=0.9))
  elif tf_version == 'tf2':
    optimizer = tf.keras.optimizers.legacy.Ftrl(
        l1_regularization_strength=0.001,
        learning_rate=tf.keras.optimizers.schedules.ExponentialDecay(
            initial_learning_rate=0.1, decay_steps=10000, decay_rate=0.9))
  return optimizer

Example 1: Migrating from LinearEstimator

TensorFlow 1: Using LinearEstimator

In TensorFlow 1, you can use tf.estimator.LinearEstimator to create a baseline linear model for regression and classification problems.

linear_estimator = tf.estimator.LinearEstimator(
    head=tf.estimator.BinaryClassHead(),
    feature_columns=feature_columns,
    optimizer=create_sample_optimizer('tf1'))
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_10080/2944250643.py:2: BinaryClassHead.__init__ (from tensorflow_estimator.python.estimator.head.binary_class_head) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_10080/2944250643.py:1: LinearEstimatorV2.__init__ (from tensorflow_estimator.python.estimator.canned.linear) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/canned/linear.py:1124: Estimator.__init__ (from tensorflow_estimator.python.estimator.estimator) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1844: RunConfig.__init__ (from tensorflow_estimator.python.estimator.run_config) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpcvrw6s1d
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpcvrw6s1d', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
linear_estimator.train(input_fn=_input_fn, steps=100)
linear_estimator.evaluate(input_fn=_eval_input_fn, steps=10)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:385: StopAtStepHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/legacy/ftrl.py:173: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/model_fn.py:250: EstimatorSpec.__new__ (from tensorflow_estimator.python.estimator.model_fn) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Done calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1416: NanTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1419: LoggingTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/basic_session_run_hooks.py:232: SecondOrStepTimer.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1456: CheckpointSaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Create CheckpointSaverHook.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:579: StepCounterHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:586: SummarySaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpcvrw6s1d/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1455: SessionRunArgs.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1454: SessionRunContext.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1474: SessionRunValues.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:loss = 0.6931472, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 20...
INFO:tensorflow:Saving checkpoints for 20 into /tmpfs/tmp/tmpcvrw6s1d/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 20...
INFO:tensorflow:Loss for final step: 0.55268794.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2023-09-16T01:21:55
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/evaluation.py:260: FinalOpsHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpcvrw6s1d/model.ckpt-20
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Inference Time : 0.58102s
INFO:tensorflow:Finished evaluation at 2023-09-16-01:21:55
INFO:tensorflow:Saving dict for global step 20: accuracy = 0.70075756, accuracy_baseline = 0.625, auc = 0.75472915, auc_precision_recall = 0.65362054, average_loss = 0.5759378, global_step = 20, label/mean = 0.375, loss = 0.5704812, precision = 0.6388889, prediction/mean = 0.41331062, recall = 0.46464646
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 20: /tmpfs/tmp/tmpcvrw6s1d/model.ckpt-20
{'accuracy': 0.70075756,
 'accuracy_baseline': 0.625,
 'auc': 0.75472915,
 'auc_precision_recall': 0.65362054,
 'average_loss': 0.5759378,
 'label/mean': 0.375,
 'loss': 0.5704812,
 'precision': 0.6388889,
 'prediction/mean': 0.41331062,
 'recall': 0.46464646,
 'global_step': 20}

TensorFlow 2: Using Keras LinearModel

In TensorFlow 2, you can create an instance of the Keras tf.compat.v1.keras.models.LinearModel which is the substitute to the tf.estimator.LinearEstimator. The tf.compat.v1.keras path is used to signify that the pre-made model exists for compatibility.

linear_model = tf.compat.v1.keras.experimental.LinearModel()
linear_model.compile(loss='mse', optimizer=create_sample_optimizer('tf2'), metrics=['accuracy'])
linear_model.fit(x_train, y_train, epochs=10)
linear_model.evaluate(x_eval, y_eval, return_dict=True)
Epoch 1/10
20/20 [==============================] - 0s 2ms/step - loss: 3.6712 - accuracy: 0.6077
Epoch 2/10
20/20 [==============================] - 0s 2ms/step - loss: 0.2146 - accuracy: 0.6715
Epoch 3/10
20/20 [==============================] - 0s 2ms/step - loss: 0.1980 - accuracy: 0.6874
Epoch 4/10
20/20 [==============================] - 0s 2ms/step - loss: 0.1880 - accuracy: 0.7129
Epoch 5/10
20/20 [==============================] - 0s 9ms/step - loss: 0.1805 - accuracy: 0.7337
Epoch 6/10
20/20 [==============================] - 0s 2ms/step - loss: 0.1807 - accuracy: 0.7624
Epoch 7/10
20/20 [==============================] - 0s 2ms/step - loss: 0.1692 - accuracy: 0.7783
Epoch 8/10
20/20 [==============================] - 0s 2ms/step - loss: 0.1691 - accuracy: 0.7927
Epoch 9/10
20/20 [==============================] - 0s 2ms/step - loss: 0.1717 - accuracy: 0.7911
Epoch 10/10
20/20 [==============================] - 0s 2ms/step - loss: 0.1596 - accuracy: 0.7990
9/9 [==============================] - 0s 2ms/step - loss: 0.1853 - accuracy: 0.7348
{'loss': 0.185323566198349, 'accuracy': 0.7348484992980957}

Example 2: Migrating from DNNEstimator

TensorFlow 1: Using DNNEstimator

In TensorFlow 1, you can use tf.estimator.DNNEstimator to create a baseline deep neural network (DNN) model for regression and classification problems.

dnn_estimator = tf.estimator.DNNEstimator(
    head=tf.estimator.BinaryClassHead(),
    feature_columns=feature_columns,
    hidden_units=[128],
    activation_fn=tf.nn.relu,
    optimizer=create_sample_optimizer('tf1'))
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_10080/1828606501.py:1: DNNEstimatorV2.__init__ (from tensorflow_estimator.python.estimator.canned.dnn) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpyih539cq
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpyih539cq', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
dnn_estimator.train(input_fn=_input_fn, steps=100)
dnn_estimator.evaluate(input_fn=_eval_input_fn, steps=10)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
2023-09-16 01:21:57.950353: W tensorflow/core/common_runtime/type_inference.cc:339] Type inference failed. This indicates an invalid graph that escaped type checking. Error message: INVALID_ARGUMENT: expected compatible input types, but input 1:
type_id: TFT_OPTIONAL
args {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_TENSOR
    args {
      type_id: TFT_INT64
    }
  }
}
 is neither a subtype nor a supertype of the combined inputs preceding it:
type_id: TFT_OPTIONAL
args {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_TENSOR
    args {
      type_id: TFT_INT32
    }
  }
}

    for Tuple type infernce function 0
    while inferring type of node 'dnn/zero_fraction/cond/output/_18'
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpyih539cq/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.9991276, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 20...
INFO:tensorflow:Saving checkpoints for 20 into /tmpfs/tmp/tmpyih539cq/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 20...
INFO:tensorflow:Loss for final step: 0.5818331.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2023-09-16T01:21:59
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpyih539cq/model.ckpt-20
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Inference Time : 0.52606s
INFO:tensorflow:Finished evaluation at 2023-09-16-01:21:59
INFO:tensorflow:Saving dict for global step 20: accuracy = 0.70454544, accuracy_baseline = 0.625, auc = 0.6964494, auc_precision_recall = 0.60180384, average_loss = 0.5988959, global_step = 20, label/mean = 0.375, loss = 0.59320897, precision = 0.6363636, prediction/mean = 0.38547936, recall = 0.4949495
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 20: /tmpfs/tmp/tmpyih539cq/model.ckpt-20
{'accuracy': 0.70454544,
 'accuracy_baseline': 0.625,
 'auc': 0.6964494,
 'auc_precision_recall': 0.60180384,
 'average_loss': 0.5988959,
 'label/mean': 0.375,
 'loss': 0.59320897,
 'precision': 0.6363636,
 'prediction/mean': 0.38547936,
 'recall': 0.4949495,
 'global_step': 20}

TensorFlow 2: Using Keras to create a custom DNN model

In TensorFlow 2, you can create a custom DNN model to substitute for one generated by tf.estimator.DNNEstimator, with similar levels of user-specified customization (for instance, as in the previous example, the ability to customize a chosen model optimizer).

A similar workflow can be used to replace tf.estimator.experimental.RNNEstimator with a Keras recurrent neural network (RNN) model. Keras provides a number of built-in, customizable choices by way of tf.keras.layers.RNN, tf.keras.layers.LSTM, and tf.keras.layers.GRU. To learn more, check out the Built-in RNN layers: a simple example section of RNN with Keras guide.

dnn_model = tf.keras.models.Sequential(
    [tf.keras.layers.Dense(128, activation='relu'),
     tf.keras.layers.Dense(1)])

dnn_model.compile(loss='mse', optimizer=create_sample_optimizer('tf2'), metrics=['accuracy'])
dnn_model.fit(x_train, y_train, epochs=10)
dnn_model.evaluate(x_eval, y_eval, return_dict=True)
Epoch 1/10
20/20 [==============================] - 0s 2ms/step - loss: 581.7921 - accuracy: 0.4338
Epoch 2/10
20/20 [==============================] - 0s 2ms/step - loss: 1.2502 - accuracy: 0.4450
Epoch 3/10
20/20 [==============================] - 0s 2ms/step - loss: 0.7286 - accuracy: 0.5183
Epoch 4/10
20/20 [==============================] - 0s 2ms/step - loss: 0.5274 - accuracy: 0.5486
Epoch 5/10
20/20 [==============================] - 0s 2ms/step - loss: 0.4200 - accuracy: 0.5550
Epoch 6/10
20/20 [==============================] - 0s 2ms/step - loss: 0.3405 - accuracy: 0.5821
Epoch 7/10
20/20 [==============================] - 0s 2ms/step - loss: 0.2975 - accuracy: 0.6124
Epoch 8/10
20/20 [==============================] - 0s 2ms/step - loss: 0.2628 - accuracy: 0.6507
Epoch 9/10
20/20 [==============================] - 0s 2ms/step - loss: 0.2382 - accuracy: 0.7002
Epoch 10/10
20/20 [==============================] - 0s 2ms/step - loss: 0.2218 - accuracy: 0.7273
9/9 [==============================] - 0s 2ms/step - loss: 0.2414 - accuracy: 0.6894
{'loss': 0.2413530945777893, 'accuracy': 0.689393937587738}

Example 3: Migrating from DNNLinearCombinedEstimator

TensorFlow 1: Using DNNLinearCombinedEstimator

In TensorFlow 1, you can use tf.estimator.DNNLinearCombinedEstimator to create a baseline combined model for regression and classification problems with customization capacity for both its linear and DNN components.

optimizer = create_sample_optimizer('tf1')

combined_estimator = tf.estimator.DNNLinearCombinedEstimator(
    head=tf.estimator.BinaryClassHead(),
    # Wide settings
    linear_feature_columns=feature_columns,
    linear_optimizer=optimizer,
    # Deep settings
    dnn_feature_columns=feature_columns,
    dnn_hidden_units=[128],
    dnn_optimizer=optimizer)
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_10080/1505653152.py:3: DNNLinearCombinedEstimatorV2.__init__ (from tensorflow_estimator.python.estimator.canned.dnn_linear_combined) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpun15otq5
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpun15otq5', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
combined_estimator.train(input_fn=_input_fn, steps=100)
combined_estimator.evaluate(input_fn=_eval_input_fn, steps=10)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpun15otq5/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 4.244113, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 20...
INFO:tensorflow:Saving checkpoints for 20 into /tmpfs/tmp/tmpun15otq5/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 20...
INFO:tensorflow:Loss for final step: 0.5406369.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2023-09-16T01:22:04
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpun15otq5/model.ckpt-20
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Inference Time : 0.58801s
INFO:tensorflow:Finished evaluation at 2023-09-16-01:22:04
INFO:tensorflow:Saving dict for global step 20: accuracy = 0.71590906, accuracy_baseline = 0.625, auc = 0.7440466, auc_precision_recall = 0.6447197, average_loss = 0.5923795, global_step = 20, label/mean = 0.375, loss = 0.5745624, precision = 0.65384614, prediction/mean = 0.3921669, recall = 0.5151515
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 20: /tmpfs/tmp/tmpun15otq5/model.ckpt-20
{'accuracy': 0.71590906,
 'accuracy_baseline': 0.625,
 'auc': 0.7440466,
 'auc_precision_recall': 0.6447197,
 'average_loss': 0.5923795,
 'label/mean': 0.375,
 'loss': 0.5745624,
 'precision': 0.65384614,
 'prediction/mean': 0.3921669,
 'recall': 0.5151515,
 'global_step': 20}

TensorFlow 2: Using Keras WideDeepModel

In TensorFlow 2, you can create an instance of the Keras tf.compat.v1.keras.models.WideDeepModel to substitute for one generated by tf.estimator.DNNLinearCombinedEstimator, with similar levels of user-specified customization (for instance, as in the previous example, the ability to customize a chosen model optimizer).

This WideDeepModel is constructed on the basis of a constituent LinearModel and a custom DNN Model, both of which are discussed in the preceding two examples. A custom linear model can also be used in place of the built-in Keras LinearModel if desired.

If you would like to build your own model instead of using a canned estimator, check out the Keras Sequential model guide. For more information on custom training and optimizers, check out the Custom training: walkthrough guide.

# Create LinearModel and DNN Model as in Examples 1 and 2
optimizer = create_sample_optimizer('tf2')

linear_model = tf.compat.v1.keras.experimental.LinearModel()
linear_model.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])
linear_model.fit(x_train, y_train, epochs=10, verbose=0)

dnn_model = tf.keras.models.Sequential(
    [tf.keras.layers.Dense(128, activation='relu'),
     tf.keras.layers.Dense(1)])
dnn_model.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])
combined_model = tf.compat.v1.keras.experimental.WideDeepModel(linear_model,
                                                               dnn_model)
combined_model.compile(
    optimizer=[optimizer, optimizer], loss='mse', metrics=['accuracy'])
combined_model.fit([x_train, x_train], y_train, epochs=10)
combined_model.evaluate(x_eval, y_eval, return_dict=True)
Epoch 1/10
20/20 [==============================] - 1s 3ms/step - loss: 682.2249 - accuracy: 0.6858
Epoch 2/10
20/20 [==============================] - 0s 2ms/step - loss: 0.2909 - accuracy: 0.7193
Epoch 3/10
20/20 [==============================] - 0s 2ms/step - loss: 0.2140 - accuracy: 0.7400
Epoch 4/10
20/20 [==============================] - 0s 2ms/step - loss: 0.2084 - accuracy: 0.7671
Epoch 5/10
20/20 [==============================] - 0s 2ms/step - loss: 0.1809 - accuracy: 0.7719
Epoch 6/10
20/20 [==============================] - 0s 2ms/step - loss: 0.1699 - accuracy: 0.7911
Epoch 7/10
20/20 [==============================] - 0s 2ms/step - loss: 0.1745 - accuracy: 0.7687
Epoch 8/10
20/20 [==============================] - 0s 2ms/step - loss: 0.1665 - accuracy: 0.7974
Epoch 9/10
20/20 [==============================] - 0s 2ms/step - loss: 0.1693 - accuracy: 0.7911
Epoch 10/10
20/20 [==============================] - 0s 2ms/step - loss: 0.1748 - accuracy: 0.7847
9/9 [==============================] - 0s 2ms/step - loss: 0.2050 - accuracy: 0.7159
{'loss': 0.2049505114555359, 'accuracy': 0.7159090638160706}

Example 4: Migrating from BoostedTreesEstimator

TensorFlow 1: Using BoostedTreesEstimator

In TensorFlow 1, you could use tf.estimator.BoostedTreesEstimator to create a baseline to create a baseline Gradient Boosting model using an ensemble of decision trees for regression and classification problems. This functionality is no longer included in TensorFlow 2.

bt_estimator = tf1.estimator.BoostedTreesEstimator(
    head=tf.estimator.BinaryClassHead(),
    n_batches_per_layer=1,
    max_depth=10,
    n_trees=1000,
    feature_columns=feature_columns)
bt_estimator.train(input_fn=_input_fn, steps=1000)
bt_estimator.evaluate(input_fn=_eval_input_fn, steps=100)

TensorFlow 2: Using TensorFlow Decision Forests

In TensorFlow 2, tf.estimator.BoostedTreesEstimator is replaced by tfdf.keras.GradientBoostedTreesModel from the TensorFlow Decision Forests package.

TensorFlow Decision Forests provides various advantages over the tf.estimator.BoostedTreesEstimator, notably regarding quality, speed, ease of use and flexibility. To learn about TensorFlow Decision Forests, start with the beginner colab.

The following example shows how to train a Gradient Boosted Trees model using TensorFlow 2:

Install TensorFlow Decision Forests.

pip install tensorflow_decision_forests

Create a TensorFlow dataset. Note that Decision Forests natively support many types of features and do not need pre-processing.

train_dataframe = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/train.csv')
eval_dataframe = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/eval.csv')

# Convert the Pandas Dataframes into TensorFlow datasets.
train_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(train_dataframe, label="survived")
eval_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(eval_dataframe, label="survived")

Train the model on the train_dataset dataset.

# Use the default hyper-parameters of the model.
gbt_model = tfdf.keras.GradientBoostedTreesModel()
gbt_model.fit(train_dataset)
Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.
WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.
Use /tmpfs/tmp/tmpinxrd9bl as temporary training directory
Reading training dataset...
[WARNING 23-09-16 01:22:09.3074 UTC gradient_boosted_trees.cc:1818] "goss_alpha" set but "sampling_method" not equal to "GOSS".
[WARNING 23-09-16 01:22:09.3074 UTC gradient_boosted_trees.cc:1829] "goss_beta" set but "sampling_method" not equal to "GOSS".
[WARNING 23-09-16 01:22:09.3074 UTC gradient_boosted_trees.cc:1843] "selective_gradient_boosting_ratio" set but "sampling_method" not equal to "SELGB".
Training dataset read in 0:00:03.607059. Found 627 examples.
Training model...
Model trained in 0:00:00.226305
Compiling model...
[INFO 23-09-16 01:22:13.1488 UTC kernel.cc:1243] Loading model from path /tmpfs/tmp/tmpinxrd9bl/model/ with prefix b67347ee49794d62
[INFO 23-09-16 01:22:13.1525 UTC abstract_model.cc:1311] Engine "GradientBoostedTreesQuickScorerExtended" built
[INFO 23-09-16 01:22:13.1525 UTC kernel.cc:1075] Use fast generic engine
WARNING:tensorflow:AutoGraph could not transform <function simple_ml_inference_op_with_handle at 0x7f7376b33700> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: could not get source code
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING:tensorflow:AutoGraph could not transform <function simple_ml_inference_op_with_handle at 0x7f7376b33700> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: could not get source code
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING: AutoGraph could not transform <function simple_ml_inference_op_with_handle at 0x7f7376b33700> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: could not get source code
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
Model compiled.
<keras.src.callbacks.History at 0x7f72a4271430>

Evaluate the quality of the model on the eval_dataset dataset.

gbt_model.compile(metrics=['accuracy'])
gbt_evaluation = gbt_model.evaluate(eval_dataset, return_dict=True)
print(gbt_evaluation)
1/1 [==============================] - 0s 296ms/step - loss: 0.0000e+00 - accuracy: 0.8295
{'loss': 0.0, 'accuracy': 0.8295454382896423}

Gradient Boosted Trees is just one of the many decision forest algorithms available in TensorFlow Decision Forests. For example, Random Forests (available as tfdf.keras.GradientBoostedTreesModel is very resistant to overfitting) while CART (available as tfdf.keras.CartModel) is great for model interpretation.

In the next example, train and plot a Random Forest model.

# Train a Random Forest model
rf_model = tfdf.keras.RandomForestModel()
rf_model.fit(train_dataset)

# Evaluate the Random Forest model
rf_model.compile(metrics=['accuracy'])
rf_evaluation = rf_model.evaluate(eval_dataset, return_dict=True)
print(rf_evaluation)
Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.
WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.
Use /tmpfs/tmp/tmpdvqhwuwo as temporary training directory
Reading training dataset...
Training dataset read in 0:00:00.187950. Found 627 examples.
Training model...
Model trained in 0:00:00.191396
Compiling model...
[INFO 23-09-16 01:22:15.4265 UTC kernel.cc:1243] Loading model from path /tmpfs/tmp/tmpdvqhwuwo/model/ with prefix 39c681f57c12496f
[INFO 23-09-16 01:22:15.5264 UTC decision_forest.cc:660] Model loaded with 300 root(s), 34556 node(s), and 9 input feature(s).
[INFO 23-09-16 01:22:15.5265 UTC kernel.cc:1075] Use fast generic engine
Model compiled.
1/1 [==============================] - 0s 141ms/step - loss: 0.0000e+00 - accuracy: 0.8333
{'loss': 0.0, 'accuracy': 0.8333333134651184}

In the final example, train and evaluate a CART model.

# Train a CART model
cart_model = tfdf.keras.CartModel()
cart_model.fit(train_dataset)

# Plot the CART model
tfdf.model_plotter.plot_model_in_colab(cart_model, max_depth=2)
Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.
WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.
Use /tmpfs/tmp/tmpmgmlcvs6 as temporary training directory
Reading training dataset...
Training dataset read in 0:00:00.187860. Found 627 examples.
Training model...
Model trained in 0:00:00.018191
Compiling model...
Model compiled.
[INFO 23-09-16 01:22:16.0918 UTC kernel.cc:1243] Loading model from path /tmpfs/tmp/tmpmgmlcvs6/model/ with prefix efef116800e041d8
[INFO 23-09-16 01:22:16.0921 UTC decision_forest.cc:660] Model loaded with 1 root(s), 21 node(s), and 5 input feature(s).
[INFO 23-09-16 01:22:16.0922 UTC kernel.cc:1075] Use fast generic engine