- 
                Notifications
    You must be signed in to change notification settings 
- Fork 8
Enable running PyTorch models #207
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
12ce0a3
              7a84f15
              f454b7b
              afde741
              c49ef45
              d9ac72f
              acbf825
              fef750f
              25a567e
              3d06f68
              3cafc8b
              34f77ef
              afb4d4f
              e7212a5
              f27e3b3
              2316e37
              9b985e8
              f2dcc48
              4d73e63
              15a0d3b
              959019d
              b6050d9
              ff8eb27
              8696df5
              0af3a70
              90ffccd
              32686d8
              de2631b
              618ca62
              9ce2f47
              c14c0e9
              4328440
              9fb6358
              0c40fe8
              7d89811
              0bbc41a
              08a63ca
              1dee091
              e098d0b
              02b7c1b
              5cefe97
              c470c36
              a4612da
              4564bd0
              196026c
              de68a84
              b502654
              686780c
              bebd7b2
              ed46b5e
              62918dd
              b98bdce
              e144517
              e6abcc7
              dfbf359
              04da3bb
              5dfecb2
              e4bbad9
              ee9cdc9
              2071749
              15a90d0
              1f56ee9
              90284fa
              52ad1ad
              61b680e
              f128fe6
              568583a
              762012d
              72f3707
              3128329
              dc5fb6e
              ebe0b4e
              4b2de70
              f09d458
              c9ac5ba
              f1cf274
              eaa53a7
              1336fb8
              6186ef2
              2229324
              aa4d477
              8bb96ed
              cf0813d
              f716851
              992b1a0
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -30,15 +30,13 @@ def get_gpu_memory(gpu: int = 0) -> int: | |
|  | ||
|  | ||
| def get_num_cache_blocks( | ||
| model, | ||
| used_memory_bytes, | ||
| block_size, | ||
| seq_lens, | ||
| num_layers, | ||
| num_kv_heads, | ||
| head_size, | ||
| gpu_memory_utilization=0.9, # the default used by vllm | ||
| ): | ||
| used_memory_bytes = model.profile_memory_usage(seq_lens) | ||
| cache_block_size = CacheManager.get_cache_block_size( | ||
| block_size, num_layers, num_kv_heads, head_size | ||
| ) | ||
|  | @@ -85,22 +83,18 @@ def sample_from_logits( | |
| requests: Sequence[RequestType], | ||
| sampling_state: SamplingState, | ||
| vocab_size: int, | ||
| copy_stream: torch.cuda.Stream, | ||
| torch_dtype: torch.dtype, | ||
| torch_dev: str, | ||
| past_decode_tokens: List[List[int]], | ||
| prompt_masks: List[torch.Tensor], | ||
| ) -> List[TextGenerationResult]: | ||
| batch_size = logits.shape[0] | ||
| assert batch_size == len(requests) | ||
|  | ||
| # Convert to torch tensors if logits are in tvm ndarray | ||
| if isinstance(logits, tvm.nd.NDArray): | ||
| logits = torch.from_dlpack(logits) | ||
|  | ||
| # synchronization point for sampling tensors | ||
| # wait until all the tensors are loaded on GPU | ||
| torch.cuda.current_stream().wait_stream(copy_stream) | ||
|         
                  masahi marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
