|TensorFlow 1 version||View source on GitHub|
Estimator class to train and evaluate TensorFlow models.
tf.estimator.Estimator( model_fn, model_dir=None, config=None, params=None, warm_start_from=None )
Used in the notebooks
|Used in the guide||Used in the tutorials|
Estimator object wraps a model which is specified by a
which, given inputs and a number of other parameters, returns the ops
necessary to perform training, evaluation, or predictions.
All outputs (checkpoints, event files, etc.) are written to
model_dir, or a
subdirectory thereof. If
model_dir is not set, a temporary directory is
config argument can be passed
tf.estimator.RunConfig object containing
information about the execution environment. It is passed on to the
model_fn, if the
model_fn has a parameter named "config" (and input
functions in the same manner). If the
config parameter is not passed, it is
instantiated by the
Estimator. Not passing config means that defaults useful
for local execution are used.
Estimator makes config available to the model
(for instance, to allow specialization based on the number of workers
available), and also uses some of its fields to control internals, especially
params argument contains hyperparameters. It is passed to the
model_fn, if the
model_fn has a parameter named "params", and to the input
functions in the same manner.
Estimator only passes params along, it does
not inspect it. The structure of
params is therefore entirely up to the
Estimator's methods can be overridden in subclasses (its
constructor enforces this). Subclasses should use
model_fn to configure
the base class, and may add methods implementing specialized functionality.
See estimators for more information.
To warm-start an
estimator = tf.estimator.DNNClassifier( feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb], hidden_units=[1024, 512, 256], warm_start_from="/path/to/checkpoint/dir")
For more details on warm-start configuration, see
Model function. Follows the signature: