|View source on GitHub|
Estimator class to train and evaluate TensorFlow models.
tf.compat.v1.estimator.Estimator( model_fn, model_dir=None, config=None, params=None, warm_start_from=None )
Estimator object wraps a model which is specified by a
which, given inputs and a number of other parameters, returns the ops
necessary to perform training, evaluation, or predictions.
All outputs (checkpoints, event files, etc.) are written to
model_dir, or a
subdirectory thereof. If
model_dir is not set, a temporary directory is
config argument can be passed
tf.estimator.RunConfig object containing
information about the execution environment. It is passed on to the
model_fn, if the
model_fn has a parameter named "config" (and input
functions in the same manner). If the
config parameter is not passed, it is
instantiated by the
Estimator. Not passing config means that defaults useful
for local execution are used.
Estimator makes config available to the model
(for instance, to allow specialization based on the number of workers
available), and also uses some of its fields to control internals, especially
params argument contains hyperparameters. It is passed to the
model_fn, if the
model_fn has a parameter named "params", and to the input
functions in the same manner.
Estimator only passes params along, it does
not inspect it. The structure of
params is therefore entirely up to the
Estimator's methods can be overridden in subclasses (its
constructor enforces this). Subclasses should use
model_fn to configure
the base class, and may add methods implementing specialized functionality.
See estimators for more information.
To warm-start an
estimator = tf.estimator.DNNClassifier( feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb], hidden_units=[1024, 512, 256], warm_start_from="/path/to/checkpoint/dir")
For more details on warm-start configuration, see
Model function. Follows the signature:
Directory to save model parameters, graph and etc. This can
also be used to load checkpoints from the directory into an estimator to
continue training a previously saved model. If
Optional string filepath to a checkpoint or SavedModel to
warm-start from, or a
if this is called via a subclass and if that class overrides
a member of
Calling methods of
Estimator will work while eager execution is enabled.
input_fn is not executed eagerly,
will switch to graph mode before calling all user-provided functions (incl.
hooks), so their code has to be compatible with graph mode execution. Note
input_fn code using
tf.data generally works in both graph and eager
eval_dir( name=None )
Shows the directory name where evaluation metrics are dumped.
||Name of the evaluation if user needs to run multiple evaluations on different data sets, such as on training data vs test data. Metrics for different evaluations are saved in separate folders, and appear separately in tensorboard.|
|A string which is the path of directory contains evaluation metrics.|
evaluate( input_fn, steps=None, hooks=None, checkpoint_path=None, name=None )
Evaluates the model given evaluation data
For each step, calls
input_fn, which returns one batch of data.
stepsbatches are processed, or
input_fnraises an end-of-input exception (
A function that constructs the input data for evaluation. See
for more information. The function should construct and return one of