Skip to content

Commit efcf5c6

Browse files
XiangpengHaoalamb
andauthored
Enable GroupValueBytesView for aggregation with StringView types (#11519)
* add functions * Update `string-view` branch to arrow-rs main (#10966) * Pin to arrow main * Fix clippy with latest arrow * Uncomment test that needs new arrow-rs to work * Update datafusion-cli Cargo.lock * Update Cargo.lock * tapelo * merge * update cast * consistent dep * fix ci * avoid unused dep * update dep * update * fix cargo check * better group value view aggregation * update --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent db65772 commit efcf5c6

File tree

7 files changed

+236
-19
lines changed

7 files changed

+236
-19
lines changed

datafusion/functions-aggregate/src/count.rs

+4
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
// under the License.
1717

1818
use ahash::RandomState;
19+
use datafusion_physical_expr_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator;
1920
use std::collections::HashSet;
2021
use std::ops::BitAnd;
2122
use std::{fmt::Debug, sync::Arc};
@@ -230,6 +231,9 @@ impl AggregateUDFImpl for Count {
230231
DataType::Utf8 => {
231232
Box::new(BytesDistinctCountAccumulator::<i32>::new(OutputType::Utf8))
232233
}
234+
DataType::Utf8View => {
235+
Box::new(BytesViewDistinctCountAccumulator::new(OutputType::Utf8))
236+
}
233237
DataType::LargeUtf8 => {
234238
Box::new(BytesDistinctCountAccumulator::<i64>::new(OutputType::Utf8))
235239
}

datafusion/physical-expr-common/src/aggregate/count_distinct/bytes.rs

+61
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
//! [`BytesDistinctCountAccumulator`] for Utf8/LargeUtf8/Binary/LargeBinary values
1919
2020
use crate::binary_map::{ArrowBytesSet, OutputType};
21+
use crate::binary_view_map::ArrowBytesViewSet;
2122
use arrow::array::{ArrayRef, OffsetSizeTrait};
2223
use datafusion_common::cast::as_list_array;
2324
use datafusion_common::utils::array_into_list_array_nullable;
@@ -88,3 +89,63 @@ impl<O: OffsetSizeTrait> Accumulator for BytesDistinctCountAccumulator<O> {
8889
std::mem::size_of_val(self) + self.0.size()
8990
}
9091
}
92+
93+
/// Specialized implementation of
94+
/// `COUNT DISTINCT` for [`StringViewArray`] and [`BinaryViewArray`].
95+
///
96+
/// [`StringViewArray`]: arrow::array::StringViewArray
97+
/// [`BinaryViewArray`]: arrow::array::BinaryViewArray
98+
#[derive(Debug)]
99+
pub struct BytesViewDistinctCountAccumulator(ArrowBytesViewSet);
100+
101+
impl BytesViewDistinctCountAccumulator {
102+
pub fn new(output_type: OutputType) -> Self {
103+
Self(ArrowBytesViewSet::new(output_type))
104+
}
105+
}
106+
107+
impl Accumulator for BytesViewDistinctCountAccumulator {
108+
fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
109+
let set = self.0.take();
110+
let arr = set.into_state();
111+
let list = Arc::new(array_into_list_array_nullable(arr));
112+
Ok(vec![ScalarValue::List(list)])
113+
}
114+
115+
fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion_common::Result<()> {
116+
if values.is_empty() {
117+
return Ok(());
118+
}
119+
120+
self.0.insert(&values[0]);
121+
122+
Ok(())
123+
}
124+
125+
fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion_common::Result<()> {
126+
if states.is_empty() {
127+
return Ok(());
128+
}
129+
assert_eq!(
130+
states.len(),
131+
1,
132+
"count_distinct states must be single array"
133+
);
134+
135+
let arr = as_list_array(&states[0])?;
136+
arr.iter().try_for_each(|maybe_list| {
137+
if let Some(list) = maybe_list {
138+
self.0.insert(&list);
139+
};
140+
Ok(())
141+
})
142+
}
143+
144+
fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
145+
Ok(ScalarValue::Int64(Some(self.0.non_null_len() as i64)))
146+
}
147+
148+
fn size(&self) -> usize {
149+
std::mem::size_of_val(self) + self.0.size()
150+
}
151+
}

datafusion/physical-expr-common/src/aggregate/count_distinct/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,6 @@ mod bytes;
1919
mod native;
2020

2121
pub use bytes::BytesDistinctCountAccumulator;
22+
pub use bytes::BytesViewDistinctCountAccumulator;
2223
pub use native::FloatDistinctCountAccumulator;
2324
pub use native::PrimitiveDistinctCountAccumulator;

