@@ -23,7 +23,7 @@ use datafusion::logical_expr::{
23
23
BuiltinScalarFunction , Case , Expr , LogicalPlan , Operator ,
24
24
} ;
25
25
use datafusion:: logical_expr:: { expr, Cast , WindowFrameBound , WindowFrameUnits } ;
26
- use datafusion:: logical_expr:: { Extension , LogicalPlanBuilder } ;
26
+ use datafusion:: logical_expr:: { Extension , Like , LogicalPlanBuilder } ;
27
27
use datafusion:: prelude:: JoinType ;
28
28
use datafusion:: sql:: TableReference ;
29
29
use datafusion:: {
@@ -32,7 +32,7 @@ use datafusion::{
32
32
prelude:: { Column , SessionContext } ,
33
33
scalar:: ScalarValue ,
34
34
} ;
35
- use substrait:: proto:: expression:: Literal ;
35
+ use substrait:: proto:: expression:: { Literal , ScalarFunction } ;
36
36
use substrait:: proto:: {
37
37
aggregate_function:: AggregationInvocation ,
38
38
expression:: {
@@ -67,8 +67,12 @@ use crate::variation_const::{
67
67
enum ScalarFunctionType {
68
68
Builtin ( BuiltinScalarFunction ) ,
69
69
Op ( Operator ) ,
70
- // logical negation
70
+ /// [Expr::Not]
71
71
Not ,
72
+ /// [Expr::Like] Used for filtering rows based on the given wildcard pattern. Case sensitive
73
+ Like ,
74
+ /// [Expr::ILike] Case insensitive operator counterpart of `Like`
75
+ ILike ,
72
76
}
73
77
74
78
pub fn name_to_op ( name : & str ) -> Result < Operator > {
@@ -104,7 +108,7 @@ pub fn name_to_op(name: &str) -> Result<Operator> {
104
108
}
105
109
}
106
110
107
- fn name_to_op_or_scalar_function ( name : & str ) -> Result < ScalarFunctionType > {
111
+ fn scalar_function_type_from_str ( name : & str ) -> Result < ScalarFunctionType > {
108
112
if let Ok ( op) = name_to_op ( name) {
109
113
return Ok ( ScalarFunctionType :: Op ( op) ) ;
110
114
}
@@ -113,23 +117,14 @@ fn name_to_op_or_scalar_function(name: &str) -> Result<ScalarFunctionType> {
113
117
return Ok ( ScalarFunctionType :: Builtin ( fun) ) ;
114
118
}
115
119
116
- Err ( DataFusionError :: NotImplemented ( format ! (
117
- "Unsupported function name: {name:?}"
118
- ) ) )
119
- }
120
-
121
- fn scalar_function_or_not ( name : & str ) -> Result < ScalarFunctionType > {
122
- if let Ok ( fun) = BuiltinScalarFunction :: from_str ( name) {
123
- return Ok ( ScalarFunctionType :: Builtin ( fun) ) ;
124
- }
125
-
126
- if name == "not" {
127
- return Ok ( ScalarFunctionType :: Not ) ;
120
+ match name {
121
+ "not" => Ok ( ScalarFunctionType :: Not ) ,
122
+ "like" => Ok ( ScalarFunctionType :: Like ) ,
123
+ "ilike" => Ok ( ScalarFunctionType :: ILike ) ,
124
+ others => Err ( DataFusionError :: NotImplemented ( format ! (
125
+ "Unsupported function name: {others:?}"
126
+ ) ) ) ,
128
127
}
129
-
130
- Err ( DataFusionError :: NotImplemented ( format ! (
131
- "Unsupported function name: {name:?}"
132
- ) ) )
133
128
}
134
129
135
130
/// Convert Substrait Plan to DataFusion DataFrame
@@ -790,20 +785,46 @@ pub async fn from_substrait_rex(
790
785
else_expr,
791
786
} ) ) )
792
787
}
793
- Some ( RexType :: ScalarFunction ( f) ) => match f. arguments . len ( ) {
794
- // BinaryExpr or ScalarFunction
795
- 2 => match ( & f. arguments [ 0 ] . arg_type , & f. arguments [ 1 ] . arg_type ) {
796
- ( Some ( ArgType :: Value ( l) ) , Some ( ArgType :: Value ( r) ) ) => {
797
- let op_or_fun = match extensions. get ( & f. function_reference ) {
798
- Some ( fname) => name_to_op_or_scalar_function ( fname) ,
799
- None => Err ( DataFusionError :: NotImplemented ( format ! (
800
- "Aggregated function not found: function reference = {:?}" ,
801
- f. function_reference
802
- ) ) ) ,
803
- } ;
804
- match op_or_fun {
805
- Ok ( ScalarFunctionType :: Op ( op) ) => {
806
- return Ok ( Arc :: new ( Expr :: BinaryExpr ( BinaryExpr {
788
+ Some ( RexType :: ScalarFunction ( f) ) => {
789
+ let fn_name = extensions. get ( & f. function_reference ) . ok_or_else ( || {
790
+ DataFusionError :: NotImplemented ( format ! (
791
+ "Aggregated function not found: function reference = {:?}" ,
792
+ f. function_reference
793
+ ) )
794
+ } ) ?;
795
+ let fn_type = scalar_function_type_from_str ( fn_name) ?;
796
+ match fn_type {
797
+ ScalarFunctionType :: Builtin ( fun) => {
798
+ let mut args = Vec :: with_capacity ( f. arguments . len ( ) ) ;
799
+ for arg in & f. arguments {
800
+ let arg_expr = match & arg. arg_type {
801
+ Some ( ArgType :: Value ( e) ) => {
802
+ from_substrait_rex ( e, input_schema, extensions) . await
803
+ }
804
+ _ => Err ( DataFusionError :: NotImplemented (
805
+ "Aggregated function argument non-Value type not supported"
806
+ . to_string ( ) ,
807
+ ) ) ,
808
+ } ;
809
+ args. push ( arg_expr?. as_ref ( ) . clone ( ) ) ;
810
+ }
811
+ Ok ( Arc :: new ( Expr :: ScalarFunction ( expr:: ScalarFunction {
812
+ fun,
813
+ args,
814
+ } ) ) )
815
+ }
816
+ ScalarFunctionType :: Op ( op) => {
817
+ if f. arguments . len ( ) != 2 {
818
+ return Err ( DataFusionError :: NotImplemented ( format ! (
819
+ "Expect two arguments for binary operator {op:?}" ,
820
+ ) ) ) ;
821
+ }
822
+ let lhs = & f. arguments [ 0 ] . arg_type ;
823
+ let rhs = & f. arguments [ 1 ] . arg_type ;
824
+
825
+ match ( lhs, rhs) {
826
+ ( Some ( ArgType :: Value ( l) ) , Some ( ArgType :: Value ( r) ) ) => {
827
+ Ok ( Arc :: new ( Expr :: BinaryExpr ( BinaryExpr {
807
828
left : Box :: new (
808
829
from_substrait_rex ( l, input_schema, extensions)
809
830
. await ?
@@ -819,116 +840,38 @@ pub async fn from_substrait_rex(
819
840
) ,
820
841
} ) ) )
821
842
}
822
- Ok ( ScalarFunctionType :: Builtin ( fun) ) => {
823
- Ok ( Arc :: new ( Expr :: ScalarFunction ( expr:: ScalarFunction {
824
- fun,
825
- args : vec ! [
826
- from_substrait_rex( l, input_schema, extensions)
827
- . await ?
828
- . as_ref( )
829
- . clone( ) ,
830
- from_substrait_rex( r, input_schema, extensions)
831
- . await ?
832
- . as_ref( )
833
- . clone( ) ,
834
- ] ,
835
- } ) ) )
836
- }
837
- Ok ( ScalarFunctionType :: Not ) => {
838
- Err ( DataFusionError :: NotImplemented (
839
- "Not expected function type: Not" . to_string ( ) ,
840
- ) )
841
- }
842
- Err ( e) => Err ( e) ,
843
- }
844
- }
845
- ( l, r) => Err ( DataFusionError :: NotImplemented ( format ! (
846
- "Invalid arguments for binary expression: {l:?} and {r:?}"
847
- ) ) ) ,
848
- } ,
849
- // ScalarFunction or Expr::Not
850
- 1 => {
851
- let fun = match extensions. get ( & f. function_reference ) {
852
- Some ( fname) => scalar_function_or_not ( fname) ,
853
- None => Err ( DataFusionError :: NotImplemented ( format ! (
854
- "Function not found: function reference = {:?}" ,
855
- f. function_reference
856
- ) ) ) ,
857
- } ;
858
-
859
- match fun {
860
- Ok ( ScalarFunctionType :: Op ( _) ) => {
861
- Err ( DataFusionError :: NotImplemented (
862
- "Not expected function type: Op" . to_string ( ) ,
863
- ) )
864
- }
865
- Ok ( scalar_function_type) => {
866
- match & f. arguments . first ( ) . unwrap ( ) . arg_type {
867
- Some ( ArgType :: Value ( e) ) => {
868
- let expr =
869
- from_substrait_rex ( e, input_schema, extensions)
870
- . await ?
871
- . as_ref ( )
872
- . clone ( ) ;
873
- match scalar_function_type {
874
- ScalarFunctionType :: Builtin ( fun) => Ok ( Arc :: new (
875
- Expr :: ScalarFunction ( expr:: ScalarFunction {
876
- fun,
877
- args : vec ! [ expr] ,
878
- } ) ,
879
- ) ) ,
880
- ScalarFunctionType :: Not => {
881
- Ok ( Arc :: new ( Expr :: Not ( Box :: new ( expr) ) ) )
882
- }
883
- _ => Err ( DataFusionError :: NotImplemented (
884
- "Invalid arguments for Not expression"
885
- . to_string ( ) ,
886
- ) ) ,
887
- }
888
- }
889
- _ => Err ( DataFusionError :: NotImplemented (
890
- "Invalid arguments for Not expression" . to_string ( ) ,
891
- ) ) ,
892
- }
843
+ ( l, r) => Err ( DataFusionError :: NotImplemented ( format ! (
844
+ "Invalid arguments for binary expression: {l:?} and {r:?}"
845
+ ) ) ) ,
893
846
}
894
- Err ( e) => Err ( e) ,
895
847
}
896
- }
897
- // ScalarFunction
898
- _ => {
899
- let fun = match extensions. get ( & f. function_reference ) {
900
- Some ( fname) => BuiltinScalarFunction :: from_str ( fname) ,
901
- None => Err ( DataFusionError :: NotImplemented ( format ! (
902
- "Aggregated function not found: function reference = {:?}" ,
903
- f. function_reference
904
- ) ) ) ,
905
- } ;
906
-
907
- let mut args: Vec < Expr > = vec ! [ ] ;
908
- for arg in f. arguments . iter ( ) {
848
+ ScalarFunctionType :: Not => {
849
+ let arg = f. arguments . first ( ) . ok_or_else ( || {
850
+ DataFusionError :: Substrait (
851
+ "expect one argument for `NOT` expr" . to_string ( ) ,
852
+ )
853
+ } ) ?;
909
854
match & arg. arg_type {
910
855
Some ( ArgType :: Value ( e) ) => {
911
- args. push (
912
- from_substrait_rex ( e, input_schema, extensions)
913
- . await ?
914
- . as_ref ( )
915
- . clone ( ) ,
916
- ) ;
917
- }
918
- e => {
919
- return Err ( DataFusionError :: NotImplemented ( format ! (
920
- "Invalid arguments for scalar function: {e:?}"
921
- ) ) )
856
+ let expr = from_substrait_rex ( e, input_schema, extensions)
857
+ . await ?
858
+ . as_ref ( )
859
+ . clone ( ) ;
860
+ Ok ( Arc :: new ( Expr :: Not ( Box :: new ( expr) ) ) )
922
861
}
862
+ _ => Err ( DataFusionError :: NotImplemented (
863
+ "Invalid arguments for Not expression" . to_string ( ) ,
864
+ ) ) ,
923
865
}
924
866
}
925
-
926
- Ok ( Arc :: new ( Expr :: ScalarFunction ( expr:: ScalarFunction {
927
- fun : fun?,
928
- args,
929
- } ) ) )
867
+ ScalarFunctionType :: Like => {
868
+ make_datafusion_like ( false , f, input_schema, extensions) . await
869
+ }
870
+ ScalarFunctionType :: ILike => {
871
+ make_datafusion_like ( true , f, input_schema, extensions) . await
872
+ }
930
873
}
931
- } ,
874
+ }
932
875
Some ( RexType :: Literal ( lit) ) => {
933
876
let scalar_value = from_substrait_literal ( lit) ?;
934
877
Ok ( Arc :: new ( Expr :: Literal ( scalar_value) ) )
@@ -1342,3 +1285,67 @@ fn from_substrait_null(null_type: &Type) -> Result<ScalarValue> {
1342
1285
) )
1343
1286
}
1344
1287
}
1288
+
1289
+ async fn make_datafusion_like (
1290
+ case_insensitive : bool ,
1291
+ f : & ScalarFunction ,
1292
+ input_schema : & DFSchema ,
1293
+ extensions : & HashMap < u32 , & String > ,
1294
+ ) -> Result < Arc < Expr > > {
1295
+ let fn_name = if case_insensitive { "ILIKE" } else { "LIKE" } ;
1296
+ if f. arguments . len ( ) != 3 {
1297
+ return Err ( DataFusionError :: NotImplemented ( format ! (
1298
+ "Expect three arguments for `{fn_name}` expr"
1299
+ ) ) ) ;
1300
+ }
1301
+
1302
+ let Some ( ArgType :: Value ( expr_substrait) ) = & f. arguments [ 0 ] . arg_type else {
1303
+ return Err ( DataFusionError :: NotImplemented (
1304
+ format ! ( "Invalid arguments type for `{fn_name}` expr" )
1305
+ ) )
1306
+ } ;
1307
+ let expr = from_substrait_rex ( expr_substrait, input_schema, extensions)
1308
+ . await ?
1309
+ . as_ref ( )
1310
+ . clone ( ) ;
1311
+ let Some ( ArgType :: Value ( pattern_substrait) ) = & f. arguments [ 1 ] . arg_type else {
1312
+ return Err ( DataFusionError :: NotImplemented (
1313
+ format ! ( "Invalid arguments type for `{fn_name}` expr" )
1314
+ ) )
1315
+ } ;
1316
+ let pattern = from_substrait_rex ( pattern_substrait, input_schema, extensions)
1317
+ . await ?
1318
+ . as_ref ( )
1319
+ . clone ( ) ;
1320
+ let Some ( ArgType :: Value ( escape_char_substrait) ) = & f. arguments [ 2 ] . arg_type else {
1321
+ return Err ( DataFusionError :: NotImplemented (
1322
+ format ! ( "Invalid arguments type for `{fn_name}` expr" )
1323
+ ) )
1324
+ } ;
1325
+ let escape_char_expr =
1326
+ from_substrait_rex ( escape_char_substrait, input_schema, extensions)
1327
+ . await ?
1328
+ . as_ref ( )
1329
+ . clone ( ) ;
1330
+ let Expr :: Literal ( ScalarValue :: Utf8 ( escape_char) ) = escape_char_expr else {
1331
+ return Err ( DataFusionError :: Substrait ( format ! (
1332
+ "Expect Utf8 literal for escape char, but found {escape_char_expr:?}" ,
1333
+ ) ) )
1334
+ } ;
1335
+
1336
+ if case_insensitive {
1337
+ Ok ( Arc :: new ( Expr :: ILike ( Like {
1338
+ negated : false ,
1339
+ expr : Box :: new ( expr) ,
1340
+ pattern : Box :: new ( pattern) ,
1341
+ escape_char : escape_char. map ( |c| c. chars ( ) . next ( ) . unwrap ( ) ) ,
1342
+ } ) ) )
1343
+ } else {
1344
+ Ok ( Arc :: new ( Expr :: Like ( Like {
1345
+ negated : false ,
1346
+ expr : Box :: new ( expr) ,
1347
+ pattern : Box :: new ( pattern) ,
1348
+ escape_char : escape_char. map ( |c| c. chars ( ) . next ( ) . unwrap ( ) ) ,
1349
+ } ) ) )
1350
+ }
1351
+ }
0 commit comments