Finds the maximum directed spanning tree of a digraph.
text.max_spanning_tree(
num_nodes, scores, forest=False, name=None
)
Given a batch of directed graphs with scored arcs and root selections, solves
for the maximum spanning tree of each digraph, where the score of a tree is
defined as the sum of the scores of the arcs and roots making up the tree.
Returns the score of the maximum spanning tree of each digraph, as well as the
arcs and roots in that tree. Each digraph in a batch may contain a different
number of nodes, so the sizes of the digraphs must be provided as an input.
Note that this operation is only differentiable w.r.t. its |scores| input and
its |max_scores| output.
The code here is intended for NLP applications, but attempts to remain
agnostic to particular NLP tasks (such as dependency parsing).
Args |
num_nodes
|
A Tensor of type int32 .
[B] vector where entry b is number of nodes in the b'th digraph.
|
scores
|
A Tensor . Must be one of the following types: int32 , float32 , float64 .
[B,M,M] tensor where entry b,t,s is the score of the arc from node s to
node t in the b'th directed graph if s!=t, or the score of selecting
node t as a root in the b'th digraph if s==t. This uniform tenosor
requires that M is >= num_nodes[b] for all b (ie. all graphs in the
batch), and ignores entries b,s,t where s or t is >= num_nodes[b].
Arcs or root selections with non-finite score are treated as
nonexistent.
|
forest
|
An optional bool . Defaults to False .
If true, solves for a maximum spanning forest instead of a maximum
spanning tree, where a spanning forest is a set of disjoint trees that
span the nodes of the digraph.
|
name
|
A name for the operation (optional).
|
Returns |
A tuple of Tensor objects (max_scores, argmax_sources).
|
max_scores
|
A Tensor . Has the same type as scores . [B] vector where entry b is the score of the maximum spanning tree
of the b'th digraph.
|
argmax_sources
|
A Tensor of type int32 . [B,M] matrix where entry b,t is the source of the arc inbound to
t in the maximum spanning tree of the b'th digraph, or t if t is
a root. Entries b,t where t is >= num_nodes[b] are set to -1.
Quickly finding the roots can be done as:
tf.equal(tf.map_fn(lambda x: tf.range(tf.size(x)),
argmax_sources), argmax_sources)
|