Skip to content

Commit da4c590

Browse files
committed
refactor group index computation.
1 parent 13296c1 commit da4c590

File tree

6 files changed

+88
-111
lines changed

6 files changed

+88
-111
lines changed

datafusion-examples/examples/advanced_udaf.rs

-1
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,6 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator {
258258
opt_filter,
259259
total_num_groups,
260260
|_, group_index, new_value| {
261-
let group_index = group_index as usize;
262261
let prod = &mut self.prods[group_index];
263262
*prod = prod.mul_wrapping(new_value);
264263

datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs

+25-41
Original file line numberDiff line numberDiff line change
@@ -118,19 +118,17 @@ impl<O: GroupIndexOperations> NullState<O> {
118118
mut value_fn: F,
119119
) where
120120
T: ArrowPrimitiveType + Send,
121-
F: FnMut(u32, u64, T::Native) + Send,
121+
F: FnMut(usize, usize, T::Native) + Send,
122122
{
123123
// ensure the seen_values is big enough (start everything at
124124
// "not seen" valid)
125125
let seen_values =
126126
initialize_builder(&mut self.seen_values, total_num_groups, false);
127127
let block_size = self.block_size.unwrap_or_default();
128-
accumulate(group_indices, values, opt_filter, |packed_index, value| {
129-
let packed_index = packed_index as u64;
130-
let block_id = O::get_block_id(packed_index);
131-
let block_offset = O::get_block_offset(packed_index);
132-
let flat_index = O::get_flat_index(block_id, block_offset, block_size);
133-
seen_values.set_bit(flat_index, false);
128+
accumulate(group_indices, values, opt_filter, |group_index, value| {
129+
let block_id = O::get_block_id(group_index, block_size);
130+
let block_offset = O::get_block_offset(group_index, block_size);
131+
seen_values.set_bit(group_index, false);
134132
value_fn(block_id, block_offset, value);
135133
});
136134
}
@@ -153,7 +151,7 @@ impl<O: GroupIndexOperations> NullState<O> {
153151
total_num_groups: usize,
154152
mut value_fn: F,
155153
) where
156-
F: FnMut(u32, u64, bool) + Send,
154+
F: FnMut(usize, usize, bool) + Send,
157155
{
158156
let data = values.values();
159157
assert_eq!(data.len(), group_indices.len());
@@ -170,13 +168,10 @@ impl<O: GroupIndexOperations> NullState<O> {
170168
// if we have previously seen nulls, ensure the null
171169
// buffer is big enough (start everything at valid)
172170
group_indices.iter().zip(data.iter()).for_each(
173-
|(&packed_index, new_value)| {
174-
let packed_index = packed_index as u64;
175-
let block_id = O::get_block_id(packed_index);
176-
let block_offset = O::get_block_offset(packed_index);
177-
let flat_index =
178-
O::get_flat_index(block_id, block_offset, block_size);
179-
seen_values.set_bit(flat_index, true);
171+
|(&group_index, new_value)| {
172+
let block_id = O::get_block_id(group_index, block_size);
173+
let block_offset = O::get_block_offset(group_index, block_size);
174+
seen_values.set_bit(group_index, false);
180175
value_fn(block_id, block_offset, new_value)
181176
},
182177
)
@@ -188,14 +183,11 @@ impl<O: GroupIndexOperations> NullState<O> {
188183
.iter()
189184
.zip(data.iter())
190185
.zip(nulls.iter())
191-
.for_each(|((&packed_index, new_value), is_valid)| {
186+
.for_each(|((&group_index, new_value), is_valid)| {
192187
if is_valid {
193-
let packed_index = packed_index as u64;
194-
let block_id = O::get_block_id(packed_index);
195-
let block_offset = O::get_block_offset(packed_index);
196-
let flat_index =
197-
O::get_flat_index(block_id, block_offset, block_size);
198-
seen_values.set_bit(flat_index, true);
188+
let block_id = O::get_block_id(group_index, block_size);
189+
let block_offset = O::get_block_offset(group_index, block_size);
190+
seen_values.set_bit(group_index, false);
199191
value_fn(block_id, block_offset, new_value);
200192
}
201193
})
@@ -208,14 +200,11 @@ impl<O: GroupIndexOperations> NullState<O> {
208200
.iter()
209201
.zip(data.iter())
210202
.zip(filter.iter())
211-
.for_each(|((&packed_index, new_value), filter_value)| {
203+
.for_each(|((&group_index, new_value), filter_value)| {
212204
if let Some(true) = filter_value {
213-
let packed_index = packed_index as u64;
214-
let block_id = O::get_block_id(packed_index);
215-
let block_offset = O::get_block_offset(packed_index);
216-
let flat_index =
217-
O::get_flat_index(block_id, block_offset, block_size);
218-
seen_values.set_bit(flat_index, true);
205+
let block_id = O::get_block_id(group_index, block_size);
206+
let block_offset = O::get_block_offset(group_index, block_size);
207+
seen_values.set_bit(group_index, false);
219208
value_fn(block_id, block_offset, new_value);
220209
}
221210
})
@@ -227,15 +216,12 @@ impl<O: GroupIndexOperations> NullState<O> {
227216
.iter()
228217
.zip(group_indices.iter())
229218
.zip(values.iter())
230-
.for_each(|((filter_value, &packed_index), new_value)| {
219+
.for_each(|((filter_value, &group_index), new_value)| {
231220
if let Some(true) = filter_value {
232221
if let Some(new_value) = new_value {
233-
let packed_index = packed_index as u64;
234-
let block_id = O::get_block_id(packed_index);
235-
let block_offset = O::get_block_offset(packed_index);
236-
let flat_index =
237-
O::get_flat_index(block_id, block_offset, block_size);
238-
seen_values.set_bit(flat_index, true);
222+
let block_id = O::get_block_id(group_index, block_size);
223+
let block_offset = O::get_block_offset(group_index, block_size);
224+
seen_values.set_bit(group_index, false);
239225
value_fn(block_id, block_offset, new_value);
240226
}
241227
}
@@ -287,7 +273,7 @@ impl NullStateAdapter {
287273
value_fn: F,
288274
) where
289275
T: ArrowPrimitiveType + Send,
290-
F: FnMut(u32, u64, T::Native) + Send,
276+
F: FnMut(usize, usize, T::Native) + Send,
291277
{
292278
match self {
293279
NullStateAdapter::Flat(null_state) => null_state.accumulate(
@@ -315,7 +301,7 @@ impl NullStateAdapter {
315301
total_num_groups: usize,
316302
value_fn: F,
317303
) where
318-
F: FnMut(u32, u64, bool) + Send,
304+
F: FnMut(usize, usize, bool) + Send,
319305
{
320306
match self {
321307
NullStateAdapter::Flat(null_state) => null_state.accumulate_boolean(
@@ -434,9 +420,7 @@ impl Default for FlatNullState {
434420
impl FlatNullState {
435421
pub fn build(&mut self, emit_to: EmitTo) -> NullBuffer {
436422
match emit_to {
437-
EmitTo::All => {
438-
NullBuffer::new(self.seen_values.finish())
439-
}
423+
EmitTo::All => NullBuffer::new(self.seen_values.finish()),
440424
EmitTo::First(n) => {
441425
// split off the first N values in seen_values
442426
//

datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/bool_op.rs

-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ where
9595
opt_filter,
9696
total_num_groups,
9797
|_, group_index, new_value| {
98-
let group_index = group_index as usize;
9998
let current_value = self.values.get_bit(group_index);
10099
let value = (self.bool_fn)(current_value, new_value);
101100
self.values.set_bit(group_index, value);

datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/group_index_operations.rs

+9-29
Original file line numberDiff line numberDiff line change
@@ -49,53 +49,33 @@ use std::fmt::Debug;
4949
/// that is for compatible for `flat group index`'s parsing.
5050
///
5151
pub trait GroupIndexOperations: Debug {
52-
fn pack_index(block_id: u32, block_offset: u64) -> u64;
52+
fn get_block_id(group_index: usize, block_size: usize) -> usize;
5353

54-
fn get_block_id(packed_index: u64) -> u32;
55-
56-
fn get_block_offset(packed_index: u64) -> u64;
57-
58-
fn get_flat_index(block_id: u32, block_offset: u64, block_size: usize) -> usize;
54+
fn get_block_offset(group_index: usize, block_size: usize) -> usize;
5955
}
6056

6157
#[derive(Debug)]
6258
pub struct BlockedGroupIndexOperations;
6359

6460
impl GroupIndexOperations for BlockedGroupIndexOperations {
65-
fn pack_index(block_id: u32, block_offset: u64) -> u64 {
66-
((block_id as u64) << 32) | block_offset
67-
}
68-
69-
fn get_block_id(packed_index: u64) -> u32 {
70-
(packed_index >> 32) as u32
71-
}
72-
73-
fn get_block_offset(packed_index: u64) -> u64 {
74-
(packed_index as u32) as u64
61+
fn get_block_id(group_index: usize, block_size: usize) -> usize {
62+
group_index / block_size
7563
}
7664

77-
fn get_flat_index(block_id: u32, block_offset: u64, block_size: usize) -> usize {
78-
block_id as usize * block_size + block_offset as usize
65+
fn get_block_offset(group_index: usize, block_size: usize) -> usize {
66+
group_index % block_size
7967
}
8068
}
8169

8270
#[derive(Debug)]
8371
pub struct FlatGroupIndexOperations;
8472

8573
impl GroupIndexOperations for FlatGroupIndexOperations {
86-
fn pack_index(_block_id: u32, block_offset: u64) -> u64 {
87-
block_offset
88-
}
89-
90-
fn get_block_id(_packed_index: u64) -> u32 {
74+
fn get_block_id(_group_index: usize, _block_size: usize) -> usize {
9175
0
9276
}
9377

94-
fn get_block_offset(packed_index: u64) -> u64 {
95-
packed_index
96-
}
97-
98-
fn get_flat_index(_block_id: u32, block_offset: u64, _block_size: usize) -> usize {
99-
block_offset as usize
78+
fn get_block_offset(group_index: usize, _block_size: usize) -> usize {
79+
group_index
10080
}
10181
}

datafusion/functions-aggregate/src/average.rs

-1
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,6 @@ where
589589
opt_filter,
590590
total_num_groups,
591591
|_, group_index, new_value| {
592-
let group_index = group_index as usize;
593592
let sum = &mut self.sums[group_index];
594593
*sum = sum.add_wrapping(new_value);
595594

0 commit comments

Comments
 (0)