@@ -109,19 +109,22 @@ def prepare_data(qa_setting: QASetting,
109
109
max_answer = max (max_answer , end )
110
110
111
111
# cut support whenever there is a maximum allowed length and recompute answer spans
112
- if max_support_length is not None and len (support_tokens ) > max_support_length > 0 :
112
+ support_length = all_support_length [doc_idx ]
113
+ if max_support_length is not None and support_length > max_support_length > 0 :
113
114
if max_answer < max_support_length :
114
115
# Find new start and end in the flattened support
115
116
new_start = 0
116
117
new_end = max_support_length
117
118
else :
118
119
offset = rng .randint (1 , 11 )
119
- new_end = max_answer + offset
120
- new_start = max (0 , min (min_answer - offset , new_end - max_support_length ))
121
- while new_end - new_start > max_support_length :
122
- answer_spans = [(s , e ) for s , e in answer_spans if e < (new_end - offset )]
123
- new_end = max (answer_spans , key = lambda span : span [1 ])[1 ] + offset
124
- new_start = max (0 , min (min_answer - offset , new_end - max_support_length ))
120
+ new_end = max_answer
121
+ new_start = max (0 , min (min_answer , new_end + 2 * offset - max_support_length ))
122
+ while new_end - new_start > max_support_length - 2 * offset :
123
+ answer_spans = [(s , e ) for s , e in answer_spans if e < new_end ]
124
+ new_end = max (answer_spans , key = lambda span : span [1 ])[1 ]
125
+ new_start = max (0 , min (min_answer , new_end + 2 * offset - max_support_length ))
126
+ new_end = min (new_end + offset , support_length )
127
+ new_start = max (new_start - offset , 0 )
125
128
126
129
# Crop support according to new start and end pointers
127
130
all_support_tokens [doc_idx ] = support_tokens [new_start :new_end ]
0 commit comments