Attend the Women in ML Symposium on December 7 Register now

Module: tfl.estimators

Stay organized with collections Save and categorize content based on your preferences.

TF Lattice canned estimators implement typical monotonic model architectures.

You can use TFL canned estimators to easily construct commonly used monotonic model architectures. To construct a TFL canned estimator, construct a model configuration from tfl.configs and pass it to the canned estimator constructor. To use automated quantile calculation, canned estimators also require passing a feature_analysis_input_fn which is similar to the one used for training, but with a single epoch or a subset of the data. To create a Crystals ensemble model using tfl.configs.CalibratedLatticeEnsembleConfig, you will also need to provide a prefitting_input_fn to the estimator constructor.

feature_columns = ...
model_config = tfl.configs.CalibratedLatticeConfig(...)
feature_analysis_input_fn = create_input_fn(num_epochs=1, ...)
train_input_fn = create_input_fn(num_epochs=100, ...)
estimator = tfl.estimators.CannedClassifier(
    feature_columns=feature_columns,
    model_config=model_config,
    feature_analysis_input_fn=feature_analysis_input_fn)
estimator.train(input_fn=train_input_fn)

Supported models are defined in tfl.configs. Each model architecture can be used for:

This module also provides tfl.estimators.get_model_graph as a mechanism to extract abstract model graphs and layer parameters from saved models. The resulting graph (not a TF graph) can be used by the tfl.visualization module for plotting and other visualization and analysis.

model_graph = estimators.get_model_graph(saved_model_path)
visualization.plot_feature_calibrator(model_graph, "feature_name")
visualization.plot_all_calibrators(model_graph)
visualization.draw_model_graph(model_graph)

Classes

class CannedClassifier: Canned classifier for TensorFlow lattice models.

class CannedEstimator: An estimator for TensorFlow lattice models.

class CannedRegressor: A regressor for TensorFlow lattice models.

class WaitTimeOutError: Timeout error when waiting for a file.

Functions

get_model_graph(...): Returns all layers and parameters used in a saved model as a graph.

transform_features(...): Parses the input features using the given feature columns.

FEATURES_SCOPE 'features'
OUTPUT_NAME 'output'
absolute_import Instance of __future__._Feature
division Instance of __future__._Feature
print_function Instance of __future__._Feature