tf.estimator.experimental.RNNEstimator

An Estimator for TensorFlow RNN models with user-specified head.

Inherits From: Estimator

Example:

token_sequence = sequence_categorical_column_with_hash_bucket(...)
token_emb = embedding_column(categorical_column=token_sequence, ...)

estimator = RNNEstimator(
    head=tf.estimator.RegressionHead(),
    sequence_feature_columns=[token_emb],
    units=[32, 16], cell_type='lstm')

# Or with custom RNN cell:
def rnn_cell_fn(_):
  cells = [ tf.keras.layers.LSTMCell(size) for size in [32, 16] ]
  return tf.keras.layers.StackedRNNCells(cells)

estimator = RNNEstimator(
    head=tf.estimator.RegressionHead(),
    sequence_feature_columns=[token_emb],
    rnn_cell_fn=rnn_cell_fn)

# Input builders
def input_fn_train: # returns x, y
  pass
estimator.train(input_fn=input_fn_train, steps=100)

def input_fn_eval: # returns x, y
  pass
metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10)
def input_fn_predict: # returns x, None
  pas