Missed TensorFlow World? Check out the recap. Learn more

nsl.estimator.add_graph_regularization

View source on GitHub

Adds graph regularization to a tf.estimator.Estimator.

nsl.estimator.add_graph_regularization(
    estimator,
    embedding_fn,
    optimizer_fn=None,
    graph_reg_config=None
)

Args:

  • estimator: An object of type tf.estimator.Estimator.
  • embedding_fn: A function that accepts the input layer (dictionary of feature names and corresponding batched tensor values) as its first argument and an instance of tf.estimator.ModeKeys as its second argument to indicate if the mode is training, evaluation, or prediction, and returns the corresponding embeddings or logits to be used for graph regularization.
  • optimizer_fn: A function that accepts no arguments and returns an instance of tf.train.Optimizer.
  • graph_reg_config: An instance of nsl.configs.GraphRegConfig that specifies various hyperparameters for graph regularization.

Returns:

A modified tf.estimator.Estimator object with graph regularization incorporated into its loss.