@@ -184,12 +184,26 @@ function RT.document_term_matrix(documents::AbstractVector{<:AbstractString})
184184end
185185
186186"""
187- RT.bm25(dtm::AbstractDocumentTermMatrix, query::Vector{String}; k1::Float32=1.2f0, b::Float32=0.75f0)
187+ RT.bm25(
188+ dtm::RT.AbstractDocumentTermMatrix, query::AbstractVector{<:AbstractString};
189+ k1::Float32 = 1.2f0, b::Float32 = 0.75f0, normalize::Bool = false, normalize_max_tf::Real = 3,
190+ normalize_min_doc_rel_length::Float32 = 1.0f0)
188191
189192Scores all documents in `dtm` based on the `query`.
190193
191194References: https://opensourceconnections.com/blog/2015/10/16/bm25-the-next-generation-of-lucene-relevation/
192195
196+ # Arguments
197+ - `dtm`: A `DocumentTermMatrix` object.
198+ - `query`: A vector of query tokens.
199+ - `k1`: The k1 parameter for BM25.
200+ - `b`: The b parameter for BM25.
201+ - `normalize`: Whether to normalize the scores (returns scores between 0 and 1).
202+ Theoretically, if you choose `normalize_max_tf` and `normalize_min_doc_rel_length` to be too low, you could get scores greater than 1.
203+ - `normalize_max_tf`: The maximum term frequency to normalize to. 3 is a good default (assumes max 3 hits per document).
204+ - `normalize_min_doc_rel_length`: The minimum document relative length to normalize to. 0.5 is a good default.
205+ Ideally, pick the minimum document relative length of the corpus that is non-zero
206+ `min_doc_rel_length = minimum(x for x in RT.doc_rel_length(RT.chunkdata(key_index)) if x > 0) |> Float32`
193207# Example
194208```
195209documents = [["this", "is", "a", "test"], ["this", "is", "another", "test"], ["foo", "bar", "baz"]]
@@ -198,14 +212,36 @@ query = ["this"]
198212scores = bm25(dtm, query)
199213# Returns array with 3 scores (one for each document)
200214```
215+
216+ Normalization is done by dividing the score by the maximum possible score (given some assumptions).
217+ It's useful to be get results in the same range as cosine similarity scores and when comparing different queries or documents.
218+
219+ # Example
220+ ```
221+ documents = [["this", "is", "a", "test"], ["this", "is", "another", "test"], ["foo", "bar", "baz"]]
222+ dtm = document_term_matrix(documents)
223+ query = ["this"]
224+ scores = bm25(dtm, query)
225+ scores_norm = bm25(dtm, query; normalize = true)
226+
227+ ## Make it more accurate for your dataset/index
228+ normalize_max_tf = 3 # assume max term frequency is 3 (what is likely for your dataset? depends on chunk size, preprocessing, etc.)
229+ normalize_min_doc_rel_length = minimum([x for x in RT.doc_rel_length(dtm) if x > 0]) |> Float32
230+ scores_norm = bm25(dtm, query; normalize = true, normalize_max_tf, normalize_min_doc_rel_length)
231+ ```
201232"""
202233function RT. bm25 (
203234 dtm:: RT.AbstractDocumentTermMatrix , query:: AbstractVector{<:AbstractString} ;
204- k1:: Float32 = 1.2f0 , b:: Float32 = 0.75f0 )
235+ k1:: Float32 = 1.2f0 , b:: Float32 = 0.75f0 , normalize:: Bool = false , normalize_max_tf:: Real = 3 ,
236+ normalize_min_doc_rel_length:: Float32 = 0.5f0 )
237+ @assert normalize_max_tf> 0 " normalize_max_tf term frequency must be positive (got $normalize_max_tf )"
238+ @assert normalize_min_doc_rel_length> 0 " normalize_min_doc_rel_length must be positive (got $normalize_min_doc_rel_length )"
239+
205240 scores = zeros (Float32, size (tf (dtm), 1 ))
206241 # # Identify non-zero items to leverage the sparsity
207242 nz_rows = rowvals (tf (dtm))
208243 nz_vals = nonzeros (tf (dtm))
244+ max_score = 0.0f0
209245 for i in eachindex (query)
210246 t = query[i]
211247 t_id = get (vocab_lookup (dtm), t, nothing )
@@ -222,9 +258,60 @@ function RT.bm25(
222258 # # @info "di: $di, tf: $tf, doc_len: $doc_len, idf: $idf, tf_top: $tf_top, tf_bottom: $tf_bottom, score: $score"
223259 scores[di] += score
224260 end
261+ # # Once per token, calculate max score
262+ # # assumes max term frequency is `normalize_max_tf` and min document relative length is `normalize_min_doc_rel_length`
263+ if normalize
264+ max_score += idf_ * (normalize_max_tf * (k1 + 1.0f0 )) / (normalize_max_tf +
265+ k1 * (1.0f0 - b + b * normalize_min_doc_rel_length))
266+ end
267+ end
268+ if normalize && ! iszero (max_score)
269+ scores ./= max_score
270+ elseif normalize && iszero (max_score)
271+ # # happens only with empty queries, so scores is zero anyway
272+ @warn " BM25: `max_score` is zero, so scores are not normalized. Returning unnormalized scores (all zero)."
225273 end
226274
227275 return scores
228276end
229277
278+ """
279+ RT.max_bm25_score(
280+ dtm::RT.AbstractDocumentTermMatrix, query_tokens::AbstractVector{<:AbstractString};
281+ k1::Float32 = 1.2f0, b::Float32 = 0.75f0, max_tf::Real = 3,
282+ min_doc_rel_length::Float32 = 0.5f0)
283+
284+ Returns the maximum BM25 score that can be achieved for a given query (assuming the `max_tf` matches and the `min_doc_rel_length` being the smallest document relative length).
285+ Good for normalizing BM25 scores.
286+
287+ # Example
288+ ```
289+ max_score = max_bm25_score(RT.chunkdata(key_index), query_tokens)
290+ ```
291+ """
292+ function RT. max_bm25_score (
293+ dtm:: RT.AbstractDocumentTermMatrix , query_tokens:: AbstractVector{<:AbstractString} ;
294+ k1:: Float32 = 1.2f0 , b:: Float32 = 0.75f0 , max_tf:: Real = 3 ,
295+ min_doc_rel_length:: Float32 = 0.5f0 )
296+ max_score = 0.0f0
297+ @inbounds for t in query_tokens
298+ t_id = get (RT. vocab_lookup (dtm), t, nothing )
299+ t_id === nothing && continue
300+
301+ idf_ = RT. idf (dtm)[t_id]
302+
303+ # Find maximum tf (term frequency) for this term in any document - pre-set in kwargs!
304+ # eg, `max_tf = maximum(@view(RT.tf(dtm)[:, t_id]))` but that would be a bit extreme and slow
305+
306+ # Find first non-zero element in doc lengths -- pre-set in kwargs!
307+ # eg, `min_doc_rel_length = minimum(x for x in RT.doc_rel_length(RT.chunkdata(key_index)) if x > 0) |> Float32`
308+
309+ # Maximum tf component assuming perfect match
310+ tf_top = (max_tf * (k1 + 1.0f0 ))
311+ tf_bottom = (max_tf + k1 * (1.0f0 - b + b * min_doc_rel_length))
312+ max_score += idf_ * tf_top / tf_bottom
313+ end
314+ return max_score
315+ end
316+
230317end # end of module
0 commit comments