Skip to content

Commit 9c541b9

Browse files
authored
Split simulator.go into several files (#199)
* Split simulator.go into several files Signed-off-by: Ira <[email protected]> * Tests reorganization Signed-off-by: Ira <[email protected]> * Lint Signed-off-by: Ira <[email protected]> * Lint Signed-off-by: Ira <[email protected]> * Test helper functiona and function renaming Signed-off-by: Ira <[email protected]> * Lint Signed-off-by: Ira <[email protected]> --------- Signed-off-by: Ira <[email protected]>
1 parent f5d40a2 commit 9c541b9

File tree

13 files changed

+1052
-1181
lines changed

13 files changed

+1052
-1181
lines changed

pkg/llm-d-inference-sim/failures_test.go

Lines changed: 12 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ import (
2626
. "github.com/onsi/ginkgo/v2"
2727
. "github.com/onsi/gomega"
2828
"github.com/openai/openai-go"
29-
"github.com/openai/openai-go/option"
3029

3130
"github.com/llm-d/llm-d-inference-sim/pkg/common"
3231
openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api"
@@ -135,18 +134,8 @@ var _ = Describe("Failures", func() {
135134
})
136135

137136
It("should always return an error response for chat completions", func() {
138-
openaiClient := openai.NewClient(
139-
option.WithBaseURL(baseURL),
140-
option.WithHTTPClient(client),
141-
)
142-
143-
_, err := openaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{
144-
Model: model,
145-
Messages: []openai.ChatCompletionMessageParamUnion{
146-
openai.UserMessage(userMessage),
147-
},
148-
})
149-
137+
openaiClient, params := getOpenAIClentAndChatParams(client, model, userMessage, false)
138+
_, err := openaiClient.Chat.Completions.New(ctx, params)
150139
Expect(err).To(HaveOccurred())
151140

152141
var openaiError *openai.Error
@@ -158,18 +147,8 @@ var _ = Describe("Failures", func() {
158147
})
159148

160149
It("should always return an error response for text completions", func() {
161-
openaiClient := openai.NewClient(
162-
option.WithBaseURL(baseURL),
163-
option.WithHTTPClient(client),
164-
)
165-
166-
_, err := openaiClient.Completions.New(ctx, openai.CompletionNewParams{
167-
Model: openai.CompletionNewParamsModel(model),
168-
Prompt: openai.CompletionNewParamsPromptUnion{
169-
OfString: openai.String(userMessage),
170-
},
171-
})
172-
150+
openaiClient, params := getOpenAIClentAndChatParams(client, model, userMessage, false)
151+
_, err := openaiClient.Chat.Completions.New(ctx, params)
173152
Expect(err).To(HaveOccurred())
174153

175154
var openaiError *openai.Error
@@ -194,18 +173,8 @@ var _ = Describe("Failures", func() {
194173
})
195174

196175
It("should return only rate limit errors", func() {
197-
openaiClient := openai.NewClient(
198-
option.WithBaseURL(baseURL),
199-
option.WithHTTPClient(client),
200-
)
201-
202-
_, err := openaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{
203-
Model: model,
204-
Messages: []openai.ChatCompletionMessageParamUnion{
205-
openai.UserMessage(userMessage),
206-
},
207-
})
208-
176+
openaiClient, params := getOpenAIClentAndChatParams(client, model, userMessage, false)
177+
_, err := openaiClient.Chat.Completions.New(ctx, params)
209178
Expect(err).To(HaveOccurred())
210179

211180
var openaiError *openai.Error
@@ -230,20 +199,11 @@ var _ = Describe("Failures", func() {
230199
})
231200

232201
It("should return only specified error types", func() {
233-
openaiClient := openai.NewClient(
234-
option.WithBaseURL(baseURL),
235-
option.WithHTTPClient(client),
236-
)
202+
openaiClient, params := getOpenAIClentAndChatParams(client, model, userMessage, false)
237203

238204
// Make multiple requests to verify we get the expected error types
239205
for i := 0; i < 10; i++ {
240-
_, err := openaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{
241-
Model: model,
242-
Messages: []openai.ChatCompletionMessageParamUnion{
243-
openai.UserMessage(userMessage),
244-
},
245-
})
246-
206+
_, err := openaiClient.Chat.Completions.New(ctx, params)
247207
Expect(err).To(HaveOccurred())
248208

249209
var openaiError *openai.Error
@@ -270,18 +230,8 @@ var _ = Describe("Failures", func() {
270230
})
271231

272232
It("should never return errors and behave like random mode", func() {
273-
openaiClient := openai.NewClient(
274-
option.WithBaseURL(baseURL),
275-
option.WithHTTPClient(client),
276-
)
277-
278-
resp, err := openaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{
279-
Model: model,
280-
Messages: []openai.ChatCompletionMessageParamUnion{
281-
openai.UserMessage(userMessage),
282-
},
283-
})
284-
233+
openaiClient, params := getOpenAIClentAndChatParams(client, model, userMessage, false)
234+
resp, err := openaiClient.Chat.Completions.New(ctx, params)
285235
Expect(err).ToNot(HaveOccurred())
286236
Expect(resp.Choices).To(HaveLen(1))
287237
Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty())
@@ -300,18 +250,8 @@ var _ = Describe("Failures", func() {
300250
}, nil)
301251
Expect(err).ToNot(HaveOccurred())
302252

303-
openaiClient := openai.NewClient(
304-
option.WithBaseURL(baseURL),
305-
option.WithHTTPClient(client),
306-
)
307-
308-
_, err = openaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{
309-
Model: model,
310-
Messages: []openai.ChatCompletionMessageParamUnion{
311-
openai.UserMessage(userMessage),
312-
},
313-
})
314-
253+
openaiClient, params := getOpenAIClentAndChatParams(client, model, userMessage, false)
254+
_, err = openaiClient.Chat.Completions.New(ctx, params)
315255
Expect(err).To(HaveOccurred())
316256

317257
var openaiError *openai.Error

pkg/llm-d-inference-sim/helpers.go

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
/*
2+
Copyright 2025 The llm-d-inference-sim Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
// Package vllmsim implements the vLLM simulator.
18+
package llmdinferencesim
19+
20+
import (
21+
"encoding/json"
22+
"fmt"
23+
)
24+
25+
// isValidModel checks if the given model is the base model or one of "loaded" LoRAs
26+
func (s *VllmSimulator) isValidModel(model string) bool {
27+
for _, name := range s.config.ServedModelNames {
28+
if model == name {
29+
return true
30+
}
31+
}
32+
for _, lora := range s.getLoras() {
33+
if model == lora {
34+
return true
35+
}
36+
}
37+
38+
return false
39+
}
40+
41+
// isLora returns true if the given model name is one of loaded LoRAs
42+
func (s *VllmSimulator) isLora(model string) bool {
43+
for _, lora := range s.getLoras() {
44+
if model == lora {
45+
return true
46+
}
47+
}
48+
49+
return false
50+
}
51+
52+
// getDisplayedModelName returns the model name that must appear in API
53+
// responses. LoRA adapters keep their explicit name, while all base-model
54+
// requests are surfaced as the first alias from --served-model-name.
55+
func (s *VllmSimulator) getDisplayedModelName(reqModel string) string {
56+
if s.isLora(reqModel) {
57+
return reqModel
58+
}
59+
return s.config.ServedModelNames[0]
60+
}
61+
62+
func (s *VllmSimulator) showConfig(dp bool) error {
63+
cfgJSON, err := json.Marshal(s.config)
64+
if err != nil {
65+
return fmt.Errorf("failed to marshal configuration to JSON: %w", err)
66+
}
67+
68+
var m map[string]interface{}
69+
err = json.Unmarshal(cfgJSON, &m)
70+
if err != nil {
71+
return fmt.Errorf("failed to unmarshal JSON to map: %w", err)
72+
}
73+
if dp {
74+
// remove the port
75+
delete(m, "port")
76+
}
77+
// clean LoraModulesString field
78+
m["lora-modules"] = m["LoraModules"]
79+
delete(m, "LoraModules")
80+
delete(m, "LoraModulesString")
81+
82+
// clean fake-metrics field
83+
if field, ok := m["fake-metrics"].(map[string]interface{}); ok {
84+
delete(field, "LorasString")
85+
}
86+
87+
// show in JSON
88+
cfgJSON, err = json.MarshalIndent(m, "", " ")
89+
if err != nil {
90+
return fmt.Errorf("failed to marshal configuration to JSON: %w", err)
91+
}
92+
s.logger.Info("Configuration:", "", string(cfgJSON))
93+
return nil
94+
}
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
Copyright 2025 The llm-d-inference-sim Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
// Package vllmsim implements the vLLM simulator.
18+
package llmdinferencesim
19+
20+
import "github.com/llm-d/llm-d-inference-sim/pkg/common"
21+
22+
func (s *VllmSimulator) getCurrLoadFactor() float64 {
23+
if s.config.MaxNumSeqs <= 1 {
24+
return 1.0
25+
}
26+
return 1 + (s.config.TimeFactorUnderLoad-1)*float64(s.nRunningReqs-1)/float64(s.config.MaxNumSeqs-1)
27+
}
28+
29+
func (s *VllmSimulator) getTimeToFirstToken() int {
30+
return int(float64(s.config.TimeToFirstToken) * s.getCurrLoadFactor())
31+
}
32+
33+
func (s *VllmSimulator) getPrefillOverhead() int {
34+
return int(float64(s.config.PrefillOverhead) * s.getCurrLoadFactor())
35+
}
36+
37+
func (s *VllmSimulator) getPrefillTimePerToken() int {
38+
return int(float64(s.config.PrefillTimePerToken) * s.getCurrLoadFactor())
39+
}
40+
41+
// returns time to first token based on the current request's doRemotePrefill
42+
func (s *VllmSimulator) getWaitTimeToFirstToken(nPromptTokens int, nCachedPromptTokens int, doRemotePrefill bool) int {
43+
if doRemotePrefill {
44+
if s.config.KVCacheTransferLatency == 0 && s.config.KVCacheTransferLatencyStdDev == 0 {
45+
// is disaggregated PD and ttft is calculated using number of prompt tokens
46+
kvCacheTransT := s.config.KVCacheTransferTimePerToken * nPromptTokens
47+
return common.RandomNorm(kvCacheTransT, s.config.KVCacheTransferTimeStdDev)
48+
}
49+
// is disaggregated PD and *not* using number of prompt tokens
50+
return common.RandomNorm(s.config.KVCacheTransferLatency, s.config.KVCacheTransferLatencyStdDev)
51+
}
52+
if s.config.TimeToFirstToken == 0 && s.config.TimeToFirstTokenStdDev == 0 {
53+
// is aggregated PD and ttft is calculated using number of prompt tokens that are not in kv cache
54+
prefillTime := s.getPrefillOverhead() + (nPromptTokens-nCachedPromptTokens)*s.getPrefillTimePerToken()
55+
return common.RandomNorm(prefillTime, s.config.PrefillTimeStdDev)
56+
}
57+
// is aggregated PD and *not* using number of prompt tokens
58+
return common.RandomNorm(s.getTimeToFirstToken(), s.config.TimeToFirstTokenStdDev)
59+
}
60+
61+
// returns inter token latency
62+
func (s *VllmSimulator) getInterTokenLatency() int {
63+
latency := int(float64(s.config.InterTokenLatency) * s.getCurrLoadFactor())
64+
return common.RandomNorm(latency, s.config.InterTokenLatencyStdDev)
65+
}

0 commit comments

Comments
 (0)