From bcf7151f06f2995e74cd2a2ed167c14f77bf2c2b Mon Sep 17 00:00:00 2001 From: comphead Date: Wed, 23 Apr 2025 16:25:14 -0700 Subject: [PATCH 1/4] feat: support `array_repeat` --- .../apache/comet/serde/QueryPlanSerde.scala | 3 ++- .../scala/org/apache/comet/serde/arrays.scala | 14 ++++++++++++ .../comet/CometArrayExpressionSuite.scala | 22 +++++++++++++++++++ 3 files changed, 38 insertions(+), 1 deletion(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 325cf15a1d..6d2faa246c 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -1979,6 +1979,7 @@ object QueryPlanSerde extends Logging with CometExprShim { case _: ArrayIntersect => convert(CometArrayIntersect) case _: ArrayJoin => convert(CometArrayJoin) case _: ArraysOverlap => convert(CometArraysOverlap) + case _: ArrayRepeat => convert(CometArrayRepeat) case _ @ArrayFilter(_, func) if func.children.head.isInstanceOf[IsNotNull] => convert(CometArrayCompact) case _: ArrayExcept => @@ -3068,7 +3069,7 @@ trait CometAggregateExpressionSerde { * Convert a Spark expression into a protocol buffer representation that can be passed into * native code. * - * @param expr + * @param aggExpr * The aggregate expression. * @param expr * The aggregate function. diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index 0e96b543d2..89b3b605b0 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -179,6 +179,20 @@ object CometArraysOverlap extends CometExpressionSerde with IncompatExpr { } } +object CometArrayRepeat extends CometExpressionSerde with IncompatExpr { + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val leftArrayExprProto = exprToProto(expr.children.head, inputs, binding) + val rightArrayExprProto = exprToProto(expr.children(1), inputs, binding) + + val arraysRepeatScalarExpr = + scalarExprToProto("array_repeat", leftArrayExprProto, rightArrayExprProto) + optExprWithInfo(arraysRepeatScalarExpr, expr, expr.children: _*) + } +} + object CometArrayCompact extends CometExpressionSerde with IncompatExpr { override def convert( expr: Expression, diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index 6a787345d7..ada7803ed8 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -410,4 +410,26 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp } } + test("array_repeat") { + withSQLConf( + CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true", + CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.key -> "true") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllTypes(path, dictionaryEnabled, 10000) + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + spark.sql("select * from t1").printSchema() + + checkSparkAnswerAndOperator(sql("SELECT array_repeat(_4, null) from t1")) +// checkSparkAnswerAndOperator( +// sql("SELECT array_repeat(_2, 5) from t1 where _2 is not null")) +// checkSparkAnswerAndOperator(sql("SELECT array_repeat(_3, 2) from t1 where _3 is null")) +// checkSparkAnswerAndOperator(sql("SELECT array_repeat(_3, _3) from t1 where _3 is null")) +// checkSparkAnswerAndOperator(sql("SELECT array_repeat(cast(_3 as string), 2) from t1")) +// checkSparkAnswerAndOperator(sql("SELECT array_repeat(array(_2, _3, _4), 2) from t1")) + } + } + } + } } From 455b9a4e1837d8175b1b3ae08b9372f13b69fd6f Mon Sep 17 00:00:00 2001 From: comphead Date: Sun, 27 Apr 2025 10:11:47 -0700 Subject: [PATCH 2/4] feat: support `array_repeat` --- native/core/src/execution/operators/scan.rs | 6 +- native/core/src/execution/planner.rs | 132 ++++++++++- .../src/array_funcs/array_repeat.rs | 216 ++++++++++++++++++ native/spark-expr/src/array_funcs/mod.rs | 2 + native/spark-expr/src/comet_scalar_funcs.rs | 11 +- .../comet/CometArrayExpressionSuite.scala | 15 +- .../comet/exec/CometNativeReaderSuite.scala | 65 ++++++ 7 files changed, 432 insertions(+), 15 deletions(-) create mode 100644 native/spark-expr/src/array_funcs/array_repeat.rs diff --git a/native/core/src/execution/operators/scan.rs b/native/core/src/execution/operators/scan.rs index 47d25901cf..c94c2be37b 100644 --- a/native/core/src/execution/operators/scan.rs +++ b/native/core/src/execution/operators/scan.rs @@ -497,14 +497,14 @@ pub enum InputBatch { /// The end of input batches. EOF, - /// A normal batch with columns and number of rows. - /// It is possible to have zero-column batch with non-zero number of rows, + /// A normal batch with columns and a number of rows. + /// It is possible to have a zero-column batch with a non-zero number of rows, /// i.e. reading empty schema from scan. Batch(Vec, usize), } impl InputBatch { - /// Constructs a `InputBatch` from columns and optional number of rows. + /// Constructs an ` InputBatch ` from columns and an optional number of rows. /// If `num_rows` is none, this function will calculate it from given /// columns. pub fn new(columns: Vec, num_rows: Option) -> Self { diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index d3646fc761..02a5fb0abe 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -2505,7 +2505,7 @@ mod tests { use futures::{poll, StreamExt}; - use arrow::array::{DictionaryArray, Int32Array, StringArray}; + use arrow::array::{Array, DictionaryArray, Int32Array, StringArray}; use arrow::datatypes::DataType; use datafusion::logical_expr::ScalarUDF; use datafusion::{assert_batches_eq, physical_plan::common::collect, prelude::SessionContext}; @@ -2912,7 +2912,6 @@ mod tests { // Separate thread to send the EOF signal once we've processed the only input batch runtime.spawn(async move { - // Create a dictionary array with 100 values, and use it as input to the execution. let a = Int32Array::from(vec![0, 3]); let b = Int32Array::from(vec![1, 4]); let c = Int32Array::from(vec![2, 5]); @@ -2953,4 +2952,133 @@ mod tests { } }); } + + #[test] + fn test_array_repeat() { + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); + let planner = PhysicalPlanner::new(Arc::from(session_ctx)); + + // Mock scan operator with 3 INT32 columns + let op_scan = Operator { + plan_id: 0, + children: vec![], + op_struct: Some(OpStruct::Scan(spark_operator::Scan { + fields: vec![ + spark_expression::DataType { + type_id: 3, // Int32 + type_info: None, + }, + spark_expression::DataType { + type_id: 3, // Int32 + type_info: None, + }, + spark_expression::DataType { + type_id: 3, // Int32 + type_info: None, + }, + ], + source: "".to_string(), + })), + }; + + // Mock expression to read a INT32 column with position 0 + let array_col = spark_expression::Expr { + expr_struct: Some(Bound(spark_expression::BoundReference { + index: 0, + datatype: Some(spark_expression::DataType { + type_id: 3, + type_info: None, + }), + })), + }; + + // Mock expression to read a INT32 column with position 1 + let array_col_1 = spark_expression::Expr { + expr_struct: Some(Bound(spark_expression::BoundReference { + index: 1, + datatype: Some(spark_expression::DataType { + type_id: 3, + type_info: None, + }), + })), + }; + + // Make a projection operator with array_repeat(array_col, array_col_1) + let projection = Operator { + children: vec![op_scan], + plan_id: 0, + op_struct: Some(OpStruct::Projection(spark_operator::Projection { + project_list: vec![spark_expression::Expr { + expr_struct: Some(ExprStruct::ScalarFunc(spark_expression::ScalarFunc { + func: "array_repeat".to_string(), + args: vec![array_col, array_col_1], + return_type: None, + })), + }], + })), + }; + + // Create a physical plan + let (mut scans, datafusion_plan) = + planner.create_plan(&projection, &mut vec![], 1).unwrap(); + + // Feed the data into plan + //scans[0].set_input_batch(input_batch); + + // Start executing the plan in a separate thread + // The plan waits for incoming batches and emitting result as input comes + let mut stream = datafusion_plan.native_plan.execute(0, task_ctx).unwrap(); + + let runtime = tokio::runtime::Runtime::new().unwrap(); + // create async channel + let (tx, mut rx) = mpsc::channel(1); + + // Send data as input to the plan being executed in a separate thread + runtime.spawn(async move { + // create data batch + // 0, 1, 2 + // 3, 4, 5 + // 6, null, null + let a = Int32Array::from(vec![Some(0), Some(3), Some(6)]); + let b = Int32Array::from(vec![Some(1), Some(4), None]); + let c = Int32Array::from(vec![Some(2), Some(5), None]); + let input_batch1 = InputBatch::Batch(vec![Arc::new(a), Arc::new(b), Arc::new(c)], 3); + let input_batch2 = InputBatch::EOF; + + let batches = vec![input_batch1, input_batch2]; + + for batch in batches.into_iter() { + tx.send(batch).await.unwrap(); + } + }); + + // Wait for the plan to finish executing and assert the result + runtime.block_on(async move { + loop { + let batch = rx.recv().await.unwrap(); + scans[0].set_input_batch(batch); + match poll!(stream.next()) { + Poll::Ready(Some(batch)) => { + assert!(batch.is_ok(), "got error {}", batch.unwrap_err()); + let batch = batch.unwrap(); + let expected = [ + "+--------------+", + "| col_0 |", + "+--------------+", + "| [0] |", + "| [3, 3, 3, 3] |", + "| |", + "+--------------+", + ]; + assert_batches_eq!(expected, &[batch]); + } + Poll::Ready(None) => { + break; + } + _ => {} + } + } + }); + } } diff --git a/native/spark-expr/src/array_funcs/array_repeat.rs b/native/spark-expr/src/array_funcs/array_repeat.rs new file mode 100644 index 0000000000..7ba8f0b910 --- /dev/null +++ b/native/spark-expr/src/array_funcs/array_repeat.rs @@ -0,0 +1,216 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + new_null_array, Array, ArrayRef, Capacities, GenericListArray, ListArray, MutableArrayData, + NullBufferBuilder, OffsetSizeTrait, UInt64Array, +}; +use arrow::buffer::OffsetBuffer; +use arrow::compute; +use arrow::compute::cast; +use arrow::datatypes::DataType::{LargeList, List}; +use arrow::datatypes::{DataType, Field}; +use datafusion::common::cast::{as_large_list_array, as_list_array, as_uint64_array}; +use datafusion::common::{exec_err, DataFusionError, ScalarValue}; +use datafusion::logical_expr::ColumnarValue; +use std::sync::Arc; + +pub fn make_scalar_function( + inner: F, +) -> impl Fn(&[ColumnarValue]) -> Result +where + F: Fn(&[ArrayRef]) -> Result, +{ + move |args: &[ColumnarValue]| { + // first, identify if any of the arguments is an Array. If yes, store its `len`, + // as any scalar will need to be converted to an array of len `len`. + let len = args + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }); + + let is_scalar = len.is_none(); + + let args = ColumnarValue::values_to_arrays(args)?; + + let result = (inner)(&args); + + if is_scalar { + // If all inputs are scalar, keeps output as scalar + let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); + result.map(ColumnarValue::Scalar) + } else { + result.map(ColumnarValue::Array) + } + } +} + +pub fn spark_array_repeat(args: &[ColumnarValue]) -> Result { + make_scalar_function(spark_array_repeat_inner)(args) +} + +/// Array_repeat SQL function +fn spark_array_repeat_inner(args: &[ArrayRef]) -> datafusion::common::Result { + let element = &args[0]; + let count_array = &args[1]; + + let count_array = match count_array.data_type() { + DataType::Int64 => &cast(count_array, &DataType::UInt64)?, + DataType::UInt64 => count_array, + _ => return exec_err!("count must be an integer type"), + }; + + let count_array = as_uint64_array(count_array)?; + + match element.data_type() { + List(_) => { + let list_array = as_list_array(element)?; + general_list_repeat::(list_array, count_array) + } + LargeList(_) => { + let list_array = as_large_list_array(element)?; + general_list_repeat::(list_array, count_array) + } + _ => general_repeat::(element, count_array), + } +} + +/// For each element of `array[i]` repeat `count_array[i]` times. +/// +/// Assumption for the input: +/// 1. `count[i] >= 0` +/// 2. `array.len() == count_array.len()` +/// +/// For example, +/// ```text +/// array_repeat( +/// [1, 2, 3], [2, 0, 1] => [[1, 1], [], [3]] +/// ) +/// ``` +fn general_repeat( + array: &ArrayRef, + count_array: &UInt64Array, +) -> datafusion::common::Result { + let data_type = array.data_type(); + let mut new_values = vec![]; + + let count_vec = count_array + .values() + .to_vec() + .iter() + .map(|x| *x as usize) + .collect::>(); + + let mut nulls = NullBufferBuilder::new(count_array.len()); + + for (row_index, &count) in count_vec.iter().enumerate() { + nulls.append(!count_array.is_null(row_index)); + let repeated_array = if array.is_null(row_index) { + new_null_array(data_type, count) + } else { + let original_data = array.to_data(); + let capacity = Capacities::Array(count); + let mut mutable = + MutableArrayData::with_capacities(vec![&original_data], false, capacity); + + for _ in 0..count { + mutable.extend(0, row_index, row_index + 1); + } + + let data = mutable.freeze(); + arrow::array::make_array(data) + }; + new_values.push(repeated_array); + } + + let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect(); + let values = compute::concat(&new_values)?; + + Ok(Arc::new(GenericListArray::::try_new( + Arc::new(Field::new_list_field(data_type.to_owned(), true)), + OffsetBuffer::from_lengths(count_vec), + values, + nulls.finish(), + )?)) +} + +/// Handle List version of `general_repeat` +/// +/// For each element of `list_array[i]` repeat `count_array[i]` times. +/// +/// For example, +/// ```text +/// array_repeat( +/// [[1, 2, 3], [4, 5], [6]], [2, 0, 1] => [[[1, 2, 3], [1, 2, 3]], [], [[6]]] +/// ) +/// ``` +fn general_list_repeat( + list_array: &GenericListArray, + count_array: &UInt64Array, +) -> datafusion::common::Result { + let data_type = list_array.data_type(); + let value_type = list_array.value_type(); + let mut new_values = vec![]; + + let count_vec = count_array + .values() + .to_vec() + .iter() + .map(|x| *x as usize) + .collect::>(); + + for (list_array_row, &count) in list_array.iter().zip(count_vec.iter()) { + let list_arr = match list_array_row { + Some(list_array_row) => { + let original_data = list_array_row.to_data(); + let capacity = Capacities::Array(original_data.len() * count); + let mut mutable = + MutableArrayData::with_capacities(vec![&original_data], false, capacity); + + for _ in 0..count { + mutable.extend(0, 0, original_data.len()); + } + + let data = mutable.freeze(); + let repeated_array = arrow::array::make_array(data); + + let list_arr = GenericListArray::::try_new( + Arc::new(Field::new_list_field(value_type.clone(), true)), + OffsetBuffer::::from_lengths(vec![original_data.len(); count]), + repeated_array, + None, + )?; + Arc::new(list_arr) as ArrayRef + } + None => new_null_array(data_type, count), + }; + new_values.push(list_arr); + } + + let lengths = new_values.iter().map(|a| a.len()).collect::>(); + let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect(); + let values = compute::concat(&new_values)?; + + Ok(Arc::new(ListArray::try_new( + Arc::new(Field::new_list_field(data_type.to_owned(), true)), + OffsetBuffer::::from_lengths(lengths), + values, + None, + )?)) +} diff --git a/native/spark-expr/src/array_funcs/mod.rs b/native/spark-expr/src/array_funcs/mod.rs index 0a215f96cf..cdfb1f4db4 100644 --- a/native/spark-expr/src/array_funcs/mod.rs +++ b/native/spark-expr/src/array_funcs/mod.rs @@ -16,9 +16,11 @@ // under the License. mod array_insert; +mod array_repeat; mod get_array_struct_fields; mod list_extract; pub use array_insert::ArrayInsert; +pub use array_repeat::spark_array_repeat; pub use get_array_struct_fields::GetArrayStructFields; pub use list_extract::ListExtract; diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index f954fdd8ce..cf06d36332 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -17,9 +17,10 @@ use crate::hash_funcs::*; use crate::{ - spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div, spark_decimal_integral_div, - spark_floor, spark_hex, spark_isnan, spark_make_decimal, spark_read_side_padding, spark_round, - spark_rpad, spark_unhex, spark_unscaled_value, SparkChrFunc, + spark_array_repeat, spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div, + spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan, spark_make_decimal, + spark_read_side_padding, spark_round, spark_rpad, spark_unhex, spark_unscaled_value, + SparkChrFunc, }; use arrow::datatypes::DataType; use datafusion::common::{DataFusionError, Result as DataFusionResult}; @@ -140,6 +141,10 @@ pub fn create_comet_physical_fun( let func = Arc::new(spark_date_sub); make_comet_scalar_udf!("date_sub", func, without data_type) } + "array_repeat" => { + let func = Arc::new(spark_array_repeat); + make_comet_scalar_udf!("array_repeat", func, without data_type) + } _ => registry.udf(fun_name).map_err(|e| { DataFusionError::Execution(format!( "Function {fun_name} not found in the registry: {e}", diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index ada7803ed8..3067619f2c 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -419,15 +419,16 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp val path = new Path(dir.toURI.toString, "test.parquet") makeParquetFileAllTypes(path, dictionaryEnabled, 10000) spark.read.parquet(path.toString).createOrReplaceTempView("t1") - spark.sql("select * from t1").printSchema() checkSparkAnswerAndOperator(sql("SELECT array_repeat(_4, null) from t1")) -// checkSparkAnswerAndOperator( -// sql("SELECT array_repeat(_2, 5) from t1 where _2 is not null")) -// checkSparkAnswerAndOperator(sql("SELECT array_repeat(_3, 2) from t1 where _3 is null")) -// checkSparkAnswerAndOperator(sql("SELECT array_repeat(_3, _3) from t1 where _3 is null")) -// checkSparkAnswerAndOperator(sql("SELECT array_repeat(cast(_3 as string), 2) from t1")) -// checkSparkAnswerAndOperator(sql("SELECT array_repeat(array(_2, _3, _4), 2) from t1")) + checkSparkAnswerAndOperator( + sql("SELECT array_repeat(_2, 5) from t1 where _2 is not null")) + checkSparkAnswerAndOperator(sql("SELECT array_repeat(_2, 5) from t1 where _2 is null")) + checkSparkAnswerAndOperator(sql("SELECT array_repeat(_3, _4) from t1 where _3 is null")) + checkSparkAnswerAndOperator( + sql("SELECT array_repeat(_3, _4) from t1 where _3 is not null")) + checkSparkAnswerAndOperator(sql("SELECT array_repeat(cast(_3 as string), 2) from t1")) + checkSparkAnswerAndOperator(sql("SELECT array_repeat(array(_2, _3, _4), 2) from t1")) } } } diff --git a/spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala index d9c71f147d..ca3bbead11 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala @@ -224,4 +224,69 @@ class CometNativeReaderSuite extends CometTestBase with AdaptiveSparkPlanHelper |""".stripMargin, "select c0 from tbl") } + + test("native reader - read a STRUCT subfield from ARRAY of STRUCTS - second field") { + testSingleLineQuery( + """ + | select array(str0, str1) c0 from + | ( + | select + | named_struct('a', 1, 'b', 'n', 'c', 'x') str0, + | named_struct('a', 2, 'b', 'w', 'c', 'y') str1 + | ) + |""".stripMargin, + "select c0[0].b col0 from tbl") + } + + test("native reader - read a STRUCT subfield from ARRAY of STRUCTS - field from first") { + testSingleLineQuery( + """ + | select array(str0, str1) c0 from + | ( + | select + | named_struct('a', 1, 'b', 'n', 'c', 'x') str0, + | named_struct('a', 2, 'b', 'w', 'c', 'y') str1 + | ) + |""".stripMargin, + "select c0[0].a, c0[0].b, c0[0].c from tbl") + } + + test("native reader - read a STRUCT subfield from ARRAY of STRUCTS - reverse fields") { + testSingleLineQuery( + """ + | select array(str0, str1) c0 from + | ( + | select + | named_struct('a', 1, 'b', 'n', 'c', 'x') str0, + | named_struct('a', 2, 'b', 'w', 'c', 'y') str1 + | ) + |""".stripMargin, + "select c0[0].c, c0[0].b, c0[0].a from tbl") + } + + test("native reader - read a STRUCT subfield from ARRAY of STRUCTS - skip field") { + testSingleLineQuery( + """ + | select array(str0, str1) c0 from + | ( + | select + | named_struct('a', 1, 'b', 'n', 'c', 'x') str0, + | named_struct('a', 2, 'b', 'w', 'c', 'y') str1 + | ) + |""".stripMargin, + "select c0[0].a, c0[0].c from tbl") + } + + test("native reader - read a STRUCT subfield from ARRAY of STRUCTS - duplicate first") { + testSingleLineQuery( + """ + | select array(str0, str1) c0 from + | ( + | select + | named_struct('a', 1, 'b', 'n', 'c', 'x') str0, + | named_struct('a', 2, 'b', 'w', 'c', 'y') str1 + | ) + |""".stripMargin, + "select c0[0].a, c0[0].a from tbl") + } } From db4b66ae096f9275ad9bf1dd65dd1053a346b934 Mon Sep 17 00:00:00 2001 From: comphead Date: Sun, 27 Apr 2025 10:12:43 -0700 Subject: [PATCH 3/4] feat: support `array_repeat` --- .../comet/exec/CometNativeReaderSuite.scala | 65 ------------------- 1 file changed, 65 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala index ca3bbead11..d9c71f147d 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala @@ -224,69 +224,4 @@ class CometNativeReaderSuite extends CometTestBase with AdaptiveSparkPlanHelper |""".stripMargin, "select c0 from tbl") } - - test("native reader - read a STRUCT subfield from ARRAY of STRUCTS - second field") { - testSingleLineQuery( - """ - | select array(str0, str1) c0 from - | ( - | select - | named_struct('a', 1, 'b', 'n', 'c', 'x') str0, - | named_struct('a', 2, 'b', 'w', 'c', 'y') str1 - | ) - |""".stripMargin, - "select c0[0].b col0 from tbl") - } - - test("native reader - read a STRUCT subfield from ARRAY of STRUCTS - field from first") { - testSingleLineQuery( - """ - | select array(str0, str1) c0 from - | ( - | select - | named_struct('a', 1, 'b', 'n', 'c', 'x') str0, - | named_struct('a', 2, 'b', 'w', 'c', 'y') str1 - | ) - |""".stripMargin, - "select c0[0].a, c0[0].b, c0[0].c from tbl") - } - - test("native reader - read a STRUCT subfield from ARRAY of STRUCTS - reverse fields") { - testSingleLineQuery( - """ - | select array(str0, str1) c0 from - | ( - | select - | named_struct('a', 1, 'b', 'n', 'c', 'x') str0, - | named_struct('a', 2, 'b', 'w', 'c', 'y') str1 - | ) - |""".stripMargin, - "select c0[0].c, c0[0].b, c0[0].a from tbl") - } - - test("native reader - read a STRUCT subfield from ARRAY of STRUCTS - skip field") { - testSingleLineQuery( - """ - | select array(str0, str1) c0 from - | ( - | select - | named_struct('a', 1, 'b', 'n', 'c', 'x') str0, - | named_struct('a', 2, 'b', 'w', 'c', 'y') str1 - | ) - |""".stripMargin, - "select c0[0].a, c0[0].c from tbl") - } - - test("native reader - read a STRUCT subfield from ARRAY of STRUCTS - duplicate first") { - testSingleLineQuery( - """ - | select array(str0, str1) c0 from - | ( - | select - | named_struct('a', 1, 'b', 'n', 'c', 'x') str0, - | named_struct('a', 2, 'b', 'w', 'c', 'y') str1 - | ) - |""".stripMargin, - "select c0[0].a, c0[0].a from tbl") - } } From 57c37d7c6b4e0f14960e229e6842d8157b5ba437 Mon Sep 17 00:00:00 2001 From: comphead Date: Mon, 28 Apr 2025 15:25:03 -0700 Subject: [PATCH 4/4] feat: support `array_repeat` --- .../test/scala/org/apache/comet/CometArrayExpressionSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index 3067619f2c..c58f1a14fd 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -421,6 +421,7 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp spark.read.parquet(path.toString).createOrReplaceTempView("t1") checkSparkAnswerAndOperator(sql("SELECT array_repeat(_4, null) from t1")) + checkSparkAnswerAndOperator(sql("SELECT array_repeat(_4, 0) from t1")) checkSparkAnswerAndOperator( sql("SELECT array_repeat(_2, 5) from t1 where _2 is not null")) checkSparkAnswerAndOperator(sql("SELECT array_repeat(_2, 5) from t1 where _2 is null"))