Skip to content

Commit 16303ad

Browse files
waynexianseekhao
andauthored
feat: implement substrait for LIKE/ILIKE expr (#6840)
* feat: implement substrait for LIKE/ILIKE expr Signed-off-by: Ruihang Xia <[email protected]> * fix clippy Signed-off-by: Ruihang Xia <[email protected]> * Apply suggestions from code review Co-authored-by: Nuttiiya Seekhao <[email protected]> * Update datafusion/substrait/src/logical_plan/consumer.rs Co-authored-by: Nuttiiya Seekhao <[email protected]> * style: rename function Signed-off-by: Ruihang Xia <[email protected]> * apply CR sugg. Signed-off-by: Ruihang Xia <[email protected]> --------- Signed-off-by: Ruihang Xia <[email protected]> Co-authored-by: Nuttiiya Seekhao <[email protected]>
1 parent 33fc013 commit 16303ad

File tree

3 files changed

+249
-136
lines changed

3 files changed

+249
-136
lines changed

datafusion/substrait/src/logical_plan/consumer.rs

+143-136
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use datafusion::logical_expr::{
2323
BuiltinScalarFunction, Case, Expr, LogicalPlan, Operator,
2424
};
2525
use datafusion::logical_expr::{expr, Cast, WindowFrameBound, WindowFrameUnits};
26-
use datafusion::logical_expr::{Extension, LogicalPlanBuilder};
26+
use datafusion::logical_expr::{Extension, Like, LogicalPlanBuilder};
2727
use datafusion::prelude::JoinType;
2828
use datafusion::sql::TableReference;
2929
use datafusion::{
@@ -32,7 +32,7 @@ use datafusion::{
3232
prelude::{Column, SessionContext},
3333
scalar::ScalarValue,
3434
};
35-
use substrait::proto::expression::Literal;
35+
use substrait::proto::expression::{Literal, ScalarFunction};
3636
use substrait::proto::{
3737
aggregate_function::AggregationInvocation,
3838
expression::{
@@ -67,8 +67,12 @@ use crate::variation_const::{
6767
enum ScalarFunctionType {
6868
Builtin(BuiltinScalarFunction),
6969
Op(Operator),
70-
// logical negation
70+
/// [Expr::Not]
7171
Not,
72+
/// [Expr::Like] Used for filtering rows based on the given wildcard pattern. Case sensitive
73+
Like,
74+
/// [Expr::ILike] Case insensitive operator counterpart of `Like`
75+
ILike,
7276
}
7377

7478
pub fn name_to_op(name: &str) -> Result<Operator> {
@@ -104,7 +108,7 @@ pub fn name_to_op(name: &str) -> Result<Operator> {
104108
}
105109
}
106110

107-
fn name_to_op_or_scalar_function(name: &str) -> Result<ScalarFunctionType> {
111+
fn scalar_function_type_from_str(name: &str) -> Result<ScalarFunctionType> {
108112
if let Ok(op) = name_to_op(name) {
109113
return Ok(ScalarFunctionType::Op(op));
110114
}
@@ -113,23 +117,14 @@ fn name_to_op_or_scalar_function(name: &str) -> Result<ScalarFunctionType> {
113117
return Ok(ScalarFunctionType::Builtin(fun));
114118
}
115119

116-
Err(DataFusionError::NotImplemented(format!(
117-
"Unsupported function name: {name:?}"
118-
)))
119-
}
120-
121-
fn scalar_function_or_not(name: &str) -> Result<ScalarFunctionType> {
122-
if let Ok(fun) = BuiltinScalarFunction::from_str(name) {
123-
return Ok(ScalarFunctionType::Builtin(fun));
124-
}
125-
126-
if name == "not" {
127-
return Ok(ScalarFunctionType::Not);
120+
match name {
121+
"not" => Ok(ScalarFunctionType::Not),
122+
"like" => Ok(ScalarFunctionType::Like),
123+
"ilike" => Ok(ScalarFunctionType::ILike),
124+
others => Err(DataFusionError::NotImplemented(format!(
125+
"Unsupported function name: {others:?}"
126+
))),
128127
}
129-
130-
Err(DataFusionError::NotImplemented(format!(
131-
"Unsupported function name: {name:?}"
132-
)))
133128
}
134129

135130
/// Convert Substrait Plan to DataFusion DataFrame
@@ -790,20 +785,46 @@ pub async fn from_substrait_rex(
790785
else_expr,
791786
})))
792787
}
793-
Some(RexType::ScalarFunction(f)) => match f.arguments.len() {
794-
// BinaryExpr or ScalarFunction
795-
2 => match (&f.arguments[0].arg_type, &f.arguments[1].arg_type) {
796-
(Some(ArgType::Value(l)), Some(ArgType::Value(r))) => {
797-
let op_or_fun = match extensions.get(&f.function_reference) {
798-
Some(fname) => name_to_op_or_scalar_function(fname),
799-
None => Err(DataFusionError::NotImplemented(format!(
800-
"Aggregated function not found: function reference = {:?}",
801-
f.function_reference
802-
))),
803-
};
804-
match op_or_fun {
805-
Ok(ScalarFunctionType::Op(op)) => {
806-
return Ok(Arc::new(Expr::BinaryExpr(BinaryExpr {
788+
Some(RexType::ScalarFunction(f)) => {
789+
let fn_name = extensions.get(&f.function_reference).ok_or_else(|| {
790+
DataFusionError::NotImplemented(format!(
791+
"Aggregated function not found: function reference = {:?}",
792+
f.function_reference
793+
))
794+
})?;
795+
let fn_type = scalar_function_type_from_str(fn_name)?;
796+
match fn_type {
797+
ScalarFunctionType::Builtin(fun) => {
798+
let mut args = Vec::with_capacity(f.arguments.len());
799+
for arg in &f.arguments {
800+
let arg_expr = match &arg.arg_type {
801+
Some(ArgType::Value(e)) => {
802+
from_substrait_rex(e, input_schema, extensions).await
803+
}
804+
_ => Err(DataFusionError::NotImplemented(
805+
"Aggregated function argument non-Value type not supported"
806+
.to_string(),
807+
)),
808+
};
809+
args.push(arg_expr?.as_ref().clone());
810+
}
811+
Ok(Arc::new(Expr::ScalarFunction(expr::ScalarFunction {
812+
fun,
813+
args,
814+
})))
815+
}
816+
ScalarFunctionType::Op(op) => {
817+
if f.arguments.len() != 2 {
818+
return Err(DataFusionError::NotImplemented(format!(
819+
"Expect two arguments for binary operator {op:?}",
820+
)));
821+
}
822+
let lhs = &f.arguments[0].arg_type;
823+
let rhs = &f.arguments[1].arg_type;
824+
825+
match (lhs, rhs) {
826+
(Some(ArgType::Value(l)), Some(ArgType::Value(r))) => {
827+
Ok(Arc::new(Expr::BinaryExpr(BinaryExpr {
807828
left: Box::new(
808829
from_substrait_rex(l, input_schema, extensions)
809830
.await?
@@ -819,116 +840,38 @@ pub async fn from_substrait_rex(
819840
),
820841
})))
821842
}
822-
Ok(ScalarFunctionType::Builtin(fun)) => {
823-
Ok(Arc::new(Expr::ScalarFunction(expr::ScalarFunction {
824-
fun,
825-
args: vec![
826-
from_substrait_rex(l, input_schema, extensions)
827-
.await?
828-
.as_ref()
829-
.clone(),
830-
from_substrait_rex(r, input_schema, extensions)
831-
.await?
832-
.as_ref()
833-
.clone(),
834-
],
835-
})))
836-
}
837-
Ok(ScalarFunctionType::Not) => {
838-
Err(DataFusionError::NotImplemented(
839-
"Not expected function type: Not".to_string(),
840-
))
841-
}
842-
Err(e) => Err(e),
843-
}
844-
}
845-
(l, r) => Err(DataFusionError::NotImplemented(format!(
846-
"Invalid arguments for binary expression: {l:?} and {r:?}"
847-
))),
848-
},
849-
// ScalarFunction or Expr::Not
850-
1 => {
851-
let fun = match extensions.get(&f.function_reference) {
852-
Some(fname) => scalar_function_or_not(fname),
853-
None => Err(DataFusionError::NotImplemented(format!(
854-
"Function not found: function reference = {:?}",
855-
f.function_reference
856-
))),
857-
};
858-
859-
match fun {
860-
Ok(ScalarFunctionType::Op(_)) => {
861-
Err(DataFusionError::NotImplemented(
862-
"Not expected function type: Op".to_string(),
863-
))
864-
}
865-
Ok(scalar_function_type) => {
866-
match &f.arguments.first().unwrap().arg_type {
867-
Some(ArgType::Value(e)) => {
868-
let expr =
869-
from_substrait_rex(e, input_schema, extensions)
870-
.await?
871-
.as_ref()
872-
.clone();
873-
match scalar_function_type {
874-
ScalarFunctionType::Builtin(fun) => Ok(Arc::new(
875-
Expr::ScalarFunction(expr::ScalarFunction {
876-
fun,
877-
args: vec![expr],
878-
}),
879-
)),
880-
ScalarFunctionType::Not => {
881-
Ok(Arc::new(Expr::Not(Box::new(expr))))
882-
}
883-
_ => Err(DataFusionError::NotImplemented(
884-
"Invalid arguments for Not expression"
885-
.to_string(),
886-
)),
887-
}
888-
}
889-
_ => Err(DataFusionError::NotImplemented(
890-
"Invalid arguments for Not expression".to_string(),
891-
)),
892-
}
843+
(l, r) => Err(DataFusionError::NotImplemented(format!(
844+
"Invalid arguments for binary expression: {l:?} and {r:?}"
845+
))),
893846
}
894-
Err(e) => Err(e),
895847
}
896-
}
897-
// ScalarFunction
898-
_ => {
899-
let fun = match extensions.get(&f.function_reference) {
900-
Some(fname) => BuiltinScalarFunction::from_str(fname),
901-
None => Err(DataFusionError::NotImplemented(format!(
902-
"Aggregated function not found: function reference = {:?}",
903-
f.function_reference
904-
))),
905-
};
906-
907-
let mut args: Vec<Expr> = vec![];
908-
for arg in f.arguments.iter() {
848+
ScalarFunctionType::Not => {
849+
let arg = f.arguments.first().ok_or_else(|| {
850+
DataFusionError::Substrait(
851+
"expect one argument for `NOT` expr".to_string(),
852+
)
853+
})?;
909854
match &arg.arg_type {
910855
Some(ArgType::Value(e)) => {
911-
args.push(
912-
from_substrait_rex(e, input_schema, extensions)
913-
.await?
914-
.as_ref()
915-
.clone(),
916-
);
917-
}
918-
e => {
919-
return Err(DataFusionError::NotImplemented(format!(
920-
"Invalid arguments for scalar function: {e:?}"
921-
)))
856+
let expr = from_substrait_rex(e, input_schema, extensions)
857+
.await?
858+
.as_ref()
859+
.clone();
860+
Ok(Arc::new(Expr::Not(Box::new(expr))))
922861
}
862+
_ => Err(DataFusionError::NotImplemented(
863+
"Invalid arguments for Not expression".to_string(),
864+
)),
923865
}
924866
}
925-
926-
Ok(Arc::new(Expr::ScalarFunction(expr::ScalarFunction {
927-
fun: fun?,
928-
args,
929-
})))
867+
ScalarFunctionType::Like => {
868+
make_datafusion_like(false, f, input_schema, extensions).await
869+
}
870+
ScalarFunctionType::ILike => {
871+
make_datafusion_like(true, f, input_schema, extensions).await
872+
}
930873
}
931-
},
874+
}
932875
Some(RexType::Literal(lit)) => {
933876
let scalar_value = from_substrait_literal(lit)?;
934877
Ok(Arc::new(Expr::Literal(scalar_value)))
@@ -1342,3 +1285,67 @@ fn from_substrait_null(null_type: &Type) -> Result<ScalarValue> {
13421285
))
13431286
}
13441287
}
1288+
1289+
async fn make_datafusion_like(
1290+
case_insensitive: bool,
1291+
f: &ScalarFunction,
1292+
input_schema: &DFSchema,
1293+
extensions: &HashMap<u32, &String>,
1294+
) -> Result<Arc<Expr>> {
1295+
let fn_name = if case_insensitive { "ILIKE" } else { "LIKE" };
1296+
if f.arguments.len() != 3 {
1297+
return Err(DataFusionError::NotImplemented(format!(
1298+
"Expect three arguments for `{fn_name}` expr"
1299+
)));
1300+
}
1301+
1302+
let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else {
1303+
return Err(DataFusionError::NotImplemented(
1304+
format!("Invalid arguments type for `{fn_name}` expr")
1305+
))
1306+
};
1307+
let expr = from_substrait_rex(expr_substrait, input_schema, extensions)
1308+
.await?
1309+
.as_ref()
1310+
.clone();
1311+
let Some(ArgType::Value(pattern_substrait)) = &f.arguments[1].arg_type else {
1312+
return Err(DataFusionError::NotImplemented(
1313+
format!("Invalid arguments type for `{fn_name}` expr")
1314+
))
1315+
};
1316+
let pattern = from_substrait_rex(pattern_substrait, input_schema, extensions)
1317+
.await?
1318+
.as_ref()
1319+
.clone();
1320+
let Some(ArgType::Value(escape_char_substrait)) = &f.arguments[2].arg_type else {
1321+
return Err(DataFusionError::NotImplemented(
1322+
format!("Invalid arguments type for `{fn_name}` expr")
1323+
))
1324+
};
1325+
let escape_char_expr =
1326+
from_substrait_rex(escape_char_substrait, input_schema, extensions)
1327+
.await?
1328+
.as_ref()
1329+
.clone();
1330+
let Expr::Literal(ScalarValue::Utf8(escape_char)) = escape_char_expr else {
1331+
return Err(DataFusionError::Substrait(format!(
1332+
"Expect Utf8 literal for escape char, but found {escape_char_expr:?}",
1333+
)))
1334+
};
1335+
1336+
if case_insensitive {
1337+
Ok(Arc::new(Expr::ILike(Like {
1338+
negated: false,
1339+
expr: Box::new(expr),
1340+
pattern: Box::new(pattern),
1341+
escape_char: escape_char.map(|c| c.chars().next().unwrap()),
1342+
})))
1343+
} else {
1344+
Ok(Arc::new(Expr::Like(Like {
1345+
negated: false,
1346+
expr: Box::new(expr),
1347+
pattern: Box::new(pattern),
1348+
escape_char: escape_char.map(|c| c.chars().next().unwrap()),
1349+
})))
1350+
}
1351+
}

0 commit comments

Comments
 (0)