Adds graph regularization to a tf.estimator.Estimator
.
nsl.estimator.add_graph_regularization(
estimator, embedding_fn, optimizer_fn=None, graph_reg_config=None
)
Args |
estimator
|
An object of type tf.estimator.Estimator .
|
embedding_fn
|
A function that accepts the input layer (dictionary of feature
names and corresponding batched tensor values) as its first argument and
an instance of tf.estimator.ModeKeys as its second argument to indicate
if the mode is training, evaluation, or prediction, and returns the
corresponding embeddings or logits to be used for graph regularization.
|
optimizer_fn
|
A function that accepts no arguments and returns an instance
of tf.train.Optimizer .
|
graph_reg_config
|
An instance of nsl.configs.GraphRegConfig that specifies
various hyperparameters for graph regularization.
|