From f9ae295507f7cb4a879cdc316757de9b2b8cbdbb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=90=BD=E5=8F=B6=E4=B9=8C=E9=BE=9F?= Date: Mon, 1 Jul 2024 15:39:41 +0800 Subject: [PATCH] feat(query): Make `BooleanQuery` supports `minimum_number_should_match` (#2405) * feat(query): Make `BooleanQuery` supports `minimum_number_should_match`. see issue #2398 In this commit, a novel scorer named DisjunctionScorer is introduced, which performs the union of inverted chains with the minimal required elements. BTW, it's implemented via a min-heap. Necessary modifications on `BooleanQuery` and `BooleanWeight` are performed as well. * fixup! fix test * fixup!: refactor code. 1. More meaningful names. 2. Add Cache for `Disjunction`'s scorers, and fix bug. 3. Optimize `BooleanWeight::complex_scorer` Thanks Paul Masurel * squash!: come up with better variable naming. * squash!: fix naming issues. * squash!: fix typo. * squash!: Remove CombinationMethod::FullIntersection --- src/query/boolean_query/boolean_query.rs | 161 ++++++++++- src/query/boolean_query/boolean_weight.rs | 134 +++++++-- src/query/disjunction.rs | 327 ++++++++++++++++++++++ src/query/mod.rs | 1 + src/query/query_parser/query_parser.rs | 9 +- 5 files changed, 590 insertions(+), 42 deletions(-) create mode 100644 src/query/disjunction.rs diff --git a/src/query/boolean_query/boolean_query.rs b/src/query/boolean_query/boolean_query.rs index 94ac5b0c5d..1c39d3d50e 100644 --- a/src/query/boolean_query/boolean_query.rs +++ b/src/query/boolean_query/boolean_query.rs @@ -66,6 +66,10 @@ use crate::schema::{IndexRecordOption, Term}; /// Term::from_field_text(title, "diary"), /// IndexRecordOption::Basic, /// )); +/// let cow_term_query: Box = Box::new(TermQuery::new( +/// Term::from_field_text(title, "cow"), +/// IndexRecordOption::Basic +/// )); /// // A TermQuery with "found" in the body /// let body_term_query: Box = Box::new(TermQuery::new( /// Term::from_field_text(body, "found"), @@ -74,7 +78,7 @@ use crate::schema::{IndexRecordOption, Term}; /// // TermQuery "diary" must and "girl" must not be present /// let queries_with_occurs1 = vec![ /// (Occur::Must, diary_term_query.box_clone()), -/// (Occur::MustNot, girl_term_query), +/// (Occur::MustNot, girl_term_query.box_clone()), /// ]; /// // Make a BooleanQuery equivalent to /// // title:+diary title:-girl @@ -82,15 +86,10 @@ use crate::schema::{IndexRecordOption, Term}; /// let count1 = searcher.search(&diary_must_and_girl_mustnot, &Count)?; /// assert_eq!(count1, 1); /// -/// // TermQuery for "cow" in the title -/// let cow_term_query: Box = Box::new(TermQuery::new( -/// Term::from_field_text(title, "cow"), -/// IndexRecordOption::Basic, -/// )); /// // "title:diary OR title:cow" /// let title_diary_or_cow = BooleanQuery::new(vec![ /// (Occur::Should, diary_term_query.box_clone()), -/// (Occur::Should, cow_term_query), +/// (Occur::Should, cow_term_query.box_clone()), /// ]); /// let count2 = searcher.search(&title_diary_or_cow, &Count)?; /// assert_eq!(count2, 4); @@ -118,21 +117,39 @@ use crate::schema::{IndexRecordOption, Term}; /// ]); /// let count4 = searcher.search(&nested_query, &Count)?; /// assert_eq!(count4, 1); +/// +/// // You may call `with_minimum_required_clauses` to +/// // specify the number of should clauses the returned documents must match. +/// let minimum_required_query = BooleanQuery::with_minimum_required_clauses(vec![ +/// (Occur::Should, cow_term_query.box_clone()), +/// (Occur::Should, girl_term_query.box_clone()), +/// (Occur::Should, diary_term_query.box_clone()), +/// ], 2); +/// // Return documents contains "Diary Cow", "Diary Girl" or "Cow Girl" +/// // Notice: "Diary" isn't "Dairy". ;-) +/// let count5 = searcher.search(&minimum_required_query, &Count)?; +/// assert_eq!(count5, 1); /// Ok(()) /// } /// ``` #[derive(Debug)] pub struct BooleanQuery { subqueries: Vec<(Occur, Box)>, + minimum_number_should_match: usize, } impl Clone for BooleanQuery { fn clone(&self) -> Self { - self.subqueries + let subqueries = self + .subqueries .iter() .map(|(occur, subquery)| (*occur, subquery.box_clone())) .collect::>() - .into() + .into(); + Self { + subqueries, + minimum_number_should_match: self.minimum_number_should_match, + } } } @@ -149,8 +166,9 @@ impl Query for BooleanQuery { .iter() .map(|(occur, subquery)| Ok((*occur, subquery.weight(enable_scoring)?))) .collect::>()?; - Ok(Box::new(BooleanWeight::new( + Ok(Box::new(BooleanWeight::with_minimum_number_should_match( sub_weights, + self.minimum_number_should_match, enable_scoring.is_scoring_enabled(), Box::new(SumWithCoordsCombiner::default), ))) @@ -166,7 +184,41 @@ impl Query for BooleanQuery { impl BooleanQuery { /// Creates a new boolean query. pub fn new(subqueries: Vec<(Occur, Box)>) -> BooleanQuery { - BooleanQuery { subqueries } + // If the bool query includes at least one should clause + // and no Must or MustNot clauses, the default value is 1. Otherwise, the default value is + // 0. Keep pace with Elasticsearch. + let mut minimum_required = 0; + for (occur, _) in &subqueries { + match occur { + Occur::Should => minimum_required = 1, + Occur::Must | Occur::MustNot => { + minimum_required = 0; + break; + } + } + } + Self::with_minimum_required_clauses(subqueries, minimum_required) + } + + /// Create a new boolean query with minimum number of required should clauses specified. + pub fn with_minimum_required_clauses( + subqueries: Vec<(Occur, Box)>, + minimum_number_should_match: usize, + ) -> BooleanQuery { + BooleanQuery { + subqueries, + minimum_number_should_match, + } + } + + /// Getter for `minimum_number_should_match` + pub fn get_minimum_number_should_match(&self) -> usize { + self.minimum_number_should_match + } + + /// Setter for `minimum_number_should_match` + pub fn set_minimum_number_should_match(&mut self, minimum_number_should_match: usize) { + self.minimum_number_should_match = minimum_number_should_match; } /// Returns the intersection of the queries. @@ -181,6 +233,18 @@ impl BooleanQuery { BooleanQuery::new(subqueries) } + /// Returns the union of the queries with minimum required clause. + pub fn union_with_minimum_required_clauses( + queries: Vec>, + minimum_required_clauses: usize, + ) -> BooleanQuery { + let subqueries = queries + .into_iter() + .map(|sub_query| (Occur::Should, sub_query)) + .collect(); + BooleanQuery::with_minimum_required_clauses(subqueries, minimum_required_clauses) + } + /// Helper method to create a boolean query matching a given list of terms. /// The resulting query is a disjunction of the terms. pub fn new_multiterms_query(terms: Vec) -> BooleanQuery { @@ -203,11 +267,13 @@ impl BooleanQuery { #[cfg(test)] mod tests { + use std::collections::HashSet; + use super::BooleanQuery; use crate::collector::{Count, DocSetCollector}; - use crate::query::{QueryClone, QueryParser, TermQuery}; - use crate::schema::{IndexRecordOption, Schema, TEXT}; - use crate::{DocAddress, Index, Term}; + use crate::query::{Query, QueryClone, QueryParser, TermQuery}; + use crate::schema::{Field, IndexRecordOption, Schema, TEXT}; + use crate::{DocAddress, DocId, Index, Term}; fn create_test_index() -> crate::Result { let mut schema_builder = Schema::builder(); @@ -223,6 +289,73 @@ mod tests { Ok(index) } + #[test] + fn test_minimum_required() -> crate::Result<()> { + fn create_test_index_with>( + docs: T, + ) -> crate::Result { + let mut schema_builder = Schema::builder(); + let text = schema_builder.add_text_field("text", TEXT); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut writer = index.writer_for_tests()?; + for doc in docs { + writer.add_document(doc!(text => doc))?; + } + writer.commit()?; + Ok(index) + } + fn create_boolean_query_with_mr>( + queries: T, + field: Field, + mr: usize, + ) -> BooleanQuery { + let terms = queries + .into_iter() + .map(|t| Term::from_field_text(field, t)) + .map(|t| TermQuery::new(t, IndexRecordOption::Basic)) + .map(|q| -> Box { Box::new(q) }) + .collect(); + BooleanQuery::union_with_minimum_required_clauses(terms, mr) + } + fn check_doc_id>( + expected: T, + actually: HashSet, + seg: u32, + ) { + assert_eq!( + actually, + expected + .into_iter() + .map(|id| DocAddress::new(seg, id)) + .collect() + ); + } + let index = create_test_index_with(["a b c", "a c e", "d f g", "z z z", "c i b"])?; + let searcher = index.reader()?.searcher(); + let text = index.schema().get_field("text").unwrap(); + // Documents contains 'a c' 'a z' 'a i' 'c z' 'c i' or 'z i' shall be return. + let q1 = create_boolean_query_with_mr(["a", "c", "z", "i"], text, 2); + let docs = searcher.search(&q1, &DocSetCollector)?; + check_doc_id([0, 1, 4], docs, 0); + // Documents contains 'a b c', 'a b e', 'a c e' or 'b c e' shall be return. + let q2 = create_boolean_query_with_mr(["a", "b", "c", "e"], text, 3); + let docs = searcher.search(&q2, &DocSetCollector)?; + check_doc_id([0, 1], docs, 0); + // Nothing queried since minimum_required is too large. + let q3 = create_boolean_query_with_mr(["a", "b"], text, 3); + let docs = searcher.search(&q3, &DocSetCollector)?; + assert!(docs.is_empty()); + // When mr is set to zero or one, there are no difference with `Boolean::Union`. + let q4 = create_boolean_query_with_mr(["a", "z"], text, 1); + let docs = searcher.search(&q4, &DocSetCollector)?; + check_doc_id([0, 1, 3], docs, 0); + let q5 = create_boolean_query_with_mr(["a", "b"], text, 0); + let docs = searcher.search(&q5, &DocSetCollector)?; + check_doc_id([0, 1, 4], docs, 0); + Ok(()) + } + #[test] fn test_union() -> crate::Result<()> { let index = create_test_index()?; diff --git a/src/query/boolean_query/boolean_weight.rs b/src/query/boolean_query/boolean_weight.rs index ece6217d22..77f847063c 100644 --- a/src/query/boolean_query/boolean_weight.rs +++ b/src/query/boolean_query/boolean_weight.rs @@ -3,6 +3,7 @@ use std::collections::HashMap; use crate::docset::COLLECT_BLOCK_BUFFER_LEN; use crate::index::SegmentReader; use crate::postings::FreqReadingOption; +use crate::query::disjunction::Disjunction; use crate::query::explanation::does_not_match; use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner}; use crate::query::term_query::TermScorer; @@ -18,6 +19,26 @@ enum SpecializedScorer { Other(Box), } +fn scorer_disjunction( + scorers: Vec>, + score_combiner: TScoreCombiner, + minimum_match_required: usize, +) -> Box +where + TScoreCombiner: ScoreCombiner, +{ + debug_assert!(!scorers.is_empty()); + debug_assert!(minimum_match_required > 1); + if scorers.len() == 1 { + return scorers.into_iter().next().unwrap(); // Safe unwrap. + } + Box::new(Disjunction::new( + scorers, + score_combiner, + minimum_match_required, + )) +} + fn scorer_union( scorers: Vec>, score_combiner_fn: impl Fn() -> TScoreCombiner, @@ -70,6 +91,7 @@ fn into_box_scorer( /// Weight associated to the `BoolQuery`. pub struct BooleanWeight { weights: Vec<(Occur, Box)>, + minimum_number_should_match: usize, scoring_enabled: bool, score_combiner_fn: Box TScoreCombiner + Sync + Send>, } @@ -85,6 +107,22 @@ impl BooleanWeight { weights, scoring_enabled, score_combiner_fn, + minimum_number_should_match: 1, + } + } + + /// Create a new boolean weight with minimum number of required should clauses specified. + pub fn with_minimum_number_should_match( + weights: Vec<(Occur, Box)>, + minimum_number_should_match: usize, + scoring_enabled: bool, + score_combiner_fn: Box TScoreCombiner + Sync + Send + 'static>, + ) -> BooleanWeight { + BooleanWeight { + weights, + minimum_number_should_match, + scoring_enabled, + score_combiner_fn, } } @@ -111,43 +149,89 @@ impl BooleanWeight { score_combiner_fn: impl Fn() -> TComplexScoreCombiner, ) -> crate::Result { let mut per_occur_scorers = self.per_occur_scorers(reader, boost)?; - - let should_scorer_opt: Option = per_occur_scorers - .remove(&Occur::Should) - .map(|scorers| scorer_union(scorers, &score_combiner_fn)); + // Indicate how should clauses are combined with other clauses. + enum CombinationMethod { + Ignored, + // Only contributes to final score. + Optional(SpecializedScorer), + // Must be fitted. + Required(Box), + } + let mut must_scorers = per_occur_scorers.remove(&Occur::Must); + let should_opt = if let Some(mut should_scorers) = per_occur_scorers.remove(&Occur::Should) + { + let num_of_should_scorers = should_scorers.len(); + if self.minimum_number_should_match > num_of_should_scorers { + return Ok(SpecializedScorer::Other(Box::new(EmptyScorer))); + } + match self.minimum_number_should_match { + 0 => CombinationMethod::Optional(scorer_union(should_scorers, &score_combiner_fn)), + 1 => CombinationMethod::Required(into_box_scorer( + scorer_union(should_scorers, &score_combiner_fn), + &score_combiner_fn, + )), + n @ _ if num_of_should_scorers == n => { + // When num_of_should_scorers equals the number of should clauses, + // they are no different from must clauses. + must_scorers = match must_scorers.take() { + Some(mut must_scorers) => { + must_scorers.append(&mut should_scorers); + Some(must_scorers) + } + None => Some(should_scorers), + }; + CombinationMethod::Ignored + } + _ => CombinationMethod::Required(scorer_disjunction( + should_scorers, + score_combiner_fn(), + self.minimum_number_should_match, + )), + } + } else { + // None of should clauses are provided. + if self.minimum_number_should_match > 0 { + return Ok(SpecializedScorer::Other(Box::new(EmptyScorer))); + } else { + CombinationMethod::Ignored + } + }; let exclude_scorer_opt: Option> = per_occur_scorers .remove(&Occur::MustNot) .map(|scorers| scorer_union(scorers, DoNothingCombiner::default)) - .map(|specialized_scorer| { + .map(|specialized_scorer: SpecializedScorer| { into_box_scorer(specialized_scorer, DoNothingCombiner::default) }); - - let must_scorer_opt: Option> = per_occur_scorers - .remove(&Occur::Must) - .map(intersect_scorers); - - let positive_scorer: SpecializedScorer = match (should_scorer_opt, must_scorer_opt) { - (Some(should_scorer), Some(must_scorer)) => { + let positive_scorer = match (should_opt, must_scorers) { + (CombinationMethod::Ignored, Some(must_scorers)) => { + SpecializedScorer::Other(intersect_scorers(must_scorers)) + } + (CombinationMethod::Optional(should_scorer), Some(must_scorers)) => { + let must_scorer = intersect_scorers(must_scorers); if self.scoring_enabled { - SpecializedScorer::Other(Box::new(RequiredOptionalScorer::< - Box, - Box, - TComplexScoreCombiner, - >::new( - must_scorer, - into_box_scorer(should_scorer, &score_combiner_fn), - ))) + SpecializedScorer::Other(Box::new( + RequiredOptionalScorer::<_, _, TScoreCombiner>::new( + must_scorer, + into_box_scorer(should_scorer, &score_combiner_fn), + ), + )) } else { SpecializedScorer::Other(must_scorer) } } - (None, Some(must_scorer)) => SpecializedScorer::Other(must_scorer), - (Some(should_scorer), None) => should_scorer, - (None, None) => { - return Ok(SpecializedScorer::Other(Box::new(EmptyScorer))); + (CombinationMethod::Required(should_scorer), Some(mut must_scorers)) => { + must_scorers.push(should_scorer); + SpecializedScorer::Other(intersect_scorers(must_scorers)) + } + (CombinationMethod::Ignored, None) => { + return Ok(SpecializedScorer::Other(Box::new(EmptyScorer))) } + (CombinationMethod::Required(should_scorer), None) => { + SpecializedScorer::Other(should_scorer) + } + // Optional options are promoted to required if no must scorers exists. + (CombinationMethod::Optional(should_scorer), None) => should_scorer, }; - if let Some(exclude_scorer) = exclude_scorer_opt { let positive_scorer_boxed = into_box_scorer(positive_scorer, &score_combiner_fn); Ok(SpecializedScorer::Other(Box::new(Exclude::new( diff --git a/src/query/disjunction.rs b/src/query/disjunction.rs new file mode 100644 index 0000000000..2b631e5522 --- /dev/null +++ b/src/query/disjunction.rs @@ -0,0 +1,327 @@ +use std::cmp::Ordering; +use std::collections::BinaryHeap; + +use crate::query::score_combiner::DoNothingCombiner; +use crate::query::{ScoreCombiner, Scorer}; +use crate::{DocId, DocSet, Score, TERMINATED}; + +/// `Disjunction` is responsible for merging `DocSet` from multiple +/// source. Specifically, It takes the union of two or more `DocSet`s +/// then filtering out elements that appear fewer times than a +/// specified threshold. +pub struct Disjunction { + chains: BinaryHeap>, + minimum_matches_required: usize, + score_combiner: TScoreCombiner, + + current_doc: DocId, + current_score: Score, +} + +/// A wrapper around a `Scorer` that caches the current `doc_id` and implements the `DocSet` trait. +/// Also, the `Ord` trait and it's family are implemented reversely. So that we can combine +/// `std::BinaryHeap>` to gain a min-heap with current doc id as key. +struct ScorerWrapper { + scorer: T, + current_doc: DocId, +} + +impl ScorerWrapper { + fn new(scorer: T) -> Self { + let current_doc = scorer.doc(); + Self { + scorer, + current_doc, + } + } +} + +impl PartialEq for ScorerWrapper { + fn eq(&self, other: &Self) -> bool { + self.doc() == other.doc() + } +} + +impl Eq for ScorerWrapper {} + +impl PartialOrd for ScorerWrapper { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for ScorerWrapper { + fn cmp(&self, other: &Self) -> Ordering { + self.doc().cmp(&other.doc()).reverse() + } +} + +impl DocSet for ScorerWrapper { + fn advance(&mut self) -> DocId { + let doc_id = self.scorer.advance(); + self.current_doc = doc_id; + doc_id + } + + fn doc(&self) -> DocId { + self.current_doc + } + + fn size_hint(&self) -> u32 { + self.scorer.size_hint() + } +} + +impl Disjunction { + pub fn new>( + docsets: T, + score_combiner: TScoreCombiner, + minimum_matches_required: usize, + ) -> Self { + debug_assert!( + minimum_matches_required > 1, + "union scorer works better if just one matches required" + ); + let chains = docsets + .into_iter() + .map(|doc| ScorerWrapper::new(doc)) + .collect(); + let mut disjunction = Self { + chains, + score_combiner, + current_doc: TERMINATED, + minimum_matches_required, + current_score: 0.0, + }; + if minimum_matches_required > disjunction.chains.len() { + return disjunction; + } + disjunction.advance(); + disjunction + } +} + +impl DocSet + for Disjunction +{ + fn advance(&mut self) -> DocId { + let mut current_num_matches = 0; + while let Some(mut candidate) = self.chains.pop() { + let next = candidate.doc(); + if next != TERMINATED { + // Peek next doc. + if self.current_doc != next { + if current_num_matches >= self.minimum_matches_required { + self.chains.push(candidate); + self.current_score = self.score_combiner.score(); + return self.current_doc; + } + // Reset current_num_matches and scores. + current_num_matches = 0; + self.current_doc = next; + self.score_combiner.clear(); + } + current_num_matches += 1; + self.score_combiner.update(&mut candidate.scorer); + candidate.advance(); + self.chains.push(candidate); + } + } + if current_num_matches < self.minimum_matches_required { + self.current_doc = TERMINATED; + } + self.current_score = self.score_combiner.score(); + return self.current_doc; + } + + #[inline] + fn doc(&self) -> DocId { + self.current_doc + } + + fn size_hint(&self) -> u32 { + self.chains + .iter() + .map(|docset| docset.size_hint()) + .max() + .unwrap_or(0u32) + } +} + +impl Scorer + for Disjunction +{ + fn score(&mut self) -> Score { + self.current_score + } +} + +#[cfg(test)] +mod tests { + use std::collections::BTreeMap; + + use super::Disjunction; + use crate::query::score_combiner::DoNothingCombiner; + use crate::query::{ConstScorer, Scorer, SumCombiner, VecDocSet}; + use crate::{DocId, DocSet, Score, TERMINATED}; + + fn conjunct(arrays: &[Vec], pass_line: usize) -> Vec { + let mut counts = BTreeMap::new(); + for array in arrays { + for &element in array { + *counts.entry(element).or_insert(0) += 1; + } + } + counts + .iter() + .filter_map(|(&element, &count)| { + if count >= pass_line { + Some(element) + } else { + None + } + }) + .collect() + } + + fn aux_test_conjunction(vals: Vec>, min_match: usize) { + let mut union_expected = VecDocSet::from(conjunct(&vals, min_match)); + let make_scorer = || { + Disjunction::new( + vals.iter() + .cloned() + .map(VecDocSet::from) + .map(|d| ConstScorer::new(d, 1.0)), + DoNothingCombiner::default(), + min_match, + ) + }; + let mut scorer: Disjunction<_, DoNothingCombiner> = make_scorer(); + let mut count = 0; + while scorer.doc() != TERMINATED { + assert_eq!(union_expected.doc(), scorer.doc()); + assert_eq!(union_expected.advance(), scorer.advance()); + count += 1; + } + assert_eq!(union_expected.advance(), TERMINATED); + assert_eq!(count, make_scorer().count_including_deleted()); + } + + #[should_panic] + #[test] + fn test_arg_check1() { + aux_test_conjunction(vec![], 0); + } + + #[should_panic] + #[test] + fn test_arg_check2() { + aux_test_conjunction(vec![], 1); + } + + #[test] + fn test_corner_case() { + aux_test_conjunction(vec![], 2); + aux_test_conjunction(vec![vec![]; 1000], 2); + aux_test_conjunction(vec![vec![]; 100], usize::MAX); + aux_test_conjunction(vec![vec![0xC0FFEE]; 10000], usize::MAX); + aux_test_conjunction((1..10000u32).map(|i| vec![i]).collect::>(), 2); + } + + #[test] + fn test_conjunction() { + aux_test_conjunction( + vec![ + vec![1, 3333, 100000000u32], + vec![1, 2, 100000000u32], + vec![1, 2, 100000000u32], + ], + 2, + ); + aux_test_conjunction( + vec![vec![8], vec![3, 4, 0xC0FFEEu32], vec![1, 2, 100000000u32]], + 2, + ); + aux_test_conjunction( + vec![ + vec![1, 3333, 100000000u32], + vec![1, 2, 100000000u32], + vec![1, 2, 100000000u32], + ], + 3, + ) + } + + // This dummy scorer does nothing but yield doc id increasingly. + // with constant score 1.0 + #[derive(Clone)] + struct DummyScorer { + cursor: usize, + foo: Vec<(DocId, f32)>, + } + + impl DummyScorer { + fn new(doc_score: Vec<(DocId, f32)>) -> Self { + Self { + cursor: 0, + foo: doc_score, + } + } + } + + impl DocSet for DummyScorer { + fn advance(&mut self) -> DocId { + self.cursor += 1; + self.doc() + } + + fn doc(&self) -> DocId { + self.foo.get(self.cursor).map(|x| x.0).unwrap_or(TERMINATED) + } + + fn size_hint(&self) -> u32 { + self.foo.len() as u32 + } + } + + impl Scorer for DummyScorer { + fn score(&mut self) -> Score { + self.foo.get(self.cursor).map(|x| x.1).unwrap_or(0.0) + } + } + + #[test] + fn test_score_calculate() { + let mut scorer = Disjunction::new( + vec![ + DummyScorer::new(vec![(1, 1f32), (2, 1f32)]), + DummyScorer::new(vec![(1, 1f32), (3, 1f32)]), + DummyScorer::new(vec![(1, 1f32), (4, 1f32)]), + DummyScorer::new(vec![(1, 1f32), (2, 1f32)]), + DummyScorer::new(vec![(1, 1f32), (2, 1f32)]), + ], + SumCombiner::default(), + 3, + ); + assert_eq!(scorer.score(), 5.0); + assert_eq!(scorer.advance(), 2); + assert_eq!(scorer.score(), 3.0); + } + + #[test] + fn test_score_calculate_corner_case() { + let mut scorer = Disjunction::new( + vec![ + DummyScorer::new(vec![(1, 1f32), (2, 1f32)]), + DummyScorer::new(vec![(1, 1f32), (3, 1f32)]), + DummyScorer::new(vec![(1, 1f32), (3, 1f32)]), + ], + SumCombiner::default(), + 2, + ); + assert_eq!(scorer.doc(), 1); + assert_eq!(scorer.score(), 3.0); + assert_eq!(scorer.advance(), 3); + assert_eq!(scorer.score(), 2.0); + } +} diff --git a/src/query/mod.rs b/src/query/mod.rs index 75cdd0f5d4..03c6c01d0d 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -5,6 +5,7 @@ mod bm25; mod boolean_query; mod boost_query; mod const_score_query; +mod disjunction; mod disjunction_max_query; mod empty_query; mod exclude; diff --git a/src/query/query_parser/query_parser.rs b/src/query/query_parser/query_parser.rs index 1a6c3aaf45..8ec50f8d4d 100644 --- a/src/query/query_parser/query_parser.rs +++ b/src/query/query_parser/query_parser.rs @@ -1815,7 +1815,8 @@ mod test { \"bad\"))], prefix: (2, Term(field=0, type=Str, \"wo\")), max_expansions: 50 }), \ (Should, PhrasePrefixQuery { field: Field(1), phrase_terms: [(0, Term(field=1, \ type=Str, \"big\")), (1, Term(field=1, type=Str, \"bad\"))], prefix: (2, \ - Term(field=1, type=Str, \"wo\")), max_expansions: 50 })] }" + Term(field=1, type=Str, \"wo\")), max_expansions: 50 })], \ + minimum_number_should_match: 1 }" ); } @@ -1880,7 +1881,8 @@ mod test { format!("{query:?}"), "BooleanQuery { subqueries: [(Should, FuzzyTermQuery { term: Term(field=0, \ type=Str, \"abc\"), distance: 1, transposition_cost_one: true, prefix: false }), \ - (Should, TermQuery(Term(field=1, type=Str, \"abc\")))] }" + (Should, TermQuery(Term(field=1, type=Str, \"abc\")))], \ + minimum_number_should_match: 1 }" ); } @@ -1897,7 +1899,8 @@ mod test { format!("{query:?}"), "BooleanQuery { subqueries: [(Should, TermQuery(Term(field=0, type=Str, \ \"abc\"))), (Should, FuzzyTermQuery { term: Term(field=1, type=Str, \"abc\"), \ - distance: 2, transposition_cost_one: false, prefix: true })] }" + distance: 2, transposition_cost_one: false, prefix: true })], \ + minimum_number_should_match: 1 }" ); } }