|  | ||
| # Logit processing for constraint sampling e.g., JSON Mode | ||
| for i, (sequence_id, request) in enumerate(zip(sequence_ids, requests)): | ||
| if request.sampling_params.logits_processor is not None: | ||
|  | @@ -140,6 +134,7 @@ def sample_from_logits( | |
| " or element < 0" | ||
| ) | ||
| logits = torch.from_dlpack(logits) | ||
|  | ||
| for i in range(batch_size): | ||
| sequence_id = sequence_ids[i] | ||
| logits_per_token = logits[i] | ||
|  | @@ -149,16 +144,14 @@ def sample_from_logits( | |
| # NOTE: Rerun the preparation for simplicity. | ||
| # Assume this code path is taken rarely and the recomputation overhead is | ||
| # marginal. | ||
| with torch.cuda.stream(copy_stream): | ||
| new_sampling_state = SamplingState.from_sampling_params( | ||
| [sampling_param], | ||
| [past_decode_tokens_per_request], | ||
| [prompt_mask], | ||
| torch_dtype, | ||
| torch_dev, | ||
| vocab_size, | ||
| ) | ||
| torch.cuda.current_stream().wait_stream(copy_stream) | ||
| new_sampling_state = SamplingState.from_sampling_params( | ||
| [sampling_param], | ||
| [past_decode_tokens_per_request], | ||
| [prompt_mask], | ||
| torch_dtype, | ||
| torch_dev, | ||
| vocab_size, | ||
| ) | ||
|         
                  masahi marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
| maybe_sampling_output: Optional[SamplingOutput] = sample( | ||
| torch.unsqueeze(logits_per_token, 0), | ||
| new_sampling_state, | ||
|  | @@ -169,6 +162,7 @@ def sample_from_logits( | |
| logprob_info = maybe_sampling_output.logprob_infos[0] | ||
| # Valid sample | ||
| request = requests[i] | ||
|  | ||
| if maybe_sampling_output is not None: | ||
| outputs.extend( | ||
| prepare_textgen_result( | ||
|  | @@ -200,24 +194,39 @@ def prepare_inputs( | |
| all_decode_block_tables, | ||
| sliding_window, | ||
| is_prefill, | ||
| block_size, | ||
| num_decode_query_tokens=1, | ||
| for_vllm=False, | ||
| ): | ||
| if for_vllm: | ||
| torch_int_dtype = torch.long | ||
| else: | ||
| torch_int_dtype = torch.int | ||
|  | ||
| block_tables = [] | ||
| seq_lens = [] | ||
| input_ids = [] | ||
| slot_mapping = [] | ||
| positions = [] | ||
| max_num_blocks_per_seq = 0 | ||
| indices_within_window = [] | ||
| start_idx = 0 | ||
| max_prompt_len = -1 | ||
| max_context_len = -1 | ||
|  | ||
| for i, (sequence_id, token_ids) in enumerate(zip(sequence_ids, all_token_ids)): | ||
| if is_prefill: | ||
| input_ids += token_ids | ||
| prompt_len = len(token_ids) | ||
| seq_lens.append(prompt_len) | ||
| positions += range(prompt_len) | ||
| slot_mapping += all_slot_mappings[sequence_id] | ||
| max_prompt_len = max(max_prompt_len, prompt_len) | ||
|  | ||
| if for_vllm: | ||
| input_ids.append(token_ids) | ||
| positions.append(list(range(prompt_len))) | ||
| slot_mapping.append(all_slot_mappings[sequence_id]) | ||
| else: | ||
| input_ids += token_ids | ||
| positions += range(prompt_len) | ||
| slot_mapping += all_slot_mappings[sequence_id] | ||
|  | ||
| if sliding_window: | ||
| indices_within_window += range( | ||
|  | @@ -228,44 +237,65 @@ def prepare_inputs( | |
|  | ||
| else: | ||
| seq_len = prompt_lens[i] + len(token_ids) | ||
| input_ids += token_ids[-num_decode_query_tokens:] | ||
|  | ||
| for i in range(num_decode_query_tokens): | ||
| positions.append(seq_len - (num_decode_query_tokens - i)) | ||
| if for_vllm: | ||
| assert num_decode_query_tokens == 1 | ||
| input_ids.append([token_ids[-1]]) | ||
| positions.append([seq_len - 1]) | ||
| slot_mapping.append([all_slot_mappings[sequence_id][-1]]) | ||
| else: | ||
| input_ids += token_ids[-num_decode_query_tokens:] | ||
|  | ||
| slot_mapping += all_slot_mappings[sequence_id][-num_decode_query_tokens:] | ||
| for i in range(num_decode_query_tokens): | ||
| positions.append(seq_len - (num_decode_query_tokens - i)) | ||
|  | ||
| slot_mapping += all_slot_mappings[sequence_id][-num_decode_query_tokens:] | ||
|  | ||
| block_table = all_decode_block_tables[sequence_id] | ||
| max_num_blocks_per_seq = max(max_num_blocks_per_seq, len(block_table)) | ||
| block_tables.append(block_table.get_blocks()) | ||
|  | ||
| if sliding_window: | ||
| seq_lens.append(min(seq_len, sliding_window)) | ||
| else: | ||
| seq_lens.append(seq_len) | ||
|  | ||
| max_context_len = max(max_context_len, seq_lens[-1]) | ||
|  | ||
| def _do_pad( | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now that we started considering vllm's tensor layout, what do you think about unifying it? It seems like upstream mlc-llm also uses 2D inputs. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And this also could help our cuda graph integration? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We haven't verified if 2D inputs is better for performance, and how much cuda graph actually helps. The upstream input looks like 2D but it is always either  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I think it is worth visiting imo. But not now, in the future. Although there might not be performance boost, it would be nice to unify the layout with upstream unless there is reason. | ||
| x: List[List[int]], | ||
| max_len: int, | ||
| pad_val: int, | ||
| ) -> List[List[int]]: | ||
| def _pad_to_max(x: List[int], max_len: int, pad_val: int) -> List[int]: | ||
| assert len(x) <= max_len | ||
| return x + [pad_val] * (max_len - len(x)) | ||
|  | ||
| return [_pad_to_max(x_i, max_len, pad_val) for x_i in x] | ||
|  | ||
| if for_vllm and is_prefill: | ||
| input_ids = _do_pad(input_ids, max_prompt_len, 0) | ||
| positions = _do_pad(positions, max_prompt_len, 0) | ||
| slot_mapping = _do_pad(slot_mapping, max_prompt_len, -1) | ||
|  | ||
| def to_torch(arr, torch_dtype): | ||
| return torch.tensor(arr, dtype=torch_dtype, device="cuda") | ||
|  | ||
| input_ids = to_torch(input_ids, torch.int) | ||
| positions = to_torch(positions, torch.int) | ||
| input_ids = to_torch(input_ids, torch_int_dtype) | ||
| positions = to_torch(positions, torch_int_dtype) | ||
| seq_lens = to_torch(seq_lens, torch.int) | ||
| slot_mapping = to_torch(slot_mapping, torch.int) | ||
| slot_mapping = to_torch(slot_mapping, torch_int_dtype) | ||
|  | ||
| if is_prefill and sliding_window: | ||
| indices_within_window = to_torch(indices_within_window, torch.int) | ||
| else: | ||
| indices_within_window = None | ||
|  | ||
| if not is_prefill: | ||
| max_block_table_len = ( | ||
| max_context_len + block_size - 1 | ||
| ) // block_size | ||
|  | ||
| def _pad_to_max(x: List[int], max_len: int) -> List[int]: | ||
| return x + [0] * (max_len - len(x)) | ||
|  | ||
| padded_block_tables = [ | ||
| _pad_to_max(block_table, max_num_blocks_per_seq) | ||
| for block_table in block_tables | ||
| ] | ||
| padded_block_tables = _do_pad(block_tables, max_block_table_len, 0) | ||
| block_tables = to_torch(padded_block_tables, torch.int) | ||
| else: | ||
| block_tables = None | ||
|  | ||
Uh oh!
There was an error while loading. Please reload this page.