Skip to content

Commit b8ffa57

Browse files
committed
add eagle proposer ut
Signed-off-by: GDzhu01 <[email protected]>
1 parent 31a2c09 commit b8ffa57

File tree

1 file changed

+312
-0
lines changed

1 file changed

+312
-0
lines changed
Lines changed: 312 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,312 @@
1+
from unittest.mock import MagicMock, patch
2+
3+
import numpy as np
4+
import torch
5+
from vllm.config import CacheConfig, CompilationMode, VllmConfig
6+
7+
from tests.ut.base import TestBase
8+
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
9+
from vllm_ascend.spec_decode.interface import SpecDcodeType
10+
11+
12+
class TestEagleProposerInitialization(TestBase):
13+
14+
def setUp(self):
15+
self.vllm_config = MagicMock(spec=VllmConfig)
16+
self.vllm_config.speculative_config = MagicMock()
17+
self.vllm_config.cache_config = MagicMock(spec=CacheConfig)
18+
self.vllm_config.scheduler_config = MagicMock()
19+
self.vllm_config.model_config = MagicMock()
20+
self.device = torch.device("cpu")
21+
self.runner = MagicMock()
22+
23+
self.vllm_config.cache_config.block_size = 16
24+
self.vllm_config.scheduler_config.max_num_batched_tokens = 1024
25+
self.vllm_config.scheduler_config.max_num_seqs = 32
26+
self.vllm_config.model_config.dtype = torch.float16
27+
self.vllm_config.model_config.max_model_len = 2048
28+
29+
def test_initialization_eagle(self):
30+
self.vllm_config.speculative_config.method = "eagle"
31+
self.vllm_config.speculative_config.draft_model_config.get_hidden_size.return_value = 4096
32+
self.vllm_config.compilation_config.mode = CompilationMode.VLLM_COMPILE
33+
self.vllm_config.model_config.enforce_eager = False
34+
35+
proposer = EagleProposer(vllm_config=self.vllm_config,
36+
device=self.device,
37+
runner=self.runner)
38+
39+
self.assertEqual(proposer.name, SpecDcodeType.EAGLE)
40+
self.assertEqual(proposer.block_size, 16)
41+
self.assertEqual(proposer.hidden_size, 4096)
42+
self.assertTrue(proposer.use_cuda_graph)
43+
44+
self.assertEqual(proposer.input_ids.shape, (1024, ))
45+
self.assertEqual(proposer.positions.shape, (1024, ))
46+
self.assertEqual(proposer.hidden_states.shape, (1024, 4096))
47+
self.assertEqual(proposer.arange.shape, (33, ))
48+
49+
def test_initialization_eagle3(self):
50+
self.vllm_config.speculative_config.method = "eagle3"
51+
self.vllm_config.speculative_config.draft_model_config.get_hidden_size.return_value = 2048
52+
self.vllm_config.compilation_config.mode = CompilationMode.NONE
53+
self.vllm_config.model_config.enforce_eager = True
54+
55+
proposer = EagleProposer(vllm_config=self.vllm_config,
56+
device=self.device,
57+
runner=self.runner)
58+
59+
self.assertEqual(proposer.name, SpecDcodeType.EAGLE3)
60+
self.assertEqual(proposer.hidden_size, 2048)
61+
self.assertFalse(proposer.use_cuda_graph)
62+
self.assertEqual(proposer.hidden_states.shape, (1024, 2048))
63+
64+
65+
class TestEagleProposerLoadModel(TestBase):
66+
67+
def setUp(self):
68+
self.vllm_config = MagicMock(spec=VllmConfig)
69+
self.vllm_config.speculative_config = MagicMock()
70+
self.vllm_config.speculative_config.method = "eagle"
71+
self.device = torch.device("cpu")
72+
self.runner = MagicMock()
73+
74+
self.vllm_config.cache_config.block_size = 16
75+
self.vllm_config.scheduler_config.max_num_batched_tokens = 1024
76+
self.vllm_config.scheduler_config.max_num_seqs = 32
77+
self.vllm_config.model_config.dtype = torch.float16
78+
self.vllm_config.model_config.max_model_len = 2048
79+
80+
self.proposer = EagleProposer(vllm_config=self.vllm_config,
81+
device=self.device,
82+
runner=self.runner)
83+
84+
@patch(
85+
"vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config")
86+
@patch("vllm_ascend.spec_decode.eagle_proposer.get_model")
87+
@patch("vllm_ascend.spec_decode.eagle_proposer.get_pp_group")
88+
def test_load_model_pp1(self, mock_pp_group, mock_get_model,
89+
mock_get_layers):
90+
mock_pp_group.return_value.world_size = 1
91+
mock_target_layers = {"layer1": MagicMock(), "layer2": MagicMock()}
92+
mock_draft_layers = {"layer1": MagicMock(), "layer3": MagicMock()}
93+
mock_get_layers.side_effect = [mock_target_layers, mock_draft_layers]
94+
95+
mock_model = MagicMock()
96+
mock_model.model.embed_tokens = MagicMock()
97+
mock_model.lm_head = MagicMock()
98+
mock_get_model.return_value = MagicMock()
99+
self.proposer.name = SpecDcodeType.EAGLE
100+
101+
self.proposer.load_model(mock_model)
102+
mock_get_model.assert_called_once()
103+
self.assertEqual(self.proposer.attn_layer_name, "layer3")
104+
self.assertIs(self.proposer.model.model.embed_tokens,
105+
mock_model.model.embed_tokens)
106+
107+
@patch(
108+
"vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config")
109+
@patch("vllm_ascend.spec_decode.eagle_proposer.get_model")
110+
@patch("vllm_ascend.spec_decode.eagle_proposer.get_pp_group")
111+
def test_load_model_pp_gt1(self, mock_pp_group, mock_get_model,
112+
mock_get_layers):
113+
mock_pp_group.return_value.world_size = 2
114+
mock_target_layers = {"layer1": MagicMock()}
115+
mock_draft_layers = {"layer2": MagicMock()}
116+
mock_get_layers.side_effect = [mock_target_layers, mock_draft_layers]
117+
118+
mock_model = MagicMock()
119+
original_embed = MagicMock()
120+
mock_get_model.return_value = MagicMock(model=MagicMock(
121+
embed_tokens=original_embed))
122+
123+
self.proposer.load_model(mock_model)
124+
125+
self.assertIsNot(self.proposer.model.model.embed_tokens,
126+
mock_model.model.embed_tokens)
127+
self.assertEqual(self.proposer.attn_layer_name, "layer2")
128+
129+
@patch(
130+
"vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config")
131+
@patch("vllm_ascend.spec_decode.eagle_proposer.get_model")
132+
@patch("vllm_ascend.spec_decode.eagle_proposer.get_pp_group")
133+
@patch("vllm_ascend.spec_decode.eagle_proposer.supports_multimodal")
134+
def test_load_model_multimodal(self, mock_supports_multi, mock_pp_group,
135+
mock_get_model, mock_get_layers):
136+
mock_model = MagicMock()
137+
mock_model.get_language_model.return_value.lm_head = MagicMock()
138+
mock_supports_multi.return_value = True
139+
original_embed = MagicMock()
140+
mock_get_model.return_value = MagicMock(model=MagicMock(
141+
embed_tokens=original_embed))
142+
143+
mock_target_layers = {"layer1": MagicMock()}
144+
mock_draft_layers = {"layer2": MagicMock()}
145+
mock_get_layers.side_effect = [mock_target_layers, mock_draft_layers]
146+
mock_pp_group.return_value.world_size = 2
147+
148+
self.proposer.model = MagicMock()
149+
self.proposer.name = SpecDcodeType.EAGLE
150+
151+
self.proposer.load_model(mock_model)
152+
mock_model.get_language_model.assert_called_once()
153+
self.assertIs(self.proposer.model.lm_head,
154+
mock_model.get_language_model.return_value.lm_head)
155+
156+
157+
class TestEagleProposerDummyRun(TestBase):
158+
159+
def setUp(self):
160+
self.vllm_config = MagicMock(spec=VllmConfig)
161+
self.vllm_config.speculative_config = MagicMock()
162+
self.device = torch.device("cpu")
163+
self.runner = MagicMock()
164+
self.runner._select_moe_comm_method.return_value = "alltoall"
165+
166+
self.vllm_config.cache_config.block_size = 16
167+
self.vllm_config.scheduler_config.max_num_batched_tokens = 1024
168+
self.vllm_config.scheduler_config.max_num_seqs = 32
169+
self.vllm_config.model_config.dtype = torch.float16
170+
self.vllm_config.model_config.max_model_len = 2048
171+
172+
self.proposer = EagleProposer(vllm_config=self.vllm_config,
173+
device=self.device,
174+
runner=self.runner)
175+
self.proposer.model = MagicMock()
176+
177+
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
178+
def test_dummy_run_basic(self, mock_context):
179+
num_tokens = 32
180+
with_prefill = False
181+
182+
self.proposer.dummy_run(num_tokens=num_tokens,
183+
with_prefill=with_prefill)
184+
185+
mock_context.assert_called_once()
186+
187+
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
188+
def test_dummy_run_with_prefill(self, mock_context):
189+
mock_context.return_value.__enter__.return_value = None
190+
self.proposer.dummy_run(num_tokens=64, with_prefill=True, num_reqs=4)
191+
192+
self.runner._select_moe_comm_method.assert_called_with(64)
193+
self.proposer.model.assert_called_once()
194+
195+
196+
class TestEagleProposerGenerateTokenIds(TestBase):
197+
198+
def setUp(self):
199+
self.vllm_config = MagicMock(spec=VllmConfig)
200+
self.vllm_config.speculative_config = MagicMock()
201+
self.vllm_config.speculative_config.method = "eagle"
202+
self.device = torch.device("cpu")
203+
self.runner = MagicMock()
204+
self.runner.input_batch = MagicMock()
205+
self.runner.input_batch.req_ids = [0, 1, 2]
206+
self.runner.requests = {
207+
0: MagicMock(get_token_id=lambda x: 100),
208+
1: MagicMock(get_token_id=lambda x: 101),
209+
2: MagicMock(get_token_id=lambda x: 102),
210+
}
211+
212+
self.vllm_config.cache_config.block_size = 16
213+
self.vllm_config.scheduler_config.max_num_batched_tokens = 1024
214+
self.vllm_config.scheduler_config.max_num_seqs = 32
215+
self.vllm_config.model_config.dtype = torch.float16
216+
self.vllm_config.model_config.max_model_len = 2048
217+
218+
self.proposer = EagleProposer(vllm_config=self.vllm_config,
219+
device=self.device,
220+
runner=self.runner)
221+
self.proposer.attn_layer_name = "layer_0"
222+
self.proposer._propose = MagicMock(
223+
return_value=torch.tensor([[1, 2], [3, 4], [5, 6]]))
224+
225+
def test_generate_token_ids_without_metadata(self):
226+
valid_sampled = [[10, 20], [30], []]
227+
scheduler_output = MagicMock()
228+
scheduler_output.num_scheduled_tokens = [2, 1, 3]
229+
positions = torch.tensor([0, 1, 2, 3, 4, 5])
230+
hidden_states = torch.randn(6, 4096)
231+
num_scheduled = 6
232+
233+
mock_attn_metadata = MagicMock()
234+
mock_attn_metadata.slot_mapping = torch.tensor([0, 1, 2, 3, 4, 5])
235+
mock_attn_metadata.query_start_loc = torch.tensor([0, 2, 3, 6])
236+
mock_attn_metadata.block_tables = MagicMock()
237+
self.proposer._get_eagle_atten_dict = MagicMock(
238+
return_value={"layer_0": mock_attn_metadata})
239+
240+
result = self.proposer.generate_token_ids(
241+
valid_sampled_token_ids=valid_sampled,
242+
scheduler_output=scheduler_output,
243+
positions=positions,
244+
num_scheduled_tokens=num_scheduled,
245+
hidden_states=hidden_states,
246+
)
247+
248+
self.proposer._propose.assert_called_once()
249+
self.assertEqual(result, [[1, 2], [3, 4], [5, 6]])
250+
251+
def test_generate_token_ids_with_metadata(self):
252+
valid_sampled = [[5], [6, 7], [8, 9, 10]]
253+
spec_metadata = MagicMock()
254+
spec_metadata.num_draft_tokens = [2, 3, 4]
255+
256+
mock_attn_metadata = MagicMock()
257+
mock_attn_metadata.slot_mapping = torch.tensor([0, 1, 2, 3, 4, 5])
258+
mock_attn_metadata.query_start_loc = torch.tensor([0, 1, 3, 6])
259+
mock_attn_metadata.block_tables = MagicMock()
260+
self.proposer._get_eagle_atten_dict = MagicMock(
261+
return_value={"layer_0": mock_attn_metadata})
262+
self.proposer._prepare_inputs = MagicMock(
263+
return_value=(torch.tensor([0, 2, 5]), torch.tensor([1, 3, 5])))
264+
265+
result = self.proposer.generate_token_ids(
266+
valid_sampled_token_ids=valid_sampled,
267+
spec_decode_metadata=spec_metadata,
268+
positions=torch.randn(6, 1),
269+
hidden_states=torch.randn(6, 4096),
270+
)
271+
272+
self.proposer._prepare_inputs.assert_called_once()
273+
self.assertEqual(self.proposer._propose.call_count, 1)
274+
self.assertEqual(len(result), 3)
275+
276+
277+
class TestEagleProposerHelperMethods(TestBase):
278+
279+
def setUp(self):
280+
self.vllm_config = MagicMock(spec=VllmConfig)
281+
self.vllm_config.scheduler_config = MagicMock(max_num_seqs=3)
282+
self.device = torch.device("cpu")
283+
self.runner = MagicMock()
284+
self.runner.input_batch = MagicMock()
285+
self.runner.input_batch.req_ids = [0, 1, 2]
286+
self.runner.arange_np = np.arange(10)
287+
self.runner.input_batch.num_reqs = 3
288+
289+
self.vllm_config.cache_config.block_size = 16
290+
self.vllm_config.scheduler_config.max_num_batched_tokens = 1024
291+
self.vllm_config.scheduler_config.max_num_seqs = 32
292+
self.vllm_config.model_config.dtype = torch.float16
293+
self.vllm_config.model_config.max_model_len = 2048
294+
295+
self.proposer = EagleProposer(vllm_config=self.vllm_config,
296+
device=self.device,
297+
runner=self.runner)
298+
299+
def test_prepare_inputs(self):
300+
self.proposer.token_arange_np = np.arange(10)
301+
mock_attn = MagicMock()
302+
mock_attn.slot_mapping = torch.tensor([0, 1, 2, 3, 4, 5])
303+
num_rejected = torch.tensor([1, 0, 1], device=self.device)
304+
305+
with patch.object(self.proposer,
306+
'_prepare_inputs',
307+
return_value=(torch.tensor([0, 2, 5]),
308+
torch.tensor([1, 2, 4]))):
309+
cu_num_tokens, indices = self.proposer._prepare_inputs(
310+
mock_attn, num_rejected)
311+
self.assertEqual(cu_num_tokens.tolist(), [0, 2, 5])
312+
self.assertEqual(indices.tolist(), [1, 2, 4])

0 commit comments

Comments
 (0)