Skip to content

Commit 3967e23

Browse files
authored
feat: Add comprehensive log probabilities support (#221)
Add support for log probabilities in both chat completions and text completions APIs: - Add logprobs parameter to chat completions with top_logprobs support - Add logprobs parameter to text completions with configurable count - Implement streaming and non-streaming logprobs functionality - Add comprehensive test coverage for all logprobs scenarios - Add utility functions for logprobs calculation and validation - Support both echo and random modes for logprobs generation - Include proper token-level probability information in responses - Fix undefined variables in tests after upstream merge Signed-off-by: Rui Vieira <[email protected]>
1 parent 9a57299 commit 3967e23

File tree

8 files changed

+857
-13
lines changed

8 files changed

+857
-13
lines changed

pkg/common/logprobs_test.go

Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
/*
2+
Copyright 2025 The llm-d-inference-sim Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package common
18+
19+
import (
20+
. "github.com/onsi/ginkgo/v2"
21+
. "github.com/onsi/gomega"
22+
)
23+
24+
var _ = Describe("Logprobs", func() {
25+
26+
Context("GenerateTextLogprobs", func() {
27+
It("should generate correct text logprobs structure", func() {
28+
tokens := []string{" Paris", ",", " the", " capital"}
29+
logprobsCount := 2
30+
31+
logprobs := GenerateTextLogprobs(tokens, logprobsCount)
32+
33+
Expect(logprobs).NotTo(BeNil())
34+
Expect(logprobs.Tokens).To(HaveLen(len(tokens)))
35+
Expect(logprobs.TokenLogprobs).To(HaveLen(len(tokens)))
36+
Expect(logprobs.TopLogprobs).To(HaveLen(len(tokens)))
37+
Expect(logprobs.TextOffset).To(HaveLen(len(tokens)))
38+
39+
// Check that each top logprobs entry has the expected number of alternatives
40+
for i, topLogprob := range logprobs.TopLogprobs {
41+
Expect(topLogprob).To(HaveLen(logprobsCount))
42+
// Check that the main token is included in the alternatives
43+
Expect(topLogprob).To(HaveKey(tokens[i]))
44+
}
45+
46+
// Check text offsets are calculated correctly (byte-based)
47+
expectedOffsets := []int{0, 6, 7, 11} // " Paris" - 6, "," - 1, " the" -4, " capital" - 11
48+
for i, expected := range expectedOffsets {
49+
Expect(logprobs.TextOffset[i]).To(Equal(expected))
50+
}
51+
52+
// Check deterministic logprobs
53+
expectedLogprob0 := -1.0 // defaultLogprob - float64(0%3)*0.1
54+
Expect(logprobs.TokenLogprobs[0]).To(Equal(expectedLogprob0))
55+
})
56+
})
57+
58+
Context("GenerateChatLogprobs", func() {
59+
It("should generate correct chat logprobs structure", func() {
60+
tokens := []string{"4"}
61+
topLogprobsCount := 3
62+
63+
logprobs := GenerateChatLogprobs(tokens, topLogprobsCount)
64+
65+
Expect(logprobs).NotTo(BeNil())
66+
Expect(logprobs.Content).To(HaveLen(len(tokens)))
67+
68+
content := logprobs.Content[0]
69+
Expect(content.Token).To(Equal(tokens[0]))
70+
Expect(content.Bytes).To(HaveLen(len(tokens[0])))
71+
Expect(content.TopLogprobs).To(HaveLen(topLogprobsCount))
72+
73+
// Check that the main token is the first in top logprobs
74+
Expect(content.TopLogprobs[0].Token).To(Equal(tokens[0]))
75+
76+
// Check alternative tokens follow the pattern
77+
expectedAlt1 := "4_1"
78+
Expect(content.TopLogprobs[1].Token).To(Equal(expectedAlt1))
79+
80+
// Check byte conversion
81+
expectedBytes := []int{52} // byte value of '4'
82+
for i, expected := range expectedBytes {
83+
Expect(content.Bytes[i]).To(Equal(expected))
84+
}
85+
86+
// Check deterministic logprobs
87+
expectedLogprob := -1.0 // defaultLogprob - float64(0%3)*0.1
88+
Expect(content.Logprob).To(Equal(expectedLogprob))
89+
})
90+
})
91+
92+
Context("calculateLogprob", func() {
93+
It("should calculate main token probabilities correctly", func() {
94+
// Test position cycle behavior (cycle of 3)
95+
// Position 0: -1.0 - (0 % 3) * 0.1 = -1.0
96+
result0 := calculateLogprob(0, 0)
97+
Expect(result0).To(Equal(-1.0))
98+
99+
// Position 1: -1.0 - (1 % 3) * 0.1 = -1.1
100+
result1 := calculateLogprob(1, 0)
101+
Expect(result1).To(Equal(-1.1))
102+
103+
// Position 2: -1.0 - (2 % 3) * 0.1 = -1.2
104+
result2 := calculateLogprob(2, 0)
105+
Expect(result2).To(Equal(-1.2))
106+
107+
// Position 3: -1.0 - (3 % 3) * 0.1 = -1.0 (cycle repeats)
108+
result3 := calculateLogprob(3, 0)
109+
Expect(result3).To(Equal(-1.0))
110+
111+
// Position 4: -1.0 - (4 % 3) * 0.1 = -1.1 (cycle repeats)
112+
result4 := calculateLogprob(4, 0)
113+
Expect(result4).To(Equal(-1.1))
114+
})
115+
116+
It("should calculate alternative token probabilities correctly", func() {
117+
// Test alternative token decrements (0.5 per alternative index)
118+
tokenPosition := 0 // Start with position 0 (main logprob = -1.0)
119+
120+
// Alternative 1: -1.0 - 1 * 0.5 = -1.5
121+
alt1 := calculateLogprob(tokenPosition, 1)
122+
Expect(alt1).To(Equal(-1.5))
123+
124+
// Alternative 2: -1.0 - 2 * 0.5 = -2.0
125+
alt2 := calculateLogprob(tokenPosition, 2)
126+
Expect(alt2).To(Equal(-2.0))
127+
128+
// Alternative 3: -1.0 - 3 * 0.5 = -2.5
129+
alt3 := calculateLogprob(tokenPosition, 3)
130+
Expect(alt3).To(Equal(-2.5))
131+
})
132+
133+
It("should combine position cycle and alternative index correctly", func() {
134+
// Test with position 1 (main logprob = -1.1)
135+
tokenPosition := 1
136+
137+
// Main token: -1.0 - (1 % 3) * 0.1 = -1.1
138+
main := calculateLogprob(tokenPosition, 0)
139+
Expect(main).To(Equal(-1.1))
140+
141+
// Alternative 1: -1.1 - 1 * 0.5 = -1.6
142+
alt1 := calculateLogprob(tokenPosition, 1)
143+
Expect(alt1).To(Equal(-1.6))
144+
145+
// Alternative 2: -1.1 - 2 * 0.5 = -2.1
146+
alt2 := calculateLogprob(tokenPosition, 2)
147+
Expect(alt2).To(Equal(-2.1))
148+
})
149+
150+
It("should handle large position values correctly", func() {
151+
// Test with large position values to ensure cycle works
152+
largePosition := 100
153+
154+
// Position 100: -1.0 - (100 % 3) * 0.1 = -1.0 - 1 * 0.1 = -1.1
155+
result := calculateLogprob(largePosition, 0)
156+
Expect(result).To(Equal(-1.1))
157+
158+
// With alternative: -1.1 - 1 * 0.5 = -1.6
159+
resultAlt := calculateLogprob(largePosition, 1)
160+
Expect(resultAlt).To(Equal(-1.6))
161+
})
162+
163+
It("should handle edge cases correctly", func() {
164+
// Test with zero values
165+
result := calculateLogprob(0, 0)
166+
Expect(result).To(Equal(-1.0))
167+
168+
// Test with large alternative index
169+
largeAlt := calculateLogprob(0, 10)
170+
expectedLargeAlt := -1.0 - float64(10)*0.5 // -6.0
171+
Expect(largeAlt).To(Equal(expectedLargeAlt))
172+
})
173+
})
174+
175+
Context("Other scenarios", func() {
176+
It("should handle empty tokens for text logprobs", func() {
177+
logprobs := GenerateTextLogprobs([]string{}, 2)
178+
179+
Expect(logprobs).NotTo(BeNil())
180+
Expect(logprobs.Tokens).To(BeEmpty())
181+
})
182+
183+
It("should handle empty tokens for chat logprobs", func() {
184+
logprobs := GenerateChatLogprobs([]string{}, 2)
185+
186+
Expect(logprobs).NotTo(BeNil())
187+
Expect(logprobs.Content).To(BeEmpty())
188+
})
189+
190+
It("should verify probability pattern as token position grows", func() {
191+
// Test the cycling pattern of probabilities
192+
193+
// Test first cycle (positions 0-2)
194+
prob0 := calculateLogprob(0, 0)
195+
prob1 := calculateLogprob(1, 0)
196+
prob2 := calculateLogprob(2, 0)
197+
198+
Expect(prob0).To(Equal(-1.0)) // defaultLogprob
199+
Expect(prob1).To(Equal(-1.1)) // defaultLogprob - 1*0.1
200+
Expect(prob2).To(Equal(-1.2)) // defaultLogprob - 2*0.1
201+
202+
// Test second cycle (positions 3-5) - should repeat the pattern
203+
prob3 := calculateLogprob(3, 0)
204+
prob4 := calculateLogprob(4, 0)
205+
prob5 := calculateLogprob(5, 0)
206+
207+
Expect(prob3).To(Equal(prob0)) // Should equal position 0
208+
Expect(prob4).To(Equal(prob1)) // Should equal position 1
209+
Expect(prob5).To(Equal(prob2)) // Should equal position 2
210+
211+
// Test third cycle (positions 6-8) - should repeat again
212+
prob6 := calculateLogprob(6, 0)
213+
prob7 := calculateLogprob(7, 0)
214+
prob8 := calculateLogprob(8, 0)
215+
216+
Expect(prob6).To(Equal(prob0)) // Should equal position 0
217+
Expect(prob7).To(Equal(prob1)) // Should equal position 1
218+
Expect(prob8).To(Equal(prob2)) // Should equal position 2
219+
220+
// Verify the cycling pattern continues for larger positions
221+
for i := 0; i < 20; i++ {
222+
expectedProb := defaultLogprob - float64(i%positionCycle)*positionDecrement
223+
actualProb := calculateLogprob(i, 0)
224+
Expect(actualProb).To(Equal(expectedProb), "Position %d should have probability %f", i, expectedProb)
225+
}
226+
})
227+
})
228+
229+
Context("No Limits", func() {
230+
It("should allow unlimited logprobs count", func() {
231+
tokens := []string{"test"}
232+
233+
// Test text completion (no clamping)
234+
textLogprobs := GenerateTextLogprobs(tokens, 10)
235+
Expect(textLogprobs.TopLogprobs[0]).To(HaveLen(10))
236+
237+
// Test chat completion (no clamping)
238+
chatLogprobs := GenerateChatLogprobs(tokens, 25)
239+
Expect(chatLogprobs.Content[0].TopLogprobs).To(HaveLen(25))
240+
241+
// Test high count
242+
textLogprobs = GenerateTextLogprobs(tokens, 100)
243+
Expect(textLogprobs.TopLogprobs[0]).To(HaveLen(100))
244+
245+
chatLogprobs = GenerateChatLogprobs(tokens, 50)
246+
Expect(chatLogprobs.Content[0].TopLogprobs).To(HaveLen(50))
247+
248+
// Test minimum (at least 1)
249+
textLogprobs = GenerateTextLogprobs(tokens, 0)
250+
Expect(textLogprobs.TopLogprobs[0]).To(HaveLen(1))
251+
})
252+
})
253+
})

0 commit comments

Comments
 (0)