Skip to content

Commit 04b70f7

Browse files
authored
Max BM25 score / BM25 normalization (#267)
1 parent 9b783ec commit 04b70f7

File tree

4 files changed

+194
-3
lines changed

4 files changed

+194
-3
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88

99
### Added
1010
- Added new Gemini 2.0 Pro and Flash models to the model registry (aliases `gem20p` and `gem20f`, respectively). Added Gemini 2.0 Flash Lite Preview model (alias `gem20fl`) and Gemini 2.0 Flash Thinking Experimental model (alias `gem20ft`).
11+
- Added BM25 normalization kwarg to `RAGTools.jl` to enable 0-1 query-specific normalization of BM25 scores for easier filtering and comparison. See `?RT.bm25` and `?RT.max_bm25_score` for more information.
1112

1213
### Fixed
1314

ext/RAGToolsExperimentalExt.jl

Lines changed: 89 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,12 +184,26 @@ function RT.document_term_matrix(documents::AbstractVector{<:AbstractString})
184184
end
185185

186186
"""
187-
RT.bm25(dtm::AbstractDocumentTermMatrix, query::Vector{String}; k1::Float32=1.2f0, b::Float32=0.75f0)
187+
RT.bm25(
188+
dtm::RT.AbstractDocumentTermMatrix, query::AbstractVector{<:AbstractString};
189+
k1::Float32 = 1.2f0, b::Float32 = 0.75f0, normalize::Bool = false, normalize_max_tf::Real = 3,
190+
normalize_min_doc_rel_length::Float32 = 1.0f0)
188191
189192
Scores all documents in `dtm` based on the `query`.
190193
191194
References: https://opensourceconnections.com/blog/2015/10/16/bm25-the-next-generation-of-lucene-relevation/
192195
196+
# Arguments
197+
- `dtm`: A `DocumentTermMatrix` object.
198+
- `query`: A vector of query tokens.
199+
- `k1`: The k1 parameter for BM25.
200+
- `b`: The b parameter for BM25.
201+
- `normalize`: Whether to normalize the scores (returns scores between 0 and 1).
202+
Theoretically, if you choose `normalize_max_tf` and `normalize_min_doc_rel_length` to be too low, you could get scores greater than 1.
203+
- `normalize_max_tf`: The maximum term frequency to normalize to. 3 is a good default (assumes max 3 hits per document).
204+
- `normalize_min_doc_rel_length`: The minimum document relative length to normalize to. 0.5 is a good default.
205+
Ideally, pick the minimum document relative length of the corpus that is non-zero
206+
`min_doc_rel_length = minimum(x for x in RT.doc_rel_length(RT.chunkdata(key_index)) if x > 0) |> Float32`
193207
# Example
194208
```
195209
documents = [["this", "is", "a", "test"], ["this", "is", "another", "test"], ["foo", "bar", "baz"]]
@@ -198,14 +212,36 @@ query = ["this"]
198212
scores = bm25(dtm, query)
199213
# Returns array with 3 scores (one for each document)
200214
```
215+
216+
Normalization is done by dividing the score by the maximum possible score (given some assumptions).
217+
It's useful to be get results in the same range as cosine similarity scores and when comparing different queries or documents.
218+
219+
# Example
220+
```
221+
documents = [["this", "is", "a", "test"], ["this", "is", "another", "test"], ["foo", "bar", "baz"]]
222+
dtm = document_term_matrix(documents)
223+
query = ["this"]
224+
scores = bm25(dtm, query)
225+
scores_norm = bm25(dtm, query; normalize = true)
226+
227+
## Make it more accurate for your dataset/index
228+
normalize_max_tf = 3 # assume max term frequency is 3 (what is likely for your dataset? depends on chunk size, preprocessing, etc.)
229+
normalize_min_doc_rel_length = minimum([x for x in RT.doc_rel_length(dtm) if x > 0]) |> Float32
230+
scores_norm = bm25(dtm, query; normalize = true, normalize_max_tf, normalize_min_doc_rel_length)
231+
```
201232
"""
202233
function RT.bm25(
203234
dtm::RT.AbstractDocumentTermMatrix, query::AbstractVector{<:AbstractString};
204-
k1::Float32 = 1.2f0, b::Float32 = 0.75f0)
235+
k1::Float32 = 1.2f0, b::Float32 = 0.75f0, normalize::Bool = false, normalize_max_tf::Real = 3,
236+
normalize_min_doc_rel_length::Float32 = 0.5f0)
237+
@assert normalize_max_tf>0 "normalize_max_tf term frequency must be positive (got $normalize_max_tf)"
238+
@assert normalize_min_doc_rel_length>0 "normalize_min_doc_rel_length must be positive (got $normalize_min_doc_rel_length)"
239+
205240
scores = zeros(Float32, size(tf(dtm), 1))
206241
## Identify non-zero items to leverage the sparsity
207242
nz_rows = rowvals(tf(dtm))
208243
nz_vals = nonzeros(tf(dtm))
244+
max_score = 0.0f0
209245
for i in eachindex(query)
210246
t = query[i]
211247
t_id = get(vocab_lookup(dtm), t, nothing)
@@ -222,9 +258,60 @@ function RT.bm25(
222258
## @info "di: $di, tf: $tf, doc_len: $doc_len, idf: $idf, tf_top: $tf_top, tf_bottom: $tf_bottom, score: $score"
223259
scores[di] += score
224260
end
261+
## Once per token, calculate max score
262+
## assumes max term frequency is `normalize_max_tf` and min document relative length is `normalize_min_doc_rel_length`
263+
if normalize
264+
max_score += idf_ * (normalize_max_tf * (k1 + 1.0f0)) / (normalize_max_tf +
265+
k1 * (1.0f0 - b + b * normalize_min_doc_rel_length))
266+
end
267+
end
268+
if normalize && !iszero(max_score)
269+
scores ./= max_score
270+
elseif normalize && iszero(max_score)
271+
## happens only with empty queries, so scores is zero anyway
272+
@warn "BM25: `max_score` is zero, so scores are not normalized. Returning unnormalized scores (all zero)."
225273
end
226274

227275
return scores
228276
end
229277

278+
"""
279+
RT.max_bm25_score(
280+
dtm::RT.AbstractDocumentTermMatrix, query_tokens::AbstractVector{<:AbstractString};
281+
k1::Float32 = 1.2f0, b::Float32 = 0.75f0, max_tf::Real = 3,
282+
min_doc_rel_length::Float32 = 0.5f0)
283+
284+
Returns the maximum BM25 score that can be achieved for a given query (assuming the `max_tf` matches and the `min_doc_rel_length` being the smallest document relative length).
285+
Good for normalizing BM25 scores.
286+
287+
# Example
288+
```
289+
max_score = max_bm25_score(RT.chunkdata(key_index), query_tokens)
290+
```
291+
"""
292+
function RT.max_bm25_score(
293+
dtm::RT.AbstractDocumentTermMatrix, query_tokens::AbstractVector{<:AbstractString};
294+
k1::Float32 = 1.2f0, b::Float32 = 0.75f0, max_tf::Real = 3,
295+
min_doc_rel_length::Float32 = 0.5f0)
296+
max_score = 0.0f0
297+
@inbounds for t in query_tokens
298+
t_id = get(RT.vocab_lookup(dtm), t, nothing)
299+
t_id === nothing && continue
300+
301+
idf_ = RT.idf(dtm)[t_id]
302+
303+
# Find maximum tf (term frequency) for this term in any document - pre-set in kwargs!
304+
# eg, `max_tf = maximum(@view(RT.tf(dtm)[:, t_id]))` but that would be a bit extreme and slow
305+
306+
# Find first non-zero element in doc lengths -- pre-set in kwargs!
307+
# eg, `min_doc_rel_length = minimum(x for x in RT.doc_rel_length(RT.chunkdata(key_index)) if x > 0) |> Float32`
308+
309+
# Maximum tf component assuming perfect match
310+
tf_top = (max_tf * (k1 + 1.0f0))
311+
tf_bottom = (max_tf + k1 * (1.0f0 - b + b * min_doc_rel_length))
312+
max_score += idf_ * tf_top / tf_bottom
313+
end
314+
return max_score
315+
end
316+
230317
end # end of module

src/Experimental/RAGTools/retrieval.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,7 @@ function find_closest(
452452
return positions[new_positions], scores
453453
end
454454

455+
function max_bm25_score end
455456
"""
456457
find_closest(
457458
finder::BM25Similarity, dtm::AbstractDocumentTermMatrix,

test/Experimental/RAGTools/retrieval.jl

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ using PromptingTools.Experimental.RAGTools: find_closest, hamming_distance, find
1616
using PromptingTools.Experimental.RAGTools: NoReranker, CohereReranker
1717
using PromptingTools.Experimental.RAGTools: hamming_distance, BitPackedCosineSimilarity,
1818
pack_bits, unpack_bits
19-
using PromptingTools.Experimental.RAGTools: bm25, document_term_matrix, DocumentTermMatrix
19+
using PromptingTools.Experimental.RAGTools: bm25, max_bm25_score, document_term_matrix,
20+
DocumentTermMatrix
2021

2122
@testset "rephrase" begin
2223
# Test rephrase with NoRephraser, simple passthrough
@@ -184,6 +185,107 @@ end
184185
@test scores[1] expected * 4
185186
@test scores[2] expected * 4
186187
@test scores[3] 0
188+
189+
## BM25 normalization
190+
# Basic test corpus
191+
documents = [
192+
["this", "is", "a", "test", "document"],
193+
["this", "is", "another", "test"],
194+
["completely", "different", "content"],
195+
["test", "test", "test", "test"], # document with repeated terms
196+
["single"] # shortest document
197+
]
198+
dtm = document_term_matrix(documents)
199+
200+
# Test 1: Basic normalization - scores should be between 0 and 1
201+
query = ["test"]
202+
rel_len = RT.doc_rel_length(dtm)
203+
scores_norm = bm25(dtm, query; normalize = true, normalize_max_tf = 3,
204+
normalize_min_doc_rel_length = minimum(rel_len))
205+
@test all(0 .≤ scores_norm .≤ 1)
206+
207+
# Test that document with most "test" occurrences gets highest score
208+
@test argmax(scores_norm) == 4
209+
210+
# Test 2: Compare with manual normalization
211+
scores_raw = bm25(dtm, query; normalize = false)
212+
max_score = max_bm25_score(
213+
dtm, query; max_tf = 3, min_doc_rel_length = minimum(rel_len))
214+
scores_manual_norm = scores_raw ./ max_score
215+
@test scores_norm scores_manual_norm
216+
217+
# Test 3: Parameter variations
218+
params = [
219+
(k1 = 1.2f0, b = 0.75f0, max_tf = 3, min_doc_len = 0.5f0),
220+
(k1 = 2.0f0, b = 0.5f0, max_tf = 10, min_doc_len = 1.0f0)
221+
]
222+
223+
for p in params
224+
scores = bm25(dtm, query;
225+
normalize = true,
226+
k1 = p.k1,
227+
b = p.b,
228+
normalize_max_tf = p.max_tf,
229+
normalize_min_doc_rel_length = p.min_doc_len
230+
)
231+
@test all(0 .≤ scores .≤ 1)
232+
233+
# Verify against max_bm25_score
234+
max_theoretical = max_bm25_score(dtm, query;
235+
k1 = p.k1,
236+
b = p.b,
237+
max_tf = p.max_tf,
238+
min_doc_rel_length = p.min_doc_len
239+
)
240+
scores_raw = bm25(dtm, query;
241+
normalize = false,
242+
k1 = p.k1,
243+
b = p.b
244+
)
245+
@test maximum(scores_raw) max_theoretical
246+
end
247+
248+
# Test 4: Edge cases
249+
# Empty query
250+
@test all(bm25(dtm, String[]; normalize = true) .== 0)
251+
252+
# Query with non-existent words
253+
@test all(bm25(dtm, ["nonexistent"]; normalize = true) .== 0)
254+
255+
# Multiple query terms
256+
multi_query = ["test", "document"]
257+
multi_scores = bm25(dtm, multi_query; normalize = true)
258+
@test all(0 .≤ multi_scores .≤ 1)
259+
# Document 1 should have highest score as it contains both terms
260+
@test argmax(multi_scores) == 1
261+
262+
# Test 5: Repeated terms in query
263+
repeated_query = ["test", "test", "test"]
264+
rep_scores = bm25(dtm, repeated_query; normalize = true)
265+
@test all(0 .≤ rep_scores .≤ 1)
266+
267+
# Test 6: Special cases - uniform document length
268+
uniform_docs = [["word", "test"] for _ in 1:3]
269+
uniform_dtm = document_term_matrix(uniform_docs)
270+
uniform_scores = bm25(uniform_dtm, ["test"]; normalize = true,
271+
normalize_max_tf = 1, normalize_min_doc_rel_length = 1.0f0)
272+
@test all(uniform_scores .≈ 1.0)
273+
274+
# Test 7: Verify normalization with different max_tf values
275+
high_tf_docs = [
276+
["test", "test", "test"], # tf = 3
277+
["test"], # tf = 1
278+
["other", "words"] # tf = 0
279+
]
280+
high_tf_dtm = document_term_matrix(high_tf_docs)
281+
282+
# With max_tf = 1 (matching actual tf in your dataset)
283+
scores_max1 = bm25(high_tf_dtm, ["test"]; normalize = true, normalize_max_tf = 1)
284+
# With max_tf = 3 (default)
285+
scores_max3 = bm25(high_tf_dtm, ["test"]; normalize = true, normalize_max_tf = 3)
286+
287+
# The first document should get a lower relative score with max_tf=3 (max will be higher!)
288+
@test scores_max3[1] < scores_max1[1]
187289
end
188290

189291
@testset "find_closest" begin

0 commit comments

Comments
 (0)