From f772baa61f74fb9ff2746f69cb95a1d6aad0c25b Mon Sep 17 00:00:00 2001 From: amory Date: Wed, 11 Dec 2024 16:55:32 +0800 Subject: [PATCH] [fix](function) fixed some nested type func's param type which is not suitable and make result wrong (#44923) nereids funcs signature for some nested type func's param type defined not right which will make result wrong such as ``` mysql> SELECT array_position([1,258],257),array_position([2],258); +-------------------------------+-------------------------------------------+ | array_position([1, 258], 257) | array_position([2], cast(258 as TINYINT)) | +-------------------------------+-------------------------------------------+ | 0 | 1 | +-------------------------------+-------------------------------------------+ 1 row in set (0.14 sec) mysql> select array_apply([258], '>' , 257), array_apply([1,2,3], '>', 258); +------------------------------+---------------------------------------------------+ | array_apply([258], '>', 257) | array_apply([1, 2, 3], '>', cast(258 as TINYINT)) | +------------------------------+---------------------------------------------------+ | [258] | [3] | +------------------------------+---------------------------------------------------+ 1 row in set (0.11 sec) mysql> select array_contains([258], 257), array_contains([1,2,3], 258); +----------------------------+-------------------------------------------------+ | array_contains([258], 257) | array_contains([1, 2, 3], cast(258 as TINYINT)) | +----------------------------+-------------------------------------------------+ | 0 | 1 | +----------------------------+-------------------------------------------------+ 1 row in set (0.12 sec) mysql> mysql> select array_pushfront([258], 257), array_pushfront([1,2,3], 258); +-----------------------------+--------------------------------------------------+ | array_pushfront([258], 257) | array_pushfront([1, 2, 3], cast(258 as TINYINT)) | +-----------------------------+--------------------------------------------------+ | [257, 258] | [2, 1, 2, 3] | +-----------------------------+--------------------------------------------------+ 1 row in set (0.12 sec) mysql> select array_pushback([1,258], 257), array_pushback([1,2,3], 258); +-------------------------------+-------------------------------------------------+ | array_pushback([1, 258], 257) | array_pushback([1, 2, 3], cast(258 as TINYINT)) | +-------------------------------+-------------------------------------------------+ | [1, 258, 257] | [1, 2, 3, 2] | +-------------------------------+-------------------------------------------------+ 1 row in set (0.10 sec) mysql> select array_remove([1,258], 257), array_remove([1,2,3], 258); +-----------------------------+-----------------------------------------------+ | array_remove([1, 258], 257) | array_remove([1, 2, 3], cast(258 as TINYINT)) | +-----------------------------+-----------------------------------------------+ | [1, 258] | [1, 3] | +-----------------------------+-----------------------------------------------+ 1 row in set (0.12 sec) mysql> select map_contains_key(map(1,258), 257), map_contains_key(map(1,2), 258); +-----------------------------------------------------+---------------------------------------------------+ | map_contains_key(map(1, 258), cast(257 as TINYINT)) | map_contains_key(map(1, 2), cast(258 as TINYINT)) | +-----------------------------------------------------+---------------------------------------------------+ | 1 | 0 | +-----------------------------------------------------+---------------------------------------------------+ 1 row in set (0.12 sec) mysql> select map_contains_key(map(1,258), 257), map_contains_key(map(1,3), 258); +-----------------------------------------------------+---------------------------------------------------+ | map_contains_key(map(1, 258), cast(257 as TINYINT)) | map_contains_key(map(1, 3), cast(258 as TINYINT)) | +-----------------------------------------------------+---------------------------------------------------+ | 1 | 0 | +-----------------------------------------------------+---------------------------------------------------+ 1 row in set (0.11 sec) mysql> select map_contains_key(map(1,258), 257), map_contains_key(map(3,1), 258); +-----------------------------------------------------+---------------------------------------------------+ | map_contains_key(map(1, 258), cast(257 as TINYINT)) | map_contains_key(map(3, 1), cast(258 as TINYINT)) | +-----------------------------------------------------+---------------------------------------------------+ | 1 | 0 | +-----------------------------------------------------+---------------------------------------------------+ ``` but true result is ``` mysql> SELECT array_position([1,258],257),array_position([2],258); +-------------------------------+---------------------------------------------------+ | array_position([1, 258], 257) | array_position(cast([2] as ARRAY), 258) | +-------------------------------+---------------------------------------------------+ | 0 | 0 | +-------------------------------+---------------------------------------------------+ 1 row in set (0.73 sec) mysql> select array_apply([258], '>' , 257), array_apply([1,2,3], '>', 258); +------------------------------+-----------------------------------------------------------+ | array_apply([258], '>', 257) | array_apply(cast([1, 2, 3] as ARRAY), '>', 258) | +------------------------------+-----------------------------------------------------------+ | [258] | [] | +------------------------------+-----------------------------------------------------------+ 1 row in set (0.10 sec) mysql> select array_contains([258], 257), array_contains([1,2,3], 258); +----------------------------+---------------------------------------------------------+ | array_contains([258], 257) | array_contains(cast([1, 2, 3] as ARRAY), 258) | +----------------------------+---------------------------------------------------------+ | 0 | 0 | +----------------------------+---------------------------------------------------------+ 1 row in set (0.09 sec) mysql> select array_pushfront([258], 257), array_pushfront([1,2,3], 258); +-----------------------------+----------------------------------------------------------+ | array_pushfront([258], 257) | array_pushfront(cast([1, 2, 3] as ARRAY), 258) | +-----------------------------+----------------------------------------------------------+ | [257, 258] | [258, 1, 2, 3] | +-----------------------------+----------------------------------------------------------+ 1 row in set (0.11 sec) mysql> select array_remove([1,258], 257), array_remove([1,2,3], 258); +-----------------------------+-------------------------------------------------------+ | array_remove([1, 258], 257) | array_remove(cast([1, 2, 3] as ARRAY), 258) | +-----------------------------+-------------------------------------------------------+ | [1, 258] | [1, 2, 3] | +-----------------------------+-------------------------------------------------------+ 1 row in set (0.09 sec) mysql> select countequal([1,258], 257), array_count_equal([1,2,3], 258); ERROR 1105 (HY000): errCode = 2, detailMessage = Can not found function 'array_count_equal' mysql> select countequal([1,258], 257), countequal([1,2,3], 258); +---------------------------+-----------------------------------------------------+ | countequal([1, 258], 257) | countequal(cast([1, 2, 3] as ARRAY), 258) | +---------------------------+-----------------------------------------------------+ | 0 | 0 | +---------------------------+-----------------------------------------------------+ 1 row in set (0.08 sec) mysql> select map_contains_key(map(1,258), 257), map_contains_key(map(2,1), 258); +--------------------------------------------------------------------+-----------------------------------------------------------------+ | map_contains_key(cast(map(1, 258) as MAP), 257) | map_contains_key(cast(map(2, 1) as MAP), 258) | +--------------------------------------------------------------------+-----------------------------------------------------------------+ | 0 | 0 | +--------------------------------------------------------------------+-----------------------------------------------------------------+ 1 row in set (0.09 sec) mysql> select map_contains_value(map(1,1), 257), map_contains_value(map(1,2), 258); +-------------------------------------------------------------------+-------------------------------------------------------------------+ | map_contains_value(cast(map(1, 1) as MAP), 257) | map_contains_value(cast(map(1, 2) as MAP), 258) | +-------------------------------------------------------------------+-------------------------------------------------------------------+ | 0 | 0 | +-------------------------------------------------- ``` --- .../functions/scalar/ArrayApply.java | 16 +++++++-- .../functions/scalar/ArrayContains.java | 15 ++++++-- .../functions/scalar/ArrayPosition.java | 16 +++++++-- .../functions/scalar/ArrayPushBack.java | 16 +++++++-- .../functions/scalar/ArrayPushFront.java | 16 +++++++-- .../functions/scalar/ArrayRemove.java | 16 +++++++-- .../functions/scalar/CountEqual.java | 19 ++++++++-- .../functions/scalar/MapContainsKey.java | 17 +++++++-- .../functions/scalar/MapContainsValue.java | 17 +++++++-- .../apache/doris/nereids/types/DataType.java | 35 +++++++++++++++++++ .../scalar_function/Array.out | 27 ++++++++++++++ .../scalar_function/Array.groovy | 17 +++++++++ 12 files changed, 209 insertions(+), 18 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayApply.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayApply.java index 86138b82cac170..fea589a7e14fd7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayApply.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayApply.java @@ -41,11 +41,16 @@ */ public class ArrayApply extends ScalarFunction implements BinaryExpression, ExplicitlyCastableSignature, PropagateNullable { - public static final List SIGNATURES = ImmutableList.of( + public static final List FOLLOW_DATATYPE_SIGNATURE = ImmutableList.of( FunctionSignature.retArgType(0) .args(ArrayType.of(new AnyDataType(0)), VarcharType.SYSTEM_DEFAULT, new FollowToAnyDataType(0))); + public static final List MIN_COMMON_TYPE_SIGNATURES = ImmutableList.of( + FunctionSignature.retArgType(0) + .args(ArrayType.of(new AnyDataType(0)), VarcharType.SYSTEM_DEFAULT, + new AnyDataType(0))); + /** * constructor */ @@ -93,6 +98,13 @@ public R accept(ExpressionVisitor visitor, C context) { @Override public List getSignatures() { - return SIGNATURES; + if (getArgument(0).getDataType().isArrayType() + && + ((ArrayType) getArgument(0).getDataType()).getItemType() + .isSameTypeForComplexTypeParam(getArgument(2).getDataType())) { + // return least common type + return MIN_COMMON_TYPE_SIGNATURES; + } + return FOLLOW_DATATYPE_SIGNATURE; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayContains.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayContains.java index 494a4a863c528d..4eee7066584da2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayContains.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayContains.java @@ -38,10 +38,14 @@ public class ArrayContains extends ScalarFunction implements BinaryExpression, ExplicitlyCastableSignature { - public static final List SIGNATURES = ImmutableList.of( + public static final List FOLLOW_DATATYPE_SIGNATURE = ImmutableList.of( FunctionSignature.ret(BooleanType.INSTANCE) .args(ArrayType.of(new AnyDataType(0)), new FollowToAnyDataType(0)) ); + public static final List MIN_COMMON_TYPE_SIGNATURES = ImmutableList.of( + FunctionSignature.ret(BooleanType.INSTANCE) + .args(ArrayType.of(new AnyDataType(0)), new AnyDataType(0)) + ); /** * constructor with 2 arguments. @@ -71,6 +75,13 @@ public R accept(ExpressionVisitor visitor, C context) { @Override public List getSignatures() { - return SIGNATURES; + if (getArgument(0).getDataType().isArrayType() + && + ((ArrayType) getArgument(0).getDataType()).getItemType() + .isSameTypeForComplexTypeParam(getArgument(1).getDataType())) { + // return least common type + return MIN_COMMON_TYPE_SIGNATURES; + } + return FOLLOW_DATATYPE_SIGNATURE; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayPosition.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayPosition.java index 3a94e84d51159e..b782a560719c5d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayPosition.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayPosition.java @@ -38,11 +38,16 @@ public class ArrayPosition extends ScalarFunction implements BinaryExpression, ExplicitlyCastableSignature { - public static final List SIGNATURES = ImmutableList.of( + public static final List FOLLOW_DATATYPE_SIGNATURE = ImmutableList.of( FunctionSignature.ret(BigIntType.INSTANCE) .args(ArrayType.of(new AnyDataType(0)), new FollowToAnyDataType(0)) ); + public static final List MIN_COMMON_TYPE_SIGNATURES = ImmutableList.of( + FunctionSignature.ret(BigIntType.INSTANCE) + .args(ArrayType.of(new AnyDataType(0)), new AnyDataType(0)) + ); + /** * constructor with 2 arguments. */ @@ -71,6 +76,13 @@ public R accept(ExpressionVisitor visitor, C context) { @Override public List getSignatures() { - return SIGNATURES; + if (getArgument(0).getDataType().isArrayType() + && + ((ArrayType) getArgument(0).getDataType()).getItemType() + .isSameTypeForComplexTypeParam(getArgument(1).getDataType())) { + // return least common type + return MIN_COMMON_TYPE_SIGNATURES; + } + return FOLLOW_DATATYPE_SIGNATURE; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayPushBack.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayPushBack.java index 932c2446132903..a96c48c4b274a1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayPushBack.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayPushBack.java @@ -38,11 +38,16 @@ public class ArrayPushBack extends ScalarFunction implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNullable { - public static final List SIGNATURES = ImmutableList.of( + public static final List FOLLOW_DATATYPE_SIGNATURE = ImmutableList.of( FunctionSignature.retArgType(0) .args(ArrayType.of(new AnyDataType(0)), new FollowToAnyDataType(0)) ); + public static final List MIN_COMMON_TYPE_SIGNATURES = ImmutableList.of( + FunctionSignature.retArgType(0) + .args(ArrayType.of(new AnyDataType(0)), new AnyDataType(0)) + ); + /** * constructor with 1 argument. */ @@ -66,6 +71,13 @@ public R accept(ExpressionVisitor visitor, C context) { @Override public List getSignatures() { - return SIGNATURES; + if (getArgument(0).getDataType().isArrayType() + && + ((ArrayType) getArgument(0).getDataType()).getItemType() + .isSameTypeForComplexTypeParam(getArgument(1).getDataType())) { + // return least common type + return MIN_COMMON_TYPE_SIGNATURES; + } + return FOLLOW_DATATYPE_SIGNATURE; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayPushFront.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayPushFront.java index 26e1cdd91e3d21..e20fea5cc20543 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayPushFront.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayPushFront.java @@ -38,11 +38,16 @@ public class ArrayPushFront extends ScalarFunction implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNullable { - public static final List SIGNATURES = ImmutableList.of( + public static final List FOLLOW_DATATYPE_SIGNATURE = ImmutableList.of( FunctionSignature.retArgType(0) .args(ArrayType.of(new AnyDataType(0)), new FollowToAnyDataType(0)) ); + public static final List MIN_COMMON_TYPE_SIGNATURES = ImmutableList.of( + FunctionSignature.retArgType(0) + .args(ArrayType.of(new AnyDataType(0)), new AnyDataType(0)) + ); + /** * constructor with 1 argument. */ @@ -66,6 +71,13 @@ public R accept(ExpressionVisitor visitor, C context) { @Override public List getSignatures() { - return SIGNATURES; + if (getArgument(0).getDataType().isArrayType() + && + ((ArrayType) getArgument(0).getDataType()).getItemType() + .isSameTypeForComplexTypeParam(getArgument(1).getDataType())) { + // return least common type + return MIN_COMMON_TYPE_SIGNATURES; + } + return FOLLOW_DATATYPE_SIGNATURE; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayRemove.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayRemove.java index acf731b8bb46c6..f5996776867dec 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayRemove.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayRemove.java @@ -38,11 +38,16 @@ public class ArrayRemove extends ScalarFunction implements BinaryExpression, ExplicitlyCastableSignature, PropagateNullable { - public static final List SIGNATURES = ImmutableList.of( + public static final List FOLLOW_DATATYPE_SIGNATURE = ImmutableList.of( FunctionSignature.retArgType(0).args( ArrayType.of(new AnyDataType(0)), new FollowToAnyDataType(0)) ); + public static final List MIN_COMMON_TYPE_SIGNATURES = ImmutableList.of( + FunctionSignature.retArgType(0).args( + ArrayType.of(new AnyDataType(0)), new AnyDataType(0)) + ); + /** * constructor with 2 arguments. */ @@ -66,6 +71,13 @@ public R accept(ExpressionVisitor visitor, C context) { @Override public List getSignatures() { - return SIGNATURES; + if (getArgument(0).getDataType().isArrayType() + && + ((ArrayType) getArgument(0).getDataType()).getItemType() + .isSameTypeForComplexTypeParam(getArgument(1).getDataType())) { + // return least common type + return MIN_COMMON_TYPE_SIGNATURES; + } + return FOLLOW_DATATYPE_SIGNATURE; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/CountEqual.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/CountEqual.java index 2ba0c9072435a7..bb76aecde31f07 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/CountEqual.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/CountEqual.java @@ -38,11 +38,16 @@ public class CountEqual extends ScalarFunction implements BinaryExpression, ExplicitlyCastableSignature { - public static final List SIGNATURES = ImmutableList.of( + public static final List FOLLOW_DATATYPE_SIGNATURE = ImmutableList.of( FunctionSignature.ret(BigIntType.INSTANCE) .args(ArrayType.of(new AnyDataType(0)), new FollowToAnyDataType(0)) ); + public static final List MIN_COMMON_TYPE_SIGNATURES = ImmutableList.of( + FunctionSignature.ret(BigIntType.INSTANCE) + .args(ArrayType.of(new AnyDataType(0)), new AnyDataType(0)) + ); + /** * constructor with 2 arguments. */ @@ -71,6 +76,16 @@ public R accept(ExpressionVisitor visitor, C context) { @Override public List getSignatures() { - return SIGNATURES; + // to find out element type in array vs param type, + // if they are different, return first array element type, + // else return least common type between element type and param + if (getArgument(0).getDataType().isArrayType() + && + ((ArrayType) getArgument(0).getDataType()).getItemType() + .isSameTypeForComplexTypeParam(getArgument(1).getDataType())) { + // return least common type + return MIN_COMMON_TYPE_SIGNATURES; + } + return FOLLOW_DATATYPE_SIGNATURE; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/MapContainsKey.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/MapContainsKey.java index b489bd47963497..798101e1a2bde5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/MapContainsKey.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/MapContainsKey.java @@ -38,12 +38,18 @@ public class MapContainsKey extends ScalarFunction implements BinaryExpression, ExplicitlyCastableSignature { - public static final List SIGNATURES = ImmutableList.of( + public static final List FOLLOW_DATATYPE_SIGNATURE = ImmutableList.of( FunctionSignature.ret(BooleanType.INSTANCE) .args(MapType.of(new AnyDataType(0), AnyDataType.INSTANCE_WITHOUT_INDEX), new FollowToAnyDataType(0)) ); + public static final List MIN_COMMON_TYPE_SIGNATURES = ImmutableList.of( + FunctionSignature.ret(BooleanType.INSTANCE) + .args(MapType.of(new AnyDataType(0), AnyDataType.INSTANCE_WITHOUT_INDEX), + new AnyDataType(0)) + ); + /** * constructor with 2 arguments. */ @@ -72,6 +78,13 @@ public R accept(ExpressionVisitor visitor, C context) { @Override public List getSignatures() { - return SIGNATURES; + if (getArgument(0).getDataType().isMapType() + && + ((MapType) getArgument(0).getDataType()).getKeyType() + .isSameTypeForComplexTypeParam(getArgument(1).getDataType())) { + // return least common type + return MIN_COMMON_TYPE_SIGNATURES; + } + return FOLLOW_DATATYPE_SIGNATURE; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/MapContainsValue.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/MapContainsValue.java index de77386c889902..0a8bcdc49c5e71 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/MapContainsValue.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/MapContainsValue.java @@ -38,12 +38,18 @@ public class MapContainsValue extends ScalarFunction implements BinaryExpression, ExplicitlyCastableSignature { - public static final List SIGNATURES = ImmutableList.of( + public static final List FOLLOW_DATATYPE_SIGNATURE = ImmutableList.of( FunctionSignature.ret(BooleanType.INSTANCE) .args(MapType.of(AnyDataType.INSTANCE_WITHOUT_INDEX, new AnyDataType(0)), new FollowToAnyDataType(0)) ); + public static final List MIN_COMMON_TYPE_SIGNATURES = ImmutableList.of( + FunctionSignature.ret(BooleanType.INSTANCE) + .args(MapType.of(AnyDataType.INSTANCE_WITHOUT_INDEX, new AnyDataType(0)), + new AnyDataType(0)) + ); + /** * constructor with 2 arguments. */ @@ -72,6 +78,13 @@ public R accept(ExpressionVisitor visitor, C context) { @Override public List getSignatures() { - return SIGNATURES; + if (getArgument(0).getDataType().isMapType() + && + ((MapType) getArgument(0).getDataType()).getValueType() + .isSameTypeForComplexTypeParam(getArgument(1).getDataType())) { + // return least common type + return MIN_COMMON_TYPE_SIGNATURES; + } + return FOLLOW_DATATYPE_SIGNATURE; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DataType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DataType.java index 1d5f69c5366e0d..9b77017f6de8cb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DataType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DataType.java @@ -644,6 +644,41 @@ public DataType promotion() { } } + /** + * whether the param dataType is same-like type for nested in complex type + * same-like type means: string-like, date-like, number type + */ + public boolean isSameTypeForComplexTypeParam(DataType paramType) { + if (this.isArrayType() && paramType.isArrayType()) { + return ((ArrayType) this).getItemType() + .isSameTypeForComplexTypeParam(((ArrayType) paramType).getItemType()); + } else if (this.isMapType() && paramType.isMapType()) { + MapType thisMapType = (MapType) this; + MapType otherMapType = (MapType) paramType; + return thisMapType.getKeyType().isSameTypeForComplexTypeParam(otherMapType.getKeyType()) + && thisMapType.getValueType().isSameTypeForComplexTypeParam(otherMapType.getValueType()); + } else if (this.isStructType() && paramType.isStructType()) { + StructType thisStructType = (StructType) this; + StructType otherStructType = (StructType) paramType; + if (thisStructType.getFields().size() != otherStructType.getFields().size()) { + return false; + } + for (int i = 0; i < thisStructType.getFields().size(); i++) { + if (!thisStructType.getFields().get(i).getDataType().isSameTypeForComplexTypeParam( + otherStructType.getFields().get(i).getDataType())) { + return false; + } + } + return true; + } else if (this.isStringLikeType() && paramType.isStringLikeType()) { + return true; + } else if (this.isDateLikeType() && paramType.isDateLikeType()) { + return true; + } else { + return this.isNumericType() && paramType.isNumericType(); + } + } + /** getAllPromotions */ public List getAllPromotions() { if (this instanceof ArrayType) { diff --git a/regression-test/data/nereids_function_p0/scalar_function/Array.out b/regression-test/data/nereids_function_p0/scalar_function/Array.out index 9570ac178d7ab7..b2e11fab37e47a 100644 --- a/regression-test/data/nereids_function_p0/scalar_function/Array.out +++ b/regression-test/data/nereids_function_p0/scalar_function/Array.out @@ -16896,3 +16896,30 @@ false -- !sql_array_match_all_18 -- true +-- !sql -- +0 0 + +-- !sql -- +[258] [] + +-- !sql -- +false false + +-- !sql -- +[257, 258] [258, 1, 2, 3] + +-- !sql -- +[1, 258, 257] [1, 2, 3, 258] + +-- !sql -- +[1, 258] [1, 2, 3] + +-- !sql -- +0 0 + +-- !sql -- +false false + +-- !sql -- +false false + diff --git a/regression-test/suites/nereids_function_p0/scalar_function/Array.groovy b/regression-test/suites/nereids_function_p0/scalar_function/Array.groovy index defa553279c84e..7178ef4436e94b 100644 --- a/regression-test/suites/nereids_function_p0/scalar_function/Array.groovy +++ b/regression-test/suites/nereids_function_p0/scalar_function/Array.groovy @@ -1407,4 +1407,21 @@ suite("nereids_scalar_fn_Array") { qt_sql_array_match_any_18 "select array_match_any(x->x=2, array())" qt_sql_array_match_all_18 "select array_match_all(x->x=2, array())" + + // tests for nereids array functions for number overflow cases + qt_sql """ SELECT array_position([1,258],257),array_position([2],258);""" + qt_sql """ select array_apply([258], '>' , 257), array_apply([1,2,3], '>', 258);""" + qt_sql """ select array_contains([258], 257), array_contains([1,2,3], 258);""" + // pushfront and pushback + qt_sql """ select array_pushfront([258], 257), array_pushfront([1,2,3], 258);""" + qt_sql """ select array_pushback([1,258], 257), array_pushback([1,2,3], 258);""" + // array_remove + qt_sql """ select array_remove([1,258], 257), array_remove([1,2,3], 258);""" + // countequal + qt_sql """ select countequal([1,258], 257), countequal([1,2,3], 258);""" + // map_contains_key + qt_sql """ select map_contains_key(map(1,258), 257), map_contains_key(map(2,1), 258);""" + // map_contains_value + qt_sql """ select map_contains_value(map(1,1), 257), map_contains_value(map(1,2), 258);""" + }