Skip to content

Commit

Permalink
[fix](function) fixed some nested type func's param type which is not…
Browse files Browse the repository at this point in the history
… suitable and make result wrong (apache#44923)

nereids funcs signature for some nested type func's param type defined
not right which will make result wrong such as
```
mysql> SELECT array_position([1,258],257),array_position([2],258);
+-------------------------------+-------------------------------------------+
| array_position([1, 258], 257) | array_position([2], cast(258 as TINYINT)) |
+-------------------------------+-------------------------------------------+
|                             0 |                                         1 |
+-------------------------------+-------------------------------------------+
1 row in set (0.14 sec)

mysql> select array_apply([258], '>' , 257), array_apply([1,2,3], '>', 258);
+------------------------------+---------------------------------------------------+
| array_apply([258], '>', 257) | array_apply([1, 2, 3], '>', cast(258 as TINYINT)) |
+------------------------------+---------------------------------------------------+
| [258]                        | [3]                                               |
+------------------------------+---------------------------------------------------+
1 row in set (0.11 sec)

mysql>  select array_contains([258], 257), array_contains([1,2,3], 258);
+----------------------------+-------------------------------------------------+
| array_contains([258], 257) | array_contains([1, 2, 3], cast(258 as TINYINT)) |
+----------------------------+-------------------------------------------------+
|                          0 |                                               1 |
+----------------------------+-------------------------------------------------+
1 row in set (0.12 sec)

mysql>
mysql> select array_pushfront([258], 257), array_pushfront([1,2,3], 258);
+-----------------------------+--------------------------------------------------+
| array_pushfront([258], 257) | array_pushfront([1, 2, 3], cast(258 as TINYINT)) |
+-----------------------------+--------------------------------------------------+
| [257, 258]                  | [2, 1, 2, 3]                                     |
+-----------------------------+--------------------------------------------------+
1 row in set (0.12 sec)

mysql> select array_pushback([1,258], 257), array_pushback([1,2,3], 258);
+-------------------------------+-------------------------------------------------+
| array_pushback([1, 258], 257) | array_pushback([1, 2, 3], cast(258 as TINYINT)) |
+-------------------------------+-------------------------------------------------+
| [1, 258, 257]                 | [1, 2, 3, 2]                                    |
+-------------------------------+-------------------------------------------------+
1 row in set (0.10 sec)

mysql> select array_remove([1,258], 257), array_remove([1,2,3], 258);
+-----------------------------+-----------------------------------------------+
| array_remove([1, 258], 257) | array_remove([1, 2, 3], cast(258 as TINYINT)) |
+-----------------------------+-----------------------------------------------+
| [1, 258]                    | [1, 3]                                        |
+-----------------------------+-----------------------------------------------+
1 row in set (0.12 sec)
mysql> select map_contains_key(map(1,258), 257), map_contains_key(map(1,2), 258);
+-----------------------------------------------------+---------------------------------------------------+
| map_contains_key(map(1, 258), cast(257 as TINYINT)) | map_contains_key(map(1, 2), cast(258 as TINYINT)) |
+-----------------------------------------------------+---------------------------------------------------+
|                                                   1 |                                                 0 |
+-----------------------------------------------------+---------------------------------------------------+
1 row in set (0.12 sec)

mysql> select map_contains_key(map(1,258), 257), map_contains_key(map(1,3), 258);
+-----------------------------------------------------+---------------------------------------------------+
| map_contains_key(map(1, 258), cast(257 as TINYINT)) | map_contains_key(map(1, 3), cast(258 as TINYINT)) |
+-----------------------------------------------------+---------------------------------------------------+
|                                                   1 |                                                 0 |
+-----------------------------------------------------+---------------------------------------------------+
1 row in set (0.11 sec)

mysql> select map_contains_key(map(1,258), 257), map_contains_key(map(3,1), 258);
+-----------------------------------------------------+---------------------------------------------------+
| map_contains_key(map(1, 258), cast(257 as TINYINT)) | map_contains_key(map(3, 1), cast(258 as TINYINT)) |
+-----------------------------------------------------+---------------------------------------------------+
|                                                   1 |                                                 0 |
+-----------------------------------------------------+---------------------------------------------------+
```
but true result is 
```
mysql> SELECT array_position([1,258],257),array_position([2],258);
+-------------------------------+---------------------------------------------------+
| array_position([1, 258], 257) | array_position(cast([2] as ARRAY<SMALLINT>), 258) |
+-------------------------------+---------------------------------------------------+
|                             0 |                                                 0 |
+-------------------------------+---------------------------------------------------+
1 row in set (0.73 sec)

mysql> select array_apply([258], '>' , 257), array_apply([1,2,3], '>', 258);
+------------------------------+-----------------------------------------------------------+
| array_apply([258], '>', 257) | array_apply(cast([1, 2, 3] as ARRAY<SMALLINT>), '>', 258) |
+------------------------------+-----------------------------------------------------------+
| [258]                        | []                                                        |
+------------------------------+-----------------------------------------------------------+
1 row in set (0.10 sec)

mysql> select array_contains([258], 257), array_contains([1,2,3], 258);
+----------------------------+---------------------------------------------------------+
| array_contains([258], 257) | array_contains(cast([1, 2, 3] as ARRAY<SMALLINT>), 258) |
+----------------------------+---------------------------------------------------------+
|                          0 |                                                       0 |
+----------------------------+---------------------------------------------------------+
1 row in set (0.09 sec)

mysql> select array_pushfront([258], 257), array_pushfront([1,2,3], 258);
+-----------------------------+----------------------------------------------------------+
| array_pushfront([258], 257) | array_pushfront(cast([1, 2, 3] as ARRAY<SMALLINT>), 258) |
+-----------------------------+----------------------------------------------------------+
| [257, 258]                  | [258, 1, 2, 3]                                           |
+-----------------------------+----------------------------------------------------------+
1 row in set (0.11 sec)

mysql> select array_remove([1,258], 257), array_remove([1,2,3], 258);
+-----------------------------+-------------------------------------------------------+
| array_remove([1, 258], 257) | array_remove(cast([1, 2, 3] as ARRAY<SMALLINT>), 258) |
+-----------------------------+-------------------------------------------------------+
| [1, 258]                    | [1, 2, 3]                                             |
+-----------------------------+-------------------------------------------------------+
1 row in set (0.09 sec)

mysql> select countequal([1,258], 257), array_count_equal([1,2,3], 258);
ERROR 1105 (HY000): errCode = 2, detailMessage = Can not found function 'array_count_equal'
mysql> select countequal([1,258], 257), countequal([1,2,3], 258);
+---------------------------+-----------------------------------------------------+
| countequal([1, 258], 257) | countequal(cast([1, 2, 3] as ARRAY<SMALLINT>), 258) |
+---------------------------+-----------------------------------------------------+
|                         0 |                                                   0 |
+---------------------------+-----------------------------------------------------+
1 row in set (0.08 sec)

mysql> select map_contains_key(map(1,258), 257), map_contains_key(map(2,1), 258);
+--------------------------------------------------------------------+-----------------------------------------------------------------+
| map_contains_key(cast(map(1, 258) as MAP<SMALLINT,SMALLINT>), 257) | map_contains_key(cast(map(2, 1) as MAP<SMALLINT,TINYINT>), 258) |
+--------------------------------------------------------------------+-----------------------------------------------------------------+
|                                                                  0 |                                                               0 |
+--------------------------------------------------------------------+-----------------------------------------------------------------+
1 row in set (0.09 sec)

mysql> select map_contains_value(map(1,1), 257), map_contains_value(map(1,2), 258);
+-------------------------------------------------------------------+-------------------------------------------------------------------+
| map_contains_value(cast(map(1, 1) as MAP<TINYINT,SMALLINT>), 257) | map_contains_value(cast(map(1, 2) as MAP<TINYINT,SMALLINT>), 258) |
+-------------------------------------------------------------------+-------------------------------------------------------------------+
|                                                                 0 |                                                                 0 |
+--------------------------------------------------
```
  • Loading branch information
amorynan authored Dec 11, 2024
1 parent 502542d commit f772baa
Show file tree
Hide file tree
Showing 12 changed files with 209 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,16 @@
*/
public class ArrayApply extends ScalarFunction
implements BinaryExpression, ExplicitlyCastableSignature, PropagateNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
public static final List<FunctionSignature> FOLLOW_DATATYPE_SIGNATURE = ImmutableList.of(
FunctionSignature.retArgType(0)
.args(ArrayType.of(new AnyDataType(0)), VarcharType.SYSTEM_DEFAULT,
new FollowToAnyDataType(0)));

public static final List<FunctionSignature> MIN_COMMON_TYPE_SIGNATURES = ImmutableList.of(
FunctionSignature.retArgType(0)
.args(ArrayType.of(new AnyDataType(0)), VarcharType.SYSTEM_DEFAULT,
new AnyDataType(0)));

/**
* constructor
*/
Expand Down Expand Up @@ -93,6 +98,13 @@ public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {

@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
if (getArgument(0).getDataType().isArrayType()
&&
((ArrayType) getArgument(0).getDataType()).getItemType()
.isSameTypeForComplexTypeParam(getArgument(2).getDataType())) {
// return least common type
return MIN_COMMON_TYPE_SIGNATURES;
}
return FOLLOW_DATATYPE_SIGNATURE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,14 @@
public class ArrayContains extends ScalarFunction
implements BinaryExpression, ExplicitlyCastableSignature {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
public static final List<FunctionSignature> FOLLOW_DATATYPE_SIGNATURE = ImmutableList.of(
FunctionSignature.ret(BooleanType.INSTANCE)
.args(ArrayType.of(new AnyDataType(0)), new FollowToAnyDataType(0))
);
public static final List<FunctionSignature> MIN_COMMON_TYPE_SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BooleanType.INSTANCE)
.args(ArrayType.of(new AnyDataType(0)), new AnyDataType(0))
);

/**
* constructor with 2 arguments.
Expand Down Expand Up @@ -71,6 +75,13 @@ public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {

@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
if (getArgument(0).getDataType().isArrayType()
&&
((ArrayType) getArgument(0).getDataType()).getItemType()
.isSameTypeForComplexTypeParam(getArgument(1).getDataType())) {
// return least common type
return MIN_COMMON_TYPE_SIGNATURES;
}
return FOLLOW_DATATYPE_SIGNATURE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,16 @@
public class ArrayPosition extends ScalarFunction
implements BinaryExpression, ExplicitlyCastableSignature {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
public static final List<FunctionSignature> FOLLOW_DATATYPE_SIGNATURE = ImmutableList.of(
FunctionSignature.ret(BigIntType.INSTANCE)
.args(ArrayType.of(new AnyDataType(0)), new FollowToAnyDataType(0))
);

public static final List<FunctionSignature> MIN_COMMON_TYPE_SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BigIntType.INSTANCE)
.args(ArrayType.of(new AnyDataType(0)), new AnyDataType(0))
);

/**
* constructor with 2 arguments.
*/
Expand Down Expand Up @@ -71,6 +76,13 @@ public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {

@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
if (getArgument(0).getDataType().isArrayType()
&&
((ArrayType) getArgument(0).getDataType()).getItemType()
.isSameTypeForComplexTypeParam(getArgument(1).getDataType())) {
// return least common type
return MIN_COMMON_TYPE_SIGNATURES;
}
return FOLLOW_DATATYPE_SIGNATURE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,16 @@
public class ArrayPushBack extends ScalarFunction
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNullable {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
public static final List<FunctionSignature> FOLLOW_DATATYPE_SIGNATURE = ImmutableList.of(
FunctionSignature.retArgType(0)
.args(ArrayType.of(new AnyDataType(0)), new FollowToAnyDataType(0))
);

public static final List<FunctionSignature> MIN_COMMON_TYPE_SIGNATURES = ImmutableList.of(
FunctionSignature.retArgType(0)
.args(ArrayType.of(new AnyDataType(0)), new AnyDataType(0))
);

/**
* constructor with 1 argument.
*/
Expand All @@ -66,6 +71,13 @@ public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {

@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
if (getArgument(0).getDataType().isArrayType()
&&
((ArrayType) getArgument(0).getDataType()).getItemType()
.isSameTypeForComplexTypeParam(getArgument(1).getDataType())) {
// return least common type
return MIN_COMMON_TYPE_SIGNATURES;
}
return FOLLOW_DATATYPE_SIGNATURE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,16 @@
public class ArrayPushFront extends ScalarFunction
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNullable {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
public static final List<FunctionSignature> FOLLOW_DATATYPE_SIGNATURE = ImmutableList.of(
FunctionSignature.retArgType(0)
.args(ArrayType.of(new AnyDataType(0)), new FollowToAnyDataType(0))
);

public static final List<FunctionSignature> MIN_COMMON_TYPE_SIGNATURES = ImmutableList.of(
FunctionSignature.retArgType(0)
.args(ArrayType.of(new AnyDataType(0)), new AnyDataType(0))
);

/**
* constructor with 1 argument.
*/
Expand All @@ -66,6 +71,13 @@ public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {

@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
if (getArgument(0).getDataType().isArrayType()
&&
((ArrayType) getArgument(0).getDataType()).getItemType()
.isSameTypeForComplexTypeParam(getArgument(1).getDataType())) {
// return least common type
return MIN_COMMON_TYPE_SIGNATURES;
}
return FOLLOW_DATATYPE_SIGNATURE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,16 @@
public class ArrayRemove extends ScalarFunction
implements BinaryExpression, ExplicitlyCastableSignature, PropagateNullable {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
public static final List<FunctionSignature> FOLLOW_DATATYPE_SIGNATURE = ImmutableList.of(
FunctionSignature.retArgType(0).args(
ArrayType.of(new AnyDataType(0)), new FollowToAnyDataType(0))
);

public static final List<FunctionSignature> MIN_COMMON_TYPE_SIGNATURES = ImmutableList.of(
FunctionSignature.retArgType(0).args(
ArrayType.of(new AnyDataType(0)), new AnyDataType(0))
);

/**
* constructor with 2 arguments.
*/
Expand All @@ -66,6 +71,13 @@ public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {

@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
if (getArgument(0).getDataType().isArrayType()
&&
((ArrayType) getArgument(0).getDataType()).getItemType()
.isSameTypeForComplexTypeParam(getArgument(1).getDataType())) {
// return least common type
return MIN_COMMON_TYPE_SIGNATURES;
}
return FOLLOW_DATATYPE_SIGNATURE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,16 @@
public class CountEqual extends ScalarFunction
implements BinaryExpression, ExplicitlyCastableSignature {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
public static final List<FunctionSignature> FOLLOW_DATATYPE_SIGNATURE = ImmutableList.of(
FunctionSignature.ret(BigIntType.INSTANCE)
.args(ArrayType.of(new AnyDataType(0)), new FollowToAnyDataType(0))
);

public static final List<FunctionSignature> MIN_COMMON_TYPE_SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BigIntType.INSTANCE)
.args(ArrayType.of(new AnyDataType(0)), new AnyDataType(0))
);

/**
* constructor with 2 arguments.
*/
Expand Down Expand Up @@ -71,6 +76,16 @@ public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {

@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
// to find out element type in array vs param type,
// if they are different, return first array element type,
// else return least common type between element type and param
if (getArgument(0).getDataType().isArrayType()
&&
((ArrayType) getArgument(0).getDataType()).getItemType()
.isSameTypeForComplexTypeParam(getArgument(1).getDataType())) {
// return least common type
return MIN_COMMON_TYPE_SIGNATURES;
}
return FOLLOW_DATATYPE_SIGNATURE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,18 @@
public class MapContainsKey extends ScalarFunction
implements BinaryExpression, ExplicitlyCastableSignature {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
public static final List<FunctionSignature> FOLLOW_DATATYPE_SIGNATURE = ImmutableList.of(
FunctionSignature.ret(BooleanType.INSTANCE)
.args(MapType.of(new AnyDataType(0), AnyDataType.INSTANCE_WITHOUT_INDEX),
new FollowToAnyDataType(0))
);

public static final List<FunctionSignature> MIN_COMMON_TYPE_SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BooleanType.INSTANCE)
.args(MapType.of(new AnyDataType(0), AnyDataType.INSTANCE_WITHOUT_INDEX),
new AnyDataType(0))
);

/**
* constructor with 2 arguments.
*/
Expand Down Expand Up @@ -72,6 +78,13 @@ public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {

@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
if (getArgument(0).getDataType().isMapType()
&&
((MapType) getArgument(0).getDataType()).getKeyType()
.isSameTypeForComplexTypeParam(getArgument(1).getDataType())) {
// return least common type
return MIN_COMMON_TYPE_SIGNATURES;
}
return FOLLOW_DATATYPE_SIGNATURE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,18 @@
public class MapContainsValue extends ScalarFunction
implements BinaryExpression, ExplicitlyCastableSignature {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
public static final List<FunctionSignature> FOLLOW_DATATYPE_SIGNATURE = ImmutableList.of(
FunctionSignature.ret(BooleanType.INSTANCE)
.args(MapType.of(AnyDataType.INSTANCE_WITHOUT_INDEX, new AnyDataType(0)),
new FollowToAnyDataType(0))
);

public static final List<FunctionSignature> MIN_COMMON_TYPE_SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BooleanType.INSTANCE)
.args(MapType.of(AnyDataType.INSTANCE_WITHOUT_INDEX, new AnyDataType(0)),
new AnyDataType(0))
);

/**
* constructor with 2 arguments.
*/
Expand Down Expand Up @@ -72,6 +78,13 @@ public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {

@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
if (getArgument(0).getDataType().isMapType()
&&
((MapType) getArgument(0).getDataType()).getValueType()
.isSameTypeForComplexTypeParam(getArgument(1).getDataType())) {
// return least common type
return MIN_COMMON_TYPE_SIGNATURES;
}
return FOLLOW_DATATYPE_SIGNATURE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,41 @@ public DataType promotion() {
}
}

/**
* whether the param dataType is same-like type for nested in complex type
* same-like type means: string-like, date-like, number type
*/
public boolean isSameTypeForComplexTypeParam(DataType paramType) {
if (this.isArrayType() && paramType.isArrayType()) {
return ((ArrayType) this).getItemType()
.isSameTypeForComplexTypeParam(((ArrayType) paramType).getItemType());
} else if (this.isMapType() && paramType.isMapType()) {
MapType thisMapType = (MapType) this;
MapType otherMapType = (MapType) paramType;
return thisMapType.getKeyType().isSameTypeForComplexTypeParam(otherMapType.getKeyType())
&& thisMapType.getValueType().isSameTypeForComplexTypeParam(otherMapType.getValueType());
} else if (this.isStructType() && paramType.isStructType()) {
StructType thisStructType = (StructType) this;
StructType otherStructType = (StructType) paramType;
if (thisStructType.getFields().size() != otherStructType.getFields().size()) {
return false;
}
for (int i = 0; i < thisStructType.getFields().size(); i++) {
if (!thisStructType.getFields().get(i).getDataType().isSameTypeForComplexTypeParam(
otherStructType.getFields().get(i).getDataType())) {
return false;
}
}
return true;
} else if (this.isStringLikeType() && paramType.isStringLikeType()) {
return true;
} else if (this.isDateLikeType() && paramType.isDateLikeType()) {
return true;
} else {
return this.isNumericType() && paramType.isNumericType();
}
}

/** getAllPromotions */
public List<DataType> getAllPromotions() {
if (this instanceof ArrayType) {
Expand Down
27 changes: 27 additions & 0 deletions regression-test/data/nereids_function_p0/scalar_function/Array.out
Original file line number Diff line number Diff line change
Expand Up @@ -16896,3 +16896,30 @@ false
-- !sql_array_match_all_18 --
true

-- !sql --
0 0

-- !sql --
[258] []

-- !sql --
false false

-- !sql --
[257, 258] [258, 1, 2, 3]

-- !sql --
[1, 258, 257] [1, 2, 3, 258]

-- !sql --
[1, 258] [1, 2, 3]

-- !sql --
0 0

-- !sql --
false false

-- !sql --
false false

Original file line number Diff line number Diff line change
Expand Up @@ -1407,4 +1407,21 @@ suite("nereids_scalar_fn_Array") {
qt_sql_array_match_any_18 "select array_match_any(x->x=2, array())"
qt_sql_array_match_all_18 "select array_match_all(x->x=2, array())"


// tests for nereids array functions for number overflow cases
qt_sql """ SELECT array_position([1,258],257),array_position([2],258);"""
qt_sql """ select array_apply([258], '>' , 257), array_apply([1,2,3], '>', 258);"""
qt_sql """ select array_contains([258], 257), array_contains([1,2,3], 258);"""
// pushfront and pushback
qt_sql """ select array_pushfront([258], 257), array_pushfront([1,2,3], 258);"""
qt_sql """ select array_pushback([1,258], 257), array_pushback([1,2,3], 258);"""
// array_remove
qt_sql """ select array_remove([1,258], 257), array_remove([1,2,3], 258);"""
// countequal
qt_sql """ select countequal([1,258], 257), countequal([1,2,3], 258);"""
// map_contains_key
qt_sql """ select map_contains_key(map(1,258), 257), map_contains_key(map(2,1), 258);"""
// map_contains_value
qt_sql """ select map_contains_value(map(1,1), 257), map_contains_value(map(1,2), 258);"""

}

0 comments on commit f772baa

Please sign in to comment.