From 04b70f754173f2f29d3e576e88c2140fae6e2765 Mon Sep 17 00:00:00 2001 From: J S <49557684+svilupp@users.noreply.github.com> Date: Sat, 8 Feb 2025 11:14:11 +0000 Subject: [PATCH] Max BM25 score / BM25 normalization (#267) --- CHANGELOG.md | 1 + ext/RAGToolsExperimentalExt.jl | 91 ++++++++++++++++++++- src/Experimental/RAGTools/retrieval.jl | 1 + test/Experimental/RAGTools/retrieval.jl | 104 +++++++++++++++++++++++- 4 files changed, 194 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ce8e3af85..1279cced2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added new Gemini 2.0 Pro and Flash models to the model registry (aliases `gem20p` and `gem20f`, respectively). Added Gemini 2.0 Flash Lite Preview model (alias `gem20fl`) and Gemini 2.0 Flash Thinking Experimental model (alias `gem20ft`). +- Added BM25 normalization kwarg to `RAGTools.jl` to enable 0-1 query-specific normalization of BM25 scores for easier filtering and comparison. See `?RT.bm25` and `?RT.max_bm25_score` for more information. ### Fixed diff --git a/ext/RAGToolsExperimentalExt.jl b/ext/RAGToolsExperimentalExt.jl index 681e879c8..4a9249e66 100644 --- a/ext/RAGToolsExperimentalExt.jl +++ b/ext/RAGToolsExperimentalExt.jl @@ -184,12 +184,26 @@ function RT.document_term_matrix(documents::AbstractVector{<:AbstractString}) end """ - RT.bm25(dtm::AbstractDocumentTermMatrix, query::Vector{String}; k1::Float32=1.2f0, b::Float32=0.75f0) + RT.bm25( + dtm::RT.AbstractDocumentTermMatrix, query::AbstractVector{<:AbstractString}; + k1::Float32 = 1.2f0, b::Float32 = 0.75f0, normalize::Bool = false, normalize_max_tf::Real = 3, + normalize_min_doc_rel_length::Float32 = 1.0f0) Scores all documents in `dtm` based on the `query`. References: https://opensourceconnections.com/blog/2015/10/16/bm25-the-next-generation-of-lucene-relevation/ +# Arguments +- `dtm`: A `DocumentTermMatrix` object. +- `query`: A vector of query tokens. +- `k1`: The k1 parameter for BM25. +- `b`: The b parameter for BM25. +- `normalize`: Whether to normalize the scores (returns scores between 0 and 1). + Theoretically, if you choose `normalize_max_tf` and `normalize_min_doc_rel_length` to be too low, you could get scores greater than 1. +- `normalize_max_tf`: The maximum term frequency to normalize to. 3 is a good default (assumes max 3 hits per document). +- `normalize_min_doc_rel_length`: The minimum document relative length to normalize to. 0.5 is a good default. + Ideally, pick the minimum document relative length of the corpus that is non-zero + `min_doc_rel_length = minimum(x for x in RT.doc_rel_length(RT.chunkdata(key_index)) if x > 0) |> Float32` # Example ``` documents = [["this", "is", "a", "test"], ["this", "is", "another", "test"], ["foo", "bar", "baz"]] @@ -198,14 +212,36 @@ query = ["this"] scores = bm25(dtm, query) # Returns array with 3 scores (one for each document) ``` + +Normalization is done by dividing the score by the maximum possible score (given some assumptions). +It's useful to be get results in the same range as cosine similarity scores and when comparing different queries or documents. + +# Example +``` +documents = [["this", "is", "a", "test"], ["this", "is", "another", "test"], ["foo", "bar", "baz"]] +dtm = document_term_matrix(documents) +query = ["this"] +scores = bm25(dtm, query) +scores_norm = bm25(dtm, query; normalize = true) + +## Make it more accurate for your dataset/index +normalize_max_tf = 3 # assume max term frequency is 3 (what is likely for your dataset? depends on chunk size, preprocessing, etc.) +normalize_min_doc_rel_length = minimum([x for x in RT.doc_rel_length(dtm) if x > 0]) |> Float32 +scores_norm = bm25(dtm, query; normalize = true, normalize_max_tf, normalize_min_doc_rel_length) +``` """ function RT.bm25( dtm::RT.AbstractDocumentTermMatrix, query::AbstractVector{<:AbstractString}; - k1::Float32 = 1.2f0, b::Float32 = 0.75f0) + k1::Float32 = 1.2f0, b::Float32 = 0.75f0, normalize::Bool = false, normalize_max_tf::Real = 3, + normalize_min_doc_rel_length::Float32 = 0.5f0) + @assert normalize_max_tf>0 "normalize_max_tf term frequency must be positive (got $normalize_max_tf)" + @assert normalize_min_doc_rel_length>0 "normalize_min_doc_rel_length must be positive (got $normalize_min_doc_rel_length)" + scores = zeros(Float32, size(tf(dtm), 1)) ## Identify non-zero items to leverage the sparsity nz_rows = rowvals(tf(dtm)) nz_vals = nonzeros(tf(dtm)) + max_score = 0.0f0 for i in eachindex(query) t = query[i] t_id = get(vocab_lookup(dtm), t, nothing) @@ -222,9 +258,60 @@ function RT.bm25( ## @info "di: $di, tf: $tf, doc_len: $doc_len, idf: $idf, tf_top: $tf_top, tf_bottom: $tf_bottom, score: $score" scores[di] += score end + ## Once per token, calculate max score + ## assumes max term frequency is `normalize_max_tf` and min document relative length is `normalize_min_doc_rel_length` + if normalize + max_score += idf_ * (normalize_max_tf * (k1 + 1.0f0)) / (normalize_max_tf + + k1 * (1.0f0 - b + b * normalize_min_doc_rel_length)) + end + end + if normalize && !iszero(max_score) + scores ./= max_score + elseif normalize && iszero(max_score) + ## happens only with empty queries, so scores is zero anyway + @warn "BM25: `max_score` is zero, so scores are not normalized. Returning unnormalized scores (all zero)." end return scores end +""" + RT.max_bm25_score( + dtm::RT.AbstractDocumentTermMatrix, query_tokens::AbstractVector{<:AbstractString}; + k1::Float32 = 1.2f0, b::Float32 = 0.75f0, max_tf::Real = 3, + min_doc_rel_length::Float32 = 0.5f0) + +Returns the maximum BM25 score that can be achieved for a given query (assuming the `max_tf` matches and the `min_doc_rel_length` being the smallest document relative length). +Good for normalizing BM25 scores. + +# Example +``` +max_score = max_bm25_score(RT.chunkdata(key_index), query_tokens) +``` +""" +function RT.max_bm25_score( + dtm::RT.AbstractDocumentTermMatrix, query_tokens::AbstractVector{<:AbstractString}; + k1::Float32 = 1.2f0, b::Float32 = 0.75f0, max_tf::Real = 3, + min_doc_rel_length::Float32 = 0.5f0) + max_score = 0.0f0 + @inbounds for t in query_tokens + t_id = get(RT.vocab_lookup(dtm), t, nothing) + t_id === nothing && continue + + idf_ = RT.idf(dtm)[t_id] + + # Find maximum tf (term frequency) for this term in any document - pre-set in kwargs! + # eg, `max_tf = maximum(@view(RT.tf(dtm)[:, t_id]))` but that would be a bit extreme and slow + + # Find first non-zero element in doc lengths -- pre-set in kwargs! + # eg, `min_doc_rel_length = minimum(x for x in RT.doc_rel_length(RT.chunkdata(key_index)) if x > 0) |> Float32` + + # Maximum tf component assuming perfect match + tf_top = (max_tf * (k1 + 1.0f0)) + tf_bottom = (max_tf + k1 * (1.0f0 - b + b * min_doc_rel_length)) + max_score += idf_ * tf_top / tf_bottom + end + return max_score +end + end # end of module diff --git a/src/Experimental/RAGTools/retrieval.jl b/src/Experimental/RAGTools/retrieval.jl index 343b67aa3..f50fc3ed4 100644 --- a/src/Experimental/RAGTools/retrieval.jl +++ b/src/Experimental/RAGTools/retrieval.jl @@ -452,6 +452,7 @@ function find_closest( return positions[new_positions], scores end +function max_bm25_score end """ find_closest( finder::BM25Similarity, dtm::AbstractDocumentTermMatrix, diff --git a/test/Experimental/RAGTools/retrieval.jl b/test/Experimental/RAGTools/retrieval.jl index 45bbb2a04..5202c2ff6 100644 --- a/test/Experimental/RAGTools/retrieval.jl +++ b/test/Experimental/RAGTools/retrieval.jl @@ -16,7 +16,8 @@ using PromptingTools.Experimental.RAGTools: find_closest, hamming_distance, find using PromptingTools.Experimental.RAGTools: NoReranker, CohereReranker using PromptingTools.Experimental.RAGTools: hamming_distance, BitPackedCosineSimilarity, pack_bits, unpack_bits -using PromptingTools.Experimental.RAGTools: bm25, document_term_matrix, DocumentTermMatrix +using PromptingTools.Experimental.RAGTools: bm25, max_bm25_score, document_term_matrix, + DocumentTermMatrix @testset "rephrase" begin # Test rephrase with NoRephraser, simple passthrough @@ -184,6 +185,107 @@ end @test scores[1] ≈ expected * 4 @test scores[2] ≈ expected * 4 @test scores[3] ≈ 0 + + ## BM25 normalization + # Basic test corpus + documents = [ + ["this", "is", "a", "test", "document"], + ["this", "is", "another", "test"], + ["completely", "different", "content"], + ["test", "test", "test", "test"], # document with repeated terms + ["single"] # shortest document + ] + dtm = document_term_matrix(documents) + + # Test 1: Basic normalization - scores should be between 0 and 1 + query = ["test"] + rel_len = RT.doc_rel_length(dtm) + scores_norm = bm25(dtm, query; normalize = true, normalize_max_tf = 3, + normalize_min_doc_rel_length = minimum(rel_len)) + @test all(0 .≤ scores_norm .≤ 1) + + # Test that document with most "test" occurrences gets highest score + @test argmax(scores_norm) == 4 + + # Test 2: Compare with manual normalization + scores_raw = bm25(dtm, query; normalize = false) + max_score = max_bm25_score( + dtm, query; max_tf = 3, min_doc_rel_length = minimum(rel_len)) + scores_manual_norm = scores_raw ./ max_score + @test scores_norm ≈ scores_manual_norm + + # Test 3: Parameter variations + params = [ + (k1 = 1.2f0, b = 0.75f0, max_tf = 3, min_doc_len = 0.5f0), + (k1 = 2.0f0, b = 0.5f0, max_tf = 10, min_doc_len = 1.0f0) + ] + + for p in params + scores = bm25(dtm, query; + normalize = true, + k1 = p.k1, + b = p.b, + normalize_max_tf = p.max_tf, + normalize_min_doc_rel_length = p.min_doc_len + ) + @test all(0 .≤ scores .≤ 1) + + # Verify against max_bm25_score + max_theoretical = max_bm25_score(dtm, query; + k1 = p.k1, + b = p.b, + max_tf = p.max_tf, + min_doc_rel_length = p.min_doc_len + ) + scores_raw = bm25(dtm, query; + normalize = false, + k1 = p.k1, + b = p.b + ) + @test maximum(scores_raw) ≤ max_theoretical + end + + # Test 4: Edge cases + # Empty query + @test all(bm25(dtm, String[]; normalize = true) .== 0) + + # Query with non-existent words + @test all(bm25(dtm, ["nonexistent"]; normalize = true) .== 0) + + # Multiple query terms + multi_query = ["test", "document"] + multi_scores = bm25(dtm, multi_query; normalize = true) + @test all(0 .≤ multi_scores .≤ 1) + # Document 1 should have highest score as it contains both terms + @test argmax(multi_scores) == 1 + + # Test 5: Repeated terms in query + repeated_query = ["test", "test", "test"] + rep_scores = bm25(dtm, repeated_query; normalize = true) + @test all(0 .≤ rep_scores .≤ 1) + + # Test 6: Special cases - uniform document length + uniform_docs = [["word", "test"] for _ in 1:3] + uniform_dtm = document_term_matrix(uniform_docs) + uniform_scores = bm25(uniform_dtm, ["test"]; normalize = true, + normalize_max_tf = 1, normalize_min_doc_rel_length = 1.0f0) + @test all(uniform_scores .≈ 1.0) + + # Test 7: Verify normalization with different max_tf values + high_tf_docs = [ + ["test", "test", "test"], # tf = 3 + ["test"], # tf = 1 + ["other", "words"] # tf = 0 + ] + high_tf_dtm = document_term_matrix(high_tf_docs) + + # With max_tf = 1 (matching actual tf in your dataset) + scores_max1 = bm25(high_tf_dtm, ["test"]; normalize = true, normalize_max_tf = 1) + # With max_tf = 3 (default) + scores_max3 = bm25(high_tf_dtm, ["test"]; normalize = true, normalize_max_tf = 3) + + # The first document should get a lower relative score with max_tf=3 (max will be higher!) + @test scores_max3[1] < scores_max1[1] end @testset "find_closest" begin