Skip to content

Commit

Permalink
update sampler tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Valery Chernov committed Feb 16, 2024
1 parent a82a090 commit ae5f2a8
Showing 1 changed file with 23 additions and 19 deletions.
42 changes: 23 additions & 19 deletions serve/tests/unittest/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ def get_sampling_state(sampling_params, past_output_tokens=None, prompt_masks=No
if past_output_tokens is None:
past_output_tokens = [[] for _ in range(batch_size)]
if prompt_masks is None:
prompt_masks = [[] for _ in range(batch_size)]
# Prepare empty prompt mask
prompt_mask = torch.zeros((vocab_size,), dtype=torch.bool)
prompt_masks = [prompt_mask] * batch_size
_copy_stream: torch.cuda.Stream = torch.cuda.Stream()
with torch.cuda.stream(_copy_stream):
sampling_state = SamplingState.from_sampling_params(
Expand All @@ -29,7 +31,7 @@ def get_sampling_state(sampling_params, past_output_tokens=None, prompt_masks=No
return sampling_state


def _test_temperature(temp=0, batch_size=1):
def test_temperature(temp=0, batch_size=1):
shape = (batch_size, vocab_size)
logits = torch.rand(shape, dtype=dtype, device=dev)
sampling_param = SamplingParams(temperature=temp)
Expand All @@ -41,7 +43,7 @@ def _test_temperature(temp=0, batch_size=1):
assert torch.allclose(expected, new_logits)


def _test_logit_bias_checker():
def test_logit_bias_checker():
# logit bias must be [-100, 100]
with pytest.raises(ValueError):
logit_bias = {1: 2, 3: 105, 2: 2}
Expand Down Expand Up @@ -78,7 +80,7 @@ def _test_logit_bias_checker():
get_sampling_state([sampling_param])


def _test_logit_bias():
def test_logit_bias():
# test single batch
batch_size = 1
shape = (batch_size, vocab_size)
Expand Down Expand Up @@ -112,7 +114,7 @@ def _test_logit_bias():
assert torch.allclose(expected, new_logits)


def _test_penalties_checker():
def test_penalties_checker():
get_sampling_state([SamplingParams(presence_penalty=-1.0)])
get_sampling_state([SamplingParams(frequency_penalty=-1.0)])
get_sampling_state([SamplingParams(repetition_penalty=0.7)])
Expand Down Expand Up @@ -143,15 +145,16 @@ def _test_penalties_checker():
)


def _test_penalties():
def test_penalties():
# TODO(vvchernov): Add test for repetition penalty
batch_size = 1
shape = (batch_size, vocab_size)
logits = torch.rand(shape, dtype=dtype, device=dev)
presence_penalties = [0.8]
frequency_penalties = [0.3]
past_output_tokens = [[2, 2, 2, 3]]
prompt_masks = [[False] * vocab_size] * batch_size
prompt_mask = torch.zeros((vocab_size,), dtype=torch.bool)
prompt_masks = [prompt_mask] * batch_size

def prepare_metadata(past_output_tokens):
count_map = []
Expand Down Expand Up @@ -202,7 +205,8 @@ def get_expected_result(
presence_penalties = [0.8, 0.7, -0.8]
frequency_penalties = [-0.3, 2.0, 1.2]
past_output_tokens = [[2, 2, 2, 3, 5], [3, 1, 2, 4], [3, 3, 1]]
prompt_masks = [[False] * vocab_size] * batch_size
prompt_mask = torch.zeros((vocab_size,), dtype=torch.bool)
prompt_masks = [prompt_mask] * batch_size

count_map, mask = prepare_metadata(past_output_tokens)
expected = get_expected_result(
Expand All @@ -225,7 +229,7 @@ def get_expected_result(
assert torch.allclose(expected, new_logits)


def _test_top_p_top_k_checker():
def test_top_p_top_k_checker():
get_sampling_state([SamplingParams(top_p=0.8)])
get_sampling_state([SamplingParams(top_k=3)])

Expand All @@ -248,7 +252,7 @@ def _test_top_p_top_k_checker():
get_sampling_state([SamplingParams(top_k=-2)])


def _test_top_p_top_k():
def test_top_p_top_k():
def get_expected_result(logits, top_pks, filter_value=-float("Inf")):
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
Expand Down Expand Up @@ -320,7 +324,7 @@ def get_expected_result(logits, top_pks, filter_value=-float("Inf")):
assert torch.allclose(expected, new_logits)


def _test_mixture_of_requests():
def test_mixture_of_requests():
# Mixed greedy & top_p/top_ks
batch_size = 6
shape = (batch_size, vocab_size)
Expand All @@ -341,11 +345,11 @@ def _test_mixture_of_requests():


if __name__ == "__main__":
_test_temperature()
_test_logit_bias_checker()
_test_logit_bias()
_test_penalties_checker()
_test_penalties()
_test_top_p_top_k_checker()
_test_top_p_top_k()
_test_mixture_of_requests()
test_temperature()
test_logit_bias_checker()
test_logit_bias()
test_penalties_checker()
test_penalties()
test_top_p_top_k_checker()
test_top_p_top_k()
test_mixture_of_requests()

0 comments on commit ae5f2a8

Please sign in to comment.