Skip to content

Commit a00922b

Browse files
authored
Implement native support StringView for substr_index (#11974)
Signed-off-by: Chojan Shang <[email protected]>
1 parent 94034be commit a00922b

File tree

3 files changed

+144
-20
lines changed

3 files changed

+144
-20
lines changed

datafusion/functions/src/unicode/substrindex.rs

+63-20
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@
1818
use std::any::Any;
1919
use std::sync::Arc;
2020

21-
use arrow::array::{ArrayRef, OffsetSizeTrait, StringBuilder};
22-
use arrow::datatypes::DataType;
21+
use arrow::array::{
22+
ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait,
23+
PrimitiveArray, StringBuilder,
24+
};
25+
use arrow::datatypes::{DataType, Int32Type, Int64Type};
2326

24-
use datafusion_common::cast::{as_generic_string_array, as_int64_array};
2527
use datafusion_common::{exec_err, Result};
2628
use datafusion_expr::TypeSignature::Exact;
2729
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
@@ -46,6 +48,7 @@ impl SubstrIndexFunc {
4648
Self {
4749
signature: Signature::one_of(
4850
vec![
51+
Exact(vec![Utf8View, Utf8View, Int64]),
4952
Exact(vec![Utf8, Utf8, Int64]),
5053
Exact(vec![LargeUtf8, LargeUtf8, Int64]),
5154
],
@@ -74,15 +77,7 @@ impl ScalarUDFImpl for SubstrIndexFunc {
7477
}
7578

7679
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
77-
match args[0].data_type() {
78-
DataType::Utf8 => make_scalar_function(substr_index::<i32>, vec![])(args),
79-
DataType::LargeUtf8 => {
80-
make_scalar_function(substr_index::<i64>, vec![])(args)
81-
}
82-
other => {
83-
exec_err!("Unsupported data type {other:?} for function substr_index")
84-
}
85-
}
80+
make_scalar_function(substr_index, vec![])(args)
8681
}
8782

8883
fn aliases(&self) -> &[String] {
@@ -95,23 +90,71 @@ impl ScalarUDFImpl for SubstrIndexFunc {
9590
/// SUBSTRING_INDEX('www.apache.org', '.', 2) = www.apache
9691
/// SUBSTRING_INDEX('www.apache.org', '.', -2) = apache.org
9792
/// SUBSTRING_INDEX('www.apache.org', '.', -1) = org
98-
pub fn substr_index<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
93+
fn substr_index(args: &[ArrayRef]) -> Result<ArrayRef> {
9994
if args.len() != 3 {
10095
return exec_err!(
10196
"substr_index was called with {} arguments. It requires 3.",
10297
args.len()
10398
);
10499
}
105100

106-
let string_array = as_generic_string_array::<T>(&args[0])?;
107-
let delimiter_array = as_generic_string_array::<T>(&args[1])?;
108-
let count_array = as_int64_array(&args[2])?;
101+
match args[0].data_type() {
102+
DataType::Utf8 => {
103+
let string_array = args[0].as_string::<i32>();
104+
let delimiter_array = args[1].as_string::<i32>();
105+
let count_array: &PrimitiveArray<Int64Type> = args[2].as_primitive();
106+
substr_index_general::<Int32Type, _, _>(
107+
string_array,
108+
delimiter_array,
109+
count_array,
110+
)
111+
}
112+
DataType::LargeUtf8 => {
113+
let string_array = args[0].as_string::<i64>();
114+
let delimiter_array = args[1].as_string::<i64>();
115+
let count_array: &PrimitiveArray<Int64Type> = args[2].as_primitive();
116+
substr_index_general::<Int64Type, _, _>(
117+
string_array,
118+
delimiter_array,
119+
count_array,
120+
)
121+
}
122+
DataType::Utf8View => {
123+
let string_array = args[0].as_string_view();
124+
let delimiter_array = args[1].as_string_view();
125+
let count_array: &PrimitiveArray<Int64Type> = args[2].as_primitive();
126+
substr_index_general::<Int32Type, _, _>(
127+
string_array,
128+
delimiter_array,
129+
count_array,
130+
)
131+
}
132+
other => {
133+
exec_err!("Unsupported data type {other:?} for function substr_index")
134+
}
135+
}
136+
}
109137

138+
pub fn substr_index_general<
139+
'a,
140+
T: ArrowPrimitiveType,
141+
V: ArrayAccessor<Item = &'a str>,
142+
P: ArrayAccessor<Item = i64>,
143+
>(
144+
string_array: V,
145+
delimiter_array: V,
146+
count_array: P,
147+
) -> Result<ArrayRef>
148+
where
149+
T::Native: OffsetSizeTrait,
150+
{
110151
let mut builder = StringBuilder::new();
111-
string_array
112-
.iter()
113-
.zip(delimiter_array.iter())
114-
.zip(count_array.iter())
152+
let string_iter = ArrayIter::new(string_array);
153+
let delimiter_array_iter = ArrayIter::new(delimiter_array);
154+
let count_array_iter = ArrayIter::new(count_array);
155+
string_iter
156+
.zip(delimiter_array_iter)
157+
.zip(count_array_iter)
115158
.for_each(|((string, delimiter), n)| match (string, delimiter, n) {
116159
(Some(string), Some(delimiter), Some(n)) => {
117160
// In MySQL, these cases will return an empty string.

datafusion/sqllogictest/test_files/functions.slt

+59
Original file line numberDiff line numberDiff line change
@@ -1042,6 +1042,65 @@ arrow.apache.org 100 arrow.apache.org
10421042
. 3 .
10431043
. 100 .
10441044

1045+
query I
1046+
SELECT levenshtein(NULL, NULL)
1047+
----
1048+
NULL
1049+
1050+
# Test substring_index using '.' as delimiter with utf8view
1051+
query TIT
1052+
SELECT str, n, substring_index(arrow_cast(str, 'Utf8View'), '.', n) AS c FROM
1053+
(VALUES
1054+
ROW('arrow.apache.org'),
1055+
ROW('.'),
1056+
ROW('...'),
1057+
ROW(NULL)
1058+
) AS strings(str),
1059+
(VALUES
1060+
ROW(1),
1061+
ROW(2),
1062+
ROW(3),
1063+
ROW(100),
1064+
ROW(-1),
1065+
ROW(-2),
1066+
ROW(-3),
1067+
ROW(-100)
1068+
) AS occurrences(n)
1069+
ORDER BY str DESC, n;
1070+
----
1071+
NULL -100 NULL
1072+
NULL -3 NULL
1073+
NULL -2 NULL
1074+
NULL -1 NULL
1075+
NULL 1 NULL
1076+
NULL 2 NULL
1077+
NULL 3 NULL
1078+
NULL 100 NULL
1079+
arrow.apache.org -100 arrow.apache.org
1080+
arrow.apache.org -3 arrow.apache.org
1081+
arrow.apache.org -2 apache.org
1082+
arrow.apache.org -1 org
1083+
arrow.apache.org 1 arrow
1084+
arrow.apache.org 2 arrow.apache
1085+
arrow.apache.org 3 arrow.apache.org
1086+
arrow.apache.org 100 arrow.apache.org
1087+
... -100 ...
1088+
... -3 ..
1089+
... -2 .
1090+
... -1 (empty)
1091+
... 1 (empty)
1092+
... 2 .
1093+
... 3 ..
1094+
... 100 ...
1095+
. -100 .
1096+
. -3 .
1097+
. -2 .
1098+
. -1 (empty)
1099+
. 1 (empty)
1100+
. 2 .
1101+
. 3 .
1102+
. 100 .
1103+
10451104
# Test substring_index using 'ac' as delimiter
10461105
query TIT
10471106
SELECT str, n, substring_index(str, 'ac', n) AS c FROM

datafusion/sqllogictest/test_files/string_view.slt

+22
Original file line numberDiff line numberDiff line change
@@ -984,6 +984,28 @@ logical_plan
984984
02)--Projection: CAST(test.column1_utf8view AS Utf8) AS __common_expr_1
985985
03)----TableScan: test projection=[column1_utf8view]
986986

987+
## Ensure no casts for SUBSTRINDEX
988+
query TT
989+
EXPLAIN SELECT
990+
SUBSTR_INDEX(column1_utf8view, 'a', 1) as c,
991+
SUBSTR_INDEX(column1_utf8view, 'a', 2) as c2
992+
FROM test;
993+
----
994+
logical_plan
995+
01)Projection: substr_index(test.column1_utf8view, Utf8View("a"), Int64(1)) AS c, substr_index(test.column1_utf8view, Utf8View("a"), Int64(2)) AS c2
996+
02)--TableScan: test projection=[column1_utf8view]
997+
998+
query TT
999+
SELECT
1000+
SUBSTR_INDEX(column1_utf8view, 'a', 1) as c,
1001+
SUBSTR_INDEX(column1_utf8view, 'a', 2) as c2
1002+
FROM test;
1003+
----
1004+
Andrew Andrew
1005+
Xi Xiangpeng
1006+
R Raph
1007+
NULL NULL
1008+
9871009
## Ensure no casts on columns for STARTS_WITH
9881010
query TT
9891011
EXPLAIN SELECT

0 commit comments

Comments
 (0)