Computes the alpha values in a linear-chain CRF.
tfa.text.crf_forward(
inputs: tfa.image.color_ops.TensorLike
,
state: tfa.image.color_ops.TensorLike
,
transition_params: tfa.image.color_ops.TensorLike
,
sequence_lengths: tfa.image.color_ops.TensorLike
) -> tf.Tensor
See http://www.cs.columbia.edu/~mcollins/fb.pdf for reference.
Args |
inputs
|
A [batch_size, num_tags] matrix of unary potentials.
|
state
|
A [batch_size, num_tags] matrix containing the previous alpha
values.
|
transition_params
|
A [num_tags, num_tags] matrix of binary potentials.
This matrix is expanded into a [1, num_tags, num_tags] in preparation
for the broadcast summation occurring within the cell.
|
sequence_lengths
|
A [batch_size] vector of true sequence lengths.
|
Returns |
new_alphas
|
A [batch_size, num_tags] matrix containing the
new alpha values.
|