@@ -14,7 +14,9 @@ def get_sampling_state(sampling_params, past_output_tokens=None, prompt_masks=No
14
14
if past_output_tokens is None :
15
15
past_output_tokens = [[] for _ in range (batch_size )]
16
16
if prompt_masks is None :
17
- prompt_masks = [[] for _ in range (batch_size )]
17
+ # Prepare empty prompt mask
18
+ prompt_mask = torch .zeros ((vocab_size ,), dtype = torch .bool )
19
+ prompt_masks = [prompt_mask ] * batch_size
18
20
_copy_stream : torch .cuda .Stream = torch .cuda .Stream ()
19
21
with torch .cuda .stream (_copy_stream ):
20
22
sampling_state = SamplingState .from_sampling_params (
@@ -29,7 +31,7 @@ def get_sampling_state(sampling_params, past_output_tokens=None, prompt_masks=No
29
31
return sampling_state
30
32
31
33
32
- def _test_temperature (temp = 0 , batch_size = 1 ):
34
+ def test_temperature (temp = 0 , batch_size = 1 ):
33
35
shape = (batch_size , vocab_size )
34
36
logits = torch .rand (shape , dtype = dtype , device = dev )
35
37
sampling_param = SamplingParams (temperature = temp )
@@ -41,7 +43,7 @@ def _test_temperature(temp=0, batch_size=1):
41
43
assert torch .allclose (expected , new_logits )
42
44
43
45
44
- def _test_logit_bias_checker ():
46
+ def test_logit_bias_checker ():
45
47
# logit bias must be [-100, 100]
46
48
with pytest .raises (ValueError ):
47
49
logit_bias = {1 : 2 , 3 : 105 , 2 : 2 }
@@ -78,7 +80,7 @@ def _test_logit_bias_checker():
78
80
get_sampling_state ([sampling_param ])
79
81
80
82
81
- def _test_logit_bias ():
83
+ def test_logit_bias ():
82
84
# test single batch
83
85
batch_size = 1
84
86
shape = (batch_size , vocab_size )
@@ -112,7 +114,7 @@ def _test_logit_bias():
112
114
assert torch .allclose (expected , new_logits )
113
115
114
116
115
- def _test_penalties_checker ():
117
+ def test_penalties_checker ():
116
118
get_sampling_state ([SamplingParams (presence_penalty = - 1.0 )])
117
119
get_sampling_state ([SamplingParams (frequency_penalty = - 1.0 )])
118
120
get_sampling_state ([SamplingParams (repetition_penalty = 0.7 )])
@@ -143,15 +145,16 @@ def _test_penalties_checker():
143
145
)
144
146
145
147
146
- def _test_penalties ():
148
+ def test_penalties ():
147
149
# TODO(vvchernov): Add test for repetition penalty
148
150
batch_size = 1
149
151
shape = (batch_size , vocab_size )
150
152
logits = torch .rand (shape , dtype = dtype , device = dev )
151
153
presence_penalties = [0.8 ]
152
154
frequency_penalties = [0.3 ]
153
155
past_output_tokens = [[2 , 2 , 2 , 3 ]]
154
- prompt_masks = [[False ] * vocab_size ] * batch_size
156
+ prompt_mask = torch .zeros ((vocab_size ,), dtype = torch .bool )
157
+ prompt_masks = [prompt_mask ] * batch_size
155
158
156
159
def prepare_metadata (past_output_tokens ):
157
160
count_map = []
@@ -202,7 +205,8 @@ def get_expected_result(
202
205
presence_penalties = [0.8 , 0.7 , - 0.8 ]
203
206
frequency_penalties = [- 0.3 , 2.0 , 1.2 ]
204
207
past_output_tokens = [[2 , 2 , 2 , 3 , 5 ], [3 , 1 , 2 , 4 ], [3 , 3 , 1 ]]
205
- prompt_masks = [[False ] * vocab_size ] * batch_size
208
+ prompt_mask = torch .zeros ((vocab_size ,), dtype = torch .bool )
209
+ prompt_masks = [prompt_mask ] * batch_size
206
210
207
211
count_map , mask = prepare_metadata (past_output_tokens )
208
212
expected = get_expected_result (
@@ -225,7 +229,7 @@ def get_expected_result(
225
229
assert torch .allclose (expected , new_logits )
226
230
227
231
228
- def _test_top_p_top_k_checker ():
232
+ def test_top_p_top_k_checker ():
229
233
get_sampling_state ([SamplingParams (top_p = 0.8 )])
230
234
get_sampling_state ([SamplingParams (top_k = 3 )])
231
235
@@ -248,7 +252,7 @@ def _test_top_p_top_k_checker():
248
252
get_sampling_state ([SamplingParams (top_k = - 2 )])
249
253
250
254
251
- def _test_top_p_top_k ():
255
+ def test_top_p_top_k ():
252
256
def get_expected_result (logits , top_pks , filter_value = - float ("Inf" )):
253
257
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
254
258
Args:
@@ -320,7 +324,7 @@ def get_expected_result(logits, top_pks, filter_value=-float("Inf")):
320
324
assert torch .allclose (expected , new_logits )
321
325
322
326
323
- def _test_mixture_of_requests ():
327
+ def test_mixture_of_requests ():
324
328
# Mixed greedy & top_p/top_ks
325
329
batch_size = 6
326
330
shape = (batch_size , vocab_size )
@@ -341,11 +345,11 @@ def _test_mixture_of_requests():
341
345
342
346
343
347
if __name__ == "__main__" :
344
- _test_temperature ()
345
- _test_logit_bias_checker ()
346
- _test_logit_bias ()
347
- _test_penalties_checker ()
348
- _test_penalties ()
349
- _test_top_p_top_k_checker ()
350
- _test_top_p_top_k ()
351
- _test_mixture_of_requests ()
348
+ test_temperature ()
349
+ test_logit_bias_checker ()
350
+ test_logit_bias ()
351
+ test_penalties_checker ()
352
+ test_penalties ()
353
+ test_top_p_top_k_checker ()
354
+ test_top_p_top_k ()
355
+ test_mixture_of_requests ()
0 commit comments