From 0d4e3199653c57c769c61f04046f3cb3182a5d2f Mon Sep 17 00:00:00 2001 From: PSeitz Date: Wed, 31 Jul 2024 21:29:32 +0900 Subject: [PATCH] add Key::I64 and Key::U64 variants in aggregation (#2468) * add Key::I64 and Key::U64 variants in aggregation Currently all `Key` numerical values are returned as f64. This causes problems in some cases with the precision and the way f64 is serialized. This PR adds `Key::I64` and `Key::U64` variants and uses them in the term aggregation. * add clarification comment --- columnar/src/value.rs | 25 +++++++ src/aggregation/agg_req_with_accessor.rs | 37 +++++++--- src/aggregation/agg_tests.rs | 61 +++++++++++++++- src/aggregation/bucket/term_agg.rs | 82 ++++++++++++++++++++-- src/aggregation/bucket/term_missing_agg.rs | 1 - src/aggregation/intermediate_agg_result.rs | 12 +++- src/aggregation/metric/cardinality.rs | 9 ++- src/aggregation/mod.rs | 10 +++ 8 files changed, 220 insertions(+), 17 deletions(-) diff --git a/columnar/src/value.rs b/columnar/src/value.rs index f8a83bbeb8..81b5367ab9 100644 --- a/columnar/src/value.rs +++ b/columnar/src/value.rs @@ -17,6 +17,31 @@ impl NumericalValue { NumericalValue::F64(_) => NumericalType::F64, } } + + /// Tries to normalize the numerical value in the following priorities: + /// i64, i64, f64 + pub fn normalize(self) -> Self { + match self { + NumericalValue::U64(val) => { + if val <= i64::MAX as u64 { + NumericalValue::I64(val as i64) + } else { + NumericalValue::F64(val as f64) + } + } + NumericalValue::I64(val) => NumericalValue::I64(val), + NumericalValue::F64(val) => { + let fract = val.fract(); + if fract == 0.0 && val >= i64::MIN as f64 && val <= i64::MAX as f64 { + NumericalValue::I64(val as i64) + } else if fract == 0.0 && val >= u64::MIN as f64 && val <= u64::MAX as f64 { + NumericalValue::U64(val as u64) + } else { + NumericalValue::F64(val) + } + } + } + } } impl From for NumericalValue { diff --git a/src/aggregation/agg_req_with_accessor.rs b/src/aggregation/agg_req_with_accessor.rs index 6ac6f7591f..bd7528d023 100644 --- a/src/aggregation/agg_req_with_accessor.rs +++ b/src/aggregation/agg_req_with_accessor.rs @@ -186,6 +186,8 @@ impl AggregationWithAccessor { .map(|missing| match missing { Key::Str(_) => ColumnType::Str, Key::F64(_) => ColumnType::F64, + Key::I64(_) => ColumnType::I64, + Key::U64(_) => ColumnType::U64, }) .unwrap_or(ColumnType::U64); let column_and_types = get_all_ff_reader_or_empty( @@ -232,13 +234,16 @@ impl AggregationWithAccessor { missing.clone() }; - let missing_value_for_accessor = if let Some(missing) = - missing_value_term_agg.as_ref() - { - get_missing_val(column_type, missing, agg.agg.get_fast_field_names()[0])? - } else { - None - }; + let missing_value_for_accessor = + if let Some(missing) = missing_value_term_agg.as_ref() { + get_missing_val_as_u64_lenient( + column_type, + missing, + agg.agg.get_fast_field_names()[0], + )? + } else { + None + }; let agg = AggregationWithAccessor { segment_ordinal, @@ -330,7 +335,14 @@ impl AggregationWithAccessor { } } -fn get_missing_val( +/// Get the missing value as internal u64 representation +/// +/// For terms we use u64::MAX as sentinel value +/// For numerical data we convert the value into the representation +/// we would get from the fast field, when we open it as u64_lenient_for_type. +/// +/// That way we can use it the same way as if it would come from the fastfield. +fn get_missing_val_as_u64_lenient( column_type: ColumnType, missing: &Key, field_name: &str, @@ -339,9 +351,18 @@ fn get_missing_val( Key::Str(_) if column_type == ColumnType::Str => Some(u64::MAX), // Allow fallback to number on text fields Key::F64(_) if column_type == ColumnType::Str => Some(u64::MAX), + Key::U64(_) if column_type == ColumnType::Str => Some(u64::MAX), + Key::I64(_) if column_type == ColumnType::Str => Some(u64::MAX), Key::F64(val) if column_type.numerical_type().is_some() => { f64_to_fastfield_u64(*val, &column_type) } + // NOTE: We may loose precision of the passed missing value by casting i64 and u64 to f64. + Key::I64(val) if column_type.numerical_type().is_some() => { + f64_to_fastfield_u64(*val as f64, &column_type) + } + Key::U64(val) if column_type.numerical_type().is_some() => { + f64_to_fastfield_u64(*val as f64, &column_type) + } _ => { return Err(crate::TantivyError::InvalidArgument(format!( "Missing value {missing:?} for field {field_name} is not supported for column \ diff --git a/src/aggregation/agg_tests.rs b/src/aggregation/agg_tests.rs index 6a32fc57d7..a4ac827a08 100644 --- a/src/aggregation/agg_tests.rs +++ b/src/aggregation/agg_tests.rs @@ -939,11 +939,11 @@ fn test_aggregation_on_json_object_mixed_types() { }, "termagg": { "buckets": [ - { "doc_count": 1, "key": 10.0, "min_price": { "value": 10.0 } }, + { "doc_count": 1, "key": 10, "min_price": { "value": 10.0 } }, { "doc_count": 3, "key": "blue", "min_price": { "value": 5.0 } }, { "doc_count": 2, "key": "red", "min_price": { "value": 1.0 } }, { "doc_count": 1, "key": -20.5, "min_price": { "value": -20.5 } }, - { "doc_count": 2, "key": 1.0, "key_as_string": "true", "min_price": { "value": null } }, + { "doc_count": 2, "key": 1, "key_as_string": "true", "min_price": { "value": null } }, ], "sum_other_doc_count": 0 } @@ -951,3 +951,60 @@ fn test_aggregation_on_json_object_mixed_types() { ) ); } + +#[test] +fn test_aggregation_on_json_object_mixed_numerical_segments() { + let mut schema_builder = Schema::builder(); + let json = schema_builder.add_json_field("json", FAST); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut index_writer: IndexWriter = index.writer_for_tests().unwrap(); + // => Segment with all values f64 numeric + index_writer + .add_document(doc!(json => json!({"mixed_price": 10.5}))) + .unwrap(); + // Gets converted to f64! + index_writer + .add_document(doc!(json => json!({"mixed_price": 10}))) + .unwrap(); + index_writer.commit().unwrap(); + // => Segment with all values i64 numeric + index_writer + .add_document(doc!(json => json!({"mixed_price": 10}))) + .unwrap(); + index_writer.commit().unwrap(); + + index_writer.commit().unwrap(); + + // All bucket types + let agg_req_str = r#" + { + "termagg": { + "terms": { + "field": "json.mixed_price" + } + } + } "#; + let agg: Aggregations = serde_json::from_str(agg_req_str).unwrap(); + let aggregation_collector = get_collector(agg); + let reader = index.reader().unwrap(); + let searcher = reader.searcher(); + + let aggregation_results = searcher.search(&AllQuery, &aggregation_collector).unwrap(); + let aggregation_res_json = serde_json::to_value(aggregation_results).unwrap(); + use pretty_assertions::assert_eq; + assert_eq!( + &aggregation_res_json, + &serde_json::json!({ + "termagg": { + "buckets": [ + { "doc_count": 2, "key": 10}, + { "doc_count": 1, "key": 10.5}, + ], + "doc_count_error_upper_bound": 0, + "sum_other_doc_count": 0 + } + } + ) + ); +} diff --git a/src/aggregation/bucket/term_agg.rs b/src/aggregation/bucket/term_agg.rs index d98c33c0bd..75a23761b2 100644 --- a/src/aggregation/bucket/term_agg.rs +++ b/src/aggregation/bucket/term_agg.rs @@ -3,7 +3,9 @@ use std::io; use std::net::Ipv6Addr; use columnar::column_values::CompactSpaceU64Accessor; -use columnar::{ColumnType, Dictionary, MonotonicallyMappableToU128, MonotonicallyMappableToU64}; +use columnar::{ + ColumnType, Dictionary, MonotonicallyMappableToU128, MonotonicallyMappableToU64, NumericalValue, +}; use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; @@ -19,7 +21,7 @@ use crate::aggregation::intermediate_agg_result::{ use crate::aggregation::segment_agg_result::{ build_segment_agg_collector, SegmentAggregationCollector, }; -use crate::aggregation::{f64_from_fastfield_u64, format_date, Key}; +use crate::aggregation::{format_date, Key}; use crate::error::DataCorruption; use crate::TantivyError; @@ -497,6 +499,12 @@ impl SegmentTermCollector { Key::F64(val) => { dict.insert(IntermediateKey::F64(*val), intermediate_entry); } + Key::U64(val) => { + dict.insert(IntermediateKey::U64(*val), intermediate_entry); + } + Key::I64(val) => { + dict.insert(IntermediateKey::I64(*val), intermediate_entry); + } } entries.swap_remove(index); @@ -583,8 +591,26 @@ impl SegmentTermCollector { } else { for (val, doc_count) in entries { let intermediate_entry = into_intermediate_bucket_entry(val, doc_count)?; - let val = f64_from_fastfield_u64(val, &self.column_type); - dict.insert(IntermediateKey::F64(val), intermediate_entry); + if self.column_type == ColumnType::U64 { + dict.insert(IntermediateKey::U64(val), intermediate_entry); + } else if self.column_type == ColumnType::I64 { + dict.insert(IntermediateKey::I64(i64::from_u64(val)), intermediate_entry); + } else { + let val = f64::from_u64(val); + let val: NumericalValue = val.into(); + + match val.normalize() { + NumericalValue::U64(val) => { + dict.insert(IntermediateKey::U64(val), intermediate_entry); + } + NumericalValue::I64(val) => { + dict.insert(IntermediateKey::I64(val), intermediate_entry); + } + NumericalValue::F64(val) => { + dict.insert(IntermediateKey::F64(val), intermediate_entry); + } + } + }; } }; @@ -1719,6 +1745,54 @@ mod tests { Ok(()) } + #[test] + fn terms_aggregation_u64_value() -> crate::Result<()> { + // Make sure that large u64 are not truncated + let mut schema_builder = Schema::builder(); + let id_field = schema_builder.add_u64_field("id", FAST); + let index = Index::create_in_ram(schema_builder.build()); + { + let mut index_writer = index.writer_with_num_threads(1, 20_000_000)?; + index_writer.set_merge_policy(Box::new(NoMergePolicy)); + index_writer.add_document(doc!( + id_field => 9_223_372_036_854_775_807u64, + ))?; + index_writer.add_document(doc!( + id_field => 1_769_070_189_829_214_202u64, + ))?; + index_writer.add_document(doc!( + id_field => 1_769_070_189_829_214_202u64, + ))?; + index_writer.commit()?; + } + + let agg_req: Aggregations = serde_json::from_value(json!({ + "my_ids": { + "terms": { + "field": "id" + }, + } + })) + .unwrap(); + + let res = exec_request_with_query(agg_req, &index, None)?; + + // id field + assert_eq!( + res["my_ids"]["buckets"][0]["key"], + 1_769_070_189_829_214_202u64 + ); + assert_eq!(res["my_ids"]["buckets"][0]["doc_count"], 2); + assert_eq!( + res["my_ids"]["buckets"][1]["key"], + 9_223_372_036_854_775_807u64 + ); + assert_eq!(res["my_ids"]["buckets"][1]["doc_count"], 1); + assert_eq!(res["my_ids"]["buckets"][2]["key"], serde_json::Value::Null); + + Ok(()) + } + #[test] fn terms_aggregation_missing1() -> crate::Result<()> { let mut schema_builder = Schema::builder(); diff --git a/src/aggregation/bucket/term_missing_agg.rs b/src/aggregation/bucket/term_missing_agg.rs index bb8b295b49..df24eee123 100644 --- a/src/aggregation/bucket/term_missing_agg.rs +++ b/src/aggregation/bucket/term_missing_agg.rs @@ -70,7 +70,6 @@ impl SegmentAggregationCollector for TermMissingAgg { )?; missing_entry.sub_aggregation = res; } - entries.insert(missing.into(), missing_entry); let bucket = IntermediateBucketResult::Terms { diff --git a/src/aggregation/intermediate_agg_result.rs b/src/aggregation/intermediate_agg_result.rs index 8162a7eb6d..f4cef1a51a 100644 --- a/src/aggregation/intermediate_agg_result.rs +++ b/src/aggregation/intermediate_agg_result.rs @@ -51,12 +51,18 @@ pub enum IntermediateKey { Str(String), /// `f64` key F64(f64), + /// `i64` key + I64(i64), + /// `u64` key + U64(u64), } impl From for IntermediateKey { fn from(value: Key) -> Self { match value { Key::Str(s) => Self::Str(s), Key::F64(f) => Self::F64(f), + Key::U64(f) => Self::U64(f), + Key::I64(f) => Self::I64(f), } } } @@ -73,7 +79,9 @@ impl From for Key { } } IntermediateKey::F64(f) => Self::F64(f), - IntermediateKey::Bool(f) => Self::F64(f as u64 as f64), + IntermediateKey::Bool(f) => Self::U64(f as u64), + IntermediateKey::U64(f) => Self::U64(f), + IntermediateKey::I64(f) => Self::I64(f), } } } @@ -86,6 +94,8 @@ impl std::hash::Hash for IntermediateKey { match self { IntermediateKey::Str(text) => text.hash(state), IntermediateKey::F64(val) => val.to_bits().hash(state), + IntermediateKey::U64(val) => val.hash(state), + IntermediateKey::I64(val) => val.hash(state), IntermediateKey::Bool(val) => val.hash(state), IntermediateKey::IpAddr(val) => val.hash(state), } diff --git a/src/aggregation/metric/cardinality.rs b/src/aggregation/metric/cardinality.rs index 63b8605900..4f494b486f 100644 --- a/src/aggregation/metric/cardinality.rs +++ b/src/aggregation/metric/cardinality.rs @@ -179,10 +179,11 @@ impl SegmentCardinalityCollector { Ok(()) })?; if has_missing { + // Replace missing with the actual value provided let missing_key = self .missing .as_ref() - .expect("Found placeholder term_ord but `missing` is None"); + .expect("Found sentinel value u64::MAX for term_ord but `missing` is not set"); match missing_key { Key::Str(missing) => { self.cardinality.sketch.insert_any(&missing); @@ -191,6 +192,12 @@ impl SegmentCardinalityCollector { let val = f64_to_u64(*val); self.cardinality.sketch.insert_any(&val); } + Key::U64(val) => { + self.cardinality.sketch.insert_any(&val); + } + Key::I64(val) => { + self.cardinality.sketch.insert_any(&val); + } } } } diff --git a/src/aggregation/mod.rs b/src/aggregation/mod.rs index 72a37703c2..f5f7d3142a 100644 --- a/src/aggregation/mod.rs +++ b/src/aggregation/mod.rs @@ -336,10 +336,16 @@ pub type SerializedKey = String; #[derive(Clone, Debug, Serialize, Deserialize, PartialOrd)] /// The key to identify a bucket. +/// +/// The order is important, with serde untagged, that we try to deserialize into i64 first. #[serde(untagged)] pub enum Key { /// String key Str(String), + /// `i64` key + I64(i64), + /// `u64` key + U64(u64), /// `f64` key F64(f64), } @@ -350,6 +356,8 @@ impl std::hash::Hash for Key { match self { Key::Str(text) => text.hash(state), Key::F64(val) => val.to_bits().hash(state), + Key::U64(val) => val.hash(state), + Key::I64(val) => val.hash(state), } } } @@ -369,6 +377,8 @@ impl Display for Key { match self { Key::Str(val) => f.write_str(val), Key::F64(val) => f.write_str(&val.to_string()), + Key::U64(val) => f.write_str(&val.to_string()), + Key::I64(val) => f.write_str(&val.to_string()), } } }