tf.compat.v1.estimator.LinearClassifier

View source on GitHub

Linear classifier model.

Inherits From: Estimator

Train a linear model to classify instances into one of multiple possible classes. When number of possible classes is 2, this is binary classification.

Example:

categorical_column_a = categorical_column_with_hash_bucket(...)
categorical_column_b = categorical_column_with_hash_bucket(...)

categorical_feature_a_x_categorical_feature_b = crossed_column(...)

# Estimator using the default optimizer.
estimator = tf.estimator.LinearClassifier(
    feature_columns=[categorical_column_a,
                     categorical_feature_a_x_categorical_feature_b])

# Or estimator using the FTRL optimizer with regularization.
estimator = tf.estimator.LinearClassifier(
    feature_columns=[categorical_column_a,
                     categorical_feature_a_x_categorical_feature_b],
    optimizer=tf.keras.optimizers.Ftrl(
      learning_rate=0.1,
      l1_regularization_strength=0.001
    ))

# Or estimator using an optimizer with a learning rate decay.
estimator = tf.estimator.LinearClassifier(
    feature_columns=[categorical_column_a,
                     categorical_feature_a_x_categorical_feature_b],
    optimizer=lambda: tf.keras.optimizers.Ftrl(
        learning_rate=tf.exponential_decay(
            learning_rate=0.1,
            global_step=tf.get_global_step(),
            decay_steps=10000,
            decay_rate=0.96))

# Or estimator with warm-starting from a previous checkpoint.
estimator = tf.estimator.LinearClassifier(
    feature_columns=[categorical_column_a,
                     categorical_feature_a_x_categorical_feature_b],
    warm_start_from="/path/to/checkpoint/dir")


# Input builders
def input_fn_train:
  # Returns tf.data.Dataset of (x, y) tuple where y represents label's class
  # index.
  pass
def input_fn_eval:
  # Returns tf.data.Dataset of (x, y) tuple where y represents label's class
  # index.
  pass
def input_fn_predict:
  # Returns tf.data.Dataset of (x, None) tuple.
  pass
estimator.train(input_fn=input_fn_train)
metrics = estimator.evaluate(input_fn=input_fn_eval)
predictions = estimator.predict(input_fn=input_fn_predict)

Input of train and evaluate should have following features, otherwise there will be a KeyError:

  • if weight_column is not None, a feature with key=weight_column whose value is a Tensor.
  • for each column in feature_columns:
    • if column is a SparseColumn, a feature with key=column.name whose value is a SparseTensor.
    • if column is a WeightedSparseColumn, two features: the first with key the id column name, the second with key the weight column name. Both features' value must be a SparseTensor.
    • if column is a RealValuedColumn, a feature with key=column.name whose value is a Tensor.

Loss is calculated by using softmax cross entropy.

model_fn Model function. Follows the signature:

  • features -- This is the first item returned from the input_fn passed to train, evaluate, and predict. This should be a single tf.Tensor or dict of same.
  • labels -- This is the second item returned from the input_fn passed to train, evaluate, and predict. This should be a single tf.Tensor or dict of same (for multi-head models). If mode is tf.estimator.ModeKeys.PREDICT, labels=None will be passed. If the model_fn's signature does not accept mode, the model_fn must still be able to handle labels=None.
  • mode -- Optional. Specifies if this is training, evaluation or prediction. See tf.estimator.ModeKeys. params -- Optional dict of hyperparameters. Will receive what is passed to Estimator in params parameter. This allows to configure Estimators from hyper parameter tuning.
  • config -- Optional estimator.RunConfig object. Will receive what is passed to Estimator as its config parameter, or a default value. Allows setting up things in your model_fn based on configuration such as num_ps_replicas, or model_dir.
  • Returns -- tf.estimator.EstimatorSpec
model_dir Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into an estimator to continue training a previously saved model. If PathLike object, the path will be resolved. If None, the model_dir in config will be used if set. If both are set, they must be same. If both are None, a temporary directory will be used.
config estimator.RunConfig configuration object.
params dict of hyper parameters that will be passed into model_fn. Keys are names of parameters, values are basic python types.
warm_start_from Optional string filepath to a checkpoint or SavedModel to warm-start from, or a tf.estimator.WarmStartSettings object to fully configure warm-starting. If None, only TRAINABLE variables are warm-started. If the string filepath is provided instead of a tf.estimator.WarmStartSettings, then all variables are warm-started, and it is assumed that vocabularies and tf.Tensor names are unchanged.

ValueError parameters of model_fn don't match params.
ValueError if this is called via a subclass and if that class overrides a member of Estimator.

Eager Compatibility

Estimators can be used while eager execution is enabled. Note that input_fn and all hooks are executed inside a graph context, so they have to be written to be compatible with graph mode. Note that input_fn code using tf.data generally works in both graph and eager modes.

config

model_dir

model_fn Returns the model_fn which is bound to self.params.
params

Methods

eval_dir

View source

Shows the directory name where evaluation metrics are dumped.