-
Notifications
You must be signed in to change notification settings - Fork 8
Enable running PyTorch models #207
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
12ce0a3
7a84f15
f454b7b
afde741
c49ef45
d9ac72f
acbf825
fef750f
25a567e
3d06f68
3cafc8b
34f77ef
afb4d4f
e7212a5
f27e3b3
2316e37
9b985e8
f2dcc48
4d73e63
15a0d3b
959019d
b6050d9
ff8eb27
8696df5
0af3a70
90ffccd
32686d8
de2631b
618ca62
9ce2f47
c14c0e9
4328440
9fb6358
0c40fe8
7d89811
0bbc41a
08a63ca
1dee091
e098d0b
02b7c1b
5cefe97
c470c36
a4612da
4564bd0
196026c
de68a84
b502654
686780c
bebd7b2
ed46b5e
62918dd
b98bdce
e144517
e6abcc7
dfbf359
04da3bb
5dfecb2
e4bbad9
ee9cdc9
2071749
15a90d0
1f56ee9
90284fa
52ad1ad
61b680e
f128fe6
568583a
762012d
72f3707
3128329
dc5fb6e
ebe0b4e
4b2de70
f09d458
c9ac5ba
f1cf274
eaa53a7
1336fb8
6186ef2
2229324
aa4d477
8bb96ed
cf0813d
f716851
992b1a0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,15 +30,13 @@ def get_gpu_memory(gpu: int = 0) -> int: | |
|
||
|
||
def get_num_cache_blocks( | ||
model, | ||
used_memory_bytes, | ||
block_size, | ||
seq_lens, | ||
num_layers, | ||
num_kv_heads, | ||
head_size, | ||
gpu_memory_utilization=0.9, # the default used by vllm | ||
): | ||
used_memory_bytes = model.profile_memory_usage(seq_lens) | ||
cache_block_size = CacheManager.get_cache_block_size( | ||
block_size, num_layers, num_kv_heads, head_size | ||
) | ||
|
@@ -85,22 +83,18 @@ def sample_from_logits( | |
requests: Sequence[RequestType], | ||
sampling_state: SamplingState, | ||
vocab_size: int, | ||
copy_stream: torch.cuda.Stream, | ||
torch_dtype: torch.dtype, | ||
torch_dev: str, | ||
past_decode_tokens: List[List[int]], | ||
prompt_masks: List[torch.Tensor], | ||
) -> List[TextGenerationResult]: | ||
batch_size = logits.shape[0] | ||
assert batch_size == len(requests) | ||
|
||
# Convert to torch tensors if logits are in tvm ndarray | ||
if isinstance(logits, tvm.nd.NDArray): | ||
logits = torch.from_dlpack(logits) | ||
|
||
# synchronization point for sampling tensors | ||
# wait until all the tensors are loaded on GPU | ||
torch.cuda.current_stream().wait_stream(copy_stream) | ||
masahi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Logit processing for constraint sampling e.g., JSON Mode | ||
for i, (sequence_id, request) in enumerate(zip(sequence_ids, requests)): | ||
if request.sampling_params.logits_processor is not None: | ||
|
@@ -140,6 +134,7 @@ def sample_from_logits( | |
" or element < 0" | ||
) | ||
logits = torch.from_dlpack(logits) | ||
|
||
for i in range(batch_size): | ||
sequence_id = sequence_ids[i] | ||
logits_per_token = logits[i] | ||
|
@@ -149,16 +144,14 @@ def sample_from_logits( | |
# NOTE: Rerun the preparation for simplicity. | ||
# Assume this code path is taken rarely and the recomputation overhead is | ||
# marginal. | ||
with torch.cuda.stream(copy_stream): | ||
new_sampling_state = SamplingState.from_sampling_params( | ||
[sampling_param], | ||
[past_decode_tokens_per_request], | ||
[prompt_mask], | ||
torch_dtype, | ||
torch_dev, | ||
vocab_size, | ||
) | ||
torch.cuda.current_stream().wait_stream(copy_stream) | ||
new_sampling_state = SamplingState.from_sampling_params( | ||
[sampling_param], | ||
[past_decode_tokens_per_request], | ||
[prompt_mask], | ||
torch_dtype, | ||
torch_dev, | ||
vocab_size, | ||
) | ||
masahi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
maybe_sampling_output: Optional[SamplingOutput] = sample( | ||
torch.unsqueeze(logits_per_token, 0), | ||
new_sampling_state, | ||
|
@@ -169,6 +162,7 @@ def sample_from_logits( | |
logprob_info = maybe_sampling_output.logprob_infos[0] | ||
# Valid sample | ||
request = requests[i] | ||
|
||
if maybe_sampling_output is not None: | ||
outputs.extend( | ||
prepare_textgen_result( | ||
|
@@ -200,24 +194,39 @@ def prepare_inputs( | |
all_decode_block_tables, | ||
sliding_window, | ||
is_prefill, | ||
block_size, | ||
num_decode_query_tokens=1, | ||
for_vllm=False, | ||
): | ||
if for_vllm: | ||
torch_int_dtype = torch.long | ||
else: | ||
torch_int_dtype = torch.int | ||
|
||
block_tables = [] | ||
seq_lens = [] | ||
input_ids = [] | ||
slot_mapping = [] | ||
positions = [] | ||
max_num_blocks_per_seq = 0 | ||
indices_within_window = [] | ||
start_idx = 0 | ||
max_prompt_len = -1 | ||
max_context_len = -1 | ||
|
||
for i, (sequence_id, token_ids) in enumerate(zip(sequence_ids, all_token_ids)): | ||
if is_prefill: | ||
input_ids += token_ids | ||
prompt_len = len(token_ids) | ||
seq_lens.append(prompt_len) | ||
positions += range(prompt_len) | ||
slot_mapping += all_slot_mappings[sequence_id] | ||
max_prompt_len = max(max_prompt_len, prompt_len) | ||
|
||
if for_vllm: | ||
input_ids.append(token_ids) | ||
positions.append(list(range(prompt_len))) | ||
slot_mapping.append(all_slot_mappings[sequence_id]) | ||
else: | ||
input_ids += token_ids | ||
positions += range(prompt_len) | ||
slot_mapping += all_slot_mappings[sequence_id] | ||
|
||
if sliding_window: | ||
indices_within_window += range( | ||
|
@@ -228,44 +237,65 @@ def prepare_inputs( | |
|
||
else: | ||
seq_len = prompt_lens[i] + len(token_ids) | ||
input_ids += token_ids[-num_decode_query_tokens:] | ||
|
||
for i in range(num_decode_query_tokens): | ||
positions.append(seq_len - (num_decode_query_tokens - i)) | ||
if for_vllm: | ||
assert num_decode_query_tokens == 1 | ||
input_ids.append([token_ids[-1]]) | ||
positions.append([seq_len - 1]) | ||
slot_mapping.append([all_slot_mappings[sequence_id][-1]]) | ||
else: | ||
input_ids += token_ids[-num_decode_query_tokens:] | ||
|
||
slot_mapping += all_slot_mappings[sequence_id][-num_decode_query_tokens:] | ||
for i in range(num_decode_query_tokens): | ||
positions.append(seq_len - (num_decode_query_tokens - i)) | ||
|
||
slot_mapping += all_slot_mappings[sequence_id][-num_decode_query_tokens:] | ||
|
||
block_table = all_decode_block_tables[sequence_id] | ||
max_num_blocks_per_seq = max(max_num_blocks_per_seq, len(block_table)) | ||
block_tables.append(block_table.get_blocks()) | ||
|
||
if sliding_window: | ||
seq_lens.append(min(seq_len, sliding_window)) | ||
else: | ||
seq_lens.append(seq_len) | ||
|
||
max_context_len = max(max_context_len, seq_lens[-1]) | ||
|
||
def _do_pad( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now that we started considering vllm's tensor layout, what do you think about unifying it? It seems like upstream mlc-llm also uses 2D inputs. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And this also could help our cuda graph integration? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We haven't verified if 2D inputs is better for performance, and how much cuda graph actually helps. The upstream input looks like 2D but it is always either There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I think it is worth visiting imo. But not now, in the future. Although there might not be performance boost, it would be nice to unify the layout with upstream unless there is reason. |
||
x: List[List[int]], | ||
max_len: int, | ||
pad_val: int, | ||
) -> List[List[int]]: | ||
def _pad_to_max(x: List[int], max_len: int, pad_val: int) -> List[int]: | ||
assert len(x) <= max_len | ||
return x + [pad_val] * (max_len - len(x)) | ||
|
||
return [_pad_to_max(x_i, max_len, pad_val) for x_i in x] | ||
|
||
if for_vllm and is_prefill: | ||
input_ids = _do_pad(input_ids, max_prompt_len, 0) | ||
positions = _do_pad(positions, max_prompt_len, 0) | ||
slot_mapping = _do_pad(slot_mapping, max_prompt_len, -1) | ||
|
||
def to_torch(arr, torch_dtype): | ||
return torch.tensor(arr, dtype=torch_dtype, device="cuda") | ||
|
||
input_ids = to_torch(input_ids, torch.int) | ||
positions = to_torch(positions, torch.int) | ||
input_ids = to_torch(input_ids, torch_int_dtype) | ||
positions = to_torch(positions, torch_int_dtype) | ||
seq_lens = to_torch(seq_lens, torch.int) | ||
slot_mapping = to_torch(slot_mapping, torch.int) | ||
slot_mapping = to_torch(slot_mapping, torch_int_dtype) | ||
|
||
if is_prefill and sliding_window: | ||
indices_within_window = to_torch(indices_within_window, torch.int) | ||
else: | ||
indices_within_window = None | ||
|
||
if not is_prefill: | ||
max_block_table_len = ( | ||
max_context_len + block_size - 1 | ||
) // block_size | ||
|
||
def _pad_to_max(x: List[int], max_len: int) -> List[int]: | ||
return x + [0] * (max_len - len(x)) | ||
|
||
padded_block_tables = [ | ||
_pad_to_max(block_table, max_num_blocks_per_seq) | ||
for block_table in block_tables | ||
] | ||
padded_block_tables = _do_pad(block_tables, max_block_table_len, 0) | ||
block_tables = to_torch(padded_block_tables, torch.int) | ||
else: | ||
block_tables = None | ||
|
Uh oh!
There was an error while loading. Please reload this page.