Skip to content

Commit ae5f2a8

Browse files
author
Valery Chernov
committed
update sampler tests
1 parent a82a090 commit ae5f2a8

File tree

1 file changed

+23
-19
lines changed

1 file changed

+23
-19
lines changed

serve/tests/unittest/test_sampler.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ def get_sampling_state(sampling_params, past_output_tokens=None, prompt_masks=No
1414
if past_output_tokens is None:
1515
past_output_tokens = [[] for _ in range(batch_size)]
1616
if prompt_masks is None:
17-
prompt_masks = [[] for _ in range(batch_size)]
17+
# Prepare empty prompt mask
18+
prompt_mask = torch.zeros((vocab_size,), dtype=torch.bool)
19+
prompt_masks = [prompt_mask] * batch_size
1820
_copy_stream: torch.cuda.Stream = torch.cuda.Stream()
1921
with torch.cuda.stream(_copy_stream):
2022
sampling_state = SamplingState.from_sampling_params(
@@ -29,7 +31,7 @@ def get_sampling_state(sampling_params, past_output_tokens=None, prompt_masks=No
2931
return sampling_state
3032

3133

32-
def _test_temperature(temp=0, batch_size=1):
34+
def test_temperature(temp=0, batch_size=1):
3335
shape = (batch_size, vocab_size)
3436
logits = torch.rand(shape, dtype=dtype, device=dev)
3537
sampling_param = SamplingParams(temperature=temp)
@@ -41,7 +43,7 @@ def _test_temperature(temp=0, batch_size=1):
4143
assert torch.allclose(expected, new_logits)
4244

4345

44-
def _test_logit_bias_checker():
46+
def test_logit_bias_checker():
4547
# logit bias must be [-100, 100]
4648
with pytest.raises(ValueError):
4749
logit_bias = {1: 2, 3: 105, 2: 2}
@@ -78,7 +80,7 @@ def _test_logit_bias_checker():
7880
get_sampling_state([sampling_param])
7981

8082

81-
def _test_logit_bias():
83+
def test_logit_bias():
8284
# test single batch
8385
batch_size = 1
8486
shape = (batch_size, vocab_size)
@@ -112,7 +114,7 @@ def _test_logit_bias():
112114
assert torch.allclose(expected, new_logits)
113115

114116

115-
def _test_penalties_checker():
117+
def test_penalties_checker():
116118
get_sampling_state([SamplingParams(presence_penalty=-1.0)])
117119
get_sampling_state([SamplingParams(frequency_penalty=-1.0)])
118120
get_sampling_state([SamplingParams(repetition_penalty=0.7)])
@@ -143,15 +145,16 @@ def _test_penalties_checker():
143145
)
144146

145147

146-
def _test_penalties():
148+
def test_penalties():
147149
# TODO(vvchernov): Add test for repetition penalty
148150
batch_size = 1
149151
shape = (batch_size, vocab_size)
150152
logits = torch.rand(shape, dtype=dtype, device=dev)
151153
presence_penalties = [0.8]
152154
frequency_penalties = [0.3]
153155
past_output_tokens = [[2, 2, 2, 3]]
154-
prompt_masks = [[False] * vocab_size] * batch_size
156+
prompt_mask = torch.zeros((vocab_size,), dtype=torch.bool)
157+
prompt_masks = [prompt_mask] * batch_size
155158

156159
def prepare_metadata(past_output_tokens):
157160
count_map = []
@@ -202,7 +205,8 @@ def get_expected_result(
202205
presence_penalties = [0.8, 0.7, -0.8]
203206
frequency_penalties = [-0.3, 2.0, 1.2]
204207
past_output_tokens = [[2, 2, 2, 3, 5], [3, 1, 2, 4], [3, 3, 1]]
205-
prompt_masks = [[False] * vocab_size] * batch_size
208+
prompt_mask = torch.zeros((vocab_size,), dtype=torch.bool)
209+
prompt_masks = [prompt_mask] * batch_size
206210

207211
count_map, mask = prepare_metadata(past_output_tokens)
208212
expected = get_expected_result(
@@ -225,7 +229,7 @@ def get_expected_result(
225229
assert torch.allclose(expected, new_logits)
226230

227231

228-
def _test_top_p_top_k_checker():
232+
def test_top_p_top_k_checker():
229233
get_sampling_state([SamplingParams(top_p=0.8)])
230234
get_sampling_state([SamplingParams(top_k=3)])
231235

@@ -248,7 +252,7 @@ def _test_top_p_top_k_checker():
248252
get_sampling_state([SamplingParams(top_k=-2)])
249253

250254

251-
def _test_top_p_top_k():
255+
def test_top_p_top_k():
252256
def get_expected_result(logits, top_pks, filter_value=-float("Inf")):
253257
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
254258
Args:
@@ -320,7 +324,7 @@ def get_expected_result(logits, top_pks, filter_value=-float("Inf")):
320324
assert torch.allclose(expected, new_logits)
321325

322326

323-
def _test_mixture_of_requests():
327+
def test_mixture_of_requests():
324328
# Mixed greedy & top_p/top_ks
325329
batch_size = 6
326330
shape = (batch_size, vocab_size)
@@ -341,11 +345,11 @@ def _test_mixture_of_requests():
341345

342346

343347
if __name__ == "__main__":
344-
_test_temperature()
345-
_test_logit_bias_checker()
346-
_test_logit_bias()
347-
_test_penalties_checker()
348-
_test_penalties()
349-
_test_top_p_top_k_checker()
350-
_test_top_p_top_k()
351-
_test_mixture_of_requests()
348+
test_temperature()
349+
test_logit_bias_checker()
350+
test_logit_bias()
351+
test_penalties_checker()
352+
test_penalties()
353+
test_top_p_top_k_checker()
354+
test_top_p_top_k()
355+
test_mixture_of_requests()

0 commit comments

Comments
 (0)