datafusion/physical-expr-common/src/binary_map.rs

+6
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,12 @@ use std::sync::Arc;
4040
pub enum OutputType {
4141
/// `StringArray` or `LargeStringArray`
4242
Utf8,
43+
/// `StringViewArray`
44+
Utf8View,
4345
/// `BinaryArray` or `LargeBinaryArray`
4446
Binary,
47+
/// `BinaryViewArray`
48+
BinaryView,
4549
}
4650

4751
/// HashSet optimized for storing string or binary values that can produce that
@@ -318,6 +322,7 @@ where
318322
observe_payload_fn,
319323
)
320324
}
325+
_ => unreachable!("View types should use `ArrowBytesViewMap`"),
321326
};
322327
}
323328

@@ -516,6 +521,7 @@ where
516521
GenericStringArray::new_unchecked(offsets, values, nulls)
517522
})
518523
}
524+
_ => unreachable!("View types should use `ArrowBytesViewMap`"),
519525
}
520526
}
521527

datafusion/physical-expr-common/src/binary_view_map.rs

+13-8
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,7 @@ use datafusion_common::utils::proxy::{RawTableAllocExt, VecAllocExt};
2828
use std::fmt::Debug;
2929
use std::sync::Arc;
3030

31-
/// Should the output be a String or Binary?
32-
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33-
pub enum OutputType {
34-
/// `StringViewArray`
35-
Utf8View,
36-
/// `BinaryViewArray`
37-
BinaryView,
38-
}
31+
use crate::binary_map::OutputType;
3932

4033
/// HashSet optimized for storing string or binary values that can produce that
4134
/// the final set as a `GenericBinaryViewArray` with minimal copies.
@@ -55,6 +48,14 @@ impl ArrowBytesViewSet {
5548
.insert_if_new(values, make_payload_fn, observe_payload_fn);
5649
}
5750

51+
/// Return the contents of this map and replace it with a new empty map with
52+
/// the same output type
53+
pub fn take(&mut self) -> Self {
54+
let mut new_self = Self::new(self.0.output_type);
55+
std::mem::swap(self, &mut new_self);
56+
new_self
57+
}
58+
5859
/// Converts this set into a `StringViewArray` or `BinaryViewArray`
5960
/// containing each distinct value that was interned.
6061
/// This is done without copying the values.
@@ -216,6 +217,7 @@ where
216217
observe_payload_fn,
217218
)
218219
}
220+
_ => unreachable!("Utf8/Binary should use `ArrowBytesSet`"),
219221
};
220222
}
221223

