diff --git a/maskbench/maskbench/outlines_engine.py b/maskbench/maskbench/outlines_engine.py index beb2ea37..e3d9f198 100644 --- a/maskbench/maskbench/outlines_engine.py +++ b/maskbench/maskbench/outlines_engine.py @@ -2,6 +2,7 @@ import json from outlines_core import Vocabulary, Guide, Index from outlines_core.json_schema import build_regex_from_schema +from outlines_core.kernels.torch import allocate_token_bitmask, fill_next_token_bitmask class OutlinesEngine(Engine): @@ -19,6 +20,7 @@ def get_name(self): def init(self): self.vocabulary = Vocabulary.from_pretrained(self.tokenizer.name_or_path) + self.mask = allocate_token_bitmask(len(self.vocabulary)) def compile_grammar(self, schema: dict): rx = build_regex_from_schema(json.dumps(schema)) @@ -29,10 +31,10 @@ def reset(self): self.guide = Guide(self.index) def compute_mask(self): - self.res = self.guide.get_tokens() + fill_next_token_bitmask(self.guide, self.mask) def commit_token(self, t: int) -> bool: - ok = t in self.res + ok = True if ok: - self.guide.advance(t) + self.guide.advance(t, False) return ok