Skip to content

Commit 7a3b6b6

Browse files
cleanup
Signed-off-by: Lucas Wilkinson <[email protected]>
1 parent 6802cab commit 7a3b6b6

File tree

4 files changed

+111
-172
lines changed

4 files changed

+111
-172
lines changed

vllm/model_executor/layers/fused_moe/modular_kernel.py

Lines changed: 11 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
dbo_register_recv_hook,
2525
dbo_yield,
2626
)
27-
from vllm.v1.worker.workspace import WorkspaceSpec, current_workspace_manager
27+
from vllm.v1.worker.workspace import current_workspace_manager
2828

2929
#
3030
# This file defines a set of base classes used to make MoE kernels more modular.
@@ -766,48 +766,31 @@ def _allocate_buffers(
766766
local_num_experts,
767767
None, # Pass None to avoid using sampled token counts
768768
)
769-
max_workspace13_spec = WorkspaceSpec(
770-
shape=max_workspace13_shape,
771-
dtype=workspace_dtype,
772-
name="moe.workspace13",
773-
)
774-
max_workspace2_spec = WorkspaceSpec(
775-
shape=max_workspace2_shape,
776-
dtype=workspace_dtype,
777-
name="moe.workspace2",
778-
)
779-
max_fused_out_spec = WorkspaceSpec(
780-
shape=max_fused_out_shape, dtype=out_dtype, name="moe.fused_out"
781-
)
782-
current_workspace_manager().reserve_simultaneous(
783-
max_workspace13_spec, max_workspace2_spec, max_fused_out_spec
769+
770+
current_workspace_manager().get_simultaneous(
771+
(max_workspace13_shape, workspace_dtype),
772+
(max_workspace2_shape, workspace_dtype),
773+
(max_fused_out_shape, out_dtype),
784774
)
785775

786776
# We can reuse the memory between cache1 and cache3 because by the
787777
# time we need cache3, we're done with cache1.
788-
workspace13_spec = WorkspaceSpec(
789-
shape=workspace13_shape, dtype=workspace_dtype, name="moe.workspace13"
790-
)
791-
workspace2_spec = WorkspaceSpec(
792-
shape=workspace2_shape, dtype=workspace_dtype, name="moe.workspace2"
793-
)
794-
795778
# Construct the entire output that can then be processed in chunks.
796779
# Reuse workspace13 for the output in the non-chunked case as long
797780
# as it is large enough. This will not always be the case for standard
798781
# format experts and with experts that have empty workspaces.
799782
if num_chunks == 1 and prod(workspace13_shape) >= prod(fused_out_shape):
800783
workspace13, workspace2 = current_workspace_manager().get_simultaneous(
801-
workspace13_spec, workspace2_spec
784+
(workspace13_shape, workspace_dtype),
785+
(workspace2_shape, workspace_dtype),
802786
)
803787
fused_out = _resize_cache(workspace13, fused_out_shape)
804788
else:
805-
fused_out_spec = WorkspaceSpec(
806-
shape=fused_out_shape, dtype=out_dtype, name="moe.fused_out"
807-
)
808789
workspace13, workspace2, fused_out = (
809790
current_workspace_manager().get_simultaneous(
810-
workspace13_spec, workspace2_spec, fused_out_spec
791+
(workspace13_shape, workspace_dtype),
792+
(workspace2_shape, workspace_dtype),
793+
(fused_out_shape, out_dtype),
811794
)
812795
)
813796

vllm/model_executor/models/deepseek_v2.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@
8686
DeepseekV32IndexerMetadata,
8787
)
8888
from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec
89-
from vllm.v1.worker.workspace import WorkspaceSpec, current_workspace_manager
89+
from vllm.v1.worker.workspace import current_workspace_manager
9090

9191
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
9292
from .utils import (
@@ -520,20 +520,13 @@ def sparse_attn_indexer(
520520
# careful! this will be None in dummy run
521521
attn_metadata = get_forward_context().attn_metadata
522522

523-
k_fp8_spec = WorkspaceSpec(
524-
shape=(total_seq_lens, head_dim),
525-
dtype=torch.float8_e4m3fn,
526-
name="sparse_attn_indexer.k_fp8",
527-
)
528-
k_scale_spec = WorkspaceSpec(
529-
shape=(total_seq_lens, 4),
530-
dtype=torch.uint8,
531-
name="sparse_attn_indexer.k_scale",
532-
)
533-
534523
# assert isinstance(attn_metadata, dict)
535524
if not isinstance(attn_metadata, dict):
536-
current_workspace_manager().reserve_simultaneous(k_fp8_spec, k_scale_spec)
525+
# Reserve workspace for indexer during profiling run
526+
current_workspace_manager().get_simultaneous(
527+
((total_seq_lens, head_dim), torch.float8_e4m3fn),
528+
((total_seq_lens, 4), torch.uint8),
529+
)
537530

538531
return sparse_attn_indexer_fake(
539532
hidden_states,
@@ -572,7 +565,8 @@ def sparse_attn_indexer(
572565
# Get the full shared workspace buffers once (will allocate on first use)
573566
workspace_manager = current_workspace_manager()
574567
k_fp8_full, k_scale_full = workspace_manager.get_simultaneous(
575-
k_fp8_spec, k_scale_spec
568+
((total_seq_lens, head_dim), torch.float8_e4m3fn),
569+
((total_seq_lens, 4), torch.uint8),
576570
)
577571

578572
for chunk in prefill_metadata.chunks:

vllm/v1/attention/backends/mla/flashmla_sparse.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
split_decodes_and_prefills,
3232
)
3333
from vllm.v1.kv_cache_interface import AttentionSpec
34-
from vllm.v1.worker.workspace import WorkspaceSpec, current_workspace_manager
34+
from vllm.v1.worker.workspace import current_workspace_manager
3535

3636
if TYPE_CHECKING:
3737
from vllm.model_executor.models.deepseek_v2 import Indexer
@@ -636,14 +636,13 @@ def __init__(
636636
vllm_config = get_current_vllm_config()
637637
prefill_workspace_size = get_prefill_workspace_size(vllm_config)
638638

639-
self.prefill_workspace_spec = WorkspaceSpec(
640-
shape=(prefill_workspace_size, head_size),
641-
dtype=torch.bfloat16,
642-
name="FlashMLASparseImpl.prefill_workspace",
643-
)
639+
self.prefill_workspace_shape = (prefill_workspace_size, head_size)
644640

645641
if kv_cache_dtype == "fp8_ds_mla":
646-
current_workspace_manager().reserve(self.prefill_workspace_spec)
642+
# Reserve workspace during initialization
643+
current_workspace_manager().get(
644+
self.prefill_workspace_shape, torch.bfloat16
645+
)
647646

648647
def _forward_bf16_kv(
649648
self,
@@ -810,7 +809,7 @@ def forward(
810809
# Process prefill chunks
811810
assert attn_metadata.prefill_chunks is not None
812811
prefill_bf16_workspace = current_workspace_manager().get(
813-
self.prefill_workspace_spec
812+
self.prefill_workspace_shape, torch.bfloat16
814813
)
815814

816815
for chunk in attn_metadata.prefill_chunks:

0 commit comments

Comments
 (0)