Skip to content

Commit 4a35fad

Browse files
committed
resolve review comments
1 parent 5bfc141 commit 4a35fad

File tree

1 file changed

+6
-221
lines changed
  • datafusion/core/src/physical_plan/aggregates

1 file changed

+6
-221
lines changed

datafusion/core/src/physical_plan/aggregates/row_hash.rs

+6-221
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,9 @@ use arrow::array::*;
4343
use arrow::compute::{cast, filter};
4444
use arrow::datatypes::{DataType, Schema, UInt32Type};
4545
use arrow::{compute, datatypes::SchemaRef, record_batch::RecordBatch};
46-
use arrow_array::types::{
47-
Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt64Type, UInt8Type,
48-
};
49-
use arrow_schema::{IntervalUnit, TimeUnit};
50-
use datafusion_common::cast::{
51-
as_boolean_array, as_decimal128_array, as_fixed_size_binary_array,
52-
as_fixed_size_list_array, as_list_array, as_struct_array,
53-
};
54-
use datafusion_common::scalar::get_dict_value;
46+
use datafusion_common::cast::as_boolean_array;
5547
use datafusion_common::utils::get_arrayref_at_indices;
56-
use datafusion_common::{DataFusionError, Result, ScalarValue};
48+
use datafusion_common::{Result, ScalarValue};
5749
use datafusion_expr::Accumulator;
5850
use datafusion_row::accessor::RowAccessor;
5951
use datafusion_row::layout::RowLayout;
@@ -866,23 +858,6 @@ fn slice_and_maybe_filter(
866858
Ok(filtered_arrays)
867859
}
868860

