Skip to content

Commit

Permalink
Max BM25 score / BM25 normalization (#267)
Browse files Browse the repository at this point in the history
  • Loading branch information
svilupp authored Feb 8, 2025
1 parent 9b783ec commit 04b70f7
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added
- Added new Gemini 2.0 Pro and Flash models to the model registry (aliases `gem20p` and `gem20f`, respectively). Added Gemini 2.0 Flash Lite Preview model (alias `gem20fl`) and Gemini 2.0 Flash Thinking Experimental model (alias `gem20ft`).
- Added BM25 normalization kwarg to `RAGTools.jl` to enable 0-1 query-specific normalization of BM25 scores for easier filtering and comparison. See `?RT.bm25` and `?RT.max_bm25_score` for more information.

### Fixed

Expand Down
91 changes: 89 additions & 2 deletions ext/RAGToolsExperimentalExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,12 +184,26 @@ function RT.document_term_matrix(documents::AbstractVector{<:AbstractString})
end

"""
RT.bm25(dtm::AbstractDocumentTermMatrix, query::Vector{String}; k1::Float32=1.2f0, b::Float32=0.75f0)
RT.bm25(
dtm::RT.AbstractDocumentTermMatrix, query::AbstractVector{<:AbstractString};
k1::Float32 = 1.2f0, b::Float32 = 0.75f0, normalize::Bool = false, normalize_max_tf::Real = 3,
normalize_min_doc_rel_length::Float32 = 1.0f0)
Scores all documents in `dtm` based on the `query`.
References: https://opensourceconnections.com/blog/2015/10/16/bm25-the-next-generation-of-lucene-relevation/
# Arguments
- `dtm`: A `DocumentTermMatrix` object.
- `query`: A vector of query tokens.
- `k1`: The k1 parameter for BM25.
- `b`: The b parameter for BM25.
- `normalize`: Whether to normalize the scores (returns scores between 0 and 1).
Theoretically, if you choose `normalize_max_tf` and `normalize_min_doc_rel_length` to be too low, you could get scores greater than 1.
- `normalize_max_tf`: The maximum term frequency to normalize to. 3 is a good default (assumes max 3 hits per document).
- `normalize_min_doc_rel_length`: The minimum document relative length to normalize to. 0.5 is a good default.
Ideally, pick the minimum document relative length of the corpus that is non-zero
`min_doc_rel_length = minimum(x for x in RT.doc_rel_length(RT.chunkdata(key_index)) if x > 0) |> Float32`
# Example
```
documents = [["this", "is", "a", "test"], ["this", "is", "another", "test"], ["foo", "bar", "baz"]]
Expand All @@ -198,14 +212,36 @@ query = ["this"]
scores = bm25(dtm, query)
# Returns array with 3 scores (one for each document)
```
Normalization is done by dividing the score by the maximum possible score (given some assumptions).
It's useful to be get results in the same range as cosine similarity scores and when comparing different queries or documents.
# Example
```
documents = [["this", "is", "a", "test"], ["this", "is", "another", "test"], ["foo", "bar", "baz"]]
dtm = document_term_matrix(documents)
query = ["this"]
scores = bm25(dtm, query)
scores_norm = bm25(dtm, query; normalize = true)
## Make it more accurate for your dataset/index
normalize_max_tf = 3 # assume max term frequency is 3 (what is likely for your dataset? depends on chunk size, preprocessing, etc.)
normalize_min_doc_rel_length = minimum([x for x in RT.doc_rel_length(dtm) if x > 0]) |> Float32
scores_norm = bm25(dtm, query; normalize = true, normalize_max_tf, normalize_min_doc_rel_length)
```
"""
function RT.bm25(
dtm::RT.AbstractDocumentTermMatrix, query::AbstractVector{<:AbstractString};
k1::Float32 = 1.2f0, b::Float32 = 0.75f0)
k1::Float32 = 1.2f0, b::Float32 = 0.75f0, normalize::Bool = false, normalize_max_tf::Real = 3,
normalize_min_doc_rel_length::Float32 = 0.5f0)
@assert normalize_max_tf>0 "normalize_max_tf term frequency must be positive (got $normalize_max_tf)"
@assert normalize_min_doc_rel_length>0 "normalize_min_doc_rel_length must be positive (got $normalize_min_doc_rel_length)"

