1+
2+ import sys , os
3+ sys .path .append (os .path .dirname (os .path .dirname (os .path .abspath (__file__ ))))
4+
5+ from exllamav2 import ExLlamaV2 , ExLlamaV2Tokenizer
6+ from exllamav2 .generator .filters import ExLlamaV2Filter
7+ from functools import lru_cache
8+ from lmformatenforcer .integrations .exllamav2 import build_token_enforcer_tokenizer_data
9+ from lmformatenforcer import TokenEnforcer , CharacterLevelParser
10+ from typing import List
11+
12+
13+ # Temporary wrapper for lm-format-enforcer, until the integration in LMFE itself is updated
14+
15+
16+ @lru_cache (10 )
17+ def _get_lmfe_tokenizer_data (tokenizer : ExLlamaV2Tokenizer ):
18+ return build_token_enforcer_tokenizer_data (tokenizer )
19+
20+
21+ class ExLlamaV2TokenEnforcerFilter (ExLlamaV2Filter ):
22+
23+ token_sequence : List [int ]
24+
25+ def __init__ (
26+ self ,
27+ model : ExLlamaV2 ,
28+ tokenizer : ExLlamaV2Tokenizer ,
29+ character_level_parser : CharacterLevelParser ,
30+ ):
31+ super ().__init__ (model , tokenizer )
32+ tokenizer_data = _get_lmfe_tokenizer_data (tokenizer )
33+ self .token_enforcer = TokenEnforcer (tokenizer_data , character_level_parser )
34+ self .token_sequence = []
35+
36+ def begin (self , prefix_str : str ) -> None :
37+ self .token_sequence = []
38+
39+ def feed (self , token ) -> None :
40+ self .token_sequence .append (int (token [0 ][0 ]))
41+
42+ def next (self ):
43+ allowed_tokens = self .token_enforcer .get_allowed_tokens (self .token_sequence )
44+ return sorted (allowed_tokens ), []
45+
46+ def use_background_worker (self ):
47+ return True
0 commit comments