tf.contrib.estimator.boosted_trees_regressor_train_in_memory

tf.contrib.estimator.boosted_trees_regressor_train_in_memory(
    train_input_fn,
    feature_columns,
    model_dir=None,
    label_dimension=canned_boosted_trees._HOLD_FOR_MULTI_DIM_SUPPORT,
    weight_column=None,
    n_trees=100,
    max_depth=6,
    learning_rate=0.1,
    l1_regularization=0.0,
    l2_regularization=0.0,
    tree_complexity=0.0,
    min_node_weight=0.0,
    config=None,
    train_hooks=None,
    center_bias=False,
    pruning_mode='none'
)

Defined in tensorflow/contrib/estimator/python/estimator/boosted_trees.py.

Trains a boosted tree regressor with in memory dataset.

Example:

bucketized_feature_1 = bucketized_column(
  numeric_column('feature_1'), BUCKET_BOUNDARIES_1)
bucketized_feature_2 = bucketized_column(
  numeric_column('feature_2'), BUCKET_BOUNDARIES_2)

def train_input_fn():
  dataset = create-dataset-from-training-data
  # This is tf.data.Dataset of a tuple of feature dict and label.
  #   e.g. Dataset.zip((Dataset.from_tensors({'f1': f1_array, ...}),
  #                     Dataset.from_tensors(label_array)))
  # The returned Dataset shouldn't be batched.
  # If Dataset repeats, only the first repetition would be used for training.
  return dataset

regressor = boosted_trees_regressor_train_in_memory(
    train_input_fn,
    feature_columns=[bucketized_feature_1, bucketized_feature_2],
    n_trees=100,
    ... <some other params>
)

def input_fn_eval():
  ...
  return dataset

metrics = regressor.evaluate(input_fn=input_fn_eval, steps=10)

Args:

  • train_input_fn: the input function returns a dataset containing a single epoch of unbatched features and labels.
  • feature_columns: An iterable containing all the feature columns used by the model. All items in the set should be instances of classes derived from FeatureColumn.
  • model_dir: Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model.
  • label_dimension: Number of regression targets per example. Multi-dimensional support is not yet implemented.
  • weight_column: A string or a _NumericColumn created by tf.feature_column.numeric_column defining feature column representing weights. It is used to downweight or boost examples during training. It will be multiplied by the loss of the example. If it is a string, it is used as a key to fetch weight tensor from the features. If it is a _NumericColumn, raw tensor is fetched by key weight_column.key, then weight_column.normalizer_fn is applied on it to get weight tensor.
  • n_trees: number trees to be created.
  • max_depth: maximum depth of the tree to grow.
  • learning_rate: shrinkage parameter to be used when a tree added to the model.
  • l1_regularization: regularization multiplier applied to the absolute weights of the tree leafs.
  • l2_regularization: regularization multiplier applied to the square weights of the tree leafs.
  • tree_complexity: regularization factor to penalize trees with more leaves.
  • min_node_weight: minimum hessian a node must have for a split to be considered. The value will be compared with sum(leaf_hessian)/ (batch_size * n_batches_per_layer).
  • config: RunConfig object to configure the runtime settings.
  • train_hooks: a list of Hook instances to be passed to estimator.train().
  • center_bias: Whether bias centering needs to occur. Bias centering refers to the first node in the very first tree returning the prediction that is aligned with the original labels distribution. For example, for regression problems, the first node will return the mean of the labels. For binary classification problems, it will return a logit for a prior probability of label 1.
  • pruning_mode: one of 'none', 'pre', 'post' to indicate no pruning, pre- pruning (do not split a node if not enough gain is observed) and post pruning (build the tree up to a max depth and then prune branches with negative gain). For pre and post pruning, you MUST provide tree_complexity >0.

Returns:

a BoostedTreesClassifier instance created with the given arguments and trained with the data loaded up on memory from the input_fn.

Raises:

  • ValueError: when wrong arguments are given or unsupported functionalities are requested.