Migrate tf.feature_columns to Keras preprocessing layers

tf.feature_columns to Keras preprocessing layers" />
View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Training a model usually comes with some amount of feature preprocessing, particularly when dealing with structured data. When training a tf.estimator.Estimator in TensorFlow 1, you usually perform feature preprocessing with the tf.feature_column API. In TensorFlow 2, you can do this directly with Keras preprocessing layers.

This migration guide demonstrates common feature transformations using both feature columns and preprocessing layers, followed by training a complete model with both APIs.

First, start with a couple of necessary imports:

import tensorflow as tf
import tensorflow.compat.v1 as tf1
import math
2024-01-17 02:27:02.309609: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-17 02:27:02.309651: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-17 02:27:02.311142: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

Now, add a utility function for calling a feature column for demonstration:

def call_feature_columns(feature_columns, inputs):
  # This is a convenient way to call a `feature_column` outside of an estimator
  # to display its output.
  feature_layer = tf1.keras.layers.DenseFeatures(feature_columns)
  return feature_layer(inputs)

Input handling

To use feature columns with an estimator, model inputs are always expected to be a dictionary of tensors:

input_dict = {
  'foo': tf.constant([1]),
  'bar': tf.constant([0]),
  'baz': tf.constant([-1])
}

Each feature column needs to be created with a key to index into the source data. The output of all feature columns is concatenated and used by the estimator model.

