A classifier for TensorFlow RNN models.

Inherits From: RNNEstimator

Trains a recurrent neural network model to classify instances into one of multiple classes.


token_sequence = sequence_categorical_column_with_hash_bucket(...)
token_emb = embedding_column(categorical_column=token_sequence, ...)

estimator = RNNClassifier(
    units=[32, 16], cell_type='lstm')

# Input builders
def input_fn_train: # returns x, y
estimator.train(input_fn=input_fn_train, steps=100)

def input_fn_eval: # returns x, y
metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10)
def input_fn_predict: # returns x, None
predictions = estimator.predict(input_fn=input_fn_predict)

Input of train and evaluate should have following features, otherwise there will be a KeyError:

  • if weight_column is not None, a feature with key=weight_column whose value is a Tensor.
  • for each column in sequence_feature_columns:
    • a feature with whose value is a SparseTensor.
  • for each column in context_feature_columns:
    • if column is a CategoricalColumn, a feature with whose value is a SparseTensor.
    • if column is a WeightedCategoricalColumn, two features: the first with key the id column name, the second with key the weight column name. Both features' value must be a SparseTensor.
    • if column is a DenseColumn, a feature with whose value is a Tensor.

Loss is calculated by using softmax cross entropy.

sequence_feature_columns An iterable containing the FeatureColumns that represent sequential input. All items in the set should either be sequence columns (e.g. sequence_numeric_column) or constructed from one (e.g. embedding_column with