Skip to content

Commit 6ba1c00

Browse files
committed
new poc2.
1 parent 971d6d3 commit 6ba1c00

File tree

2 files changed

+111
-13
lines changed

2 files changed

+111
-13
lines changed

datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs

Lines changed: 107 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ pub struct NullState<O: GroupIndexOperations> {
7676

7777
row_offset_buffer: Vec<usize>,
7878

79-
values_buffer: MutableBuffer,
79+
row_idxs_buffer: Vec<usize>,
8080

8181
/// phantom data for required type `<O>`
8282
_phantom: PhantomData<O>,
@@ -139,7 +139,7 @@ impl<O: GroupIndexOperations> NullState<O> {
139139
mut value_fn: F,
140140
) where
141141
T: ArrowPrimitiveType + Send,
142-
F: FnMut(&[u32], &[u64], &[usize], &[T::Native]) + Send,
142+
F: FnMut(&[u32], &[u64], &[usize], &[usize]) + Send,
143143
{
144144
// ensure the seen_values is big enough (start everything at
145145
// "not seen" valid)
@@ -149,14 +149,14 @@ impl<O: GroupIndexOperations> NullState<O> {
149149
let block_ids_buffer = &mut self.block_ids_buffer;
150150
let block_offsets_buffer = &mut self.block_offsets_buffer;
151151
let row_offset_buffer = &mut self.row_offset_buffer;
152-
let values_buffer = &mut self.values_buffer;
152+
let row_idxs_buffer = &mut self.row_idxs_buffer;
153153

154154
block_ids_buffer.clear();
155155
block_offsets_buffer.clear();
156156
row_offset_buffer.clear();
157-
values_buffer.clear();
157+
row_idxs_buffer.clear();
158158

159-
accumulate(group_indices, values, opt_filter, |packed_index, value| {
159+
accumulate_blocks(group_indices, values, opt_filter, |packed_index, row_idx| {
160160
let packed_index = packed_index as u64;
161161
let block_id = O::get_block_id(packed_index);
162162
let block_offset = O::get_block_offset(packed_index);
@@ -168,7 +168,7 @@ impl<O: GroupIndexOperations> NullState<O> {
168168
row_offset_buffer.push(block_offsets_buffer.len());
169169
}
170170

171-
values_buffer.push(value);
171+
row_idxs_buffer.push(row_idx);
172172
block_offsets_buffer.push(block_offset);
173173
});
174174
row_offset_buffer.push(block_offsets_buffer.len());
@@ -182,7 +182,7 @@ impl<O: GroupIndexOperations> NullState<O> {
182182
&block_ids_buffer,
183183
&block_offsets_buffer,
184184
&row_offset_buffer,
185-
values_buffer.typed_data(),
185+
&row_idxs_buffer,
186186
);
187187
}
188188

@@ -351,7 +351,7 @@ impl NullStateAdapter {
351351
value_fn: F,
352352
) where
353353
T: ArrowPrimitiveType + Send,
354-
F: FnMut(&[u32], &[u64], &[usize], &[T::Native]) + Send,
354+
F: FnMut(&[u32], &[u64], &[usize], &[usize]) + Send,
355355
{
356356
match self {
357357
NullStateAdapter::Flat(null_state) => null_state.accumulate_blocks(
@@ -493,7 +493,7 @@ impl Default for FlatNullState {
493493
block_ids_buffer: Vec::new(),
494494
block_offsets_buffer: Vec::new(),
495495
row_offset_buffer: Vec::new(),
496-
values_buffer: MutableBuffer::new(0),
496+
row_idxs_buffer: Vec::new(),
497497
_phantom: PhantomData {},
498498
}
499499
}
@@ -529,7 +529,7 @@ impl BlockedNullState {
529529
block_ids_buffer: Vec::new(),
530530
block_offsets_buffer: Vec::new(),
531531
row_offset_buffer: Vec::new(),
532-
values_buffer: MutableBuffer::new(0),
532+
row_idxs_buffer: Vec::new(),
533533
_phantom: PhantomData {},
534534
}
535535
}
@@ -767,6 +767,103 @@ pub fn accumulate<T, F>(
767767
}
768768
}
769769

770+
pub fn accumulate_blocks<T, F>(
771+
group_indices: &[usize],
772+
values: &PrimitiveArray<T>,
773+
opt_filter: Option<&BooleanArray>,
774+
mut value_fn: F,
775+
) where
776+
T: ArrowPrimitiveType + Send,
777+
F: FnMut(usize, usize) + Send,
778+
{
779+
let data: &[T::Native] = values.values();
780+
assert_eq!(data.len(), group_indices.len());
781+
782+
match (values.null_count() > 0, opt_filter) {
783+
// no nulls, no filter,
784+
(false, None) => {
785+
for (row_idx, &group_index) in group_indices.iter().enumerate() {
786+
value_fn(group_index, row_idx);
787+
}
788+
}
789+
// nulls, no filter
790+
(true, None) => {
791+
let nulls = values.nulls().unwrap();
792+
// This is based on (ahem, COPY/PASTE) arrow::compute::aggregate::sum
793+
// iterate over in chunks of 64 bits for more efficient null checking
794+
let group_indices_chunks = group_indices.chunks_exact(64);
795+
let bit_chunks = nulls.inner().bit_chunks();
796+
let group_indices_remainder = group_indices_chunks.remainder();
797+
798+
let mut row_idx = 0;
799+
group_indices_chunks.zip(bit_chunks.iter()).for_each(
800+
|(group_index_chunk, mask)| {
801+
// index_mask has value 1 << i in the loop
802+
let mut index_mask = 1;
803+
group_index_chunk.iter().for_each(|&group_index| {
804+
// valid bit was set, real value
805+
let is_valid = (mask & index_mask) != 0;
806+
if is_valid {
807+
value_fn(group_index, row_idx);
808+
}
809+
index_mask <<= 1;
810+
row_idx += 1;
811+
})
812+
},
813+
);
814+
815+
// handle any remaining bits (after the initial 64)
816+
let remainder_bits = bit_chunks.remainder_bits();
817+
group_indices_remainder
818+
.iter()
819+
.enumerate()
820+
.for_each(|(i, &group_index)| {
821+
let is_valid = remainder_bits & (1 << i) != 0;
822+
if is_valid {
823+
value_fn(group_index, row_idx);
824+
}
825+
row_idx += 1;
826+
});
827+
}
828+
829+
// no nulls, but a filter
830+
(false, Some(filter)) => {
831+
assert_eq!(filter.len(), group_indices.len());
832+
// The performance with a filter could be improved by
833+
// iterating over the filter in chunks, rather than a single
834+
// iterator. TODO file a ticket
835+
group_indices
836+
.iter()
837+
.zip(filter.iter())
838+
.enumerate()
839+
.for_each(|(row_idx, (&group_index, filter_value))| {
840+
if let Some(true) = filter_value {
841+
value_fn(group_index, row_idx);
842+
}
843+
})
844+
}
845+
846+
// both null values and filters
847+
(true, Some(filter)) => {
848+
assert_eq!(filter.len(), group_indices.len());
849+
// The performance with a filter could be improved by
850+
// iterating over the filter in chunks, rather than using
851+
// iterators. TODO file a ticket
852+
filter
853+
.iter()
854+
.zip(group_indices.iter())
855+
.enumerate()
856+
.for_each(|(row_idx, (filter_value, &group_index))| {
857+
if let Some(true) = filter_value {
858+
if values.is_null(row_idx) {
859+
value_fn(group_index, row_idx)
860+
}
861+
}
862+
})
863+
}
864+
}
865+
}
866+
770867
/// Accumulates with multiple accumulate(value) columns. (e.g. `corr(c1, c2)`)
771868
///
772869
/// This method assumes that for any input record index, if any of the value column

datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,20 +112,21 @@ where
112112
// (self.prim_fn)(value, new_value);
113113
// },
114114
// );
115+
let data = values.values();
115116
self.null_state.accumulate_blocks(
116117
group_indices,
117118
values,
118119
opt_filter,
119120
total_num_groups,
120-
|block_ids, block_offsets, row_offsets, new_values| {
121+
|block_ids, block_offsets, row_offsets, row_idxs| {
121122
let iter = block_ids.iter().zip(row_offsets.windows(2));
122123
for (&block_id, row_bound) in iter {
123124
let block = &mut self.values[block_id as usize];
124125
(row_bound[0]..row_bound[1]).for_each(|idx| {
125126
let block_offset = block_offsets[idx];
126127
let value = &mut block[block_offset as usize];
127-
let new_value = new_values[idx];
128-
(self.prim_fn)(value, new_value);
128+
let value_row_idx = row_idxs[idx];
129+
(self.prim_fn)(value, data[value_row_idx]);
129130
});
130131
}
131132
},

0 commit comments

Comments
 (0)