Skip to content

Commit fe62017

Browse files
committed
add ut
1 parent ab7d5c4 commit fe62017

File tree

1 file changed

+43
-1
lines changed

1 file changed

+43
-1
lines changed

tests/pytorch/paging/test_block_trie.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import pytest
33

44
from lmdeploy.pytorch.config import CacheConfig
5-
from lmdeploy.pytorch.messages import SchedulerSession, SequenceManager, SequenceMeta
5+
from lmdeploy.pytorch.messages import SamplingParam, SchedulerSession, SequenceManager, SequenceMeta
66
from lmdeploy.pytorch.paging.block_manager import build_block_manager
77
from lmdeploy.pytorch.paging.block_trie import BlockTrie
88

@@ -37,13 +37,55 @@ def block_mgr(self, cache_config):
3737
def block_trie(self, cache_config, block_mgr):
3838
yield BlockTrie(cache_config, block_mgr)
3939

40+
@pytest.fixture
41+
def num_moe_layers(self):
42+
yield 4
43+
44+
@pytest.fixture
45+
def experts_topk(self):
46+
yield 4
47+
4048
@pytest.fixture
4149
def seq_manager(self, block_size):
4250
from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy
4351
strategy = ARSequenceStrategy()
4452
seq_meta = SequenceMeta(block_size, strategy=strategy)
4553
yield SequenceManager(seq_meta)
4654

55+
def test_with_routed_experts(self, block_trie, block_mgr, seq_manager, num_moe_layers, experts_topk):
56+
57+
def _get_routed_experts(size, value):
58+
return np.full((size, num_moe_layers, experts_topk), value, dtype=np.int32)
59+
60+
sess = SchedulerSession(0, seq_manager)
61+
block_size = sess.seq_meta.block_size
62+
token_ids = ([1] * block_size + [2] * block_size)
63+
all_routed_experts = [_get_routed_experts(block_size, 1), _get_routed_experts(block_size, 2)]
64+
token_ids += [3] * (block_size // 2)
65+
all_routed_experts += [_get_routed_experts(block_size // 2, 3)]
66+
seq = sess.add_sequence(token_ids, sampling_param=SamplingParam(return_routed_experts=True))
67+
all_routed_experts += [_get_routed_experts(block_size - 1, 4)]
68+
routed_experts = np.concatenate(all_routed_experts, axis=0)
69+
seq.update_token_ids([4] * block_size, routed_experts=routed_experts)
70+
71+
# test allocate
72+
block_mgr.allocate(seq)
73+
block_trie.allocate(seq)
74+
node = getattr(seq.logical_blocks, 'last_shared_node', None)
75+
assert node is not None
76+
assert node.routed_experts is not None
77+
target_routed_experts = np.concatenate(
78+
[_get_routed_experts(block_size // 2, 3),
79+
_get_routed_experts(block_size // 2, 4)], axis=0)
80+
assert np.array_equal(node.routed_experts, target_routed_experts)
81+
82+
# test match
83+
seq_query = sess.add_sequence(token_ids, sampling_param=SamplingParam(return_routed_experts=True))
84+
block_trie.match(seq_query)
85+
assert seq_query.all_routed_experts is not None
86+
assert len(seq_query.all_routed_experts) == block_size * 2
87+
assert np.array_equal(seq_query.all_routed_experts.get_real(), np.concatenate(all_routed_experts[:2], axis=0))
88+
4789
def test_allocate(self, block_trie, block_mgr, seq_manager):
4890
allocator = block_trie.allocator
4991
sess = SchedulerSession(0, seq_manager)

0 commit comments

Comments
 (0)