columns = [
  tf1.feature_column.numeric_column('foo'),
  tf1.feature_column.numeric_column('bar'),
  tf1.feature_column.numeric_column('baz'),
]
call_feature_columns(columns, input_dict)
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_19805/3124623333.py:2: numeric_column (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.
Instructions for updating:
Use Keras preprocessing layers instead, either directly or via the `tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has a functional equivalent in `tf.keras.layers` for feature preprocessing when training a Keras model.
<tf.Tensor: shape=(1, 3), dtype=float32, numpy=array([[ 0., -1.,  1.]], dtype=float32)>

In Keras, model input is much more flexible. A tf.keras.Model can handle a single tensor input, a list of tensor features, or a dictionary of tensor features. You can handle dictionary input by passing a dictionary of tf.keras.Input on model creation. Inputs will not be concatenated automatically, which allows them to be used in much more flexible ways. They can be concatenated with tf.keras.layers.Concatenate.

inputs = {
  'foo': tf.keras.Input(shape=()),
  'bar': tf.keras.Input(shape=()),
  'baz': tf.keras.Input(shape=()),
}
# Inputs are typically transformed by preprocessing layers before concatenation.
outputs = tf.keras.layers.Concatenate()(inputs.values())
model = tf.keras.Model(inputs=inputs, outputs=outputs)
model(input_dict)
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([ 1.,  0., -1.], dtype=float32)>

One-hot encoding integer IDs

A common feature transformation is one-hot encoding integer inputs of a known range. Here is an example using feature columns:

categorical_col = tf1.feature_column.categorical_column_with_identity(
    'type', num_buckets=3)
indicator_col = tf1.feature_column.indicator_column(categorical_col)
call_feature_columns(indicator_col, {'type': [0, 1, 2]})
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_19805/1369923821.py:1: categorical_column_with_identity (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.
Instructions for updating:
Use Keras preprocessing layers instead, either directly or via the `tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has a functional equivalent in `tf.keras.layers` for feature preprocessing when training a Keras model.
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_19805/1369923821.py:3: indicator_column (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.
Instructions for updating:
Use Keras preprocessing layers instead, either directly or via the `tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has a functional equivalent in `tf.keras.layers` for feature preprocessing when training a Keras model.
<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]], dtype=float32)>

Using Keras preprocessing layers, these columns can be replaced by a single tf.keras.layers.CategoryEncoding layer with output_mode set to 'one_hot':

one_hot_layer = tf.keras.layers.CategoryEncoding(
    num_tokens=3, output_mode='one_hot')
one_hot_layer([0, 1, 2])
<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]], dtype=float32)>

Normalizing numeric features

When handling continuous, floating-point features with feature columns, you need to use a tf.feature_column.numeric_column. In the case where the input is already normalized, converting this to Keras is trivial. You can simply use a tf.keras.Input directly into your model, as shown above.

A numeric_column can also be used to normalize input:

def normalize(x):
  mean, variance = (2.0, 1.0)
  return (x - mean) / math.sqrt(variance)
numeric_col = tf1.feature_column.numeric_column('col', normalizer_fn=normalize)
call_feature_columns(numeric_col, {'col': tf.constant([[0.], [1.], [2.]])})
<tf.Tensor: shape=(3, 1), dtype=float32, numpy=
array([[-2.],
       [-1.],
       [ 0.]], dtype=float32)>

In contrast, with Keras, this normalization can be done with tf.keras.layers.Normalization.

normalization_layer = tf.keras.layers.Normalization(mean=2.0, variance=1.0)
normalization_layer(tf.constant([[0.], [1.], [2.]]))
<tf.Tensor: shape=(3, 1), dtype=float32, numpy=
array([[-2.],
       [-1.],
       [ 0.]], dtype=float32)>

Bucketizing and one-hot encoding numeric features

Another common transformation of continuous, floating point inputs is to bucketize then to integers of a fixed range.

In feature columns, this can be achieved with a tf.feature_column.bucketized_column:

numeric_col = tf1.feature_column.numeric_column('col')
bucketized_col = tf1.feature_column.bucketized_column(numeric_col, [1, 4, 5])
call_feature_columns(bucketized_col, {'col': tf.constant([1., 2., 3., 4., 5.])})
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_19805/3043215186.py:2: bucketized_column (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.
Instructions for updating:
Use Keras preprocessing layers instead, either directly or via the `tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has a functional equivalent in `tf.keras.layers` for feature preprocessing when training a Keras model.
<tf.Tensor: shape=(5, 4), dtype=float32, numpy=
array([[0., 1., 0., 0.],
       [0., 1., 0., 0.],
       [0., 1., 0., 0.],
       [0., 0., 1., 0.],
       [0., 0., 0., 1.]], dtype=float32)>

In Keras, this can be replaced by tf.keras.layers.Discretization:

discretization_layer = tf.keras.layers.Discretization(bin_boundaries=[1, 4, 5])
one_hot_layer = tf.keras.layers.CategoryEncoding(
    num_tokens=4, output_mode='one_hot')
one_hot_layer(discretization_layer([1., 2., 3., 4., 5.]))
<tf.Tensor: shape=(5, 4), dtype=float32, numpy=
array([[0., 1., 0., 0.],
       [0., 1., 0., 0.],
       [0., 1., 0., 0.],
       [0., 0., 1., 0.],
       [0., 0., 0., 1.]], dtype=float32)>

One-hot encoding string data with a vocabulary

Handling string features often requires a vocabulary lookup to translate strings into indices. Here is an example using feature columns to lookup strings and then one-hot encode the indices:

vocab_col = tf1.feature_column.categorical_column_with_vocabulary_list(
    'sizes',
    vocabulary_list=['small', 'medium', 'large'],
    num_oov_buckets=0)
indicator_col = tf1.feature_column.indicator_column(vocab_col)
call_feature_columns(indicator_col, {'sizes': ['small', 'medium', 'large']})
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_19805/2845961037.py:1: categorical_column_with_vocabulary_list (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.
Instructions for updating:
Use Keras preprocessing layers instead, either directly or via the `tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has a functional equivalent in `tf.keras.layers` for feature preprocessing when training a Keras model.
<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]], dtype=float32)>

Using Keras preprocessing layers, use the tf.keras.layers.StringLookup layer with output_mode set to 'one_hot':

string_lookup_layer = tf.keras.layers.StringLookup(
    vocabulary=['small', 'medium', 'large'],
    num_oov_indices=0,
    output_mode='one_hot')
string_lookup_layer(['small', 'medium', 'large'])
<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]], dtype=float32)>

Embedding string data with a vocabulary

For larger vocabularies, an embedding is often needed for good performance. Here is an example embedding a string feature using feature columns:

vocab_col = tf1.feature_column.categorical_column_with_vocabulary_list(
    'col',
    vocabulary_list=['small', 'medium', 'large'],
    num_oov_buckets=0)
embedding_col = tf1.feature_column.embedding_column(vocab_col, 4)
call_feature_columns(embedding_col, {'col': ['small', 'medium', 'large']})
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_19805/999372599.py:5: embedding_column (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.
Instructions for updating:
Use Keras preprocessing layers instead, either directly or via the `tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has a functional equivalent in `tf.keras.layers` for feature preprocessing when training a Keras model.
<tf.Tensor: shape=(3, 4), dtype=float32, numpy=
array([[ 0.19490445, -0.41044798, -0.5603343 ,  0.32616043],
       [-0.03856034,  0.45223498,  0.15508214,  0.02957107],
       [-0.20157112,  0.27702358, -0.079571  , -0.6845111 ]],
      dtype=float32)>

Using Keras preprocessing layers, this can be achieved by combining a tf.keras.layers.StringLookup layer and an tf.keras.layers.Embedding layer. The default output for the StringLookup will be integer indices which can be fed directly into an embedding.

string_lookup_layer = tf.keras.layers.StringLookup(
    vocabulary=['small', 'medium', 'large'], num_oov_indices=0)
embedding = tf.keras.layers.Embedding(3, 4)
embedding(string_lookup_layer(['small', 'medium', 'large']))
<tf.Tensor: shape=(3, 4), dtype=float32, numpy=
array([[-0.01309206, -0.01742087,  0.04137394,  0.03209827],
       [ 0.009628  ,  0.03649005,  0.0101964 , -0.02682712],
       [ 0.02947212,  0.04911048,  0.00290644, -0.00840418]],
      dtype=float32)>

Summing weighted categorical data

In some cases, you need to deal with categorical data where each occurance of a category comes with an associated weight. In feature columns, this is handled with tf.feature_column.weighted_categorical_column. When paired with an indicator_column, this has the effect of summing weights per category.

ids = tf.constant([[5, 11, 5, 17, 17]])
weights = tf.constant([[0.5, 1.5, 0.7, 1.8, 0.2]])

categorical_col = tf1.feature_column.categorical_column_with_identity(
    'ids', num_buckets=20)
weighted_categorical_col = tf1.feature_column.weighted_categorical_column(
    categorical_col, 'weights')
indicator_col = tf1.feature_column.indicator_column(weighted_categorical_col)
call_feature_columns(indicator_col, {'ids': ids, 'weights': weights})
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_19805/3529191023.py:6: weighted_categorical_column (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.
Instructions for updating:
Use Keras preprocessing layers instead, either directly or via the `tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has a functional equivalent in `tf.keras.layers` for feature preprocessing when training a Keras model.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/feature_column/feature_column_v2.py:4033: sparse_merge (from tensorflow.python.ops.sparse_ops) is deprecated and will be removed in a future version.
Instructions for updating:
No similar op available at this time.
<tf.Tensor: shape=(1, 20), dtype=float32, numpy=
array([[0. , 0. , 0. , 0. , 0. , 1.2, 0. , 0. , 0. , 0. , 0. , 1.5, 0. ,

        0. , 0. , 0. , 0. , 2. , 0. , 0. ]], dtype=float32)>

In Keras, this can be done by passing a count_weights input to tf.keras.layers.CategoryEncoding with output_mode='count'.

ids = tf.constant([[5, 11, 5, 17, 17]])
weights = tf.constant([[0.5, 1.5, 0.7, 1.8, 0.2]])

# Using sparse output is more efficient when `num_tokens` is large.
count_layer = tf.keras.layers.CategoryEncoding(
    num_tokens=20, output_mode='count', sparse=True)
tf.sparse.to_dense(count_layer(ids, count_weights=weights))
<tf.Tensor: shape=(1, 20), dtype=float32, numpy=
array([[0. , 0. , 0. , 0. , 0. , 1.2, 0. , 0. , 0. , 0. , 0. , 1.5, 0. ,

        0. , 0. , 0. , 0. , 2. , 0. , 0. ]], dtype=float32)>

Embedding weighted categorical data

You might alternately want to embed weighted categorical inputs. In feature columns, the embedding_column contains a combiner argument. If any sample contains multiple entries for a category, they will be combined according to the argument setting (by default 'mean').

ids = tf.constant([[5, 11, 5, 17, 17]])
weights = tf.constant([[0.5, 1.5, 0.7, 1.8, 0.2]])

categorical_col = tf1.feature_column.categorical_column_with_identity(
    'ids', num_buckets=20)
weighted_categorical_col = tf1.feature_column.weighted_categorical_column(
    categorical_col, 'weights')
embedding_col = tf1.feature_column.embedding_column(
    weighted_categorical_col, 4, combiner='mean')
call_feature_columns(embedding_col, {'ids': ids, 'weights': weights})
<tf.Tensor: shape=(1, 4), dtype=float32, numpy=
array([[ 0.3714184 , -0.01248874, -0.3219157 , -0.39734876]],
      dtype=float32)>

In Keras, there is no combiner option to tf.keras.layers.Embedding, but you can achieve the same effect with tf.keras.layers.Dense. The embedding_column above is simply linearly combining embedding vectors according to category weight. Though not obvious at first, it is exactly equivalent to representing your categorical inputs as a sparse weight vector of size (num_tokens), and multiplying them by a Dense kernel of shape (embedding_size, num_tokens).

ids = tf.constant([[5, 11, 5, 17, 17]])
weights = tf.constant([[0.5, 1.5, 0.7, 1.8, 0.2]])

# For `combiner='mean'`, normalize your weights to sum to 1. Removing this line
# would be equivalent to an `embedding_column` with `combiner='sum'`.
weights = weights / tf.reduce_sum(weights, axis=-1, keepdims=True)

count_layer = tf.keras.layers.CategoryEncoding(
    num_tokens=20, output_mode='count', sparse=True)
embedding_layer = tf.keras.layers.Dense(4, use_bias=False)
embedding_layer(count_layer(ids, count_weights=weights))
<tf.Tensor: shape=(1, 4), dtype=float32, numpy=
array([[-0.24909994, -0.02287644,  0.06182466, -0.11370523]],
      dtype=float32)>

Complete training example

To show a complete training workflow, first prepare some data with three features of different types:

features = {
    'type': [0, 1, 1],
    'size': ['small', 'small', 'medium'],
    'weight': [2.7, 1.8, 1.6],
}
labels = [1, 1, 0]
predict_features = {'type': [0], 'size': ['foo'], 'weight': [-0.7]}

Define some common constants for both TensorFlow 1 and TensorFlow 2 workflows:

vocab = ['small', 'medium', 'large']
one_hot_dims = 3
embedding_dims = 4
weight_mean = 2.0
weight_variance = 1.0

With feature columns

Feature columns must be passed as a list to the estimator on creation, and will be called implicitly during training.

categorical_col = tf1.feature_column.categorical_column_with_identity(
    'type', num_buckets=one_hot_dims)
# Convert index to one-hot; e.g., [2] -> [0,0,1].
indicator_col = tf1.feature_column.indicator_column(categorical_col)

# Convert strings to indices; e.g., ['small'] -> [1].
vocab_col = tf1.feature_column.categorical_column_with_vocabulary_list(
    'size', vocabulary_list=vocab, num_oov_buckets=1)
# Embed the indices.
embedding_col = tf1.feature_column.embedding_column(vocab_col, embedding_dims)

normalizer_fn = lambda x: (x - weight_mean) / math.sqrt(weight_variance)
# Normalize the numeric inputs; e.g., [2.0] -> [0.0].
numeric_col = tf1.feature_column.numeric_column(
    'weight', normalizer_fn=normalizer_fn)

estimator = tf1.estimator.DNNClassifier(
    feature_columns=[indicator_col, embedding_col, numeric_col],
    hidden_units=[1])

def _input_fn():
  return tf1.data.Dataset.from_tensor_slices((features, labels)).batch(1)

estimator.train(_input_fn)
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_19805/1997355744.py:17: DNNClassifier.__init__ (from tensorflow_estimator.python.estimator.canned.dnn) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/canned/dnn.py:807: Estimator.__init__ (from tensorflow_estimator.python.estimator.estimator) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1844: RunConfig.__init__ (from tensorflow_estimator.python.estimator.run_config) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpm02fidwe
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpm02fidwe', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/canned/dnn.py:446: dnn_logit_fn_builder (from tensorflow_estimator.python.estimator.canned.dnn) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/model_fn.py:250: EstimatorSpec.__new__ (from tensorflow_estimator.python.estimator.model_fn) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Done calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1416: NanTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1419: LoggingTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/basic_session_run_hooks.py:232: SecondOrStepTimer.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1456: CheckpointSaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Create CheckpointSaverHook.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:579: StepCounterHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:586: SummarySaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
2024-01-17 02:27:09.322220: W tensorflow/core/common_runtime/type_inference.cc:339] Type inference failed. This indicates an invalid graph that escaped type checking. Error message: INVALID_ARGUMENT: expected compatible input types, but input 1:
type_id: TFT_OPTIONAL
args {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_TENSOR
    args {
      type_id: TFT_INT64
    }
  }
}
 is neither a subtype nor a supertype of the combined inputs preceding it:
type_id: TFT_OPTIONAL
args {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_TENSOR
    args {
      type_id: TFT_INT32
    }
  }
}

    for Tuple type infernce function 0
    while inferring type of node 'dnn/zero_fraction/cond/output/_18'
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpm02fidwe/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1455: SessionRunArgs.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1454: SessionRunContext.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1474: SessionRunValues.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:loss = 0.6289518, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 3...
INFO:tensorflow:Saving checkpoints for 3 into /tmpfs/tmp/tmpm02fidwe/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 3...
INFO:tensorflow:Loss for final step: 0.81661654.
<tensorflow_estimator.python.estimator.canned.dnn.DNNClassifier at 0x7f4edc48e8b0>

The feature columns will also be used to transform input data when running inference on the model.

def _predict_fn():
  return tf1.data.Dataset.from_tensor_slices(predict_features).batch(1)

next(estimator.predict(_predict_fn))
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/canned/head.py:596: ClassificationOutput.__init__ (from tensorflow.python.saved_model.model_utils.export_output) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/canned/head.py:1307: RegressionOutput.__init__ (from tensorflow.python.saved_model.model_utils.export_output) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/canned/head.py:1309: PredictOutput.__init__ (from tensorflow.python.saved_model.model_utils.export_output) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpm02fidwe/model.ckpt-3
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
{'logits': array([0.57964283], dtype=float32),
 'logistic': array([0.6409852], dtype=float32),
 'probabilities': array([0.35901475, 0.6409852 ], dtype=float32),
 'class_ids': array([1]),
 'classes': array([b'1'], dtype=object),
 'all_class_ids': array([0, 1], dtype=int32),
 'all_classes': array([b'0', b'1'], dtype=object)}

With Keras preprocessing layers

Keras preprocessing layers are more flexible in where they can be called. A layer can be applied directly to tensors, used inside a tf.data input pipeline, or built directly into a trainable Keras model.

In this example, you will apply preprocessing layers inside a tf.data input pipeline. To do this, you can define a separate tf.keras.Model to preprocess your input features. This model is not trainable, but is a convenient way to group preprocessing layers.

inputs = {
  'type': tf.keras.Input(shape=(), dtype='int64'),
  'size': tf.keras.Input(shape=(), dtype='string'),
  'weight': tf.keras.Input(shape=(), dtype='float32'),
}
# Convert index to one-hot; e.g., [2] -> [0,0,1].
type_output = tf.keras.layers.CategoryEncoding(
      one_hot_dims, output_mode='one_hot')(inputs['type'])
# Convert size strings to indices; e.g., ['small'] -> [1].
size_output = tf.keras.layers.StringLookup(vocabulary=vocab)(inputs['size'])
# Normalize the numeric inputs; e.g., [2.0] -> [0.0].
weight_output = tf.keras.layers.Normalization(
      axis=None, mean=weight_mean, variance=weight_variance)(inputs['weight'])
outputs = {
  'type': type_output,
  'size': size_output,
  'weight': weight_output,
}
preprocessing_model = tf.keras.Model(inputs, outputs)

You can now apply this model inside a call to tf.data.Dataset.map. Please note that the function passed to map will automatically be converted into a tf.function, and usual caveats for writing tf.function code apply (no side effects).

# Apply the preprocessing in tf.data.Dataset.map.
dataset = tf.data.Dataset.from_tensor_slices((features, labels)).batch(1)
dataset = dataset.map(lambda x, y: (preprocessing_model(x), y),
                      num_parallel_calls=tf.data.AUTOTUNE)
# Display a preprocessed input sample.
next(dataset.take(1).as_numpy_iterator())
({'type': array([[1., 0., 0.]], dtype=float32),
  'size': array([1]),
  'weight': array([0.70000005], dtype=float32)},
 array([1], dtype=int32))

Next, you can define a separate Model containing the trainable layers. Note how the inputs to this model now reflect the preprocessed feature types and shapes.

inputs = {
  'type': tf.keras.Input(shape=(one_hot_dims,), dtype='float32'),
  'size': tf.keras.Input(shape=(), dtype='int64'),
  'weight': tf.keras.Input(shape=(), dtype='float32'),
}
# Since the embedding is trainable, it needs to be part of the training model.
embedding = tf.keras.layers.Embedding(len(vocab), embedding_dims)
outputs = tf.keras.layers.Concatenate()([
  inputs['type'],
  embedding(inputs['size']),
  tf.expand_dims(inputs['weight'], -1),
])
outputs = tf.keras.layers.Dense(1)(outputs)
training_model = tf.keras.Model(inputs, outputs)

You can now train the training_model with tf.keras.Model.fit.

# Train on the preprocessed data.
training_model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True))
training_model.fit(dataset)
3/3 [==============================] - 1s 5ms/step - loss: 0.8194
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1705458433.603835   19973 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
<keras.src.callbacks.History at 0x7f4e700e7a30>

Finally, at inference time, it can be useful to combine these separate stages into a single model that handles raw feature inputs.

inputs = preprocessing_model.input
outputs = training_model(preprocessing_model(inputs))
inference_model = tf.keras.Model(inputs, outputs)

predict_dataset = tf.data.Dataset.from_tensor_slices(predict_features).batch(1)
inference_model.predict(predict_dataset)
1/1 [==============================] - 0s 102ms/step
array([[-0.95717776]], dtype=float32)

This composed model can be saved as a .keras file for later use.

inference_model.save('model.keras')
restored_model = tf.keras.models.load_model('model.keras')
restored_model.predict(predict_dataset)
1/1 [==============================] - 0s 79ms/step
array([[-0.95717776]], dtype=float32)

Feature column equivalence table

For reference, here is an approximate correspondence between feature columns and Keras preprocessing layers:

Feature column Keras layer
tf.feature_column.bucketized_column tf.keras.layers.Discretization
tf.feature_column.categorical_column_with_hash_bucket tf.keras.layers.Hashing
tf.feature_column.categorical_column_with_identity tf.keras.layers.CategoryEncoding
tf.feature_column.categorical_column_with_vocabulary_file tf.keras.layers.StringLookup or tf.keras.layers.IntegerLookup
tf.feature_column.categorical_column_with_vocabulary_list tf.keras.layers.StringLookup or tf.keras.layers.IntegerLookup
tf.feature_column.crossed_column tf.keras.layers.experimental.preprocessing.HashedCrossing
tf.feature_column.embedding_column tf.keras.layers.Embedding
tf.feature_column.indicator_column output_mode='one_hot' or output_mode='multi_hot'*
tf.feature_column.numeric_column tf.keras.layers.Normalization
tf.feature_column.sequence_categorical_column_with_hash_bucket tf.keras.layers.Hashing
tf.feature_column.sequence_categorical_column_with_identity tf.keras.layers.CategoryEncoding
tf.feature_column.sequence_categorical_column_with_vocabulary_file tf.keras.layers.StringLookup, tf.keras.layers.IntegerLookup, or tf.keras.layer.TextVectorization
tf.feature_column.sequence_categorical_column_with_vocabulary_list tf.keras.layers.StringLookup, tf.keras.layers.IntegerLookup, or tf.keras.layer.TextVectorization
tf.feature_column.sequence_numeric_column tf.keras.layers.Normalization
tf.feature_column.weighted_categorical_column tf.keras.layers.CategoryEncoding

* The output_mode can be passed to tf.keras.layers.CategoryEncoding, tf.keras.layers.StringLookup, tf.keras.layers.IntegerLookup, and tf.keras.layers.TextVectorization.

tf.keras.layers.TextVectorization can handle freeform text input directly (for example, entire sentences or paragraphs). This is not one-to-one replacement for categorical sequence handling in TensorFlow 1, but may offer a convenient replacement for ad-hoc text preprocessing.

Next steps