diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index 1925e6f3f7..b266d7c456 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -245,6 +245,7 @@ def generate( prompt_lens = [] sampling_params = [] past_decode_tokens = [] + prompt_masks = [] for request in requests: if isinstance(request, PrefillRequest): @@ -260,6 +261,7 @@ def generate( all_token_ids.append(request.token_ids) sampling_params.append(request.sampling_params) + prompt_masks.append(request.prompt_mask) selected_token_indices: List[int] = [] @@ -352,6 +354,7 @@ def generate( sampling_metadata = SamplingState.from_sampling_params( sampling_params, past_decode_tokens, + prompt_masks, torch.float32, "cuda", vocab_size, @@ -366,6 +369,7 @@ def generate( torch.float32, "cuda", past_decode_tokens, + prompt_masks, )