diff --git a/tensorflow_addons/text/crf.py b/tensorflow_addons/text/crf.py index 4ddcc1dcac..c1bde04e0d 100644 --- a/tensorflow_addons/text/crf.py +++ b/tensorflow_addons/text/crf.py @@ -309,7 +309,7 @@ def crf_forward( See http://www.cs.columbia.edu/~mcollins/fb.pdf for reference. Args: - inputs: A [batch_size, num_tags] matrix of unary potentials. + inputs: A [batch_size, max_seq_len, num_tags] matrix of unary potentials. state: A [batch_size, num_tags] matrix containing the previous alpha values. transition_params: A [num_tags, num_tags] matrix of binary potentials.