@@ -327,6 +329,9 @@ where
327329
let array = unsafe { array.to_string_view_unchecked() };
328330
Arc::new(array)
329331
}
332+
_ => {
333+
unreachable!("Utf8/Binary should use `ArrowBytesMap`")
334+
}
330335
}
331336
}
332337

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use crate::aggregates::group_values::GroupValues;
19+
use arrow_array::{Array, ArrayRef, RecordBatch};
20+
use datafusion_expr::EmitTo;
21+
use datafusion_physical_expr::binary_map::OutputType;
22+
use datafusion_physical_expr_common::binary_view_map::ArrowBytesViewMap;
23+
24+
/// A [`GroupValues`] storing single column of Utf8View/BinaryView values
25+
///
26+
/// This specialization is significantly faster than using the more general
27+
/// purpose `Row`s format
28+
pub struct GroupValuesBytesView {
29+
/// Map string/binary values to group index
30+
map: ArrowBytesViewMap<usize>,
31+
/// The total number of groups so far (used to assign group_index)
32+
num_groups: usize,
33+
}
34+
35+
impl GroupValuesBytesView {
36+
pub fn new(output_type: OutputType) -> Self {
37+
Self {
38+
map: ArrowBytesViewMap::new(output_type),
39+
num_groups: 0,
40+
}
41+
}
42+
}
43+
44+
impl GroupValues for GroupValuesBytesView {
45+
fn intern(
46+
&mut self,
47+
cols: &[ArrayRef],
48+
groups: &mut Vec<usize>,
49+
) -> datafusion_common::Result<()> {
50+
assert_eq!(cols.len(), 1);
51+
52+
// look up / add entries in the table
53+
let arr = &cols[0];
54+
55+
groups.clear();
56+
self.map.insert_if_new(
57+
arr,
58+
// called for each new group
59+
|_value| {
60+
// assign new group index on each insert
61+
let group_idx = self.num_groups;
62+
self.num_groups += 1;
63+
group_idx
64+
},
65+
// called for each group
66+
|group_idx| {
67+
groups.push(group_idx);
68+
},
69+
);
70+
71+
// ensure we assigned a group to for each row
72+
assert_eq!(groups.len(), arr.len());
73+
Ok(())
74+
}
75+
76+
fn size(&self) -> usize {
77+
self.map.size() + std::mem::size_of::<Self>()
78+
}
79+
80+
fn is_empty(&self) -> bool {
81+
self.num_groups == 0
82+
}
83+
84+
fn len(&self) -> usize {
85+
self.num_groups
86+
}
87+
88+
fn emit(&mut self, emit_to: EmitTo) -> datafusion_common::Result<Vec<ArrayRef>> {
89+
// Reset the map to default, and convert it into a single array
90+
let map_contents = self.map.take().into_state();
91+
92+
let group_values = match emit_to {
93+
EmitTo::All => {
94+
self.num_groups -= map_contents.len();
95+
map_contents
96+
}
97+
EmitTo::First(n) if n == self.len() => {
98+
self.num_groups -= map_contents.len();
99+
map_contents
100+
}
101+
EmitTo::First(n) => {
102+
// if we only wanted to take the first n, insert the rest back
103+
// into the map we could potentially avoid this reallocation, at
104+
// the expense of much more complex code.
105+
// see https://github.com/apache/datafusion/issues/9195
106+
let emit_group_values = map_contents.slice(0, n);
107+
let remaining_group_values =
108+
map_contents.slice(n, map_contents.len() - n);
109+
110+
self.num_groups = 0;
111+
let mut group_indexes = vec![];
112+
self.intern(&[remaining_group_values], &mut group_indexes)?;
113+
114+
// Verify that the group indexes were assigned in the correct order
115+
assert_eq!(0, group_indexes[0]);
116+
117+
emit_group_values
118+
}
119+
};
120+
121+
Ok(vec![group_values])
122+
}
123+
124+
fn clear_shrink(&mut self, _batch: &RecordBatch) {
125+
// in theory we could potentially avoid this reallocation and clear the
126+
// contents of the maps, but for now we just reset the map from the beginning
127+
self.map.take();
128+
}
129+
}

datafusion/physical-plan/src/aggregates/group_values/mod.rs

+22-11
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
use arrow::record_batch::RecordBatch;
1919
use arrow_array::{downcast_primitive, ArrayRef};
2020
use arrow_schema::{DataType, SchemaRef};
21+
use bytes_view::GroupValuesBytesView;
2122
use datafusion_common::Result;
2223

2324
pub(crate) mod primitive;
@@ -28,6 +29,7 @@ mod row;
2829
use row::GroupValuesRows;
2930

3031
mod bytes;
32+
mod bytes_view;
3133
use bytes::GroupValuesByes;
3234
use datafusion_physical_expr::binary_map::OutputType;
3335

@@ -67,17 +69,26 @@ pub fn new_group_values(schema: SchemaRef) -> Result<Box<dyn GroupValues>> {
6769
_ => {}
6870
}
6971

70-
if let DataType::Utf8 = d {
71-
return Ok(Box::new(GroupValuesByes::<i32>::new(OutputType::Utf8)));
72-
}
73-
if let DataType::LargeUtf8 = d {
74-
return Ok(Box::new(GroupValuesByes::<i64>::new(OutputType::Utf8)));
75-
}
76-
if let DataType::Binary = d {
77-
return Ok(Box::new(GroupValuesByes::<i32>::new(OutputType::Binary)));
78-
}
79-
if let DataType::LargeBinary = d {
80-
return Ok(Box::new(GroupValuesByes::<i64>::new(OutputType::Binary)));
72+
match d {
73+
DataType::Utf8 => {
74+
return Ok(Box::new(GroupValuesByes::<i32>::new(OutputType::Utf8)));
75+
}
76+
DataType::LargeUtf8 => {
77+
return Ok(Box::new(GroupValuesByes::<i64>::new(OutputType::Utf8)));
78+
}
79+
DataType::Utf8View => {
80+
return Ok(Box::new(GroupValuesBytesView::new(OutputType::Utf8View)));
81+
}
82+
DataType::Binary => {
83+
return Ok(Box::new(GroupValuesByes::<i32>::new(OutputType::Binary)));
84+
}
85+
DataType::LargeBinary => {
86+
return Ok(Box::new(GroupValuesByes::<i64>::new(OutputType::Binary)));
87+
}
88+
DataType::BinaryView => {
89+
return Ok(Box::new(GroupValuesBytesView::new(OutputType::BinaryView)));
90+
}
91+
_ => {}
8192
}
8293
}
8394

0 commit comments

Comments
 (0)