diff --git a/jack/readers/extractive_qa/util.py b/jack/readers/extractive_qa/util.py index 46402d8c..d9c6e44c 100644 --- a/jack/readers/extractive_qa/util.py +++ b/jack/readers/extractive_qa/util.py @@ -109,19 +109,22 @@ def prepare_data(qa_setting: QASetting, max_answer = max(max_answer, end) # cut support whenever there is a maximum allowed length and recompute answer spans - if max_support_length is not None and len(support_tokens) > max_support_length > 0: + support_length = all_support_length[doc_idx] + if max_support_length is not None and support_length > max_support_length > 0: if max_answer < max_support_length: # Find new start and end in the flattened support new_start = 0 new_end = max_support_length else: offset = rng.randint(1, 11) - new_end = max_answer + offset - new_start = max(0, min(min_answer - offset, new_end - max_support_length)) - while new_end - new_start > max_support_length: - answer_spans = [(s, e) for s, e in answer_spans if e < (new_end - offset)] - new_end = max(answer_spans, key=lambda span: span[1])[1] + offset - new_start = max(0, min(min_answer - offset, new_end - max_support_length)) + new_end = max_answer + new_start = max(0, min(min_answer, new_end + 2 * offset - max_support_length)) + while new_end - new_start > max_support_length - 2 * offset: + answer_spans = [(s, e) for s, e in answer_spans if e < new_end] + new_end = max(answer_spans, key=lambda span: span[1])[1] + new_start = max(0, min(min_answer, new_end + 2 * offset - max_support_length)) + new_end = min(new_end + offset, support_length) + new_start = max(new_start - offset, 0) # Crop support according to new start and end pointers all_support_tokens[doc_idx] = support_tokens[new_start:new_end]