|
| 1 | +from unittest.mock import MagicMock, patch |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +import torch |
| 5 | +from vllm.config import CacheConfig, CompilationMode, VllmConfig |
| 6 | + |
| 7 | +from tests.ut.base import TestBase |
| 8 | +from vllm_ascend.spec_decode.eagle_proposer import EagleProposer |
| 9 | +from vllm_ascend.spec_decode.interface import SpecDcodeType |
| 10 | + |
| 11 | + |
| 12 | +class TestEagleProposerInitialization(TestBase): |
| 13 | + |
| 14 | + def setUp(self): |
| 15 | + self.vllm_config = MagicMock(spec=VllmConfig) |
| 16 | + self.vllm_config.speculative_config = MagicMock() |
| 17 | + self.vllm_config.cache_config = MagicMock(spec=CacheConfig) |
| 18 | + self.vllm_config.scheduler_config = MagicMock() |
| 19 | + self.vllm_config.model_config = MagicMock() |
| 20 | + self.device = torch.device("cpu") |
| 21 | + self.runner = MagicMock() |
| 22 | + |
| 23 | + self.vllm_config.cache_config.block_size = 16 |
| 24 | + self.vllm_config.scheduler_config.max_num_batched_tokens = 1024 |
| 25 | + self.vllm_config.scheduler_config.max_num_seqs = 32 |
| 26 | + self.vllm_config.model_config.dtype = torch.float16 |
| 27 | + self.vllm_config.model_config.max_model_len = 2048 |
| 28 | + |
| 29 | + def test_initialization_eagle(self): |
| 30 | + self.vllm_config.speculative_config.method = "eagle" |
| 31 | + self.vllm_config.speculative_config.draft_model_config.get_hidden_size.return_value = 4096 |
| 32 | + self.vllm_config.compilation_config.mode = CompilationMode.VLLM_COMPILE |
| 33 | + self.vllm_config.model_config.enforce_eager = False |
| 34 | + |
| 35 | + proposer = EagleProposer(vllm_config=self.vllm_config, |
| 36 | + device=self.device, |
| 37 | + runner=self.runner) |
| 38 | + |
| 39 | + self.assertEqual(proposer.name, SpecDcodeType.EAGLE) |
| 40 | + self.assertEqual(proposer.block_size, 16) |
| 41 | + self.assertEqual(proposer.hidden_size, 4096) |
| 42 | + self.assertTrue(proposer.use_cuda_graph) |
| 43 | + |
| 44 | + self.assertEqual(proposer.input_ids.shape, (1024, )) |
| 45 | + self.assertEqual(proposer.positions.shape, (1024, )) |
| 46 | + self.assertEqual(proposer.hidden_states.shape, (1024, 4096)) |
| 47 | + self.assertEqual(proposer.arange.shape, (33, )) |
| 48 | + |
| 49 | + def test_initialization_eagle3(self): |
| 50 | + self.vllm_config.speculative_config.method = "eagle3" |
| 51 | + self.vllm_config.speculative_config.draft_model_config.get_hidden_size.return_value = 2048 |
| 52 | + self.vllm_config.compilation_config.mode = CompilationMode.NONE |
| 53 | + self.vllm_config.model_config.enforce_eager = True |
| 54 | + |
| 55 | + proposer = EagleProposer(vllm_config=self.vllm_config, |
| 56 | + device=self.device, |
| 57 | + runner=self.runner) |
| 58 | + |
| 59 | + self.assertEqual(proposer.name, SpecDcodeType.EAGLE3) |
| 60 | + self.assertEqual(proposer.hidden_size, 2048) |
| 61 | + self.assertFalse(proposer.use_cuda_graph) |
| 62 | + self.assertEqual(proposer.hidden_states.shape, (1024, 2048)) |
| 63 | + |
| 64 | + |
| 65 | +class TestEagleProposerLoadModel(TestBase): |
| 66 | + |
| 67 | + def setUp(self): |
| 68 | + self.vllm_config = MagicMock(spec=VllmConfig) |
| 69 | + self.vllm_config.speculative_config = MagicMock() |
| 70 | + self.vllm_config.speculative_config.method = "eagle" |
| 71 | + self.device = torch.device("cpu") |
| 72 | + self.runner = MagicMock() |
| 73 | + |
| 74 | + self.vllm_config.cache_config.block_size = 16 |
| 75 | + self.vllm_config.scheduler_config.max_num_batched_tokens = 1024 |
| 76 | + self.vllm_config.scheduler_config.max_num_seqs = 32 |
| 77 | + self.vllm_config.model_config.dtype = torch.float16 |
| 78 | + self.vllm_config.model_config.max_model_len = 2048 |
| 79 | + |
| 80 | + self.proposer = EagleProposer(vllm_config=self.vllm_config, |
| 81 | + device=self.device, |
| 82 | + runner=self.runner) |
| 83 | + |
| 84 | + @patch( |
| 85 | + "vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config") |
| 86 | + @patch("vllm_ascend.spec_decode.eagle_proposer.get_model") |
| 87 | + @patch("vllm_ascend.spec_decode.eagle_proposer.get_pp_group") |
| 88 | + def test_load_model_pp1(self, mock_pp_group, mock_get_model, |
| 89 | + mock_get_layers): |
| 90 | + mock_pp_group.return_value.world_size = 1 |
| 91 | + mock_target_layers = {"layer1": MagicMock(), "layer2": MagicMock()} |
| 92 | + mock_draft_layers = {"layer1": MagicMock(), "layer3": MagicMock()} |
| 93 | + mock_get_layers.side_effect = [mock_target_layers, mock_draft_layers] |
| 94 | + |
| 95 | + mock_model = MagicMock() |
| 96 | + mock_model.model.embed_tokens = MagicMock() |
| 97 | + mock_model.lm_head = MagicMock() |
| 98 | + mock_get_model.return_value = MagicMock() |
| 99 | + self.proposer.name = SpecDcodeType.EAGLE |
| 100 | + |
| 101 | + self.proposer.load_model(mock_model) |
| 102 | + mock_get_model.assert_called_once() |
| 103 | + self.assertEqual(self.proposer.attn_layer_name, "layer3") |
| 104 | + self.assertIs(self.proposer.model.model.embed_tokens, |
| 105 | + mock_model.model.embed_tokens) |
| 106 | + |
| 107 | + @patch( |
| 108 | + "vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config") |
| 109 | + @patch("vllm_ascend.spec_decode.eagle_proposer.get_model") |
| 110 | + @patch("vllm_ascend.spec_decode.eagle_proposer.get_pp_group") |
| 111 | + def test_load_model_pp_gt1(self, mock_pp_group, mock_get_model, |
| 112 | + mock_get_layers): |
| 113 | + mock_pp_group.return_value.world_size = 2 |
| 114 | + mock_target_layers = {"layer1": MagicMock()} |
| 115 | + mock_draft_layers = {"layer2": MagicMock()} |
| 116 | + mock_get_layers.side_effect = [mock_target_layers, mock_draft_layers] |
| 117 | + |
| 118 | + mock_model = MagicMock() |
| 119 | + original_embed = MagicMock() |
| 120 | + mock_get_model.return_value = MagicMock(model=MagicMock( |
| 121 | + embed_tokens=original_embed)) |
| 122 | + |
| 123 | + self.proposer.load_model(mock_model) |
| 124 | + |
| 125 | + self.assertIsNot(self.proposer.model.model.embed_tokens, |
| 126 | + mock_model.model.embed_tokens) |
| 127 | + self.assertEqual(self.proposer.attn_layer_name, "layer2") |
| 128 | + |
| 129 | + @patch( |
| 130 | + "vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config") |
| 131 | + @patch("vllm_ascend.spec_decode.eagle_proposer.get_model") |
| 132 | + @patch("vllm_ascend.spec_decode.eagle_proposer.get_pp_group") |
| 133 | + @patch("vllm_ascend.spec_decode.eagle_proposer.supports_multimodal") |
| 134 | + def test_load_model_multimodal(self, mock_supports_multi, mock_pp_group, |
| 135 | + mock_get_model, mock_get_layers): |
| 136 | + mock_model = MagicMock() |
| 137 | + mock_model.get_language_model.return_value.lm_head = MagicMock() |
| 138 | + mock_supports_multi.return_value = True |
| 139 | + original_embed = MagicMock() |
| 140 | + mock_get_model.return_value = MagicMock(model=MagicMock( |
| 141 | + embed_tokens=original_embed)) |
| 142 | + |
| 143 | + mock_target_layers = {"layer1": MagicMock()} |
| 144 | + mock_draft_layers = {"layer2": MagicMock()} |
| 145 | + mock_get_layers.side_effect = [mock_target_layers, mock_draft_layers] |
| 146 | + mock_pp_group.return_value.world_size = 2 |
| 147 | + |
| 148 | + self.proposer.model = MagicMock() |
| 149 | + self.proposer.name = SpecDcodeType.EAGLE |
| 150 | + |
| 151 | + self.proposer.load_model(mock_model) |
| 152 | + mock_model.get_language_model.assert_called_once() |
| 153 | + self.assertIs(self.proposer.model.lm_head, |
| 154 | + mock_model.get_language_model.return_value.lm_head) |
| 155 | + |
| 156 | + |
| 157 | +class TestEagleProposerDummyRun(TestBase): |
| 158 | + |
| 159 | + def setUp(self): |
| 160 | + self.vllm_config = MagicMock(spec=VllmConfig) |
| 161 | + self.vllm_config.speculative_config = MagicMock() |
| 162 | + self.device = torch.device("cpu") |
| 163 | + self.runner = MagicMock() |
| 164 | + self.runner._select_moe_comm_method.return_value = "alltoall" |
| 165 | + |
| 166 | + self.vllm_config.cache_config.block_size = 16 |
| 167 | + self.vllm_config.scheduler_config.max_num_batched_tokens = 1024 |
| 168 | + self.vllm_config.scheduler_config.max_num_seqs = 32 |
| 169 | + self.vllm_config.model_config.dtype = torch.float16 |
| 170 | + self.vllm_config.model_config.max_model_len = 2048 |
| 171 | + |
| 172 | + self.proposer = EagleProposer(vllm_config=self.vllm_config, |
| 173 | + device=self.device, |
| 174 | + runner=self.runner) |
| 175 | + self.proposer.model = MagicMock() |
| 176 | + |
| 177 | + @patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context") |
| 178 | + def test_dummy_run_basic(self, mock_context): |
| 179 | + num_tokens = 32 |
| 180 | + with_prefill = False |
| 181 | + |
| 182 | + self.proposer.dummy_run(num_tokens=num_tokens, |
| 183 | + with_prefill=with_prefill) |
| 184 | + |
| 185 | + mock_context.assert_called_once() |
| 186 | + |
| 187 | + @patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context") |
| 188 | + def test_dummy_run_with_prefill(self, mock_context): |
| 189 | + mock_context.return_value.__enter__.return_value = None |
| 190 | + self.proposer.dummy_run(num_tokens=64, with_prefill=True, num_reqs=4) |
| 191 | + |
| 192 | + self.runner._select_moe_comm_method.assert_called_with(64) |
| 193 | + self.proposer.model.assert_called_once() |
| 194 | + |
| 195 | + |
| 196 | +class TestEagleProposerGenerateTokenIds(TestBase): |
| 197 | + |
| 198 | + def setUp(self): |
| 199 | + self.vllm_config = MagicMock(spec=VllmConfig) |
| 200 | + self.vllm_config.speculative_config = MagicMock() |
| 201 | + self.vllm_config.speculative_config.method = "eagle" |
| 202 | + self.device = torch.device("cpu") |
| 203 | + self.runner = MagicMock() |
| 204 | + self.runner.input_batch = MagicMock() |
| 205 | + self.runner.input_batch.req_ids = [0, 1, 2] |
| 206 | + self.runner.requests = { |
| 207 | + 0: MagicMock(get_token_id=lambda x: 100), |
| 208 | + 1: MagicMock(get_token_id=lambda x: 101), |
| 209 | + 2: MagicMock(get_token_id=lambda x: 102), |
| 210 | + } |
| 211 | + |
| 212 | + self.vllm_config.cache_config.block_size = 16 |
| 213 | + self.vllm_config.scheduler_config.max_num_batched_tokens = 1024 |
| 214 | + self.vllm_config.scheduler_config.max_num_seqs = 32 |
| 215 | + self.vllm_config.model_config.dtype = torch.float16 |
| 216 | + self.vllm_config.model_config.max_model_len = 2048 |
| 217 | + |
| 218 | + self.proposer = EagleProposer(vllm_config=self.vllm_config, |
| 219 | + device=self.device, |
| 220 | + runner=self.runner) |
| 221 | + self.proposer.attn_layer_name = "layer_0" |
| 222 | + self.proposer._propose = MagicMock( |
| 223 | + return_value=torch.tensor([[1, 2], [3, 4], [5, 6]])) |
| 224 | + |
| 225 | + def test_generate_token_ids_without_metadata(self): |
| 226 | + valid_sampled = [[10, 20], [30], []] |
| 227 | + scheduler_output = MagicMock() |
| 228 | + scheduler_output.num_scheduled_tokens = [2, 1, 3] |
| 229 | + positions = torch.tensor([0, 1, 2, 3, 4, 5]) |
| 230 | + hidden_states = torch.randn(6, 4096) |
| 231 | + num_scheduled = 6 |
| 232 | + |
| 233 | + mock_attn_metadata = MagicMock() |
| 234 | + mock_attn_metadata.slot_mapping = torch.tensor([0, 1, 2, 3, 4, 5]) |
| 235 | + mock_attn_metadata.query_start_loc = torch.tensor([0, 2, 3, 6]) |
| 236 | + mock_attn_metadata.block_tables = MagicMock() |
| 237 | + self.proposer._get_eagle_atten_dict = MagicMock( |
| 238 | + return_value={"layer_0": mock_attn_metadata}) |
| 239 | + |
| 240 | + result = self.proposer.generate_token_ids( |
| 241 | + valid_sampled_token_ids=valid_sampled, |
| 242 | + scheduler_output=scheduler_output, |
| 243 | + positions=positions, |
| 244 | + num_scheduled_tokens=num_scheduled, |
| 245 | + hidden_states=hidden_states, |
| 246 | + ) |
| 247 | + |
| 248 | + self.proposer._propose.assert_called_once() |
| 249 | + self.assertEqual(result, [[1, 2], [3, 4], [5, 6]]) |
| 250 | + |
| 251 | + def test_generate_token_ids_with_metadata(self): |
| 252 | + valid_sampled = [[5], [6, 7], [8, 9, 10]] |
| 253 | + spec_metadata = MagicMock() |
| 254 | + spec_metadata.num_draft_tokens = [2, 3, 4] |
| 255 | + |
| 256 | + mock_attn_metadata = MagicMock() |
| 257 | + mock_attn_metadata.slot_mapping = torch.tensor([0, 1, 2, 3, 4, 5]) |
| 258 | + mock_attn_metadata.query_start_loc = torch.tensor([0, 1, 3, 6]) |
| 259 | + mock_attn_metadata.block_tables = MagicMock() |
| 260 | + self.proposer._get_eagle_atten_dict = MagicMock( |
| 261 | + return_value={"layer_0": mock_attn_metadata}) |
| 262 | + self.proposer._prepare_inputs = MagicMock( |
| 263 | + return_value=(torch.tensor([0, 2, 5]), torch.tensor([1, 3, 5]))) |
| 264 | + |
| 265 | + result = self.proposer.generate_token_ids( |
| 266 | + valid_sampled_token_ids=valid_sampled, |
| 267 | + spec_decode_metadata=spec_metadata, |
| 268 | + positions=torch.randn(6, 1), |
| 269 | + hidden_states=torch.randn(6, 4096), |
| 270 | + ) |
| 271 | + |
| 272 | + self.proposer._prepare_inputs.assert_called_once() |
| 273 | + self.assertEqual(self.proposer._propose.call_count, 1) |
| 274 | + self.assertEqual(len(result), 3) |
| 275 | + |
| 276 | + |
| 277 | +class TestEagleProposerHelperMethods(TestBase): |
| 278 | + |
| 279 | + def setUp(self): |
| 280 | + self.vllm_config = MagicMock(spec=VllmConfig) |
| 281 | + self.vllm_config.scheduler_config = MagicMock(max_num_seqs=3) |
| 282 | + self.device = torch.device("cpu") |
| 283 | + self.runner = MagicMock() |
| 284 | + self.runner.input_batch = MagicMock() |
| 285 | + self.runner.input_batch.req_ids = [0, 1, 2] |
| 286 | + self.runner.arange_np = np.arange(10) |
| 287 | + self.runner.input_batch.num_reqs = 3 |
| 288 | + |
| 289 | + self.vllm_config.cache_config.block_size = 16 |
| 290 | + self.vllm_config.scheduler_config.max_num_batched_tokens = 1024 |
| 291 | + self.vllm_config.scheduler_config.max_num_seqs = 32 |
| 292 | + self.vllm_config.model_config.dtype = torch.float16 |
| 293 | + self.vllm_config.model_config.max_model_len = 2048 |
| 294 | + |
| 295 | + self.proposer = EagleProposer(vllm_config=self.vllm_config, |
| 296 | + device=self.device, |
| 297 | + runner=self.runner) |
| 298 | + |
| 299 | + def test_prepare_inputs(self): |
| 300 | + self.proposer.token_arange_np = np.arange(10) |
| 301 | + mock_attn = MagicMock() |
| 302 | + mock_attn.slot_mapping = torch.tensor([0, 1, 2, 3, 4, 5]) |
| 303 | + num_rejected = torch.tensor([1, 0, 1], device=self.device) |
| 304 | + |
| 305 | + with patch.object(self.proposer, |
| 306 | + '_prepare_inputs', |
| 307 | + return_value=(torch.tensor([0, 2, 5]), |
| 308 | + torch.tensor([1, 2, 4]))): |
| 309 | + cu_num_tokens, indices = self.proposer._prepare_inputs( |
| 310 | + mock_attn, num_rejected) |
| 311 | + self.assertEqual(cu_num_tokens.tolist(), [0, 2, 5]) |
| 312 | + self.assertEqual(indices.tolist(), [1, 2, 4]) |
0 commit comments