nsl.estimator.add_graph_regularization

Adds graph regularization to a tf.estimator.Estimator.

Used in the notebooks

Used in the tutorials

estimator An object of type tf.estimator.Estimator.
embedding_fn A function that accepts the input layer (dictionary of feature names and corresponding batched tensor values) as its first argument, an instance of tf.estimator.ModeKeys as its second argument to indicate if the mode is training, evaluation, or prediction, and an optional third argument named params which is a dict similar to the params argument of tf.estimator.Estimator's model_fn, and returns the corresponding embeddings or logits to be used for graph regularization. The params argument will receive what was passed to estimator at the time of its creation as its params argument.
optimizer_fn A function that accepts no arguments and returns an instance of tf.train.Optimizer.
graph_reg_config An instance of nsl.configs.GraphRegConfig that specifies various hyperparameters for graph regularization.

A modified tf.estimator.Estimator object with graph regularization incorporated into its loss.