-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Extend sampler tests #200
Extend sampler tests #200
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Hello @sunggg! Could you check testing of sampler? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @Ailurus1 for the great contribution!
Overall, it looks great, I have a couple of suggestion.
For the comments about more complicated tests, I think we don't have to address them in this PR to merge this PR before the migration to the new repo, but I'd like us to follow-up there.
serve/tests/unittest/test_sampler.py
Outdated
def test_logit_bias_checker(): | ||
# logit bias values must be [-100, 100] | ||
get_sampling_state([SamplingParams(logit_bias={1: 100, 3: -100, 2: 2})]) | ||
get_sampling_state([SamplingParams(logit_bias={34: 0, 23: -0.5})]) | ||
# TODO(@team): it seems like the valid range is [1,vocab_size]. Double check. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we double check on this and remove this line if true?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't find any descriptive API reference for indices in logit_bias, but figured that in transformers similar parameter sequence_bias must be >= 0, but if not then error message says it has to be positive so the usage of 0 is not quite clear here. I guess it can be also used since these are all indices and there is a token (<unk>
) with id 0 so I changed range to [0, vocab_size)
in the last commits
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Can we follow-up about this after this PR? I think this can be a rough edge so would like to address this before we run into the issue with customers. Maybe we can do deep-dive transformers
or TGI
to match our behavior.
serve/tests/unittest/test_sampler.py
Outdated
temperature = temperatures[i] | ||
if temperature < SAMPLING_EPS: | ||
temperature = 1.0 | ||
rep_pen = torch.where( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is basically the same computation in adjust_logits
, can we implement another approach that does same thing to cross-check?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tried to do it in a bit more naive way in last commits
), | ||
batch_size | ||
): | ||
sampling_params = [SamplingParams(top_p=top_p, top_k=top_k) for top_p, top_k in top_pks] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like I don't see the test with temperature=0
. Let's add this in the next PR.
assert isinstance(output.logprob_infos[idx].current_logprob, float) | ||
assert output.logprob_infos[idx].top_token_ids.nelement() == 0 | ||
assert output.logprob_infos[idx].top_logprobs.nelement() == 0 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto. we also need to test the complicated scenarios, such as,
(1) Requests in the same batch asks for the different top-k. Some of them are not logprobs
(2) Logprob can be used with other sampling params.
I think this may take some time, so I would like us to follow-up about this after this PR.
and make sure that _apply_top_p_top_k from sampler.py does not produce too many -inf values | ||
""") | ||
@pytest.mark.parametrize("batch_size", [1, 4, 8, 12]) | ||
def test_mixture_of_requests(batch_size: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should test this more exhaustively. Let's follow-up in the next PR.
Hello @sunggg |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for making the sampler more robust!
Like we discussed, let's address the remaining feedback in the new repo.
@@ -422,6 +547,8 @@ def _test_json_mode( | |||
_test_stop(sync_engine) | |||
_test_logprobs(sync_engine) | |||
_test_logprobs_mixed_requests(sync_engine) | |||
_test_num_sequences(sync_engine) | |||
_test_logit_bias(sync_engine) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This tests seems to always fail when I run all tests in test_engine_with_samplers.py
. But if I comment out other tests and run only this one, it works. So there is a strange issue in the tests. @Ailurus1 Can you take a look and send a fix to the ollm repo?
No description provided.