Missed TensorFlow World? Check out the recap. Learn more



Defined in tensorflow/contrib/estimator/python/estimator/extenders.py.

Forward features to predictions dictionary.

In some cases, user wants to see some of the features in estimators prediction output. As an example, consider a batch prediction service: The service simply runs inference on the users graph and returns the results. Keys are essential because there is no order guarantee on the outputs so they need to be rejoined to the inputs via keys or transclusion of the inputs in the outputs.


  def input_fn():
    features, labels = ...
    features['unique_example_id'] = ...
    features, labels

  estimator = tf.estimator.LinearClassifier(...)
  estimator = tf.contrib.estimator.forward_features(
      estimator, 'unique_example_id')
  assert 'unique_example_id' in estimator.predict(...)


  • estimator: A tf.estimator.Estimator object.
  • keys: a string or a list of string. If it is None, all of the features in dict is forwarded to the predictions. If it is a string, only given key is forwarded. If it is a list of strings, all the given keys are forwarded.


A new tf.estimator.Estimator which forwards features to predictions.


  • ValueError: * if keys is already part of predictions. We don't allow override.
    • if 'keys' does not exist in features.
    • if feature key refers to a SparseTensor, since we don't support SparseTensor in predictions. SparseTensor is common in features.
  • TypeError: if keys type is not one of string or list/tuple of string.