Skip to content

Commit c762855

Browse files
fix support cutting for XQA data preparation
1 parent 5434c1d commit c762855

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

jack/readers/extractive_qa/util.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -109,19 +109,22 @@ def prepare_data(qa_setting: QASetting,
109109
max_answer = max(max_answer, end)
110110

111111
# cut support whenever there is a maximum allowed length and recompute answer spans
112-
if max_support_length is not None and len(support_tokens) > max_support_length > 0:
112+
support_length = all_support_length[doc_idx]
113+
if max_support_length is not None and support_length > max_support_length > 0:
113114
if max_answer < max_support_length:
114115
# Find new start and end in the flattened support
115116
new_start = 0
116117
new_end = max_support_length
117118
else:
118119
offset = rng.randint(1, 11)
119-
new_end = max_answer + offset
120-
new_start = max(0, min(min_answer - offset, new_end - max_support_length))
121-
while new_end - new_start > max_support_length:
122-
answer_spans = [(s, e) for s, e in answer_spans if e < (new_end - offset)]
123-
new_end = max(answer_spans, key=lambda span: span[1])[1] + offset
124-
new_start = max(0, min(min_answer - offset, new_end - max_support_length))
120+
new_end = max_answer
121+
new_start = max(0, min(min_answer, new_end + 2 * offset - max_support_length))
122+
while new_end - new_start > max_support_length - 2 * offset:
123+
answer_spans = [(s, e) for s, e in answer_spans if e < new_end]
124+
new_end = max(answer_spans, key=lambda span: span[1])[1]
125+
new_start = max(0, min(min_answer, new_end + 2 * offset - max_support_length))
126+
new_end = min(new_end + offset, support_length)
127+
new_start = max(new_start - offset, 0)
125128

126129
# Crop support according to new start and end pointers
127130
all_support_tokens[doc_idx] = support_tokens[new_start:new_end]

0 commit comments

Comments
 (0)