diff --git a/Cargo.toml b/Cargo.toml index 29b920d9e7..ff1a05feaa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -64,6 +64,7 @@ tantivy-bitpacker = { version = "0.6", path = "./bitpacker" } common = { version = "0.7", path = "./common/", package = "tantivy-common" } tokenizer-api = { version = "0.3", path = "./tokenizer-api", package = "tantivy-tokenizer-api" } sketches-ddsketch = { version = "0.3.0", features = ["use_serde"] } +hyperloglogplus = { version = "0.4.1", features = ["const-loop"] } futures-util = { version = "0.3.28", optional = true } fnv = "1.0.7" diff --git a/src/aggregation/agg_req.rs b/src/aggregation/agg_req.rs index be1a5e51df..f2a7a6aed7 100644 --- a/src/aggregation/agg_req.rs +++ b/src/aggregation/agg_req.rs @@ -34,8 +34,9 @@ use super::bucket::{ DateHistogramAggregationReq, HistogramAggregation, RangeAggregation, TermsAggregation, }; use super::metric::{ - AverageAggregation, CountAggregation, ExtendedStatsAggregation, MaxAggregation, MinAggregation, - PercentilesAggregationReq, StatsAggregation, SumAggregation, TopHitsAggregation, + AverageAggregation, CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation, + MaxAggregation, MinAggregation, PercentilesAggregationReq, StatsAggregation, SumAggregation, + TopHitsAggregation, }; /// The top-level aggregation request structure, which contains [`Aggregation`] and their user @@ -160,6 +161,9 @@ pub enum AggregationVariants { /// Finds the top k values matching some order #[serde(rename = "top_hits")] TopHits(TopHitsAggregation), + /// Computes an estimate of the number of unique values + #[serde(rename = "cardinality")] + Cardinality(CardinalityAggregationReq), } impl AggregationVariants { @@ -179,6 +183,7 @@ impl AggregationVariants { AggregationVariants::Sum(sum) => vec![sum.field_name()], AggregationVariants::Percentiles(per) => vec![per.field_name()], AggregationVariants::TopHits(top_hits) => top_hits.field_names(), + AggregationVariants::Cardinality(per) => vec![per.field_name()], } } diff --git a/src/aggregation/agg_req_with_accessor.rs b/src/aggregation/agg_req_with_accessor.rs index a1fb15ec97..6ac6f7591f 100644 --- a/src/aggregation/agg_req_with_accessor.rs +++ b/src/aggregation/agg_req_with_accessor.rs @@ -11,8 +11,8 @@ use super::bucket::{ DateHistogramAggregationReq, HistogramAggregation, RangeAggregation, TermsAggregation, }; use super::metric::{ - AverageAggregation, CountAggregation, ExtendedStatsAggregation, MaxAggregation, MinAggregation, - StatsAggregation, SumAggregation, + AverageAggregation, CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation, + MaxAggregation, MinAggregation, StatsAggregation, SumAggregation, }; use super::segment_agg_result::AggregationLimits; use super::VecWithNames; @@ -162,6 +162,11 @@ impl AggregationWithAccessor { field: ref field_name, ref missing, .. + }) + | Cardinality(CardinalityAggregationReq { + field: ref field_name, + ref missing, + .. }) => { let str_dict_column = reader.fast_fields().str(field_name)?; let allowed_column_types = [ diff --git a/src/aggregation/agg_result.rs b/src/aggregation/agg_result.rs index aac1efac4d..82d40651d4 100644 --- a/src/aggregation/agg_result.rs +++ b/src/aggregation/agg_result.rs @@ -98,6 +98,8 @@ pub enum MetricResult { Percentiles(PercentilesMetricResult), /// Top hits metric result TopHits(TopHitsMetricResult), + /// Cardinality metric result + Cardinality(SingleMetricResult), } impl MetricResult { @@ -116,6 +118,7 @@ impl MetricResult { MetricResult::TopHits(_) => Err(TantivyError::AggregationError( AggregationError::InvalidRequest("top_hits can't be used to order".to_string()), )), + MetricResult::Cardinality(card) => Ok(card.value), } } } diff --git a/src/aggregation/agg_tests.rs b/src/aggregation/agg_tests.rs index 126c2240e9..6a32fc57d7 100644 --- a/src/aggregation/agg_tests.rs +++ b/src/aggregation/agg_tests.rs @@ -110,6 +110,16 @@ fn test_aggregation_flushing( } } } + }, + "cardinality_string_id":{ + "cardinality": { + "field": "string_id" + } + }, + "cardinality_score":{ + "cardinality": { + "field": "score" + } } }); @@ -212,6 +222,9 @@ fn test_aggregation_flushing( ) ); + assert_eq!(res["cardinality_string_id"]["value"], 2.0); + assert_eq!(res["cardinality_score"]["value"], 80.0); + Ok(()) } diff --git a/src/aggregation/intermediate_agg_result.rs b/src/aggregation/intermediate_agg_result.rs index 1df3266599..8162a7eb6d 100644 --- a/src/aggregation/intermediate_agg_result.rs +++ b/src/aggregation/intermediate_agg_result.rs @@ -26,6 +26,7 @@ use super::segment_agg_result::AggregationLimits; use super::{format_date, AggregationError, Key, SerializedKey}; use crate::aggregation::agg_result::{AggregationResults, BucketEntries, BucketEntry}; use crate::aggregation::bucket::TermsAggregationInternal; +use crate::aggregation::metric::CardinalityCollector; use crate::TantivyError; /// Contains the intermediate aggregation result, which is optimized to be merged with other @@ -227,6 +228,9 @@ pub(crate) fn empty_from_req(req: &Aggregation) -> IntermediateAggregationResult TopHits(ref req) => IntermediateAggregationResult::Metric( IntermediateMetricResult::TopHits(TopHitsTopNComputer::new(req)), ), + Cardinality(_) => IntermediateAggregationResult::Metric( + IntermediateMetricResult::Cardinality(CardinalityCollector::default()), + ), } } @@ -291,6 +295,8 @@ pub enum IntermediateMetricResult { Sum(IntermediateSum), /// Intermediate top_hits result TopHits(TopHitsTopNComputer), + /// Intermediate cardinality result + Cardinality(CardinalityCollector), } impl IntermediateMetricResult { @@ -324,6 +330,9 @@ impl IntermediateMetricResult { IntermediateMetricResult::TopHits(top_hits) => { MetricResult::TopHits(top_hits.into_final_result()) } + IntermediateMetricResult::Cardinality(cardinality) => { + MetricResult::Cardinality(cardinality.finalize().into()) + } } } @@ -372,6 +381,12 @@ impl IntermediateMetricResult { (IntermediateMetricResult::TopHits(left), IntermediateMetricResult::TopHits(right)) => { left.merge_fruits(right)?; } + ( + IntermediateMetricResult::Cardinality(left), + IntermediateMetricResult::Cardinality(right), + ) => { + left.merge_fruits(right)?; + } _ => { panic!("incompatible fruit types in tree or missing merge_fruits handler"); } diff --git a/src/aggregation/metric/cardinality.rs b/src/aggregation/metric/cardinality.rs new file mode 100644 index 0000000000..7e0dfc2ab1 --- /dev/null +++ b/src/aggregation/metric/cardinality.rs @@ -0,0 +1,417 @@ +use std::collections::hash_map::DefaultHasher; +use std::hash::{BuildHasher, Hasher}; + +use columnar::column_values::CompactSpaceU64Accessor; +use columnar::{BytesColumn, StrColumn}; +use common::f64_to_u64; +use hyperloglogplus::{HyperLogLog, HyperLogLogPlus}; +use rustc_hash::FxHashSet; +use serde::{Deserialize, Serialize}; + +use crate::aggregation::agg_req_with_accessor::{ + AggregationWithAccessor, AggregationsWithAccessor, +}; +use crate::aggregation::intermediate_agg_result::{ + IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult, +}; +use crate::aggregation::segment_agg_result::SegmentAggregationCollector; +use crate::aggregation::*; +use crate::TantivyError; + +#[derive(Clone, Debug, Serialize, Deserialize)] +struct BuildSaltedHasher { + salt: u8, +} + +impl BuildHasher for BuildSaltedHasher { + type Hasher = DefaultHasher; + + fn build_hasher(&self) -> Self::Hasher { + let mut hasher = DefaultHasher::new(); + hasher.write_u8(self.salt); + + hasher + } +} + +/// # Cardinality +/// +/// The cardinality aggregation allows for computing an estimate +/// of the number of different values in a data set based on the +/// HyperLogLog++ alogrithm. +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct CardinalityAggregationReq { + /// The field name to compute the percentiles on. + pub field: String, + /// The missing parameter defines how documents that are missing a value should be treated. + /// By default they will be ignored but it is also possible to treat them as if they had a + /// value. Examples in JSON format: + /// { "field": "my_numbers", "missing": "10.0" } + #[serde(skip_serializing_if = "Option::is_none", default)] + pub missing: Option, +} + +impl CardinalityAggregationReq { + /// Creates a new [`CardinalityAggregationReq`] instance from a field name. + pub fn from_field_name(field_name: String) -> Self { + Self { + field: field_name, + missing: None, + } + } + /// Returns the field name the aggregation is computed on. + pub fn field_name(&self) -> &str { + &self.field + } +} + +#[derive(Clone, Debug, PartialEq)] +pub(crate) struct SegmentCardinalityCollector { + cardinality: CardinalityCollector, + entries: FxHashSet, + column_type: ColumnType, + accessor_idx: usize, + missing: Option, +} + +impl SegmentCardinalityCollector { + pub fn from_req(column_type: ColumnType, accessor_idx: usize, missing: &Option) -> Self { + Self { + cardinality: CardinalityCollector::new(column_type as u8), + entries: Default::default(), + column_type, + accessor_idx, + missing: missing.clone(), + } + } + + fn fetch_block_with_field( + &mut self, + docs: &[crate::DocId], + agg_accessor: &mut AggregationWithAccessor, + ) { + if let Some(missing) = agg_accessor.missing_value_for_accessor { + agg_accessor.column_block_accessor.fetch_block_with_missing( + docs, + &agg_accessor.accessor, + missing, + ); + } else { + agg_accessor + .column_block_accessor + .fetch_block(docs, &agg_accessor.accessor); + } + } + + fn into_intermediate_metric_result( + mut self, + agg_with_accessor: &AggregationWithAccessor, + ) -> crate::Result { + if self.column_type == ColumnType::Str { + let mut buffer = String::new(); + let term_dict = agg_with_accessor + .str_dict_column + .as_ref() + .cloned() + .unwrap_or_else(|| { + StrColumn::wrap(BytesColumn::empty(agg_with_accessor.accessor.num_docs())) + }); + let mut has_missing = false; + for term_ord in self.entries.into_iter() { + if term_ord == u64::MAX { + has_missing = true; + } else { + if !term_dict.ord_to_str(term_ord, &mut buffer)? { + return Err(TantivyError::InternalError(format!( + "Couldn't find term_ord {term_ord} in dict" + ))); + } + self.cardinality.sketch.insert_any(&buffer); + } + } + if has_missing { + let missing_key = self + .missing + .as_ref() + .expect("Found placeholder term_ord but `missing` is None"); + match missing_key { + Key::Str(missing) => { + self.cardinality.sketch.insert_any(&missing); + } + Key::F64(val) => { + let val = f64_to_u64(*val); + self.cardinality.sketch.insert_any(&val); + } + } + } + } + + Ok(IntermediateMetricResult::Cardinality(self.cardinality)) + } +} + +impl SegmentAggregationCollector for SegmentCardinalityCollector { + fn add_intermediate_aggregation_result( + self: Box, + agg_with_accessor: &AggregationsWithAccessor, + results: &mut IntermediateAggregationResults, + ) -> crate::Result<()> { + let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string(); + let agg_with_accessor = &agg_with_accessor.aggs.values[self.accessor_idx]; + + let intermediate_result = self.into_intermediate_metric_result(agg_with_accessor)?; + results.push( + name, + IntermediateAggregationResult::Metric(intermediate_result), + )?; + + Ok(()) + } + + fn collect( + &mut self, + doc: crate::DocId, + agg_with_accessor: &mut AggregationsWithAccessor, + ) -> crate::Result<()> { + self.collect_block(&[doc], agg_with_accessor) + } + + fn collect_block( + &mut self, + docs: &[crate::DocId], + agg_with_accessor: &mut AggregationsWithAccessor, + ) -> crate::Result<()> { + let bucket_agg_accessor = &mut agg_with_accessor.aggs.values[self.accessor_idx]; + self.fetch_block_with_field(docs, bucket_agg_accessor); + + let col_block_accessor = &bucket_agg_accessor.column_block_accessor; + if self.column_type == ColumnType::Str { + for term_ord in col_block_accessor.iter_vals() { + self.entries.insert(term_ord); + } + } else if self.column_type == ColumnType::IpAddr { + let compact_space_accessor = bucket_agg_accessor + .accessor + .values + .clone() + .downcast_arc::() + .map_err(|_| { + TantivyError::AggregationError( + crate::aggregation::AggregationError::InternalError( + "Type mismatch: Could not downcast to CompactSpaceU64Accessor" + .to_string(), + ), + ) + })?; + for val in col_block_accessor.iter_vals() { + let val: u128 = compact_space_accessor.compact_to_u128(val as u32); + self.cardinality.sketch.insert_any(&val); + } + } else { + for val in col_block_accessor.iter_vals() { + self.cardinality.sketch.insert_any(&val); + } + } + + Ok(()) + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +/// The percentiles collector used during segment collection and for merging results. +pub struct CardinalityCollector { + sketch: HyperLogLogPlus, +} +impl Default for CardinalityCollector { + fn default() -> Self { + Self::new(0) + } +} + +impl PartialEq for CardinalityCollector { + fn eq(&self, _other: &Self) -> bool { + false + } +} + +impl CardinalityCollector { + /// Compute the final cardinality estimate. + pub fn finalize(self) -> Option { + Some(self.sketch.clone().count().trunc()) + } + + fn new(salt: u8) -> Self { + Self { + sketch: HyperLogLogPlus::new(16, BuildSaltedHasher { salt }).unwrap(), + } + } + + pub(crate) fn merge_fruits(&mut self, right: CardinalityCollector) -> crate::Result<()> { + self.sketch.merge(&right.sketch).map_err(|err| { + TantivyError::AggregationError(AggregationError::InternalError(format!( + "Error while merging cardinality {err:?}" + ))) + })?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + + use std::net::IpAddr; + use std::str::FromStr; + + use columnar::MonotonicallyMappableToU64; + + use crate::aggregation::agg_req::Aggregations; + use crate::aggregation::tests::{exec_request, get_test_index_from_terms}; + use crate::schema::{IntoIpv6Addr, Schema, FAST}; + use crate::Index; + + #[test] + fn cardinality_aggregation_test_empty_index() -> crate::Result<()> { + let values = vec![]; + let index = get_test_index_from_terms(false, &values)?; + let agg_req: Aggregations = serde_json::from_value(json!({ + "cardinality": { + "cardinality": { + "field": "string_id", + } + }, + })) + .unwrap(); + + let res = exec_request(agg_req, &index)?; + assert_eq!(res["cardinality"]["value"], 0.0); + + Ok(()) + } + + #[test] + fn cardinality_aggregation_test_single_segment() -> crate::Result<()> { + cardinality_aggregation_test_merge_segment(true) + } + #[test] + fn cardinality_aggregation_test() -> crate::Result<()> { + cardinality_aggregation_test_merge_segment(false) + } + fn cardinality_aggregation_test_merge_segment(merge_segments: bool) -> crate::Result<()> { + let segment_and_terms = vec![ + vec!["terma"], + vec!["termb"], + vec!["termc"], + vec!["terma"], + vec!["terma"], + vec!["terma"], + vec!["termb"], + vec!["terma"], + ]; + let index = get_test_index_from_terms(merge_segments, &segment_and_terms)?; + let agg_req: Aggregations = serde_json::from_value(json!({ + "cardinality": { + "cardinality": { + "field": "string_id", + } + }, + })) + .unwrap(); + + let res = exec_request(agg_req, &index)?; + assert_eq!(res["cardinality"]["value"], 3.0); + + Ok(()) + } + + #[test] + fn cardinality_aggregation_u64() -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + let id_field = schema_builder.add_u64_field("id", FAST); + let index = Index::create_in_ram(schema_builder.build()); + { + let mut writer = index.writer_for_tests()?; + writer.add_document(doc!(id_field => 1u64))?; + writer.add_document(doc!(id_field => 2u64))?; + writer.add_document(doc!(id_field => 3u64))?; + writer.add_document(doc!())?; + writer.commit()?; + } + + let agg_req: Aggregations = serde_json::from_value(json!({ + "cardinality": { + "cardinality": { + "field": "id", + "missing": 0u64 + }, + } + })) + .unwrap(); + + let res = exec_request(agg_req, &index)?; + assert_eq!(res["cardinality"]["value"], 4.0); + + Ok(()) + } + + #[test] + fn cardinality_aggregation_ip_addr() -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + let field = schema_builder.add_ip_addr_field("ip_field", FAST); + let index = Index::create_in_ram(schema_builder.build()); + { + let mut writer = index.writer_for_tests()?; + // IpV6 loopback + writer.add_document(doc!(field=>IpAddr::from_str("::1").unwrap().into_ipv6_addr()))?; + writer.add_document(doc!(field=>IpAddr::from_str("::1").unwrap().into_ipv6_addr()))?; + // IpV4 + writer.add_document( + doc!(field=>IpAddr::from_str("127.0.0.1").unwrap().into_ipv6_addr()), + )?; + writer.commit()?; + } + + let agg_req: Aggregations = serde_json::from_value(json!({ + "cardinality": { + "cardinality": { + "field": "ip_field" + }, + } + })) + .unwrap(); + + let res = exec_request(agg_req, &index)?; + assert_eq!(res["cardinality"]["value"], 2.0); + + Ok(()) + } + + #[test] + fn cardinality_aggregation_json() -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + let field = schema_builder.add_json_field("json", FAST); + let index = Index::create_in_ram(schema_builder.build()); + { + let mut writer = index.writer_for_tests()?; + writer.add_document(doc!(field => json!({"value": false})))?; + writer.add_document(doc!(field => json!({"value": true})))?; + writer.add_document(doc!(field => json!({"value": i64::from_u64(0u64)})))?; + writer.add_document(doc!(field => json!({"value": i64::from_u64(1u64)})))?; + writer.commit()?; + } + + let agg_req: Aggregations = serde_json::from_value(json!({ + "cardinality": { + "cardinality": { + "field": "json.value" + }, + } + })) + .unwrap(); + + let res = exec_request(agg_req, &index)?; + assert_eq!(res["cardinality"]["value"], 4.0); + + Ok(()) + } +} diff --git a/src/aggregation/metric/mod.rs b/src/aggregation/metric/mod.rs index 4c64737b5f..9d470bc228 100644 --- a/src/aggregation/metric/mod.rs +++ b/src/aggregation/metric/mod.rs @@ -17,6 +17,7 @@ //! - [Percentiles](PercentilesAggregationReq) mod average; +mod cardinality; mod count; mod extended_stats; mod max; @@ -29,6 +30,7 @@ mod top_hits; use std::collections::HashMap; pub use average::*; +pub use cardinality::*; pub use count::*; pub use extended_stats::*; pub use max::*; diff --git a/src/aggregation/segment_agg_result.rs b/src/aggregation/segment_agg_result.rs index cc0aab5a67..5023b943db 100644 --- a/src/aggregation/segment_agg_result.rs +++ b/src/aggregation/segment_agg_result.rs @@ -16,7 +16,10 @@ use super::metric::{ SumAggregation, }; use crate::aggregation::bucket::TermMissingAgg; -use crate::aggregation::metric::{SegmentExtendedStatsCollector, TopHitsSegmentCollector}; +use crate::aggregation::metric::{ + CardinalityAggregationReq, SegmentCardinalityCollector, SegmentExtendedStatsCollector, + TopHitsSegmentCollector, +}; pub(crate) trait SegmentAggregationCollector: CollectorClone + Debug { fn add_intermediate_aggregation_result( @@ -169,6 +172,9 @@ pub(crate) fn build_single_agg_segment_collector( accessor_idx, req.segment_ordinal, ))), + Cardinality(CardinalityAggregationReq { missing, .. }) => Ok(Box::new( + SegmentCardinalityCollector::from_req(req.field_type, accessor_idx, missing), + )), } }