869-
macro_rules! typed_cast_to_scalar {
870-
($array:expr, $index:expr, $ARRAYTYPE:ident, $SCALAR:ident) => {{
871-
let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap();
872-
Ok(ScalarValue::$SCALAR(Some(array.value($index).into())))
873-
}};
874-
}
875-
876-
macro_rules! typed_cast_tz_to_scalar {
877-
($array:expr, $index:expr, $ARRAYTYPE:ident, $SCALAR:ident, $TZ:expr) => {{
878-
let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap();
879-
Ok(ScalarValue::$SCALAR(
880-
Some(array.value($index).into()),
881-
$TZ.clone(),
882-
))
883-
}};
884-
}
885-
886861
/// This method is similar to Scalar::try_from_array except for the Null handling.
887862
/// This method returns [ScalarValue::Null] instead of [ScalarValue::Type(None)]
888863
fn col_to_scalar(
@@ -898,199 +873,9 @@ fn col_to_scalar(
898873
return Ok(ScalarValue::Null);
899874
}
900875
}
901-
match array.data_type() {
902-
DataType::Null => Ok(ScalarValue::Null),
903-
DataType::Boolean => {
904-
typed_cast_to_scalar!(array, row_index, BooleanArray, Boolean)
905-
}
906-
DataType::Int8 => typed_cast_to_scalar!(array, row_index, Int8Array, Int8),
907-
DataType::Int16 => typed_cast_to_scalar!(array, row_index, Int16Array, Int16),
908-
DataType::Int32 => typed_cast_to_scalar!(array, row_index, Int32Array, Int32),
909-
DataType::Int64 => typed_cast_to_scalar!(array, row_index, Int64Array, Int64),
910-
DataType::UInt8 => typed_cast_to_scalar!(array, row_index, UInt8Array, UInt8),
911-
DataType::UInt16 => {
912-
typed_cast_to_scalar!(array, row_index, UInt16Array, UInt16)
913-
}
914-
DataType::UInt32 => {
915-
typed_cast_to_scalar!(array, row_index, UInt32Array, UInt32)
916-
}
917-
DataType::UInt64 => {
918-
typed_cast_to_scalar!(array, row_index, UInt64Array, UInt64)
919-
}
920-
DataType::Float32 => {
921-
typed_cast_to_scalar!(array, row_index, Float32Array, Float32)
922-
}
923-
DataType::Float64 => {
924-
typed_cast_to_scalar!(array, row_index, Float64Array, Float64)
925-
}
926-
DataType::Decimal128(p, s) => {
927-
let array = as_decimal128_array(array)?;
928-
Ok(ScalarValue::Decimal128(
929-
Some(array.value(row_index)),
930-
*p,
931-
*s,
932-
))
933-
}
934-
DataType::Binary => {
935-
typed_cast_to_scalar!(array, row_index, BinaryArray, Binary)
936-
}
937-
DataType::LargeBinary => {
938-
typed_cast_to_scalar!(array, row_index, LargeBinaryArray, LargeBinary)
939-
}
940-
DataType::Utf8 => typed_cast_to_scalar!(array, row_index, StringArray, Utf8),
941-
DataType::LargeUtf8 => {
942-
typed_cast_to_scalar!(array, row_index, LargeStringArray, LargeUtf8)
943-
}
944-
DataType::List(nested_type) => {
945-
let list_array = as_list_array(array)?;
946-
947-
let nested_array = list_array.value(row_index);
948-
let scalar_vec = (0..nested_array.len())
949-
.map(|i| ScalarValue::try_from_array(&nested_array, i))
950-
.collect::<Result<Vec<_>>>()?;
951-
let value = Some(scalar_vec);
952-
Ok(ScalarValue::new_list(
953-
value,
954-
nested_type.data_type().clone(),
955-
))
956-
}
957-
DataType::Date32 => {
958-
typed_cast_to_scalar!(array, row_index, Date32Array, Date32)
959-
}
960-
DataType::Date64 => {
961-
typed_cast_to_scalar!(array, row_index, Date64Array, Date64)
962-
}
963-
DataType::Time32(TimeUnit::Second) => {
964-
typed_cast_to_scalar!(array, row_index, Time32SecondArray, Time32Second)
965-
}
966-
DataType::Time32(TimeUnit::Millisecond) => typed_cast_to_scalar!(
967-
array,
968-
row_index,
969-
Time32MillisecondArray,
970-
Time32Millisecond
971-
),
972-
DataType::Time64(TimeUnit::Microsecond) => typed_cast_to_scalar!(
973-
array,
974-
row_index,
975-
Time64MicrosecondArray,
976-
Time64Microsecond
977-
),
978-
DataType::Time64(TimeUnit::Nanosecond) => typed_cast_to_scalar!(
979-
array,
980-
row_index,
981-
Time64NanosecondArray,
982-
Time64Nanosecond
983-
),
984-
DataType::Timestamp(TimeUnit::Second, tz_opt) => typed_cast_tz_to_scalar!(
985-
array,
986-
row_index,
987-
TimestampSecondArray,
988-
TimestampSecond,
989-
tz_opt
990-
),
991-
DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => {
992-
typed_cast_tz_to_scalar!(
993-
array,
994-
row_index,
995-
TimestampMillisecondArray,
996-
TimestampMillisecond,
997-
tz_opt
998-
)
999-
}
1000-
DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => {
1001-
typed_cast_tz_to_scalar!(
1002-
array,
1003-
row_index,
1004-
TimestampMicrosecondArray,
1005-
TimestampMicrosecond,
1006-
tz_opt
1007-
)
1008-
}
1009-
DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => {
1010-
typed_cast_tz_to_scalar!(
1011-
array,
1012-
row_index,
1013-
TimestampNanosecondArray,
1014-
TimestampNanosecond,
1015-
tz_opt
1016-
)
1017-
}
1018-
DataType::Dictionary(key_type, _) => {
1019-
let (values_array, values_index) = match key_type.as_ref() {
1020-
DataType::Int8 => get_dict_value::<Int8Type>(array, row_index),
1021-
DataType::Int16 => get_dict_value::<Int16Type>(array, row_index),
1022-
DataType::Int32 => get_dict_value::<Int32Type>(array, row_index),
1023-
DataType::Int64 => get_dict_value::<Int64Type>(array, row_index),
1024-
DataType::UInt8 => get_dict_value::<UInt8Type>(array, row_index),
1025-
DataType::UInt16 => get_dict_value::<UInt16Type>(array, row_index),
1026-
DataType::UInt32 => get_dict_value::<UInt32Type>(array, row_index),
1027-
DataType::UInt64 => get_dict_value::<UInt64Type>(array, row_index),
1028-
_ => unreachable!("Invalid dictionary keys type: {:?}", key_type),
1029-
};
1030-
// look up the index in the values dictionary
1031-
match values_index {
1032-
Some(values_index) => {
1033-
let value = ScalarValue::try_from_array(values_array, values_index)?;
1034-
Ok(ScalarValue::Dictionary(key_type.clone(), Box::new(value)))
1035-
}
1036-
// else entry was null, so return null
1037-
None => Ok(ScalarValue::Null),
1038-
}
1039-
}
1040-
DataType::Struct(fields) => {
1041-
let array = as_struct_array(array)?;
1042-
let mut field_values: Vec<ScalarValue> = Vec::new();
1043-
for col_index in 0..array.num_columns() {
1044-
let col_array = array.column(col_index);
1045-
let col_scalar = ScalarValue::try_from_array(col_array, row_index)?;
1046-
field_values.push(col_scalar);
1047-
}
1048-
Ok(ScalarValue::Struct(Some(field_values), fields.clone()))
1049-
}
1050-
DataType::FixedSizeList(nested_type, _len) => {
1051-
let list_array = as_fixed_size_list_array(array)?;
1052-
match list_array.is_null(row_index) {
1053-
true => Ok(ScalarValue::Null),
1054-
false => {
1055-
let nested_array = list_array.value(row_index);
1056-
let scalar_vec = (0..nested_array.len())
1057-
.map(|i| ScalarValue::try_from_array(&nested_array, i))
1058-
.collect::<Result<Vec<_>>>()?;
1059-
Ok(ScalarValue::new_list(
1060-
Some(scalar_vec),
1061-
nested_type.data_type().clone(),
1062-
))
1063-
}
1064-
}
1065-
}
1066-
DataType::FixedSizeBinary(_) => {
1067-
let array = as_fixed_size_binary_array(array)?;
1068-
let size = match array.data_type() {
1069-
DataType::FixedSizeBinary(size) => *size,
1070-
_ => unreachable!(),
1071-
};
1072-
Ok(ScalarValue::FixedSizeBinary(
1073-
size,
1074-
Some(array.value(row_index).into()),
1075-
))
1076-
}
1077-
DataType::Interval(IntervalUnit::DayTime) => {
1078-
typed_cast_to_scalar!(array, row_index, IntervalDayTimeArray, IntervalDayTime)
1079-
}
1080-
DataType::Interval(IntervalUnit::YearMonth) => typed_cast_to_scalar!(
1081-
array,
1082-
row_index,
1083-
IntervalYearMonthArray,
1084-
IntervalYearMonth
1085-
),
1086-
DataType::Interval(IntervalUnit::MonthDayNano) => typed_cast_to_scalar!(
1087-
array,
1088-
row_index,
1089-
IntervalMonthDayNanoArray,
1090-
IntervalMonthDayNano
1091-
),
1092-
other => Err(DataFusionError::NotImplemented(format!(
1093-
"GroupedHashAggregate: can't create a scalar from array of type \"{other:?}\""
1094-
))),
876+
let mut res = ScalarValue::try_from_array(array, row_index)?;
877+
if res.is_null() {
878+
res = ScalarValue::Null;
1095879
}
880+
Ok(res)
1096881
}

0 commit comments

Comments
 (0)