Skip to content

Commit c07c26c

Browse files
authored
Fix incorrect results in BitAnd GroupsAccumulator (#6957)
Fix accumulator
1 parent e0cc8c8 commit c07c26c

File tree

3 files changed

+160
-119
lines changed

3 files changed

+160
-119
lines changed

datafusion/core/tests/sqllogictests/test_files/aggregate.slt

+117-67
Original file line numberDiff line numberDiff line change
@@ -1420,65 +1420,95 @@ select var(sq.column1), var_pop(sq.column1), stddev(sq.column1), stddev_pop(sq.c
14201420
2 1 1.414213562373 1
14211421

14221422

1423-
# sum / count for all nulls
1424-
statement ok
1425-
create table the_nulls as values (null::bigint, 1), (null::bigint, 1), (null::bigint, 2);
14261423

1427-
# counts should be zeros (even for nulls)
1428-
query II
1429-
SELECT count(column1), column2 from the_nulls group by column2 order by column2;
1430-
----
1431-
0 1
1432-
0 2
1433-
1434-
# sums should be null
1435-
query II
1436-
SELECT sum(column1), column2 from the_nulls group by column2 order by column2;
1424+
# aggregates on empty tables
1425+
statement ok
1426+
CREATE TABLE empty (column1 bigint, column2 int);
1427+
1428+
# no group by column
1429+
query IIRIIIII
1430+
SELECT
1431+
count(column1), -- counts should be zero, even for nulls
1432+
sum(column1), -- other aggregates should be null
1433+
avg(column1),
1434+
min(column1),
1435+
max(column1),
1436+
bit_and(column1),
1437+
bit_or(column1),
1438+
bit_xor(column1)
1439+
FROM empty
1440+
----
1441+
0 NULL NULL NULL NULL NULL NULL NULL
1442+
1443+
# Same query but with grouping (no groups, so no output)
1444+
query IIRIIIIII
1445+
SELECT
1446+
count(column1),
1447+
sum(column1),
1448+
avg(column1),
1449+
min(column1),
1450+
max(column1),
1451+
bit_and(column1),
1452+
bit_or(column1),
1453+
bit_xor(column1),
1454+
column2
1455+
FROM empty
1456+
GROUP BY column2
1457+
ORDER BY column2;
14371458
----
1438-
NULL 1
1439-
NULL 2
14401459

1441-
# avg should be null
1442-
query RI
1443-
SELECT avg(column1), column2 from the_nulls group by column2 order by column2;
1444-
----
1445-
NULL 1
1446-
NULL 2
14471460

1448-
# bit_and should be null
1449-
query II
1450-
SELECT bit_and(column1), column2 from the_nulls group by column2 order by column2;
1451-
----
1452-
NULL 1
1453-
NULL 2
1461+
statement ok
1462+
drop table empty
14541463

1455-
# bit_or should be null
1456-
query II
1457-
SELECT bit_or(column1), column2 from the_nulls group by column2 order by column2;
1458-
----
1459-
NULL 1
1460-
NULL 2
1464+
# aggregates on all nulls
1465+
statement ok
1466+
CREATE TABLE the_nulls
1467+
AS VALUES
1468+
(null::bigint, 1),
1469+
(null::bigint, 1),
1470+
(null::bigint, 2);
14611471

1462-
# bit_xor should be null
14631472
query II
1464-
SELECT bit_xor(column1), column2 from the_nulls group by column2 order by column2;
1473+
select * from the_nulls
14651474
----
14661475
NULL 1
1467-
NULL 2
1468-
1469-
# min should be null
1470-
query II
1471-
SELECT min(column1), column2 from the_nulls group by column2 order by column2;
1472-
----
14731476
NULL 1
14741477
NULL 2
14751478

1476-
# max should be null
1477-
query II
1478-
SELECT max(column1), column2 from the_nulls group by column2 order by column2;
1479-
----
1480-
NULL 1
1481-
NULL 2
1479+
# no group by column
1480+
query IIRIIIII
1481+
SELECT
1482+
count(column1), -- counts should be zero, even for nulls
1483+
sum(column1), -- other aggregates should be null
1484+
avg(column1),
1485+
min(column1),
1486+
max(column1),
1487+
bit_and(column1),
1488+
bit_or(column1),
1489+
bit_xor(column1)
1490+
FROM the_nulls
1491+
----
1492+
0 NULL NULL NULL NULL NULL NULL NULL
1493+
1494+
# Same query but with grouping
1495+
query IIRIIIIII
1496+
SELECT
1497+
count(column1), -- counts should be zero, even for nulls
1498+
sum(column1), -- other aggregates should be null
1499+
avg(column1),
1500+
min(column1),
1501+
max(column1),
1502+
bit_and(column1),
1503+
bit_or(column1),
1504+
bit_xor(column1),
1505+
column2
1506+
FROM the_nulls
1507+
GROUP BY column2
1508+
ORDER BY column2;
1509+
----
1510+
0 NULL NULL NULL NULL NULL NULL NULL 1
1511+
0 NULL NULL NULL NULL NULL NULL NULL 2
14821512

14831513

14841514
statement ok
@@ -1489,29 +1519,49 @@ create table bit_aggregate_functions (
14891519
c1 SMALLINT NOT NULL,
14901520
c2 SMALLINT NOT NULL,
14911521
c3 SMALLINT,
1522+
tag varchar
14921523
)
14931524
as values
1494-
(5, 10, 11),
1495-
(33, 11, null),
1496-
(9, 12, null);
1497-
1498-
# query_bit_and
1499-
query III
1500-
SELECT bit_and(c1), bit_and(c2), bit_and(c3) FROM bit_aggregate_functions
1501-
----
1502-
1 8 11
1503-
1504-
# query_bit_or
1505-
query III
1506-
SELECT bit_or(c1), bit_or(c2), bit_or(c3) FROM bit_aggregate_functions
1507-
----
1508-
45 15 11
1525+
(5, 10, 11, 'A'),
1526+
(33, 11, null, 'B'),
1527+
(9, 12, null, 'A');
1528+
1529+
# query_bit_and, query_bit_or, query_bit_xor
1530+
query IIIIIIIII
1531+
SELECT
1532+
bit_and(c1),
1533+
bit_and(c2),
1534+
bit_and(c3),
1535+
bit_or(c1),
1536+
bit_or(c2),
1537+
bit_or(c3),
1538+
bit_xor(c1),
1539+
bit_xor(c2),
1540+
bit_xor(c3)
1541+
FROM bit_aggregate_functions
1542+
----
1543+
1 8 11 45 15 11 45 13 11
1544+
1545+
# query_bit_and, query_bit_or, query_bit_xor, with group
1546+
query IIIIIIIIIT
1547+
SELECT
1548+
bit_and(c1),
1549+
bit_and(c2),
1550+
bit_and(c3),
1551+
bit_or(c1),
1552+
bit_or(c2),
1553+
bit_or(c3),
1554+
bit_xor(c1),
1555+
bit_xor(c2),
1556+
bit_xor(c3),
1557+
tag
1558+
FROM bit_aggregate_functions
1559+
GROUP BY tag
1560+
ORDER BY tag
1561+
----
1562+
1 8 11 13 14 11 12 6 11 A
1563+
33 11 NULL 33 11 NULL 33 11 NULL B
15091564

1510-
# query_bit_xor
1511-
query III
1512-
SELECT bit_xor(c1), bit_xor(c2), bit_xor(c3) FROM bit_aggregate_functions
1513-
----
1514-
45 13 11
15151565

15161566
statement ok
15171567
create table bool_aggregate_functions (

datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs

+32-51
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,16 @@ use arrow::compute::{bit_and, bit_or, bit_xor};
4949
use datafusion_row::accessor::RowAccessor;
5050

5151
/// Creates a [`PrimitiveGroupsAccumulator`] with the specified
52-
/// [`ArrowPrimitiveType`] which applies `$FN` to each element
52+
/// [`ArrowPrimitiveType`] that initailizes each accumulator to $START
53+
/// and applies `$FN` to each element
5354
///
5455
/// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType
55-
macro_rules! instantiate_primitive_accumulator {
56-
($SELF:expr, $PRIMTYPE:ident, $FN:expr) => {{
57-
Ok(Box::new(PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new(
58-
&$SELF.data_type,
59-
$FN,
60-
)))
56+
macro_rules! instantiate_accumulator {
57+
($SELF:expr, $START:expr, $PRIMTYPE:ident, $FN:expr) => {{
58+
Ok(Box::new(
59+
PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new(&$SELF.data_type, $FN)
60+
.with_starting_value($START),
61+
))
6162
}};
6263
}
6364

@@ -279,35 +280,31 @@ impl AggregateExpr for BitAnd {
279280
use std::ops::BitAndAssign;
280281
match self.data_type {
281282
DataType::Int8 => {
282-
instantiate_primitive_accumulator!(self, Int8Type, |x, y| x
283-
.bitand_assign(y))
283+
instantiate_accumulator!(self, -1, Int8Type, |x, y| x.bitand_assign(y))
284284
}
285285
DataType::Int16 => {
286-
instantiate_primitive_accumulator!(self, Int16Type, |x, y| x
287-
.bitand_assign(y))
286+
instantiate_accumulator!(self, -1, Int16Type, |x, y| x.bitand_assign(y))
288287
}
289288
DataType::Int32 => {
290-
instantiate_primitive_accumulator!(self, Int32Type, |x, y| x
291-
.bitand_assign(y))
289+
instantiate_accumulator!(self, -1, Int32Type, |x, y| x.bitand_assign(y))
292290
}
293291
DataType::Int64 => {
294-
instantiate_primitive_accumulator!(self, Int64Type, |x, y| x
295-
.bitand_assign(y))
292+
instantiate_accumulator!(self, -1, Int64Type, |x, y| x.bitand_assign(y))
296293
}
297294
DataType::UInt8 => {
298-
instantiate_primitive_accumulator!(self, UInt8Type, |x, y| x
295+
instantiate_accumulator!(self, u8::MAX, UInt8Type, |x, y| x
299296
.bitand_assign(y))
300297
}
301298
DataType::UInt16 => {
302-
instantiate_primitive_accumulator!(self, UInt16Type, |x, y| x
299+
instantiate_accumulator!(self, u16::MAX, UInt16Type, |x, y| x
303300
.bitand_assign(y))
304301
}
305302
DataType::UInt32 => {
306-
instantiate_primitive_accumulator!(self, UInt32Type, |x, y| x
303+
instantiate_accumulator!(self, u32::MAX, UInt32Type, |x, y| x
307304
.bitand_assign(y))
308305
}
309306
DataType::UInt64 => {
310-
instantiate_primitive_accumulator!(self, UInt64Type, |x, y| x
307+
instantiate_accumulator!(self, u64::MAX, UInt64Type, |x, y| x
311308
.bitand_assign(y))
312309
}
313310

@@ -517,36 +514,28 @@ impl AggregateExpr for BitOr {
517514
use std::ops::BitOrAssign;
518515
match self.data_type {
519516
DataType::Int8 => {
520-
instantiate_primitive_accumulator!(self, Int8Type, |x, y| x
521-
.bitor_assign(y))
517+
instantiate_accumulator!(self, 0, Int8Type, |x, y| x.bitor_assign(y))
522518
}
523519
DataType::Int16 => {
524-
instantiate_primitive_accumulator!(self, Int16Type, |x, y| x
525-
.bitor_assign(y))
520+
instantiate_accumulator!(self, 0, Int16Type, |x, y| x.bitor_assign(y))
526521
}
527522
DataType::Int32 => {
528-
instantiate_primitive_accumulator!(self, Int32Type, |x, y| x
529-
.bitor_assign(y))
523+
instantiate_accumulator!(self, 0, Int32Type, |x, y| x.bitor_assign(y))
530524
}
531525
DataType::Int64 => {
532-
instantiate_primitive_accumulator!(self, Int64Type, |x, y| x
533-
.bitor_assign(y))
526+
instantiate_accumulator!(self, 0, Int64Type, |x, y| x.bitor_assign(y))
534527
}
535528
DataType::UInt8 => {
536-
instantiate_primitive_accumulator!(self, UInt8Type, |x, y| x
537-
.bitor_assign(y))
529+
instantiate_accumulator!(self, 0, UInt8Type, |x, y| x.bitor_assign(y))
538530
}
539531
DataType::UInt16 => {
540-
instantiate_primitive_accumulator!(self, UInt16Type, |x, y| x
541-
.bitor_assign(y))
532+
instantiate_accumulator!(self, 0, UInt16Type, |x, y| x.bitor_assign(y))
542533
}
543534
DataType::UInt32 => {
544-
instantiate_primitive_accumulator!(self, UInt32Type, |x, y| x
545-
.bitor_assign(y))
535+
instantiate_accumulator!(self, 0, UInt32Type, |x, y| x.bitor_assign(y))
546536
}
547537
DataType::UInt64 => {
548-
instantiate_primitive_accumulator!(self, UInt64Type, |x, y| x
549-
.bitor_assign(y))
538+
instantiate_accumulator!(self, 0, UInt64Type, |x, y| x.bitor_assign(y))
550539
}
551540

552541
_ => Err(DataFusionError::NotImplemented(format!(
@@ -756,36 +745,28 @@ impl AggregateExpr for BitXor {
756745
use std::ops::BitXorAssign;
757746
match self.data_type {
758747
DataType::Int8 => {
759-
instantiate_primitive_accumulator!(self, Int8Type, |x, y| x
760-
.bitxor_assign(y))
748+
instantiate_accumulator!(self, 0, Int8Type, |x, y| x.bitxor_assign(y))
761749
}
762750
DataType::Int16 => {
763-
instantiate_primitive_accumulator!(self, Int16Type, |x, y| x
764-
.bitxor_assign(y))
751+
instantiate_accumulator!(self, 0, Int16Type, |x, y| x.bitxor_assign(y))
765752
}
766753
DataType::Int32 => {
767-
instantiate_primitive_accumulator!(self, Int32Type, |x, y| x
768-
.bitxor_assign(y))
754+
instantiate_accumulator!(self, 0, Int32Type, |x, y| x.bitxor_assign(y))
769755
}
770756
DataType::Int64 => {
771-
instantiate_primitive_accumulator!(self, Int64Type, |x, y| x
772-
.bitxor_assign(y))
757+
instantiate_accumulator!(self, 0, Int64Type, |x, y| x.bitxor_assign(y))
773758
}
774759
DataType::UInt8 => {
775-
instantiate_primitive_accumulator!(self, UInt8Type, |x, y| x
776-
.bitxor_assign(y))
760+
instantiate_accumulator!(self, 0, UInt8Type, |x, y| x.bitxor_assign(y))
777761
}
778762
DataType::UInt16 => {
779-
instantiate_primitive_accumulator!(self, UInt16Type, |x, y| x
780-
.bitxor_assign(y))
763+
instantiate_accumulator!(self, 0, UInt16Type, |x, y| x.bitxor_assign(y))
781764
}
782765
DataType::UInt32 => {
783-
instantiate_primitive_accumulator!(self, UInt32Type, |x, y| x
784-
.bitxor_assign(y))
766+
instantiate_accumulator!(self, 0, UInt32Type, |x, y| x.bitxor_assign(y))
785767
}
786768
DataType::UInt64 => {
787-
instantiate_primitive_accumulator!(self, UInt64Type, |x, y| x
788-
.bitxor_assign(y))
769+
instantiate_accumulator!(self, 0, UInt64Type, |x, y| x.bitxor_assign(y))
789770
}
790771

791772
_ => Err(DataFusionError::NotImplemented(format!(

0 commit comments

Comments
 (0)