![]() |
![]() |
![]() |
![]() |
Canned (or Premade) Estimators have traditionally been used in TensorFlow 1 as quick and easy ways to train models for a variety of typical use cases. TensorFlow 2 provides straightforward approximate substitutes for a number of them by way of Keras models. For those canned estimators that do not have built-in TensorFlow 2 substitutes, you can still build your own replacement fairly easily.
This guide will walk you through a few examples of direct equivalents and custom substitutions to demonstrate how TensorFlow 1's tf.estimator
-derived models can be migrated to TensorFlow 2 with Keras.
Namely, this guide includes examples for migrating:
- From
tf.estimator
'sLinearEstimator
,Classifier
orRegressor
in TensorFlow 1 to Kerastf.compat.v1.keras.models.LinearModel
in TensorFlow 2 - From
tf.estimator
'sDNNEstimator
,Classifier
orRegressor
in TensorFlow 1 to a custom Keras DNN ModelKeras in TensorFlow 2 - From
tf.estimator
'sDNNLinearCombinedEstimator
,Classifier
orRegressor
in TensorFlow 1 totf.compat.v1.keras.models.WideDeepModel
in TensorFlow 2 - From
tf.estimator
'sBoostedTreesEstimator
,Classifier
orRegressor
in TensorFlow 1 totfdf.keras.GradientBoostedTreesModel
in TensorFlow 2
A common precursor to the training of a model is feature preprocessing, which is done for TensorFlow 1 Estimator models with tf.feature_column
. For more information on feature preprocessing in TensorFlow 2, see this guide on migrating from feature columns to the Keras preprocessing layers API.
Setup
Start with a couple of necessary TensorFlow imports,
pip install tensorflow_decision_forests
import keras
import pandas as pd
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import tensorflow_decision_forests as tfdf
prepare some simple data for demonstration from the standard Titanic dataset,
x_train = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/train.csv')
x_eval = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/eval.csv')
x_train['sex'].replace(('male', 'female'), (0, 1), inplace=True)
x_eval['sex'].replace(('male', 'female'), (0, 1), inplace=True)
x_train['alone'].replace(('n', 'y'), (0, 1), inplace=True)
x_eval['alone'].replace(('n', 'y'), (0, 1), inplace=True)
x_train['class'].replace(('First', 'Second', 'Third'), (1, 2, 3), inplace=True)
x_eval['class'].replace(('First', 'Second', 'Third'), (1, 2, 3), inplace=True)
x_train.drop(['embark_town', 'deck'], axis=1, inplace=True)
x_eval.drop(['embark_town', 'deck'], axis=1, inplace=True)
y_train = x_train.pop('survived')
y_eval = x_eval.pop('survived')
# Data setup for TensorFlow 1 with `tf.estimator`
def _input_fn():
return tf1.data.Dataset.from_tensor_slices((dict(x_train), y_train)).batch(32)
def _eval_input_fn():
return tf1.data.Dataset.from_tensor_slices((dict(x_eval), y_eval)).batch(32)
FEATURE_NAMES = [
'age', 'fare', 'sex', 'n_siblings_spouses', 'parch', 'class', 'alone'
]
feature_columns = []
for fn in FEATURE_NAMES:
feat_col = tf1.feature_column.numeric_column(fn, dtype=tf.float32)
feature_columns.append(feat_col)
and create a method to instantiate a simplistic sample optimizer to use with various TensorFlow 1 Estimator and TensorFlow 2 Keras models.
def create_sample_optimizer(tf_version):
if tf_version == 'tf1':
optimizer = lambda: tf.keras.optimizers.legacy.Ftrl(
l1_regularization_strength=0.001,
learning_rate=tf1.train.exponential_decay(
learning_rate=0.1,
global_step=tf1.train.get_global_step(),
decay_steps=10000,
decay_rate=0.9))
elif tf_version == 'tf2':
optimizer = tf.keras.optimizers.legacy.Ftrl(
l1_regularization_strength=0.001,
learning_rate=tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate=0.1, decay_steps=10000, decay_rate=0.9))
return optimizer
Example 1: Migrating from LinearEstimator
TensorFlow 1: Using LinearEstimator
In TensorFlow 1, you can use tf.estimator.LinearEstimator
to create a baseline linear model for regression and classification problems.
linear_estimator = tf.estimator.LinearEstimator(
head=tf.estimator.BinaryClassHead(),
feature_columns=feature_columns,
optimizer=create_sample_optimizer('tf1'))
linear_estimator.train(input_fn=_input_fn, steps=100)
linear_estimator.evaluate(input_fn=_eval_input_fn, steps=10)
TensorFlow 2: Using Keras LinearModel
In TensorFlow 2, you can create an instance of the Keras tf.compat.v1.keras.models.LinearModel
which is the substitute to the tf.estimator.LinearEstimator
. The tf.compat.v1.keras
path is used to signify that the pre-made model exists for compatibility.
linear_model = tf.compat.v1.keras.experimental.LinearModel()
linear_model.compile(loss='mse', optimizer=create_sample_optimizer('tf2'), metrics=['accuracy'])
linear_model.fit(x_train, y_train, epochs=10)
linear_model.evaluate(x_eval, y_eval, return_dict=True)
Example 2: Migrating from DNNEstimator
TensorFlow 1: Using DNNEstimator
In TensorFlow 1, you can use tf.estimator.DNNEstimator
to create a baseline deep neural network (DNN) model for regression and classification problems.
dnn_estimator = tf.estimator.DNNEstimator(
head=tf.estimator.BinaryClassHead(),
feature_columns=feature_columns,
hidden_units=[128],
activation_fn=tf.nn.relu,
optimizer=create_sample_optimizer('tf1'))
dnn_estimator.train(input_fn=_input_fn, steps=100)
dnn_estimator.evaluate(input_fn=_eval_input_fn, steps=10)
TensorFlow 2: Using Keras to create a custom DNN model
In TensorFlow 2, you can create a custom DNN model to substitute for one generated by tf.estimator.DNNEstimator
, with similar levels of user-specified customization (for instance, as in the previous example, the ability to customize a chosen model optimizer).
A similar workflow can be used to replace tf.estimator.experimental.RNNEstimator
with a Keras recurrent neural network (RNN) model. Keras provides a number of built-in, customizable choices by way of tf.keras.layers.RNN
, tf.keras.layers.LSTM
, and tf.keras.layers.GRU
. To learn more, check out the Built-in RNN layers: a simple example section of RNN with Keras guide.
dnn_model = tf.keras.models.Sequential(
[tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(1)])
dnn_model.compile(loss='mse', optimizer=create_sample_optimizer('tf2'), metrics=['accuracy'])
dnn_model.fit(x_train, y_train, epochs=10)
dnn_model.evaluate(x_eval, y_eval, return_dict=True)
Example 3: Migrating from DNNLinearCombinedEstimator
TensorFlow 1: Using DNNLinearCombinedEstimator
In TensorFlow 1, you can use tf.estimator.DNNLinearCombinedEstimator
to create a baseline combined model for regression and classification problems with customization capacity for both its linear and DNN components.
optimizer = create_sample_optimizer('tf1')
combined_estimator = tf.estimator.DNNLinearCombinedEstimator(
head=tf.estimator.BinaryClassHead(),
# Wide settings
linear_feature_columns=feature_columns,
linear_optimizer=optimizer,
# Deep settings
dnn_feature_columns=feature_columns,
dnn_hidden_units=[128],
dnn_optimizer=optimizer)
combined_estimator.train(input_fn=_input_fn, steps=100)
combined_estimator.evaluate(input_fn=_eval_input_fn, steps=10)
TensorFlow 2: Using Keras WideDeepModel
In TensorFlow 2, you can create an instance of the Keras tf.compat.v1.keras.models.WideDeepModel
to substitute for one generated by tf.estimator.DNNLinearCombinedEstimator
, with similar levels of user-specified customization (for instance, as in the previous example, the ability to customize a chosen model optimizer).
This WideDeepModel
is constructed on the basis of a constituent LinearModel
and a custom DNN Model, both of which are discussed in the preceding two examples. A custom linear model can also be used in place of the built-in Keras LinearModel
if desired.
If you would like to build your own model instead of using a canned estimator, check out the Keras Sequential model guide. For more information on custom training and optimizers, check out the Custom training: walkthrough guide.
# Create LinearModel and DNN Model as in Examples 1 and 2
optimizer = create_sample_optimizer('tf2')
linear_model = tf.compat.v1.keras.experimental.LinearModel()
linear_model.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])
linear_model.fit(x_train, y_train, epochs=10, verbose=0)
dnn_model = tf.keras.models.Sequential(
[tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(1)])
dnn_model.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])
combined_model = tf.compat.v1.keras.experimental.WideDeepModel(linear_model,
dnn_model)
combined_model.compile(
optimizer=[optimizer, optimizer], loss='mse', metrics=['accuracy'])
combined_model.fit([x_train, x_train], y_train, epochs=10)
combined_model.evaluate(x_eval, y_eval, return_dict=True)
Example 4: Migrating from BoostedTreesEstimator
TensorFlow 1: Using BoostedTreesEstimator
In TensorFlow 1, you could use tf.estimator.BoostedTreesEstimator
to create a baseline to create a baseline Gradient Boosting model using an ensemble of decision trees for regression and classification problems. This functionality is no longer included in TensorFlow 2.
bt_estimator = tf1.estimator.BoostedTreesEstimator(
head=tf.estimator.BinaryClassHead(),
n_batches_per_layer=1,
max_depth=10,
n_trees=1000,
feature_columns=feature_columns)
bt_estimator.train(input_fn=_input_fn, steps=1000)
bt_estimator.evaluate(input_fn=_eval_input_fn, steps=100)
TensorFlow 2: Using TensorFlow Decision Forests
In TensorFlow 2, tf.estimator.BoostedTreesEstimator
is replaced by tfdf.keras.GradientBoostedTreesModel from the TensorFlow Decision Forests package.
TensorFlow Decision Forests provides various advantages over the tf.estimator.BoostedTreesEstimator
, notably regarding quality, speed, ease of use and flexibility. To learn about TensorFlow Decision Forests, start with the beginner colab.
The following example shows how to train a Gradient Boosted Trees model using TensorFlow 2:
Install TensorFlow Decision Forests.
pip install tensorflow_decision_forests
Create a TensorFlow dataset. Note that Decision Forests natively support many types of features and do not need pre-processing.
train_dataframe = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/train.csv')
eval_dataframe = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/eval.csv')
# Convert the Pandas Dataframes into TensorFlow datasets.
train_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(train_dataframe, label="survived")
eval_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(eval_dataframe, label="survived")
Train the model on the train_dataset
dataset.
# Use the default hyper-parameters of the model.
gbt_model = tfdf.keras.GradientBoostedTreesModel()
gbt_model.fit(train_dataset)
Evaluate the quality of the model on the eval_dataset
dataset.
gbt_model.compile(metrics=['accuracy'])
gbt_evaluation = gbt_model.evaluate(eval_dataset, return_dict=True)
print(gbt_evaluation)
Gradient Boosted Trees is just one of the many decision forest algorithms available in TensorFlow Decision Forests. For example, Random Forests (available as tfdf.keras.GradientBoostedTreesModel is very resistant to overfitting) while CART (available as tfdf.keras.CartModel) is great for model interpretation.
In the next example, train and plot a Random Forest model.
# Train a Random Forest model
rf_model = tfdf.keras.RandomForestModel()
rf_model.fit(train_dataset)
# Evaluate the Random Forest model
rf_model.compile(metrics=['accuracy'])
rf_evaluation = rf_model.evaluate(eval_dataset, return_dict=True)
print(rf_evaluation)
In the final example, train and evaluate a CART model.
# Train a CART model
cart_model = tfdf.keras.CartModel()
cart_model.fit(train_dataset)
# Plot the CART model
tfdf.model_plotter.plot_model_in_colab(cart_model, max_depth=2)