From 113ec6d8ffb38e3599f6bab03050f135f4c691cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ra=C3=BAl=20Cumplido?= Date: Wed, 2 Oct 2024 16:00:19 +0200 Subject: [PATCH] Add fail_on_error to be passed up to make_comet_scalar_udf --- .../datafusion/expressions/comet_scalar_funcs.rs | 13 ++++++++++++- native/core/src/execution/datafusion/planner.rs | 11 +++++++++-- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/native/core/src/execution/datafusion/expressions/comet_scalar_funcs.rs b/native/core/src/execution/datafusion/expressions/comet_scalar_funcs.rs index 06717aabe..8f58ba7ca 100644 --- a/native/core/src/execution/datafusion/expressions/comet_scalar_funcs.rs +++ b/native/core/src/execution/datafusion/expressions/comet_scalar_funcs.rs @@ -34,6 +34,16 @@ use std::fmt::Debug; use std::sync::Arc; macro_rules! make_comet_scalar_udf { + ($name:expr, $func:ident, $data_type:ident, $fail_on_error:ident) => {{ + let scalar_func = CometScalarFunction::new( + $name.to_string(), + Signature::variadic_any(Volatility::Immutable), + $data_type.clone(), + Arc::new(move |args| $func(args, &$data_type)), + ); + // TODO Check for overflow + Ok(Arc::new(ScalarUDF::new_from_impl(scalar_func))) + }}; ($name:expr, $func:ident, $data_type:ident) => {{ let scalar_func = CometScalarFunction::new( $name.to_string(), @@ -59,6 +69,7 @@ pub fn create_comet_physical_fun( fun_name: &str, data_type: DataType, registry: &dyn FunctionRegistry, + _fail_on_error: &bool, ) -> Result, DataFusionError> { match fun_name { "ceil" => { @@ -72,7 +83,7 @@ pub fn create_comet_physical_fun( make_comet_scalar_udf!("read_side_padding", func, without data_type) } "round" => { - make_comet_scalar_udf!("round", spark_round, data_type) + make_comet_scalar_udf!("round", spark_round, data_type, _fail_on_error) } "unscaled_value" => { let func = Arc::new(spark_unscaled_value); diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index 15de7c9ad..822301674 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -775,10 +775,12 @@ impl PhysicalPlanner { Ok(DataType::Decimal128(_p2, _s2)), ) => { let data_type = return_type.map(to_arrow_datatype).unwrap(); + let fail_on_error = false; let fun_expr = create_comet_physical_fun( "decimal_div", data_type.clone(), &self.session_ctx.state(), + &fail_on_error, )?; Ok(Arc::new(ScalarFunctionExpr::new( "decimal_div", @@ -1872,6 +1874,7 @@ impl PhysicalPlanner { .collect::, _>>()?; let fun_name = &expr.func; + let fail_on_error = &expr.fail_on_error; let input_expr_types = args .iter() .map(|x| x.data_type(input_schema.as_ref())) @@ -1897,8 +1900,12 @@ impl PhysicalPlanner { } }; - let fun_expr = - create_comet_physical_fun(fun_name, data_type.clone(), &self.session_ctx.state())?; + let fun_expr = create_comet_physical_fun( + fun_name, + data_type.clone(), + &self.session_ctx.state(), + fail_on_error, + )?; let args = args .into_iter()