scores = zeros(Float32, size(tf(dtm), 1))
## Identify non-zero items to leverage the sparsity
nz_rows = rowvals(tf(dtm))
nz_vals = nonzeros(tf(dtm))
max_score = 0.0f0
for i in eachindex(query)
t = query[i]
t_id = get(vocab_lookup(dtm), t, nothing)
Expand All @@ -222,9 +258,60 @@ function RT.bm25(
## @info "di: $di, tf: $tf, doc_len: $doc_len, idf: $idf, tf_top: $tf_top, tf_bottom: $tf_bottom, score: $score"
scores[di] += score
end
## Once per token, calculate max score
## assumes max term frequency is `normalize_max_tf` and min document relative length is `normalize_min_doc_rel_length`
if normalize
max_score += idf_ * (normalize_max_tf * (k1 + 1.0f0)) / (normalize_max_tf +
k1 * (1.0f0 - b + b * normalize_min_doc_rel_length))
end
end
if normalize && !iszero(max_score)
scores ./= max_score
elseif normalize && iszero(max_score)
## happens only with empty queries, so scores is zero anyway
@warn "BM25: `max_score` is zero, so scores are not normalized. Returning unnormalized scores (all zero)."
end

return scores
end

"""
RT.max_bm25_score(
dtm::RT.AbstractDocumentTermMatrix, query_tokens::AbstractVector{<:AbstractString};
k1::Float32 = 1.2f0, b::Float32 = 0.75f0, max_tf::Real = 3,
min_doc_rel_length::Float32 = 0.5f0)
Returns the maximum BM25 score that can be achieved for a given query (assuming the `max_tf` matches and the `min_doc_rel_length` being the smallest document relative length).
Good for normalizing BM25 scores.
# Example
```
max_score = max_bm25_score(RT.chunkdata(key_index), query_tokens)
```
"""
function RT.max_bm25_score(
dtm::RT.AbstractDocumentTermMatrix, query_tokens::AbstractVector{<:AbstractString};
k1::Float32 = 1.2f0, b::Float32 = 0.75f0, max_tf::Real = 3,
min_doc_rel_length::Float32 = 0.5f0)
max_score = 0.0f0
@inbounds for t in query_tokens
t_id = get(RT.vocab_lookup(dtm), t, nothing)
t_id === nothing && continue

idf_ = RT.idf(dtm)[t_id]

# Find maximum tf (term frequency) for this term in any document - pre-set in kwargs!
# eg, `max_tf = maximum(@view(RT.tf(dtm)[:, t_id]))` but that would be a bit extreme and slow

# Find first non-zero element in doc lengths -- pre-set in kwargs!
# eg, `min_doc_rel_length = minimum(x for x in RT.doc_rel_length(RT.chunkdata(key_index)) if x > 0) |> Float32`

# Maximum tf component assuming perfect match
tf_top = (max_tf * (k1 + 1.0f0))
tf_bottom = (max_tf + k1 * (1.0f0 - b + b * min_doc_rel_length))
max_score += idf_ * tf_top / tf_bottom
end
return max_score
end

end # end of module
1 change: 1 addition & 0 deletions src/Experimental/RAGTools/retrieval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,7 @@ function find_closest(
return positions[new_positions], scores
end

function max_bm25_score end
"""
find_closest(
finder::BM25Similarity, dtm::AbstractDocumentTermMatrix,
Expand Down
104 changes: 103 additions & 1 deletion test/Experimental/RAGTools/retrieval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ using PromptingTools.Experimental.RAGTools: find_closest, hamming_distance, find
using PromptingTools.Experimental.RAGTools: NoReranker, CohereReranker
using PromptingTools.Experimental.RAGTools: hamming_distance, BitPackedCosineSimilarity,
pack_bits, unpack_bits
using PromptingTools.Experimental.RAGTools: bm25, document_term_matrix, DocumentTermMatrix
using PromptingTools.Experimental.RAGTools: bm25, max_bm25_score, document_term_matrix,
DocumentTermMatrix

@testset "rephrase" begin
# Test rephrase with NoRephraser, simple passthrough
Expand Down Expand Up @@ -184,6 +185,107 @@ end
@test scores[1] expected * 4
@test scores[2] expected * 4
@test scores[3] 0

## BM25 normalization
# Basic test corpus
documents = [
["this", "is", "a", "test", "document"],
["this", "is", "another", "test"],
["completely", "different", "content"],
["test", "test", "test", "test"], # document with repeated terms
["single"] # shortest document
]
dtm = document_term_matrix(documents)

# Test 1: Basic normalization - scores should be between 0 and 1
query = ["test"]
rel_len = RT.doc_rel_length(dtm)
scores_norm = bm25(dtm, query; normalize = true, normalize_max_tf = 3,
normalize_min_doc_rel_length = minimum(rel_len))
@test all(0 .≤ scores_norm .≤ 1)

# Test that document with most "test" occurrences gets highest score
@test argmax(scores_norm) == 4

# Test 2: Compare with manual normalization
scores_raw = bm25(dtm, query; normalize = false)
max_score = max_bm25_score(
dtm, query; max_tf = 3, min_doc_rel_length = minimum(rel_len))
scores_manual_norm = scores_raw ./ max_score
@test scores_norm scores_manual_norm

# Test 3: Parameter variations
params = [
(k1 = 1.2f0, b = 0.75f0, max_tf = 3, min_doc_len = 0.5f0),
(k1 = 2.0f0, b = 0.5f0, max_tf = 10, min_doc_len = 1.0f0)
]

for p in params
scores = bm25(dtm, query;
normalize = true,
k1 = p.k1,
b = p.b,
normalize_max_tf = p.max_tf,
normalize_min_doc_rel_length = p.min_doc_len
)
@test all(0 .≤ scores .≤ 1)

# Verify against max_bm25_score
max_theoretical = max_bm25_score(dtm, query;
k1 = p.k1,
b = p.b,
max_tf = p.max_tf,
min_doc_rel_length = p.min_doc_len
)
scores_raw = bm25(dtm, query;
normalize = false,
k1 = p.k1,
b = p.b
)
@test maximum(scores_raw) max_theoretical
end

# Test 4: Edge cases
# Empty query
@test all(bm25(dtm, String[]; normalize = true) .== 0)

# Query with non-existent words
@test all(bm25(dtm, ["nonexistent"]; normalize = true) .== 0)

# Multiple query terms
multi_query = ["test", "document"]
multi_scores = bm25(dtm, multi_query; normalize = true)
@test all(0 .≤ multi_scores .≤ 1)
# Document 1 should have highest score as it contains both terms
@test argmax(multi_scores) == 1

# Test 5: Repeated terms in query
repeated_query = ["test", "test", "test"]
rep_scores = bm25(dtm, repeated_query; normalize = true)
@test all(0 .≤ rep_scores .≤ 1)

# Test 6: Special cases - uniform document length
uniform_docs = [["word", "test"] for _ in 1:3]
uniform_dtm = document_term_matrix(uniform_docs)
uniform_scores = bm25(uniform_dtm, ["test"]; normalize = true,
normalize_max_tf = 1, normalize_min_doc_rel_length = 1.0f0)
@test all(uniform_scores .≈ 1.0)

# Test 7: Verify normalization with different max_tf values
high_tf_docs = [
["test", "test", "test"], # tf = 3
["test"], # tf = 1
["other", "words"] # tf = 0
]
high_tf_dtm = document_term_matrix(high_tf_docs)

# With max_tf = 1 (matching actual tf in your dataset)
scores_max1 = bm25(high_tf_dtm, ["test"]; normalize = true, normalize_max_tf = 1)
# With max_tf = 3 (default)
scores_max3 = bm25(high_tf_dtm, ["test"]; normalize = true, normalize_max_tf = 3)

# The first document should get a lower relative score with max_tf=3 (max will be higher!)
@test scores_max3[1] < scores_max1[1]
end

@testset "find_closest" begin
Expand Down

0 comments on commit 04b70f7

Please sign in to comment.