diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs index 92dd91bd86bc..5c5d66aef444 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs @@ -22,6 +22,7 @@ pub mod accumulate; pub mod bool_op; pub mod nulls; pub mod prim_op; +pub mod string_op; use arrow::{ array::{ArrayRef, AsArray, BooleanArray, PrimitiveArray}, diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/string_op.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/string_op.rs new file mode 100644 index 000000000000..2b95ed106513 --- /dev/null +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/string_op.rs @@ -0,0 +1,249 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +use arrow::array::{ + Array, ArrayRef, AsArray, BinaryBuilder, BinaryViewBuilder, BooleanArray, +}; +use datafusion_common::{DataFusionError, Result}; +use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; +use std::sync::Arc; + +pub struct StringGroupsAccumulator { + states: Vec, + fun: F, +} + +impl StringGroupsAccumulator +where + F: Fn(&[u8], &[u8]) -> bool + Send + Sync, +{ + pub fn new(s_fn: F) -> Self { + Self { + states: Vec::new(), + fun: s_fn, + } + } +} + +impl GroupsAccumulator for StringGroupsAccumulator +where + F: Fn(&[u8], &[u8]) -> bool + Send + Sync, +{ + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + if self.states.len() < total_num_groups { + self.states.resize(total_num_groups, String::new()); + } + + let input_array = &values[0]; + + for (i, &group_index) in group_indices.iter().enumerate() { + invoke_accumulator::(self, input_array, opt_filter, group_index, i) + } + + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let states = emit_to.take_needed(&mut self.states); + + let array = if VIEW { + let mut builder = BinaryViewBuilder::new(); + + for value in states { + if value.is_empty() { + builder.append_null(); + } else { + builder.append_value(value.as_bytes()); + } + } + + Arc::new(builder.finish()) as ArrayRef + } else { + let mut builder = BinaryBuilder::new(); + + for value in states { + if value.is_empty() { + builder.append_null(); + } else { + builder.append_value(value.as_bytes()); + } + } + + Arc::new(builder.finish()) as ArrayRef + }; + + Ok(array) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + let states = emit_to.take_needed(&mut self.states); + + let array = if VIEW { + let mut builder = BinaryViewBuilder::new(); + + for value in states { + if value.is_empty() { + builder.append_null(); + } else { + builder.append_value(value.as_bytes()); + } + } + + Arc::new(builder.finish()) as ArrayRef + } else { + let mut builder = BinaryBuilder::new(); + + for value in states { + if value.is_empty() { + builder.append_null(); + } else { + builder.append_value(value.as_bytes()); + } + } + + Arc::new(builder.finish()) as ArrayRef + }; + + Ok(vec![array]) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + if self.states.len() < total_num_groups { + self.states.resize(total_num_groups, String::new()); + } + + let input_array = &values[0]; + + for (i, &group_index) in group_indices.iter().enumerate() { + invoke_accumulator::(self, input_array, opt_filter, group_index, i) + } + + Ok(()) + } + + fn convert_to_state( + &self, + values: &[ArrayRef], + opt_filter: Option<&BooleanArray>, + ) -> Result> { + let input_array = &values[0]; + + if opt_filter.is_none() { + return Ok(vec![Arc::::clone(input_array)]); + } + + let filter = opt_filter.unwrap(); + + let array = if VIEW { + let mut builder = BinaryViewBuilder::new(); + + for i in 0..values.len() { + let value = input_array.as_binary_view().value(i); + + if !filter.value(i) { + builder.append_null(); + continue; + } + + if value.is_empty() { + builder.append_null(); + } else { + builder.append_value(value); + } + } + + Arc::new(builder.finish()) as ArrayRef + } else { + let mut builder = BinaryBuilder::new(); + + for i in 0..values.len() { + let value = input_array.as_binary::().value(i); + + if !filter.value(i) { + builder.append_null(); + continue; + } + + if value.is_empty() { + builder.append_null(); + } else { + builder.append_value(value); + } + } + + Arc::new(builder.finish()) as ArrayRef + }; + + Ok(vec![array]) + } + + fn supports_convert_to_state(&self) -> bool { + true + } + + fn size(&self) -> usize { + self.states.iter().map(|s| s.len()).sum() + } +} + +fn invoke_accumulator( + accumulator: &mut StringGroupsAccumulator, + input_array: &ArrayRef, + opt_filter: Option<&BooleanArray>, + group_index: usize, + i: usize, +) where + F: Fn(&[u8], &[u8]) -> bool + Send + Sync, +{ + if let Some(filter) = opt_filter { + if !filter.value(i) { + return; + } + } + if input_array.is_null(i) { + return; + } + + let value: &[u8] = if VIEW { + input_array.as_binary_view().value(i) + } else { + input_array.as_binary::().value(i) + }; + + let value_str = std::str::from_utf8(value) + .map_err(|e| DataFusionError::Execution(format!("could not build utf8 {}", e))) + .expect("failed to build utf8"); + + if accumulator.states[group_index].is_empty() { + accumulator.states[group_index] = value_str.to_string(); + } else { + let curr_value_bytes = accumulator.states[group_index].as_bytes(); + if (accumulator.fun)(value, curr_value_bytes) { + accumulator.states[group_index] = value_str.parse().unwrap(); + } + } +} diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index 961e8639604c..605c17f9327e 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -53,6 +53,7 @@ use datafusion_common::{ downcast_value, exec_err, internal_err, DataFusionError, Result, }; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::string_op::StringGroupsAccumulator; use std::fmt::Debug; use arrow::datatypes::i256; @@ -128,6 +129,14 @@ macro_rules! instantiate_max_accumulator { }}; } +macro_rules! instantiate_max_string_accumulator { + ($VIEW:expr) => {{ + Ok(Box::new(StringGroupsAccumulator::<_, $VIEW>::new( + |a, b| a > b, + ))) + }}; +} + /// Creates a [`PrimitiveGroupsAccumulator`] for computing `MIN` /// the specified [`ArrowPrimitiveType`]. /// @@ -147,6 +156,14 @@ macro_rules! instantiate_min_accumulator { }}; } +macro_rules! instantiate_min_string_accumulator { + ($VIEW:expr) => {{ + Ok(Box::new(StringGroupsAccumulator::<_, $VIEW>::new( + |a, b| a < b, + ))) + }}; +} + impl AggregateUDFImpl for Max { fn as_any(&self) -> &dyn std::any::Any { self @@ -193,6 +210,7 @@ impl AggregateUDFImpl for Max { | Time32(_) | Time64(_) | Timestamp(_, _) + | BinaryView ) } @@ -253,6 +271,12 @@ impl AggregateUDFImpl for Max { Decimal256(_, _) => { instantiate_max_accumulator!(data_type, i256, Decimal256Type) } + BinaryView => { + instantiate_max_string_accumulator!(true) + } + Binary => { + instantiate_max_string_accumulator!(false) + } // It would be nice to have a fast implementation for Strings as well // https://github.com/apache/datafusion/issues/6906 @@ -972,6 +996,8 @@ impl AggregateUDFImpl for Min { | Time32(_) | Time64(_) | Timestamp(_, _) + | BinaryView + | Binary ) } @@ -1032,6 +1058,12 @@ impl AggregateUDFImpl for Min { Decimal256(_, _) => { instantiate_min_accumulator!(data_type, i256, Decimal256Type) } + BinaryView => { + instantiate_min_string_accumulator!(true) + } + Binary => { + instantiate_min_string_accumulator!(false) + } // It would be nice to have a fast implementation for Strings as well // https://github.com/apache/datafusion/issues/6906 diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 60efc7711216..87b9aa69459a 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -48,14 +48,13 @@ use datafusion_expr::{EmitTo, GroupsAccumulator}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::{GroupsAccumulatorAdapter, PhysicalSortExpr}; +use super::order::GroupOrdering; +use super::AggregateExec; use datafusion_physical_expr::aggregate::AggregateFunctionExpr; use futures::ready; use futures::stream::{Stream, StreamExt}; use log::debug; -use super::order::GroupOrdering; -use super::AggregateExec; - #[derive(Debug, Clone)] /// This object tracks the aggregation phase (input/output) pub(crate) enum ExecutionState {