text.max_spanning_tree_gradient

Returns a subgradient of the MaximumSpanningTree op.

Note that MaximumSpanningTree is only differentiable w.r.t. its |scores| input and its |max_scores| output.

mst_op The MaximumSpanningTree op being differentiated.
d_loss_d_max_scores [B] vector where entry b is the gradient of the network loss w.r.t. entry b of the |max_scores| output of the |mstop|.
*_<a id="*"> The gradients w.r.t. the other outputs; ignored.

  1. None, since the op is not differentiable w.r.t. its |num_nodes| input.
  2. [B,M,M] tensor where entry b,t,s is a subgradient of the network loss w.r.t. entry b,t,s of the |scores| input, with the same dtype as |d_loss_d_max_scores|.