tf.feature_column
s to Keras preprocessing layers" />
![]() |
![]() |
![]() |
![]() |
Training a model usually comes with some amount of feature preprocessing, particularly when dealing with structured data. When training a tf.estimator.Estimator
in TensorFlow 1, you usually perform feature preprocessing with the tf.feature_column
API. In TensorFlow 2, you can do this directly with Keras preprocessing layers.
This migration guide demonstrates common feature transformations using both feature columns and preprocessing layers, followed by training a complete model with both APIs.
First, start with a couple of necessary imports:
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import math
Now, add a utility function for calling a feature column for demonstration:
def call_feature_columns(feature_columns, inputs):
# This is a convenient way to call a `feature_column` outside of an estimator
# to display its output.
feature_layer = tf1.keras.layers.DenseFeatures(feature_columns)
return feature_layer(inputs)
Input handling
To use feature columns with an estimator, model inputs are always expected to be a dictionary of tensors:
input_dict = {
'foo': tf.constant([1]),
'bar': tf.constant([0]),
'baz': tf.constant([-1])
}
Each feature column needs to be created with a key to index into the source data. The output of all feature columns is concatenated and used by the estimator model.
columns = [
tf1.feature_column.numeric_column('foo'),
tf1.feature_column.numeric_column('bar'),
tf1.feature_column.numeric_column('baz'),
]
call_feature_columns(columns, input_dict)
In Keras, model input is much more flexible. A tf.keras.Model
can handle a single tensor input, a list of tensor features, or a dictionary of tensor features. You can handle dictionary input by passing a dictionary of tf.keras.Input
on model creation. Inputs will not be concatenated automatically, which allows them to be used in much more flexible ways. They can be concatenated with tf.keras.layers.Concatenate
.
inputs = {
'foo': tf.keras.Input(shape=()),
'bar': tf.keras.Input(shape=()),
'baz': tf.keras.Input(shape=()),
}
# Inputs are typically transformed by preprocessing layers before concatenation.
outputs = tf.keras.layers.Concatenate()(inputs.values())
model = tf.keras.Model(inputs=inputs, outputs=outputs)
model(input_dict)
One-hot encoding integer IDs
A common feature transformation is one-hot encoding integer inputs of a known range. Here is an example using feature columns:
categorical_col = tf1.feature_column.categorical_column_with_identity(
'type', num_buckets=3)
indicator_col = tf1.feature_column.indicator_column(categorical_col)
call_feature_columns(indicator_col, {'type': [0, 1, 2]})
Using Keras preprocessing layers, these columns can be replaced by a single tf.keras.layers.CategoryEncoding
layer with output_mode
set to 'one_hot'
:
one_hot_layer = tf.keras.layers.CategoryEncoding(
num_tokens=3, output_mode='one_hot')
one_hot_layer([0, 1, 2])
Normalizing numeric features
When handling continuous, floating-point features with feature columns, you need to use a tf.feature_column.numeric_column
. In the case where the input is already normalized, converting this to Keras is trivial. You can simply use a tf.keras.Input
directly into your model, as shown above.
A numeric_column
can also be used to normalize input:
def normalize(x):
mean, variance = (2.0, 1.0)
return (x - mean) / math.sqrt(variance)
numeric_col = tf1.feature_column.numeric_column('col', normalizer_fn=normalize)
call_feature_columns(numeric_col, {'col': tf.constant([[0.], [1.], [2.]])})
In contrast, with Keras, this normalization can be done with tf.keras.layers.Normalization
.
normalization_layer = tf.keras.layers.Normalization(mean=2.0, variance=1.0)
normalization_layer(tf.constant([[0.], [1.], [2.]]))
Bucketizing and one-hot encoding numeric features
Another common transformation of continuous, floating point inputs is to bucketize then to integers of a fixed range.
In feature columns, this can be achieved with a tf.feature_column.bucketized_column
:
numeric_col = tf1.feature_column.numeric_column('col')
bucketized_col = tf1.feature_column.bucketized_column(numeric_col, [1, 4, 5])
call_feature_columns(bucketized_col, {'col': tf.constant([1., 2., 3., 4., 5.])})
In Keras, this can be replaced by tf.keras.layers.Discretization
:
discretization_layer = tf.keras.layers.Discretization(bin_boundaries=[1, 4, 5])
one_hot_layer = tf.keras.layers.CategoryEncoding(
num_tokens=4, output_mode='one_hot')
one_hot_layer(discretization_layer([1., 2., 3., 4., 5.]))
One-hot encoding string data with a vocabulary
Handling string features often requires a vocabulary lookup to translate strings into indices. Here is an example using feature columns to lookup strings and then one-hot encode the indices:
vocab_col = tf1.feature_column.categorical_column_with_vocabulary_list(
'sizes',
vocabulary_list=['small', 'medium', 'large'],
num_oov_buckets=0)
indicator_col = tf1.feature_column.indicator_column(vocab_col)
call_feature_columns(indicator_col, {'sizes': ['small', 'medium', 'large']})
Using Keras preprocessing layers, use the tf.keras.layers.StringLookup
layer with output_mode
set to 'one_hot'
:
string_lookup_layer = tf.keras.layers.StringLookup(
vocabulary=['small', 'medium', 'large'],
num_oov_indices=0,
output_mode='one_hot')
string_lookup_layer(['small', 'medium', 'large'])
Embedding string data with a vocabulary
For larger vocabularies, an embedding is often needed for good performance. Here is an example embedding a string feature using feature columns:
vocab_col = tf1.feature_column.categorical_column_with_vocabulary_list(
'col',
vocabulary_list=['small', 'medium', 'large'],
num_oov_buckets=0)
embedding_col = tf1.feature_column.embedding_column(vocab_col, 4)
call_feature_columns(embedding_col, {'col': ['small', 'medium', 'large']})
Using Keras preprocessing layers, this can be achieved by combining a tf.keras.layers.StringLookup
layer and an tf.keras.layers.Embedding
layer. The default output for the StringLookup
will be integer indices which can be fed directly into an embedding.
string_lookup_layer = tf.keras.layers.StringLookup(
vocabulary=['small', 'medium', 'large'], num_oov_indices=0)
embedding = tf.keras.layers.Embedding(3, 4)
embedding(string_lookup_layer(['small', 'medium', 'large']))
Summing weighted categorical data
In some cases, you need to deal with categorical data where each occurance of a category comes with an associated weight. In feature columns, this is handled with tf.feature_column.weighted_categorical_column
. When paired with an indicator_column
, this has the effect of summing weights per category.
ids = tf.constant([[5, 11, 5, 17, 17]])
weights = tf.constant([[0.5, 1.5, 0.7, 1.8, 0.2]])
categorical_col = tf1.feature_column.categorical_column_with_identity(
'ids', num_buckets=20)
weighted_categorical_col = tf1.feature_column.weighted_categorical_column(
categorical_col, 'weights')
indicator_col = tf1.feature_column.indicator_column(weighted_categorical_col)
call_feature_columns(indicator_col, {'ids': ids, 'weights': weights})
In Keras, this can be done by passing a count_weights
input to tf.keras.layers.CategoryEncoding
with output_mode='count'
.
ids = tf.constant([[5, 11, 5, 17, 17]])
weights = tf.constant([[0.5, 1.5, 0.7, 1.8, 0.2]])
# Using sparse output is more efficient when `num_tokens` is large.
count_layer = tf.keras.layers.CategoryEncoding(
num_tokens=20, output_mode='count', sparse=True)
tf.sparse.to_dense(count_layer(ids, count_weights=weights))
Embedding weighted categorical data
You might alternately want to embed weighted categorical inputs. In feature columns, the embedding_column
contains a combiner
argument. If any sample
contains multiple entries for a category, they will be combined according to the argument setting (by default 'mean'
).
ids = tf.constant([[5, 11, 5, 17, 17]])
weights = tf.constant([[0.5, 1.5, 0.7, 1.8, 0.2]])
categorical_col = tf1.feature_column.categorical_column_with_identity(
'ids', num_buckets=20)
weighted_categorical_col = tf1.feature_column.weighted_categorical_column(
categorical_col, 'weights')
embedding_col = tf1.feature_column.embedding_column(
weighted_categorical_col, 4, combiner='mean')
call_feature_columns(embedding_col, {'ids': ids, 'weights': weights})
In Keras, there is no combiner
option to tf.keras.layers.Embedding
, but you can achieve the same effect with tf.keras.layers.Dense
. The embedding_column
above is simply linearly combining embedding vectors according to category weight. Though not obvious at first, it is exactly equivalent to representing your categorical inputs as a sparse weight vector of size (num_tokens)
, and multiplying them by a Dense
kernel of shape (embedding_size, num_tokens)
.
ids = tf.constant([[5, 11, 5, 17, 17]])
weights = tf.constant([[0.5, 1.5, 0.7, 1.8, 0.2]])
# For `combiner='mean'`, normalize your weights to sum to 1. Removing this line
# would be equivalent to an `embedding_column` with `combiner='sum'`.
weights = weights / tf.reduce_sum(weights, axis=-1, keepdims=True)
count_layer = tf.keras.layers.CategoryEncoding(
num_tokens=20, output_mode='count', sparse=True)
embedding_layer = tf.keras.layers.Dense(4, use_bias=False)
embedding_layer(count_layer(ids, count_weights=weights))
Complete training example
To show a complete training workflow, first prepare some data with three features of different types:
features = {
'type': [0, 1, 1],
'size': ['small', 'small', 'medium'],
'weight': [2.7, 1.8, 1.6],
}
labels = [1, 1, 0]
predict_features = {'type': [0], 'size': ['foo'], 'weight': [-0.7]}
Define some common constants for both TensorFlow 1 and TensorFlow 2 workflows:
vocab = ['small', 'medium', 'large']
one_hot_dims = 3
embedding_dims = 4
weight_mean = 2.0
weight_variance = 1.0
With feature columns
Feature columns must be passed as a list to the estimator on creation, and will be called implicitly during training.
categorical_col = tf1.feature_column.categorical_column_with_identity(
'type', num_buckets=one_hot_dims)
# Convert index to one-hot; e.g. [2] -> [0,0,1].
indicator_col = tf1.feature_column.indicator_column(categorical_col)
# Convert strings to indices; e.g. ['small'] -> [1].
vocab_col = tf1.feature_column.categorical_column_with_vocabulary_list(
'size', vocabulary_list=vocab, num_oov_buckets=1)
# Embed the indices.
embedding_col = tf1.feature_column.embedding_column(vocab_col, embedding_dims)
normalizer_fn = lambda x: (x - weight_mean) / math.sqrt(weight_variance)
# Normalize the numeric inputs; e.g. [2.0] -> [0.0].
numeric_col = tf1.feature_column.numeric_column(
'weight', normalizer_fn=normalizer_fn)
estimator = tf1.estimator.DNNClassifier(
feature_columns=[indicator_col, embedding_col, numeric_col],
hidden_units=[1])
def _input_fn():
return tf1.data.Dataset.from_tensor_slices((features, labels)).batch(1)
estimator.train(_input_fn)
The feature columns will also be used to transform input data when running inference on the model.
def _predict_fn():
return tf1.data.Dataset.from_tensor_slices(predict_features).batch(1)
next(estimator.predict(_predict_fn))
With Keras preprocessing layers
Keras preprocessing layers are more flexible in where they can be called. A layer can be applied directly to tensors, used inside a tf.data
input pipeline, or built directly into a trainable Keras model.
In this example, you will apply preprocessing layers inside a tf.data
input pipeline. To do this, you can define a separate tf.keras.Model
to preprocess your input features. This model is not trainable, but is a convenient way to group preprocessing layers.
inputs = {
'type': tf.keras.Input(shape=(), dtype='int64'),
'size': tf.keras.Input(shape=(), dtype='string'),
'weight': tf.keras.Input(shape=(), dtype='float32'),
}
# Convert index to one-hot; e.g. [2] -> [0,0,1].
type_output = tf.keras.layers.CategoryEncoding(
one_hot_dims, output_mode='one_hot')(inputs['type'])
# Convert size strings to indices; e.g. ['small'] -> [1].
size_output = tf.keras.layers.StringLookup(vocabulary=vocab)(inputs['size'])
# Normalize the numeric inputs; e.g. [2.0] -> [0.0].
weight_output = tf.keras.layers.Normalization(
axis=None, mean=weight_mean, variance=weight_variance)(inputs['weight'])
outputs = {
'type': type_output,
'size': size_output,
'weight': weight_output,
}
preprocessing_model = tf.keras.Model(inputs, outputs)
You can now apply this model inside a call to tf.data.Dataset.map
. Please note that the function passed to map
will automatically be converted into
a tf.function
, and usual caveats for writing tf.function
code apply (no side effects).
# Apply the preprocessing in tf.data.Dataset.map.
dataset = tf.data.Dataset.from_tensor_slices((features, labels)).batch(1)
dataset = dataset.map(lambda x, y: (preprocessing_model(x), y),
num_parallel_calls=tf.data.AUTOTUNE)
# Display a preprocessed input sample.
next(dataset.take(1).as_numpy_iterator())
Next, you can define a separate Model
containing the trainable layers. Note how the inputs to this model now reflect the preprocessed feature types and shapes.
inputs = {
'type': tf.keras.Input(shape=(one_hot_dims,), dtype='float32'),
'size': tf.keras.Input(shape=(), dtype='int64'),
'weight': tf.keras.Input(shape=(), dtype='float32'),
}
# Since the embedding is trainable, it needs to be part of the training model.
embedding = tf.keras.layers.Embedding(len(vocab), embedding_dims)
outputs = tf.keras.layers.Concatenate()([
inputs['type'],
embedding(inputs['size']),
tf.expand_dims(inputs['weight'], -1),
])
outputs = tf.keras.layers.Dense(1)(outputs)
training_model = tf.keras.Model(inputs, outputs)
You can now train the training_model
with tf.keras.Model.fit
.
# Train on the preprocessed data.
training_model.compile(
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True))
training_model.fit(dataset)
Finally, at inference time, it can be useful to combine these separate stages into a single model that handles raw feature inputs.
inputs = preprocessing_model.input
outputs = training_model(preprocessing_model(inputs))
inference_model = tf.keras.Model(inputs, outputs)
predict_dataset = tf.data.Dataset.from_tensor_slices(predict_features).batch(1)
inference_model.predict(predict_dataset)
This composed model can be saved as a .keras
file for later use.
inference_model.save('model.keras')
restored_model = tf.keras.models.load_model('model.keras')
restored_model.predict(predict_dataset)
Feature column equivalence table
For reference, here is an approximate correspondence between feature columns and Keras preprocessing layers:
* The output_mode
can be passed to tf.keras.layers.CategoryEncoding
, tf.keras.layers.StringLookup
, tf.keras.layers.IntegerLookup
, and tf.keras.layers.TextVectorization
.
† tf.keras.layers.TextVectorization
can handle freeform text input directly (for example, entire sentences or paragraphs). This is not one-to-one replacement for categorical sequence handling in TensorFlow 1, but may offer a convenient replacement for ad-hoc text preprocessing.
Next steps
- For more information on Keras preprocessing layers, go to the Working with preprocessing layers guide.
- For a more in-depth example of applying preprocessing layers to structured data, refer to the Classify structured data using Keras preprocessing layers tutorial.