@@ -18,6 +18,7 @@ package llmdinferencesim
1818
1919import (
2020 "context"
21+ "encoding/json"
2122 "errors"
2223 "fmt"
2324 "io"
@@ -29,6 +30,8 @@ import (
2930
3031 "github.com/llm-d/llm-d-inference-sim/pkg/common"
3132 kvcache "github.com/llm-d/llm-d-inference-sim/pkg/kv-cache"
33+ vllmapi "github.com/llm-d/llm-d-inference-sim/pkg/vllm-api"
34+ "github.com/llm-d/llm-d-kv-cache-manager/pkg/tokenization"
3235 . "github.com/onsi/ginkgo/v2"
3336 . "github.com/onsi/gomega"
3437 "github.com/openai/openai-go"
@@ -39,6 +42,7 @@ import (
3942)
4043
4144const model = "my_model"
45+ const qwenModelName = "Qwen/Qwen2-0.5B"
4246const baseURL = "http://localhost/v1"
4347const userMessage = "This is a test."
4448const invalidMaxTokensErrMsg = "Max completion tokens and max tokens should be positive"
@@ -97,8 +101,17 @@ func startServerWithArgs(ctx context.Context, mode string, args []string, envs m
97101 return nil , err
98102 }
99103
104+ tokenizationConfig := tokenization .DefaultConfig ()
105+ if s .config .TokenizersCacheDir != "" {
106+ tokenizationConfig .TokenizersCacheDir = s .config .TokenizersCacheDir
107+ }
108+ s .tokenizer , err = tokenization .NewCachedHFTokenizer (tokenizationConfig .HFTokenizerConfig )
109+ if err != nil {
110+ return nil , fmt .Errorf ("failed to create tokenizer: %w" , err )
111+ }
112+
100113 if s .config .EnableKVCache {
101- s .kvcacheHelper , err = kvcache .NewKVCacheHelper (s .config , s .logger , s .kvCacheUsageChan )
114+ s .kvcacheHelper , err = kvcache .NewKVCacheHelper (s .config , s .logger , s .kvCacheUsageChan , s . tokenizer )
102115 if err != nil {
103116 return nil , err
104117 }
@@ -1065,7 +1078,71 @@ var _ = Describe("Simulator", func() {
10651078 Expect (factor ).To (BeNumerically (">" , 1.0 ))
10661079 Expect (factor ).To (BeNumerically ("<" , simulator .config .TimeFactorUnderLoad ))
10671080 })
1068-
10691081 })
10701082
1083+ Context ("tokenize" , Ordered , func () {
1084+ tmpDir := "./tests-tmp/"
1085+ AfterAll (func () {
1086+ err := os .RemoveAll (tmpDir )
1087+ Expect (err ).NotTo (HaveOccurred ())
1088+ })
1089+
1090+ It ("Should return correct response to /tokenize chat" , func () {
1091+ ctx := context .TODO ()
1092+ args := []string {"cmd" , "--model" , qwenModelName , "--mode" , common .ModeRandom ,
1093+ "--tokenizers-cache-dir" , tmpDir , "--max-model-len" , "2048" }
1094+ client , err := startServerWithArgs (ctx , common .ModeRandom , args , nil )
1095+ Expect (err ).NotTo (HaveOccurred ())
1096+
1097+ reqBody := `{
1098+ "messages": [{"role": "user", "content": "This is a test"}],
1099+ "model": "Qwen/Qwen2-0.5B"
1100+ }`
1101+ resp , err := client .Post ("http://localhost/tokenize" , "application/json" , strings .NewReader (reqBody ))
1102+ Expect (err ).NotTo (HaveOccurred ())
1103+ defer func () {
1104+ err := resp .Body .Close ()
1105+ Expect (err ).NotTo (HaveOccurred ())
1106+ }()
1107+
1108+ body , err := io .ReadAll (resp .Body )
1109+ Expect (err ).NotTo (HaveOccurred ())
1110+
1111+ var tokenizeResp vllmapi.TokenizeResponse
1112+ err = json .Unmarshal (body , & tokenizeResp )
1113+ Expect (err ).NotTo (HaveOccurred ())
1114+ Expect (tokenizeResp .Count ).To (Equal (4 ))
1115+ Expect (tokenizeResp .Tokens ).To (HaveLen (4 ))
1116+ Expect (tokenizeResp .MaxModelLen ).To (Equal (2048 ))
1117+ })
1118+
1119+ It ("Should return correct response to /tokenize text" , func () {
1120+ ctx := context .TODO ()
1121+ args := []string {"cmd" , "--model" , qwenModelName , "--mode" , common .ModeRandom ,
1122+ "--tokenizers-cache-dir" , tmpDir , "--max-model-len" , "2048" }
1123+ client , err := startServerWithArgs (ctx , common .ModeRandom , args , nil )
1124+ Expect (err ).NotTo (HaveOccurred ())
1125+
1126+ reqBody := `{
1127+ "prompt": "This is a test",
1128+ "model": "Qwen/Qwen2-0.5B"
1129+ }`
1130+ resp , err := client .Post ("http://localhost/tokenize" , "application/json" , strings .NewReader (reqBody ))
1131+ Expect (err ).NotTo (HaveOccurred ())
1132+ defer func () {
1133+ err := resp .Body .Close ()
1134+ Expect (err ).NotTo (HaveOccurred ())
1135+ }()
1136+
1137+ body , err := io .ReadAll (resp .Body )
1138+ Expect (err ).NotTo (HaveOccurred ())
1139+
1140+ var tokenizeResp vllmapi.TokenizeResponse
1141+ err = json .Unmarshal (body , & tokenizeResp )
1142+ Expect (err ).NotTo (HaveOccurred ())
1143+ Expect (tokenizeResp .Count ).To (Equal (4 ))
1144+ Expect (tokenizeResp .Tokens ).To (HaveLen (4 ))
1145+ Expect (tokenizeResp .MaxModelLen ).To (Equal (2048 ))
1146+ })
1147+ })
10711148})
0 commit comments