From faac90b94272b09db41eead577d3e95ee953e6df Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Sun, 24 Nov 2024 15:10:57 +0100 Subject: [PATCH] Make list.rs non generic & simplify the code --- native/spark-expr/src/list.rs | 88 ++++++++++++----------------------- 1 file changed, 29 insertions(+), 59 deletions(-) diff --git a/native/spark-expr/src/list.rs b/native/spark-expr/src/list.rs index 7dc17b5688..e6680318e4 100644 --- a/native/spark-expr/src/list.rs +++ b/native/spark-expr/src/list.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use arrow::array::as_list_array; use arrow::{ array::{as_primitive_array, Capacities, MutableArrayData}, buffer::{NullBuffer, OffsetBuffer}, @@ -22,14 +23,13 @@ use arrow::{ record_batch::RecordBatch, }; use arrow_array::{ - make_array, Array, ArrayRef, GenericListArray, Int32Array, OffsetSizeTrait, StructArray, + make_array, Array, ArrayRef, GenericListArray, Int32Array, ListArray, StructArray, }; use arrow_schema::{DataType, Field, FieldRef, Schema}; use datafusion::logical_expr::ColumnarValue; use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion_common::{ - cast::{as_int32_array, as_large_list_array, as_list_array}, - internal_err, DataFusionError, Result as DataFusionResult, ScalarValue, + cast::as_int32_array, internal_err, DataFusionError, Result as DataFusionResult, ScalarValue, }; use datafusion_physical_expr::PhysicalExpr; use std::{ @@ -72,7 +72,7 @@ impl ListExtract { fn child_field(&self, input_schema: &Schema) -> DataFusionResult { match self.child.data_type(input_schema)? { - DataType::List(field) | DataType::LargeList(field) => Ok(field), + DataType::List(field) => Ok(field), data_type => Err(DataFusionError::Internal(format!( "Unexpected data type in ListExtract: {:?}", data_type @@ -127,19 +127,7 @@ impl PhysicalExpr for ListExtract { match child_value.data_type() { DataType::List(_) => { - let list_array = as_list_array(&child_value)?; - let index_array = as_int32_array(&ordinal_value)?; - - list_extract( - list_array, - index_array, - &default_value, - self.fail_on_error, - adjust_index, - ) - } - DataType::LargeList(_) => { - let list_array = as_large_list_array(&child_value)?; + let list_array = as_list_array(&child_value); let index_array = as_int32_array(&ordinal_value)?; list_extract( @@ -220,8 +208,8 @@ fn zero_based_index(index: i32, len: usize) -> DataFusionResult> { } } -fn list_extract( - list_array: &GenericListArray, +fn list_extract( + list_array: &ListArray, index_array: &Int32Array, default_value: &ScalarValue, fail_on_error: bool, @@ -329,7 +317,6 @@ impl PhysicalExpr for GetArrayStructFields { let struct_field = self.child_field(input_schema)?; match self.child.data_type(input_schema)? { DataType::List(_) => Ok(DataType::List(struct_field)), - DataType::LargeList(_) => Ok(DataType::LargeList(struct_field)), data_type => Err(DataFusionError::Internal(format!( "Unexpected data type in GetArrayStructFields: {:?}", data_type @@ -347,12 +334,7 @@ impl PhysicalExpr for GetArrayStructFields { match child_value.data_type() { DataType::List(_) => { - let list_array = as_list_array(&child_value)?; - - get_array_struct_fields(list_array, self.ordinal) - } - DataType::LargeList(_) => { - let list_array = as_large_list_array(&child_value)?; + let list_array = as_list_array(&child_value); get_array_struct_fields(list_array, self.ordinal) } @@ -388,8 +370,8 @@ impl PhysicalExpr for GetArrayStructFields { } } -fn get_array_struct_fields( - list_array: &GenericListArray, +fn get_array_struct_fields( + list_array: &ListArray, ordinal: usize, ) -> DataFusionResult { let values = list_array @@ -452,7 +434,6 @@ impl ArrayInsert { pub fn array_type(&self, data_type: &DataType) -> DataFusionResult { match data_type { DataType::List(field) => Ok(DataType::List(Arc::clone(field))), - DataType::LargeList(field) => Ok(DataType::LargeList(Arc::clone(field))), data_type => Err(DataFusionError::Internal(format!( "Unexpected src array type in ArrayInsert: {:?}", data_type @@ -497,7 +478,6 @@ impl PhysicalExpr for ArrayInsert { let src_element_type = match self.array_type(src_value.data_type())? { DataType::List(field) => &field.data_type().clone(), - DataType::LargeList(field) => &field.data_type().clone(), _ => unreachable!(), }; @@ -514,27 +494,13 @@ impl PhysicalExpr for ArrayInsert { ))); } - match src_value.data_type() { - DataType::List(_) => { - let list_array = as_list_array(&src_value)?; - array_insert( - list_array, - &item_value, - &pos_value, - self.legacy_negative_index, - ) - } - DataType::LargeList(_) => { - let list_array = as_large_list_array(&src_value)?; - array_insert( - list_array, - &item_value, - &pos_value, - self.legacy_negative_index, - ) - } - _ => unreachable!(), // This case is checked already - } + let list_array = as_list_array(&src_value); + array_insert( + list_array, + &item_value, + &pos_value, + self.legacy_negative_index, + ) } fn children(&self) -> Vec<&Arc> { @@ -566,8 +532,8 @@ impl PhysicalExpr for ArrayInsert { } } -fn array_insert( - list_array: &GenericListArray, +fn array_insert( + list_array: &ListArray, items_array: &ArrayRef, pos_array: &ArrayRef, legacy_mode: bool, @@ -587,7 +553,7 @@ fn array_insert( let mut mutable_values = MutableArrayData::with_capacities(vec![&values_data, &item_data], true, new_capacity); - let mut new_offsets = vec![O::usize_as(0)]; + let mut new_offsets: Vec = vec![0]; let mut new_nulls = Vec::::with_capacity(list_array.len()); let pos_data: &Int32Array = as_primitive_array(&pos_array); // Spark supports only i32 for positions @@ -601,7 +567,7 @@ fn array_insert( if list_array.is_null(row_index) { // In Spark if value of the array is NULL than nothing happens mutable_values.extend_nulls(1); - new_offsets.push(new_offsets[row_index] + O::one()); + new_offsets.push(new_offsets[row_index] + 1); new_nulls.push(false); continue; } @@ -630,14 +596,17 @@ fn array_insert( mutable_values.extend(0, start, start + corrected_pos); mutable_values.extend(1, row_index, row_index + 1); mutable_values.extend(0, start + corrected_pos, end); - new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len)); + // new_array_len is less than MAX_ROUNDED_ARRAY_LENGTH that is less than i32 max value + new_offsets.push(new_offsets[row_index] + i32::from_usize(new_array_len).unwrap()); } else { mutable_values.extend(0, start, end); mutable_values.extend_nulls(new_array_len - (end - start)); mutable_values.extend(1, row_index, row_index + 1); // In that case spark actualy makes array longer than expected; // For example, if pos is equal to 5, len is eq to 3, than resulted len will be 5 - new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len) + O::one()); + new_offsets + .push(new_offsets[row_index] + i32::from_usize(new_array_len).unwrap() + 1); + // new_array_len is less than MAX_ROUNDED_ARRAY_LENGTH that is less than i32 max value } } else { // This comment is takes from the Apache Spark source code as is: @@ -655,7 +624,8 @@ fn array_insert( mutable_values.extend(1, row_index, row_index + 1); mutable_values.extend_nulls(new_array_len - (end - start + 1)); mutable_values.extend(0, start, end); - new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len)); + // new_array_len is less than MAX_ROUNDED_ARRAY_LENGTH that is less than i32 max value + new_offsets.push(new_offsets[row_index] + i32::from_usize(new_array_len).unwrap()); } if is_item_null { if (start == end) || (values.is_null(row_index)) { @@ -674,7 +644,7 @@ fn array_insert( DataType::LargeList(field) => field.data_type(), _ => unreachable!(), }; - let new_array = GenericListArray::::try_new( + let new_array = ListArray::try_new( Arc::new(Field::new("item", data_type.clone(), true)), OffsetBuffer::new(new_offsets.into()), data,