diff --git a/pkg/common/logprobs_test.go b/pkg/common/logprobs_test.go new file mode 100644 index 0000000..443afeb --- /dev/null +++ b/pkg/common/logprobs_test.go @@ -0,0 +1,253 @@ +/* +Copyright 2025 The llm-d-inference-sim Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package common + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Logprobs", func() { + + Context("GenerateTextLogprobs", func() { + It("should generate correct text logprobs structure", func() { + tokens := []string{" Paris", ",", " the", " capital"} + logprobsCount := 2 + + logprobs := GenerateTextLogprobs(tokens, logprobsCount) + + Expect(logprobs).NotTo(BeNil()) + Expect(logprobs.Tokens).To(HaveLen(len(tokens))) + Expect(logprobs.TokenLogprobs).To(HaveLen(len(tokens))) + Expect(logprobs.TopLogprobs).To(HaveLen(len(tokens))) + Expect(logprobs.TextOffset).To(HaveLen(len(tokens))) + + // Check that each top logprobs entry has the expected number of alternatives + for i, topLogprob := range logprobs.TopLogprobs { + Expect(topLogprob).To(HaveLen(logprobsCount)) + // Check that the main token is included in the alternatives + Expect(topLogprob).To(HaveKey(tokens[i])) + } + + // Check text offsets are calculated correctly (byte-based) + expectedOffsets := []int{0, 6, 7, 11} // " Paris" - 6, "," - 1, " the" -4, " capital" - 11 + for i, expected := range expectedOffsets { + Expect(logprobs.TextOffset[i]).To(Equal(expected)) + } + + // Check deterministic logprobs + expectedLogprob0 := -1.0 // defaultLogprob - float64(0%3)*0.1 + Expect(logprobs.TokenLogprobs[0]).To(Equal(expectedLogprob0)) + }) + }) + + Context("GenerateChatLogprobs", func() { + It("should generate correct chat logprobs structure", func() { + tokens := []string{"4"} + topLogprobsCount := 3 + + logprobs := GenerateChatLogprobs(tokens, topLogprobsCount) + + Expect(logprobs).NotTo(BeNil()) + Expect(logprobs.Content).To(HaveLen(len(tokens))) + + content := logprobs.Content[0] + Expect(content.Token).To(Equal(tokens[0])) + Expect(content.Bytes).To(HaveLen(len(tokens[0]))) + Expect(content.TopLogprobs).To(HaveLen(topLogprobsCount)) + + // Check that the main token is the first in top logprobs + Expect(content.TopLogprobs[0].Token).To(Equal(tokens[0])) + + // Check alternative tokens follow the pattern + expectedAlt1 := "4_1" + Expect(content.TopLogprobs[1].Token).To(Equal(expectedAlt1)) + + // Check byte conversion + expectedBytes := []int{52} // byte value of '4' + for i, expected := range expectedBytes { + Expect(content.Bytes[i]).To(Equal(expected)) + } + + // Check deterministic logprobs + expectedLogprob := -1.0 // defaultLogprob - float64(0%3)*0.1 + Expect(content.Logprob).To(Equal(expectedLogprob)) + }) + }) + + Context("calculateLogprob", func() { + It("should calculate main token probabilities correctly", func() { + // Test position cycle behavior (cycle of 3) + // Position 0: -1.0 - (0 % 3) * 0.1 = -1.0 + result0 := calculateLogprob(0, 0) + Expect(result0).To(Equal(-1.0)) + + // Position 1: -1.0 - (1 % 3) * 0.1 = -1.1 + result1 := calculateLogprob(1, 0) + Expect(result1).To(Equal(-1.1)) + + // Position 2: -1.0 - (2 % 3) * 0.1 = -1.2 + result2 := calculateLogprob(2, 0) + Expect(result2).To(Equal(-1.2)) + + // Position 3: -1.0 - (3 % 3) * 0.1 = -1.0 (cycle repeats) + result3 := calculateLogprob(3, 0) + Expect(result3).To(Equal(-1.0)) + + // Position 4: -1.0 - (4 % 3) * 0.1 = -1.1 (cycle repeats) + result4 := calculateLogprob(4, 0) + Expect(result4).To(Equal(-1.1)) + }) + + It("should calculate alternative token probabilities correctly", func() { + // Test alternative token decrements (0.5 per alternative index) + tokenPosition := 0 // Start with position 0 (main logprob = -1.0) + + // Alternative 1: -1.0 - 1 * 0.5 = -1.5 + alt1 := calculateLogprob(tokenPosition, 1) + Expect(alt1).To(Equal(-1.5)) + + // Alternative 2: -1.0 - 2 * 0.5 = -2.0 + alt2 := calculateLogprob(tokenPosition, 2) + Expect(alt2).To(Equal(-2.0)) + + // Alternative 3: -1.0 - 3 * 0.5 = -2.5 + alt3 := calculateLogprob(tokenPosition, 3) + Expect(alt3).To(Equal(-2.5)) + }) + + It("should combine position cycle and alternative index correctly", func() { + // Test with position 1 (main logprob = -1.1) + tokenPosition := 1 + + // Main token: -1.0 - (1 % 3) * 0.1 = -1.1 + main := calculateLogprob(tokenPosition, 0) + Expect(main).To(Equal(-1.1)) + + // Alternative 1: -1.1 - 1 * 0.5 = -1.6 + alt1 := calculateLogprob(tokenPosition, 1) + Expect(alt1).To(Equal(-1.6)) + + // Alternative 2: -1.1 - 2 * 0.5 = -2.1 + alt2 := calculateLogprob(tokenPosition, 2) + Expect(alt2).To(Equal(-2.1)) + }) + + It("should handle large position values correctly", func() { + // Test with large position values to ensure cycle works + largePosition := 100 + + // Position 100: -1.0 - (100 % 3) * 0.1 = -1.0 - 1 * 0.1 = -1.1 + result := calculateLogprob(largePosition, 0) + Expect(result).To(Equal(-1.1)) + + // With alternative: -1.1 - 1 * 0.5 = -1.6 + resultAlt := calculateLogprob(largePosition, 1) + Expect(resultAlt).To(Equal(-1.6)) + }) + + It("should handle edge cases correctly", func() { + // Test with zero values + result := calculateLogprob(0, 0) + Expect(result).To(Equal(-1.0)) + + // Test with large alternative index + largeAlt := calculateLogprob(0, 10) + expectedLargeAlt := -1.0 - float64(10)*0.5 // -6.0 + Expect(largeAlt).To(Equal(expectedLargeAlt)) + }) + }) + + Context("Other scenarios", func() { + It("should handle empty tokens for text logprobs", func() { + logprobs := GenerateTextLogprobs([]string{}, 2) + + Expect(logprobs).NotTo(BeNil()) + Expect(logprobs.Tokens).To(BeEmpty()) + }) + + It("should handle empty tokens for chat logprobs", func() { + logprobs := GenerateChatLogprobs([]string{}, 2) + + Expect(logprobs).NotTo(BeNil()) + Expect(logprobs.Content).To(BeEmpty()) + }) + + It("should verify probability pattern as token position grows", func() { + // Test the cycling pattern of probabilities + + // Test first cycle (positions 0-2) + prob0 := calculateLogprob(0, 0) + prob1 := calculateLogprob(1, 0) + prob2 := calculateLogprob(2, 0) + + Expect(prob0).To(Equal(-1.0)) // defaultLogprob + Expect(prob1).To(Equal(-1.1)) // defaultLogprob - 1*0.1 + Expect(prob2).To(Equal(-1.2)) // defaultLogprob - 2*0.1 + + // Test second cycle (positions 3-5) - should repeat the pattern + prob3 := calculateLogprob(3, 0) + prob4 := calculateLogprob(4, 0) + prob5 := calculateLogprob(5, 0) + + Expect(prob3).To(Equal(prob0)) // Should equal position 0 + Expect(prob4).To(Equal(prob1)) // Should equal position 1 + Expect(prob5).To(Equal(prob2)) // Should equal position 2 + + // Test third cycle (positions 6-8) - should repeat again + prob6 := calculateLogprob(6, 0) + prob7 := calculateLogprob(7, 0) + prob8 := calculateLogprob(8, 0) + + Expect(prob6).To(Equal(prob0)) // Should equal position 0 + Expect(prob7).To(Equal(prob1)) // Should equal position 1 + Expect(prob8).To(Equal(prob2)) // Should equal position 2 + + // Verify the cycling pattern continues for larger positions + for i := 0; i < 20; i++ { + expectedProb := defaultLogprob - float64(i%positionCycle)*positionDecrement + actualProb := calculateLogprob(i, 0) + Expect(actualProb).To(Equal(expectedProb), "Position %d should have probability %f", i, expectedProb) + } + }) + }) + + Context("No Limits", func() { + It("should allow unlimited logprobs count", func() { + tokens := []string{"test"} + + // Test text completion (no clamping) + textLogprobs := GenerateTextLogprobs(tokens, 10) + Expect(textLogprobs.TopLogprobs[0]).To(HaveLen(10)) + + // Test chat completion (no clamping) + chatLogprobs := GenerateChatLogprobs(tokens, 25) + Expect(chatLogprobs.Content[0].TopLogprobs).To(HaveLen(25)) + + // Test high count + textLogprobs = GenerateTextLogprobs(tokens, 100) + Expect(textLogprobs.TopLogprobs[0]).To(HaveLen(100)) + + chatLogprobs = GenerateChatLogprobs(tokens, 50) + Expect(chatLogprobs.Content[0].TopLogprobs).To(HaveLen(50)) + + // Test minimum (at least 1) + textLogprobs = GenerateTextLogprobs(tokens, 0) + Expect(textLogprobs.TopLogprobs[0]).To(HaveLen(1)) + }) + }) +}) diff --git a/pkg/common/logprobs_types.go b/pkg/common/logprobs_types.go new file mode 100644 index 0000000..04da401 --- /dev/null +++ b/pkg/common/logprobs_types.go @@ -0,0 +1,47 @@ +/* +Copyright 2025 The llm-d-inference-sim Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package common + +// LogprobsContent represents logprobs for a single token in chat completions +type LogprobsContent struct { + // Token is the token string + Token string `json:"token"` + // Logprob is the log probability of the token + Logprob float64 `json:"logprob"` + // Bytes is the byte representation of the token + Bytes []int `json:"bytes"` + // TopLogprobs is the list of top alternative tokens along their log probabilities + TopLogprobs []LogprobsContent `json:"top_logprobs,omitempty"` +} + +// ChatLogprobs represents logprobs for chat completion responses +type ChatLogprobs struct { + // Content is an array of logprobs for each token in the content + Content []LogprobsContent `json:"content"` +} + +// TextLogprobs represents logprobs for text completion responses +type TextLogprobs struct { + // Tokens is an array of tokens + Tokens []string `json:"tokens"` + // TokenLogprobs is an array of log probabilities for each token + TokenLogprobs []float64 `json:"token_logprobs"` + // TopLogprobs is an array of objects containing the top alternative tokens + TopLogprobs []map[string]float64 `json:"top_logprobs"` + // TextOffset is an array of character offsets + TextOffset []int `json:"text_offset"` +} diff --git a/pkg/common/logprobs_utils.go b/pkg/common/logprobs_utils.go new file mode 100644 index 0000000..a46abd9 --- /dev/null +++ b/pkg/common/logprobs_utils.go @@ -0,0 +1,258 @@ +/* +Copyright 2025 The llm-d-inference-sim Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package common + +import ( + "fmt" +) + +const ( + // Default logprob value + defaultLogprob = -1.0 + // Cycle length for position-based variation + positionCycle = 3 + // Logprob decrement per cycle position + positionDecrement = 0.1 + // Logprob decrement per alternative token + alternativeDecrement = 0.5 +) + +// NOTE: These functions produce synthetic data for API shape compatibility. +// The logprobs are deterministic placeholders and have no semantic meaning. + +// calculateLogprob calculates synthetic log probabilities using a deterministic algorithm. +// For the main token (alternativeIndex = 0), it uses a cycle of 3 positions with decreasing probability. +// For alternative tokens, it decreases probability by 0.5 per alternative index. +// +// Algorithm: +// - Main token: defaultLogprob - (tokenPosition % 3) * 0.1 +// - Alternative: mainTokenLogprob - alternativeIndex * 0.5 +func calculateLogprob(tokenPosition int, alternativeIndex int) float64 { + // Calculate main token probability based on position cycle + mainLogprob := defaultLogprob - float64(tokenPosition%positionCycle)*positionDecrement + + // For main token (index 0), return the main probability + if alternativeIndex == 0 { + return mainLogprob + } + + // For alternatives, decrease by alternativeDecrement per index + return mainLogprob - float64(alternativeIndex)*alternativeDecrement +} + +// GenerateSingleTokenChatLogprobs generates logprobs for a single token in chat completion streaming +func GenerateSingleTokenChatLogprobs(token string, tokenPosition int, topLogprobsCount int) *LogprobsContent { + if token == "" { + return nil + } + + // Calculate main token probability + mainLogprob := calculateLogprob(tokenPosition, 0) + tokenBytes := stringToIntBytes(token) + + content := LogprobsContent{ + Token: token, + Logprob: mainLogprob, + Bytes: tokenBytes, + } + + // Generate top alternatives if requested + if topLogprobsCount > 0 { + // Pre-size alternatives slice + content.TopLogprobs = make([]LogprobsContent, topLogprobsCount) + + // Main token first + content.TopLogprobs[0] = LogprobsContent{ + Token: token, + Logprob: mainLogprob, + Bytes: tokenBytes, + } + + // Alternative tokens + for j := 1; j < topLogprobsCount; j++ { + altToken := fmt.Sprintf("%s_%d", token, j) + altLogprob := calculateLogprob(tokenPosition, j) + altBytes := stringToIntBytes(altToken) + + content.TopLogprobs[j] = LogprobsContent{ + Token: altToken, + Logprob: altLogprob, + Bytes: altBytes, + } + } + } + + return &content +} + +// GenerateSingleTokenTextLogprobs generates logprobs for a single token in text completion streaming +func GenerateSingleTokenTextLogprobs(token string, tokenPosition int, logprobsCount int) *TextLogprobs { + if token == "" { + return nil + } + + // Ensure minimum count + if logprobsCount <= 0 { + logprobsCount = 1 // Include the main token, at a minimum + } + + logprobs := &TextLogprobs{ + Tokens: []string{token}, + TokenLogprobs: make([]float64, 1), + TopLogprobs: make([]map[string]float64, 1), + TextOffset: []int{0}, + } + + // Calculate main token probability + mainLogprob := calculateLogprob(tokenPosition, 0) + logprobs.TokenLogprobs[0] = mainLogprob + + topLogprobs := make(map[string]float64, logprobsCount) + topLogprobs[token] = mainLogprob + + // Add alternative tokens + for j := 1; j < logprobsCount; j++ { + altToken := fmt.Sprintf("%s_%d", token, j) + altLogprob := calculateLogprob(tokenPosition, j) + topLogprobs[altToken] = altLogprob + } + + logprobs.TopLogprobs[0] = topLogprobs + + return logprobs +} + +// GenerateTextLogprobs generates synthetic log probabilities for text completion responses +func GenerateTextLogprobs(tokens []string, logprobsCount int) *TextLogprobs { + // Always return a valid struct, even for empty input + if len(tokens) == 0 { + return &TextLogprobs{ + Tokens: []string{}, + TokenLogprobs: []float64{}, + TopLogprobs: []map[string]float64{}, + TextOffset: []int{}, + } + } + + // Ensure minimum count + if logprobsCount <= 0 { + logprobsCount = 1 // Include the main token, at least + } + + // Avoid reallocations + numTokens := len(tokens) + logprobs := &TextLogprobs{ + Tokens: tokens, + TokenLogprobs: make([]float64, numTokens), + TopLogprobs: make([]map[string]float64, numTokens), + TextOffset: make([]int, numTokens), + } + + offset := 0 + for i, token := range tokens { + logprobs.TextOffset[i] = offset + offset += len(token) // Use byte length + + // Calculate main token probability using helper function + mainLogprob := calculateLogprob(i, 0) + logprobs.TokenLogprobs[i] = mainLogprob + + topLogprobs := make(map[string]float64, logprobsCount) + topLogprobs[token] = mainLogprob + + // Add alternative tokens using helper function + for j := 1; j < logprobsCount; j++ { + altToken := fmt.Sprintf("%s_%d", token, j) + altLogprob := calculateLogprob(i, j) + topLogprobs[altToken] = altLogprob + } + + logprobs.TopLogprobs[i] = topLogprobs + } + + return logprobs +} + +// GenerateChatLogprobs generates synthetic log probabilities for chat completion responses +func GenerateChatLogprobs(tokens []string, topLogprobsCount int) *ChatLogprobs { + // Always return a valid struct, even for empty input + if len(tokens) == 0 { + return &ChatLogprobs{ + Content: []LogprobsContent{}, + } + } + + numTokens := len(tokens) + logprobs := &ChatLogprobs{ + Content: make([]LogprobsContent, numTokens), + } + + for i, token := range tokens { + // Calculate main token probability using helper function + mainLogprob := calculateLogprob(i, 0) + + tokenBytes := stringToIntBytes(token) + + content := LogprobsContent{ + Token: token, + Logprob: mainLogprob, + Bytes: tokenBytes, + } + + // Generate top alternatives if requested + if topLogprobsCount > 0 { + // Pre-size alternatives slice + content.TopLogprobs = make([]LogprobsContent, topLogprobsCount) + + // Main token first + content.TopLogprobs[0] = LogprobsContent{ + Token: token, + Logprob: mainLogprob, + Bytes: tokenBytes, + } + + // Alternative tokens using helper function + for j := 1; j < topLogprobsCount; j++ { + altToken := fmt.Sprintf("%s_%d", token, j) + altLogprob := calculateLogprob(i, j) + altBytes := stringToIntBytes(altToken) + + content.TopLogprobs[j] = LogprobsContent{ + Token: altToken, + Logprob: altLogprob, + Bytes: altBytes, + } + } + } + + logprobs.Content[i] = content + } + + return logprobs +} + +// stringToIntBytes converts a string to []int of byte values inline +func stringToIntBytes(s string) []int { + if s == "" { + return nil + } + out := make([]int, len(s)) + for i := range out { + out[i] = int(s[i]) + } + return out +} diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index 3851c51..f936e7b 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -403,6 +403,11 @@ func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) { doRemotePrefill: req.IsDoRemotePrefill(), nPromptTokens: usageData.PromptTokens, nCachedPromptTokens: reqCtx.CompletionReq.GetNumberOfCachedPromptTokens(), + requestID: req.GetRequestID(), + // Logprobs configuration + includeLogprobs: req.IncludeLogprobs(), + topLogprobs: req.GetTopLogprobsCount(), + logprobsCount: req.GetLogprobs(), }, responseTokens, toolCalls, finishReason, usageDataToSend, ) @@ -437,15 +442,23 @@ func (s *VllmSimulator) responseSentCallback(model string, isChatCompletion bool } } +// LogprobData has the logprob-related information needed for response generation +type LogprobData struct { + IncludeLogprobs bool + TopLogprobs int + Logprobs *int +} + // createCompletionResponse creates the response for completion requests, supports both completion request types (text and chat) // as defined by isChatCompletion +// logprobData - pointer to logprob configuration data, can be nil if no logprobs needed // respTokens - tokenized content to be sent in the response // toolCalls - tool calls to be sent in the response // finishReason - a pointer to string that represents finish reason, can be nil or stop or length, ... // usageData - usage (tokens statistics) for this response // modelName - display name returned to the client and used in metrics. It is either the first alias // from --served-model-name (for a base-model request) or the LoRA adapter name (for a LoRA request). -func (s *VllmSimulator) createCompletionResponse(isChatCompletion bool, respTokens []string, toolCalls []openaiserverapi.ToolCall, +func (s *VllmSimulator) createCompletionResponse(logprobData *LogprobData, isChatCompletion bool, respTokens []string, toolCalls []openaiserverapi.ToolCall, finishReason *string, usageData *openaiserverapi.Usage, modelName string, doRemoteDecode bool) openaiserverapi.CompletionResponse { baseResp := openaiserverapi.BaseCompletionResponse{ ID: chatComplIDPrefix + common.GenerateUUIDString(), @@ -477,16 +490,41 @@ func (s *VllmSimulator) createCompletionResponse(isChatCompletion bool, respToke } else { message.Content = openaiserverapi.Content{Raw: respText} } + + choice := openaiserverapi.ChatRespChoice{ + Message: message, + BaseResponseChoice: baseChoice, + } + + // Generate logprobs if requested + if logprobData != nil && logprobData.IncludeLogprobs && toolCalls == nil { + if logprobs := common.GenerateChatLogprobs(respTokens, logprobData.TopLogprobs); logprobs != nil && len(logprobs.Content) > 0 { + choice.Logprobs = logprobs + } + } + return &openaiserverapi.ChatCompletionResponse{ BaseCompletionResponse: baseResp, - Choices: []openaiserverapi.ChatRespChoice{{Message: message, BaseResponseChoice: baseChoice}}, + Choices: []openaiserverapi.ChatRespChoice{choice}, + } + } + + choice := openaiserverapi.TextRespChoice{ + BaseResponseChoice: baseChoice, + Text: respText, + } + + // Generate logprobs if requested for text completion + if logprobData != nil && logprobData.Logprobs != nil && *logprobData.Logprobs > 0 { + if logprobs := common.GenerateTextLogprobs(respTokens, *logprobData.Logprobs); logprobs != nil && len(logprobs.Tokens) > 0 { + choice.Logprobs = logprobs } } baseResp.Object = textCompletionObject return &openaiserverapi.TextCompletionResponse{ BaseCompletionResponse: baseResp, - Choices: []openaiserverapi.TextRespChoice{{BaseResponseChoice: baseChoice, Text: respText}}, + Choices: []openaiserverapi.TextRespChoice{choice}, } } @@ -500,7 +538,19 @@ func (s *VllmSimulator) createCompletionResponse(isChatCompletion bool, respToke // usageData - usage (tokens statistics) for this response func (s *VllmSimulator) sendResponse(reqCtx *openaiserverapi.CompletionReqCtx, respTokens []string, toolCalls []openaiserverapi.ToolCall, modelName string, finishReason string, usageData *openaiserverapi.Usage) { - resp := s.createCompletionResponse(reqCtx.IsChatCompletion, respTokens, toolCalls, &finishReason, usageData, modelName, + // Extract logprob data from request (if needed) + var logprobData *LogprobData + logprobsPtr := reqCtx.CompletionReq.GetLogprobs() + // Only create logprobData if logprobs are requested (value > 0) + if logprobsPtr != nil && *logprobsPtr > 0 && toolCalls == nil { + logprobData = &LogprobData{ + IncludeLogprobs: true, + TopLogprobs: *logprobsPtr, + Logprobs: logprobsPtr, + } + } + + resp := s.createCompletionResponse(logprobData, reqCtx.IsChatCompletion, respTokens, toolCalls, &finishReason, usageData, modelName, reqCtx.CompletionReq.IsDoRemoteDecode()) // calculate how long to wait before returning the response, time is based on number of tokens diff --git a/pkg/llm-d-inference-sim/simulator_test.go b/pkg/llm-d-inference-sim/simulator_test.go index e8e57e4..de26333 100644 --- a/pkg/llm-d-inference-sim/simulator_test.go +++ b/pkg/llm-d-inference-sim/simulator_test.go @@ -507,6 +507,203 @@ var _ = Describe("Simulator", func() { }) }) + Context("logprobs functionality", func() { + DescribeTable("streaming chat completions with logprobs", + func(mode string, logprobs bool, topLogprobs int) { + ctx := context.TODO() + client, err := startServer(ctx, mode) + Expect(err).NotTo(HaveOccurred()) + + openaiclient, params := getOpenAIClentAndChatParams(client, model, userMessage, true) + params.Logprobs = param.NewOpt(logprobs) + if logprobs && topLogprobs > 0 { + params.TopLogprobs = param.NewOpt(int64(topLogprobs)) + } + + stream := openaiclient.Chat.Completions.NewStreaming(ctx, params) + defer func() { + err := stream.Close() + Expect(err).NotTo(HaveOccurred()) + }() + + tokens := []string{} + chunksWithLogprobs := 0 + + for stream.Next() { + chunk := stream.Current() + for _, choice := range chunk.Choices { + if choice.FinishReason == "" && choice.Delta.Content != "" { + tokens = append(tokens, choice.Delta.Content) + + // Check logprobs in streaming chunks + if logprobs && len(choice.Logprobs.Content) > 0 { + chunksWithLogprobs++ + logprobContent := choice.Logprobs.Content[0] + Expect(logprobContent.Token).To(Equal(choice.Delta.Content)) + Expect(logprobContent.Logprob).To(BeNumerically("<=", 0)) + + if topLogprobs > 0 { + Expect(logprobContent.TopLogprobs).To(HaveLen(topLogprobs)) + Expect(logprobContent.TopLogprobs[0].Token).To(Equal(choice.Delta.Content)) + } + } + } + } + } + + msg := strings.Join(tokens, "") + if mode == common.ModeRandom { + Expect(dataset.IsValidText(msg)).To(BeTrue()) + } else { + Expect(msg).Should(Equal(userMessage)) + } + + // Verify logprobs behaviour + if logprobs { + Expect(chunksWithLogprobs).To(BeNumerically(">", 0), "Should have chunks with logprobs") + } else { + Expect(chunksWithLogprobs).To(Equal(0), "Should not have chunks with logprobs when not requested") + } + }, + func(mode string, logprobs bool, topLogprobs int) string { + return fmt.Sprintf("mode: %s logprobs: %t top_logprobs: %d", mode, logprobs, topLogprobs) + }, + Entry(nil, common.ModeEcho, true, 0), // logprobs=true, default top_logprobs + Entry(nil, common.ModeEcho, true, 2), // logprobs=true, top_logprobs=2 + Entry(nil, common.ModeEcho, false, 0), // logprobs=false + ) + + DescribeTable("streaming text completions with logprobs", + func(mode string, logprobsCount int) { + ctx := context.TODO() + client, err := startServer(ctx, mode) + Expect(err).NotTo(HaveOccurred()) + + openaiclient, params := getOpenAIClentAndCompletionParams(client, model, userMessage, true) + if logprobsCount > 0 { + params.Logprobs = param.NewOpt(int64(logprobsCount)) + } + + stream := openaiclient.Completions.NewStreaming(ctx, params) + defer func() { + err := stream.Close() + Expect(err).NotTo(HaveOccurred()) + }() + + tokens := []string{} + chunksWithLogprobs := 0 + + for stream.Next() { + chunk := stream.Current() + for _, choice := range chunk.Choices { + if choice.FinishReason == "" && choice.Text != "" { + tokens = append(tokens, choice.Text) + + // Check logprobs in streaming chunks + if logprobsCount > 0 && len(choice.Logprobs.Tokens) > 0 { + chunksWithLogprobs++ + Expect(choice.Logprobs.Tokens[0]).To(Equal(choice.Text)) + Expect(choice.Logprobs.TokenLogprobs[0]).To(BeNumerically("<=", 0)) + Expect(choice.Logprobs.TopLogprobs[0]).To(HaveLen(logprobsCount)) + Expect(choice.Logprobs.TopLogprobs[0]).To(HaveKey(choice.Text)) + } + } + } + } + + text := strings.Join(tokens, "") + if mode == common.ModeRandom { + Expect(dataset.IsValidText(text)).To(BeTrue()) + } else { + Expect(text).Should(Equal(userMessage)) + } + + // Verify logprobs behaviour + if logprobsCount > 0 { + Expect(chunksWithLogprobs).To(BeNumerically(">", 0), "Should have chunks with logprobs") + } else { + Expect(chunksWithLogprobs).To(Equal(0), "Should not have chunks with logprobs when not requested") + } + }, + func(mode string, logprobsCount int) string { + return fmt.Sprintf("mode: %s logprobs: %d", mode, logprobsCount) + }, + Entry(nil, common.ModeEcho, 0), // No logprobs + Entry(nil, common.ModeEcho, 2), // logprobs=2 + ) + + DescribeTable("non-streaming completions with logprobs", + func(isChat bool, mode string, logprobsParam interface{}) { + ctx := context.TODO() + client, err := startServer(ctx, mode) + Expect(err).NotTo(HaveOccurred()) + + var resp interface{} + + if isChat { + openaiclient, params := getOpenAIClentAndChatParams(client, model, userMessage, false) + if logprobsParam != nil { + if logprobs, ok := logprobsParam.(bool); ok && logprobs { + params.Logprobs = param.NewOpt(true) + params.TopLogprobs = param.NewOpt(int64(2)) + } + } + resp, err = openaiclient.Chat.Completions.New(ctx, params) + } else { + openaiclient, params := getOpenAIClentAndCompletionParams(client, model, userMessage, false) + if logprobsParam != nil { + if logprobsCount, ok := logprobsParam.(int); ok && logprobsCount > 0 { + params.Logprobs = param.NewOpt(int64(logprobsCount)) + } + } + resp, err = openaiclient.Completions.New(ctx, params) + } + + Expect(err).NotTo(HaveOccurred()) + + // Verify logprobs in non-streaming response + if isChat { + chatResp := resp.(*openai.ChatCompletion) + Expect(chatResp.Choices).ShouldNot(BeEmpty()) + + if logprobsParam != nil { + Expect(chatResp.Choices[0].Logprobs).NotTo(BeNil()) + Expect(chatResp.Choices[0].Logprobs.Content).NotTo(BeEmpty()) + + tokens := common.Tokenize(chatResp.Choices[0].Message.Content) + Expect(chatResp.Choices[0].Logprobs.Content).To(HaveLen(len(tokens))) + } else { + Expect(chatResp.Choices[0].Logprobs).To(BeNil()) + } + } else { + textResp := resp.(*openai.Completion) + Expect(textResp.Choices).ShouldNot(BeEmpty()) + + if logprobsParam != nil { + Expect(textResp.Choices[0].Logprobs).NotTo(BeNil()) + + tokens := common.Tokenize(textResp.Choices[0].Text) + Expect(textResp.Choices[0].Logprobs.Tokens).To(HaveLen(len(tokens))) + } else { + Expect(textResp.Choices[0].Logprobs).To(BeNil()) + } + } + }, + func(isChat bool, mode string, logprobsParam interface{}) string { + apiType := "text" + if isChat { + apiType = "chat" + } + return fmt.Sprintf("%s mode: %s logprobs: %v", apiType, mode, logprobsParam) + }, + Entry(nil, true, common.ModeEcho, true), // Chat with logprobs + Entry(nil, true, common.ModeEcho, nil), // Chat without logprobs + Entry(nil, false, common.ModeEcho, 2), // Text with logprobs=2 + Entry(nil, false, common.ModeEcho, nil), // Text without logprobs + ) + + }) + Context("max-model-len context window validation", func() { It("Should reject requests exceeding context window", func() { ctx := context.TODO() diff --git a/pkg/llm-d-inference-sim/streaming.go b/pkg/llm-d-inference-sim/streaming.go index 1bd2525..9411ec5 100644 --- a/pkg/llm-d-inference-sim/streaming.go +++ b/pkg/llm-d-inference-sim/streaming.go @@ -37,6 +37,10 @@ type streamingContext struct { nPromptTokens int nCachedPromptTokens int requestID string + // Logprobs configuration + includeLogprobs bool + topLogprobs int + logprobsCount *int // Text completions } // sendStreamingResponse creates and sends a streaming response for completion requests of both types (text and chat) @@ -186,9 +190,9 @@ func (s *VllmSimulator) createUsageChunk(context *streamingContext, usageData *o } // createTextCompletionChunk creates and returns a CompletionRespChunk, a single chunk of streamed completion API response, -// for text completion +// for text completion. func (s *VllmSimulator) createTextCompletionChunk(context *streamingContext, token string, finishReason *string) openaiserverapi.CompletionRespChunk { - return &openaiserverapi.TextCompletionResponse{ + chunk := &openaiserverapi.TextCompletionResponse{ BaseCompletionResponse: openaiserverapi.BaseCompletionResponse{ ID: chatComplIDPrefix + common.GenerateUUIDString(), Created: context.creationTime, @@ -202,6 +206,18 @@ func (s *VllmSimulator) createTextCompletionChunk(context *streamingContext, tok }, }, } + + // Generate logprobs if requested and token is not empty + if context.includeLogprobs && token != "" && context.logprobsCount != nil && *context.logprobsCount > 0 { + // Use token position based on current time + tokenPosition := int(context.creationTime) % 1000 // Simple position simulation + logprobs := common.GenerateSingleTokenTextLogprobs(token, tokenPosition, *context.logprobsCount) + if logprobs != nil { + chunk.Choices[0].Logprobs = logprobs + } + } + + return chunk } // createChatCompletionChunk creates and returns a CompletionRespChunk, a single chunk of streamed completion @@ -230,6 +246,18 @@ func (s *VllmSimulator) createChatCompletionChunk(context *streamingContext, tok chunk.Choices[0].Delta.ToolCalls = []openaiserverapi.ToolCall{*tool} } else if len(token) > 0 { chunk.Choices[0].Delta.Content.Raw = token + + // Generate logprobs if requested and token is not empty + if context.includeLogprobs { + // Use token position based on current time + tokenPosition := int(context.creationTime) % 1000 // Simple position simulation + logprobs := common.GenerateSingleTokenChatLogprobs(token, tokenPosition, context.topLogprobs) + if logprobs != nil { + chunk.Choices[0].Logprobs = &common.ChatLogprobs{ + Content: []common.LogprobsContent{*logprobs}, + } + } + } } return &chunk diff --git a/pkg/openai-server-api/request.go b/pkg/openai-server-api/request.go index 34db0ee..67a0e88 100644 --- a/pkg/openai-server-api/request.go +++ b/pkg/openai-server-api/request.go @@ -67,6 +67,14 @@ type CompletionRequest interface { IsDoRemotePrefill() bool // GetFullPrompt returns the full prompt including system and user prompts GetFullPrompt() string + // GetLogprobs returns the logprobs parameter (for text completions) + GetLogprobs() *int + // IncludeLogprobs returns true if logprobs should be included (for chat completions) + IncludeLogprobs() bool + // GetTopLogprobs returns the number of top logprobs to include (for chat completions) + GetTopLogprobs() *int + // GetTopLogprobsCount returns the computed count of top logprobs to include + GetTopLogprobsCount() int } // BaseCompletionRequest contains base completion request related information @@ -178,6 +186,12 @@ type ChatCompletionRequest struct { // possible values: none, auto, required. // Sending an object with a specific tool, is currently not supported. ToolChoice string `json:"tool_choice,omitempty"` + + // Logprobs controls whether log probabilities are included in the response + Logprobs bool `json:"logprobs,omitempty"` + + // TopLogprobs controls how many alternative tokens to include in the logprobs + TopLogprobs *int `json:"top_logprobs,omitempty"` } // function defines a tool @@ -253,6 +267,37 @@ func (req *ChatCompletionRequest) GetFullPrompt() string { return prompt } +func (c *ChatCompletionRequest) GetLogprobs() *int { + if !c.Logprobs { + return nil // No logprobs requested + } + if c.TopLogprobs != nil { + return c.TopLogprobs // Return the top_logprobs value + } + // Default to 1 if logprobs=true but no top_logprobs specified + defaultVal := 1 + return &defaultVal +} + +func (c *ChatCompletionRequest) IncludeLogprobs() bool { + return c.Logprobs +} + +func (c *ChatCompletionRequest) GetTopLogprobs() *int { + return c.TopLogprobs +} + +func (c *ChatCompletionRequest) GetTopLogprobsCount() int { + if c.Logprobs { + if c.TopLogprobs != nil { + return *c.TopLogprobs + } + // Default to 1 if logprobs=true but no top_logprobs specified + return 1 + } + return 0 +} + // v1/completion // TextCompletionRequest defines structure of /completion request type TextCompletionRequest struct { @@ -266,6 +311,12 @@ type TextCompletionRequest struct { // The token count of your prompt plus `max_tokens` cannot exceed the model's // context length. MaxTokens *int64 `json:"max_tokens"` + + // Logprobs includes the log probabilities on the logprobs most likely tokens, + // as well the chosen tokens. For example, if logprobs is 5, the API will return + // a list of the 5 most likely tokens. The API will always return the logprob + // of the sampled token, so there may be up to logprobs+1 elements in the response. + Logprobs *int `json:"logprobs,omitempty"` } func (t *TextCompletionRequest) GetPrompt() string { @@ -291,3 +342,23 @@ func (c *TextCompletionRequest) GetMaxCompletionTokens() *int64 { func (t *TextCompletionRequest) GetFullPrompt() string { return "### user:\n" + t.Prompt + "\n" } + +func (t *TextCompletionRequest) GetLogprobs() *int { + return t.Logprobs +} + +func (t *TextCompletionRequest) IncludeLogprobs() bool { + return t.Logprobs != nil && *t.Logprobs > 0 +} + +func (t *TextCompletionRequest) GetTopLogprobs() *int { + // For text completions, this method returns the same as GetLogprobs + return t.Logprobs +} + +func (t *TextCompletionRequest) GetTopLogprobsCount() int { + if t.Logprobs != nil && *t.Logprobs > 0 { + return *t.Logprobs + } + return 0 +} diff --git a/pkg/openai-server-api/response.go b/pkg/openai-server-api/response.go index d32784e..25d44d2 100644 --- a/pkg/openai-server-api/response.go +++ b/pkg/openai-server-api/response.go @@ -22,6 +22,7 @@ import ( "errors" "strings" + "github.com/llm-d/llm-d-inference-sim/pkg/common" "github.com/valyala/fasthttp" ) @@ -175,6 +176,8 @@ type ChatRespChoice struct { BaseResponseChoice // Message contains choice's Message Message Message `json:"message"` + // Logprobs contains the log probabilities for the response + Logprobs *common.ChatLogprobs `json:"logprobs,omitempty"` } // TextCompletionResponse defines structure of /completion response @@ -189,6 +192,8 @@ type TextRespChoice struct { BaseResponseChoice // Text defines request's content Text string `json:"text"` + // Logprobs contains the log probabilities for the response + Logprobs *common.TextLogprobs `json:"logprobs,omitempty"` } // CompletionRespChunk is an interface that defines a single response chunk @@ -206,6 +211,8 @@ type ChatRespChunkChoice struct { BaseResponseChoice // Delta is a content of the chunk Delta Message `json:"delta"` + // Logprobs contains the log probabilities for the response chunk + Logprobs *common.ChatLogprobs `json:"logprobs,omitempty"` } // CompletionError defines the simulator's response in case of an error