Skip to content

Commit

Permalink
mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Jan 11, 2024
1 parent 3a9f6d6 commit 9fb9261
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 7 deletions.
6 changes: 4 additions & 2 deletions serve/mlc_serve/engine/engine_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,10 @@ def update_sequence(

def get_requests_to_process(
current_states: list[RequestState], cache_manager: KVCacheManager
) -> Tuple[list[Union[PrefillRequest, DecodeRequest]], bool, int]:
requests: list[Union[PrefillRequest, DecodeRequest]] = []
) -> Tuple[
list[Union[PrefillRequest, DecodeRequest, MultiQueryDecodeRequest]], bool, int
]:
requests: list[Union[PrefillRequest, DecodeRequest, MultiQueryDecodeRequest]] = []
# TODO: consider having hybrid batch if the underlying attention kernel supports
# mixing prefill and decode.
is_prompt_batch = any(not state.is_prefilled for state in current_states)
Expand Down
2 changes: 1 addition & 1 deletion serve/mlc_serve/engine/model_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class TextGenerator(Protocol):

def generate(
self,
requests: List[Union[PrefillRequest, DecodeRequest]],
requests: List[Union[PrefillRequest, DecodeRequest, MultiQueryDecodeRequest]],
kv_cache: KVCache,
) -> List[TextGenerationResult]:
"""
Expand Down
8 changes: 4 additions & 4 deletions serve/mlc_serve/model/paged_cache_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ def sample_from_logits(
requests[i].sampling_params.appeared_tokens_freq[new_token] += 1
if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX:
assert isinstance(requests[i], PrefillRequest)
for seq_id in range(requests[i].num_sequence):
for seq_id in range(requests[i].num_sequence): # type: ignore
outputs.append(
TextGenerationResult(
sequence_id=SequenceId(sequence_id.request_id, seq_id),
Expand Down Expand Up @@ -532,7 +532,7 @@ def sample_from_logits(
requests[i].sampling_params.appeared_tokens_freq[new_token] += 1
if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX:
assert isinstance(requests[i], PrefillRequest)
for seq_id in range(requests[i].num_sequence):
for seq_id in range(requests[i].num_sequence): # type: ignore
outputs.append(
TextGenerationResult(
sequence_id=SequenceId(
Expand All @@ -553,7 +553,7 @@ def sample_from_logits(
else:
if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX:
assert isinstance(requests[i], PrefillRequest)
for seq_id in range(requests[i].num_sequence):
for seq_id in range(requests[i].num_sequence): # type: ignore
outputs.append(
TextGenerationResult(
sequence_id=SequenceId(
Expand All @@ -578,7 +578,7 @@ def generate_multi_query(
self, requests: List[MultiQueryDecodeRequest], cache: KVCache
) -> List[TextGenerationResult]:
sequence_ids = []
last_query_offsets = []
last_query_offsets: List[int] = []
for request in requests:
assert not isinstance(request.queries, DraftTokens)
sequence_ids.append(request.sequence_id)
Expand Down

0 comments on commit 9fb9261

Please sign in to comment.