diff --git a/native/Cargo.lock b/native/Cargo.lock index b34ed54bc..8cb39f5b7 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -942,7 +942,6 @@ dependencies = [ "regex", "thiserror", "twox-hash", - "unicode-segmentation", ] [[package]] diff --git a/native/core/src/execution/datafusion/expressions/comet_scalar_funcs.rs b/native/core/src/execution/datafusion/expressions/comet_scalar_funcs.rs index 70cbdebae..1203f90d7 100644 --- a/native/core/src/execution/datafusion/expressions/comet_scalar_funcs.rs +++ b/native/core/src/execution/datafusion/expressions/comet_scalar_funcs.rs @@ -21,8 +21,8 @@ use datafusion_comet_spark_expr::scalar_funcs::hash_expressions::{ }; use datafusion_comet_spark_expr::scalar_funcs::{ spark_ceil, spark_decimal_div, spark_floor, spark_hex, spark_isnan, spark_make_decimal, - spark_murmur3_hash, spark_round, spark_rpad, spark_unhex, spark_unscaled_value, spark_xxhash64, - SparkChrFunc, + spark_murmur3_hash, spark_read_side_padding, spark_round, spark_unhex, spark_unscaled_value, + spark_xxhash64, SparkChrFunc, }; use datafusion_common::{DataFusionError, Result as DataFusionResult}; use datafusion_expr::registry::FunctionRegistry; @@ -67,9 +67,9 @@ pub fn create_comet_physical_fun( "floor" => { make_comet_scalar_udf!("floor", spark_floor, data_type) } - "rpad" => { - let func = Arc::new(spark_rpad); - make_comet_scalar_udf!("rpad", func, without data_type) + "read_side_padding" => { + let func = Arc::new(spark_read_side_padding); + make_comet_scalar_udf!("read_side_padding", func, without data_type) } "round" => { make_comet_scalar_udf!("round", spark_round, data_type) diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index a16ceda8c..b604e98ba 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -1724,11 +1724,16 @@ impl PhysicalPlanner { let data_type = match expr.return_type.as_ref().map(to_arrow_datatype) { Some(t) => t, - None => self - .session_ctx - .udf(fun_name)? - .inner() - .return_type(&input_expr_types)?, + None => { + let fun_name = match fun_name.as_str() { + "read_side_padding" => "rpad", // use the same return type as rpad + other => other, + }; + self.session_ctx + .udf(fun_name)? + .inner() + .return_type(&input_expr_types)? + } }; let fun_expr = diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index 96eae39ff..1a8c8aeb4 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -41,7 +41,6 @@ chrono-tz = { workspace = true } num = { workspace = true } regex = { workspace = true } thiserror = { workspace = true } -unicode-segmentation = "1.11.0" [dev-dependencies] arrow-data = {workspace = true} diff --git a/native/spark-expr/src/scalar_funcs.rs b/native/spark-expr/src/scalar_funcs.rs index 7cbaf12aa..ffd6fd212 100644 --- a/native/spark-expr/src/scalar_funcs.rs +++ b/native/spark-expr/src/scalar_funcs.rs @@ -15,15 +15,14 @@ // specific language governing permissions and limitations // under the License. -use std::{cmp::min, sync::Arc}; - use arrow::{ array::{ - ArrayRef, AsArray, Decimal128Builder, Float32Array, Float64Array, GenericStringArray, - Int16Array, Int32Array, Int64Array, Int64Builder, Int8Array, OffsetSizeTrait, + ArrayRef, AsArray, Decimal128Builder, Float32Array, Float64Array, Int16Array, Int32Array, + Int64Array, Int64Builder, Int8Array, OffsetSizeTrait, }, datatypes::{validate_decimal_precision, Decimal128Type, Int64Type}, }; +use arrow_array::builder::GenericStringBuilder; use arrow_array::{Array, ArrowNativeTypeOp, BooleanArray, Decimal128Array}; use arrow_schema::{DataType, DECIMAL128_MAX_PRECISION}; use datafusion::{functions::math::round::round, physical_plan::ColumnarValue}; @@ -35,7 +34,8 @@ use num::{ integer::{div_ceil, div_floor}, BigInt, Signed, ToPrimitive, }; -use unicode_segmentation::UnicodeSegmentation; +use std::fmt::Write; +use std::{cmp::min, sync::Arc}; mod unhex; pub use unhex::spark_unhex; @@ -387,52 +387,54 @@ pub fn spark_round( } /// Similar to DataFusion `rpad`, but not to truncate when the string is already longer than length -pub fn spark_rpad(args: &[ColumnarValue]) -> Result { +pub fn spark_read_side_padding(args: &[ColumnarValue]) -> Result { match args { [ColumnarValue::Array(array), ColumnarValue::Scalar(ScalarValue::Int32(Some(length)))] => { - match args[0].data_type() { - DataType::Utf8 => spark_rpad_internal::(array, *length), - DataType::LargeUtf8 => spark_rpad_internal::(array, *length), + match array.data_type() { + DataType::Utf8 => spark_read_side_padding_internal::(array, *length), + DataType::LargeUtf8 => spark_read_side_padding_internal::(array, *length), // TODO: handle Dictionary types other => Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function rpad", + "Unsupported data type {other:?} for function read_side_padding", ))), } } other => Err(DataFusionError::Internal(format!( - "Unsupported arguments {other:?} for function rpad", + "Unsupported arguments {other:?} for function read_side_padding", ))), } } -fn spark_rpad_internal( +fn spark_read_side_padding_internal( array: &ArrayRef, length: i32, ) -> Result { let string_array = as_generic_string_array::(array)?; + let length = 0.max(length) as usize; + let space_string = " ".repeat(length); + + let mut builder = + GenericStringBuilder::::with_capacity(string_array.len(), string_array.len() * length); - let result = string_array - .iter() - .map(|string| match string { + for string in string_array.iter() { + match string { Some(string) => { - let length = if length < 0 { 0 } else { length as usize }; - if length == 0 { - Ok(Some("".to_string())) + // It looks Spark's UTF8String is closer to chars rather than graphemes + // https://stackoverflow.com/a/46290728 + let char_len = string.chars().count(); + if length <= char_len { + builder.append_value(string); } else { - let graphemes = string.graphemes(true).collect::>(); - if length < graphemes.len() { - Ok(Some(string.to_string())) - } else { - let mut s = string.to_string(); - s.push_str(" ".repeat(length - graphemes.len()).as_str()); - Ok(Some(s)) - } + // write_str updates only the value buffer, not null nor offset buffer + // This is convenient for concatenating str(s) + builder.write_str(string)?; + builder.append_value(&space_string[char_len..]); } } - _ => Ok(None), - }) - .collect::, DataFusionError>>()?; - Ok(ColumnarValue::Array(Arc::new(result))) + _ => builder.append_null(), + } + } + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) } // Let Decimal(p3, s3) as return type i.e. Decimal(p1, s1) / Decimal(p2, s2) = Decimal(p3, s3). diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 8f08eeba8..5f3cc7a2e 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2178,7 +2178,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim } // With Spark 3.4, CharVarcharCodegenUtils.readSidePadding gets called to pad spaces for - // char types. Use rpad to achieve the behavior. + // char types. // See https://github.com/apache/spark/pull/38151 case s: StaticInvoke if s.staticObject.isInstanceOf[Class[CharVarcharCodegenUtils]] && @@ -2194,7 +2194,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim if (argsExpr.forall(_.isDefined)) { val builder = ExprOuterClass.ScalarFunc.newBuilder() - builder.setFunc("rpad") + builder.setFunc("read_side_padding") argsExpr.foreach(arg => builder.addArgs(arg.get)) Some(ExprOuterClass.Expr.newBuilder().setScalarFunc(builder).build()) diff --git a/spark/src/test/resources/tpcds-micro-benchmarks/char_type.sql b/spark/src/test/resources/tpcds-micro-benchmarks/char_type.sql new file mode 100644 index 000000000..8a5359d4c --- /dev/null +++ b/spark/src/test/resources/tpcds-micro-benchmarks/char_type.sql @@ -0,0 +1,7 @@ +SELECT + cd_gender +FROM customer_demographics +WHERE + cd_gender = 'M' AND + cd_marital_status = 'S' AND + cd_education_status = 'College' diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index cce487198..ded5bc5c5 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -1911,6 +1911,20 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("readSidePadding") { + // https://stackoverflow.com/a/46290728 + val table = "test" + withTable(table) { + sql(s"create table $table(col1 CHAR(2)) using parquet") + sql(s"insert into $table values('é')") // unicode 'e\\u{301}' + sql(s"insert into $table values('é')") // unicode '\\u{e9}' + sql(s"insert into $table values('')") + sql(s"insert into $table values('ab')") + + checkSparkAnswerAndOperator(s"SELECT * FROM $table") + } + } + test("isnan") { Seq("true", "false").foreach { dictionary => withSQLConf("parquet.enable.dictionary" -> dictionary) { diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometTPCDSMicroBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometTPCDSMicroBenchmark.scala index b09e0486f..aa0c91155 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometTPCDSMicroBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometTPCDSMicroBenchmark.scala @@ -63,6 +63,7 @@ object CometTPCDSMicroBenchmark extends CometTPCQueryBenchmarkBase { "agg_sum_integers_no_grouping", "case_when_column_or_null", "case_when_scalar", + "char_type", "filter_highly_selective", "filter_less_selective", "if_column_or_null",