Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend sampler tests #200

Merged
merged 34 commits into from
Feb 22, 2024
Merged

Extend sampler tests #200

merged 34 commits into from
Feb 22, 2024

Conversation

Ailurus1
Copy link

@Ailurus1 Ailurus1 commented Feb 9, 2024

No description provided.

serve/mlc_serve/engine/sampling_params.py Outdated Show resolved Hide resolved
serve/tests/unittest/test_sampler.py Show resolved Hide resolved
@Ailurus1 Ailurus1 changed the title [WIP] Extend sampler tests Extend sampler tests Feb 19, 2024
@Ailurus1 Ailurus1 marked this pull request as ready for review February 19, 2024 11:33
Copy link

@vvchernov vvchernov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@vvchernov
Copy link

Hello @sunggg! Could you check testing of sampler?

Copy link
Member

@sunggg sunggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @Ailurus1 for the great contribution!
Overall, it looks great, I have a couple of suggestion.
For the comments about more complicated tests, I think we don't have to address them in this PR to merge this PR before the migration to the new repo, but I'd like us to follow-up there.

def test_logit_bias_checker():
# logit bias values must be [-100, 100]
get_sampling_state([SamplingParams(logit_bias={1: 100, 3: -100, 2: 2})])
get_sampling_state([SamplingParams(logit_bias={34: 0, 23: -0.5})])
# TODO(@team): it seems like the valid range is [1,vocab_size]. Double check.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we double check on this and remove this line if true?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't find any descriptive API reference for indices in logit_bias, but figured that in transformers similar parameter sequence_bias must be >= 0, but if not then error message says it has to be positive so the usage of 0 is not quite clear here. I guess it can be also used since these are all indices and there is a token (<unk>) with id 0 so I changed range to [0, vocab_size) in the last commits

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Can we follow-up about this after this PR? I think this can be a rough edge so would like to address this before we run into the issue with customers. Maybe we can do deep-dive transformers or TGI to match our behavior.

serve/tests/unittest/test_sampler.py Show resolved Hide resolved
serve/tests/unittest/test_sampler.py Show resolved Hide resolved
temperature = temperatures[i]
if temperature < SAMPLING_EPS:
temperature = 1.0
rep_pen = torch.where(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is basically the same computation in adjust_logits, can we implement another approach that does same thing to cross-check?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried to do it in a bit more naive way in last commits

serve/tests/unittest/test_sampler.py Show resolved Hide resolved
),
batch_size
):
sampling_params = [SamplingParams(top_p=top_p, top_k=top_k) for top_p, top_k in top_pks]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like I don't see the test with temperature=0. Let's add this in the next PR.

assert isinstance(output.logprob_infos[idx].current_logprob, float)
assert output.logprob_infos[idx].top_token_ids.nelement() == 0
assert output.logprob_infos[idx].top_logprobs.nelement() == 0

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto. we also need to test the complicated scenarios, such as,
(1) Requests in the same batch asks for the different top-k. Some of them are not logprobs
(2) Logprob can be used with other sampling params.

I think this may take some time, so I would like us to follow-up about this after this PR.

and make sure that _apply_top_p_top_k from sampler.py does not produce too many -inf values
""")
@pytest.mark.parametrize("batch_size", [1, 4, 8, 12])
def test_mixture_of_requests(batch_size: int):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should test this more exhaustively. Let's follow-up in the next PR.

@Ailurus1
Copy link
Author

Hello @sunggg
Thank you for the review and comments!
I updated the tests according to the comments. Could you please take a look at the latest changes once again
Going to take a deeper look at the latest test for a mixture of different parameters in the follow-up PR

Copy link
Member

@sunggg sunggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for making the sampler more robust!
Like we discussed, let's address the remaining feedback in the new repo.

@sunggg sunggg merged commit d66880c into octoml:batch-serving Feb 22, 2024
1 check passed
@@ -422,6 +547,8 @@ def _test_json_mode(
_test_stop(sync_engine)
_test_logprobs(sync_engine)
_test_logprobs_mixed_requests(sync_engine)
_test_num_sequences(sync_engine)
_test_logit_bias(sync_engine)
Copy link
Member

@masahi masahi Mar 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This tests seems to always fail when I run all tests in test_engine_with_samplers.py. But if I comment out other tests and run only this one, it works. So there is a strange issue in the tests. @Ailurus1 Can you take a look and send a fix to the ollm repo?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants