Skip to content

[SPARK-51051][SQL]Add an optional parameter for array_position function to specify starting index for matching. #51317

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions python/pyspark/sql/connect/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1782,9 +1782,8 @@ def cardinality(col: "ColumnOrName") -> Column:
cardinality.__doc__ = pysparkfuncs.cardinality.__doc__


def array_position(col: "ColumnOrName", value: Any) -> Column:
return _invoke_function("array_position", _to_col(col), lit(value))

def array_position(col: "ColumnOrName", value: Any, start: Union[Column, int] = 1) -> Column:
return _invoke_function("array_position", _to_col(col), lit(value), lit(start))

array_position.__doc__ = pysparkfuncs.array_position.__doc__

Expand Down
25 changes: 22 additions & 3 deletions python/pyspark/sql/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18453,7 +18453,7 @@ def concat(*cols: "ColumnOrName") -> Column:


@_try_remote_functions
def array_position(col: "ColumnOrName", value: Any) -> Column:
def array_position(col: "ColumnOrName", value: Any, start: Union[Column, int] = None) -> Column:
"""
Array function: Locates the position of the first occurrence of the given value
in the given array. Returns null if either of the arguments are null.
Expand All @@ -18463,6 +18463,9 @@ def array_position(col: "ColumnOrName", value: Any) -> Column:
.. versionchanged:: 3.4.0
Supports Spark Connect.

.. versionchanged:: 4.1.0
Supports start index.

Notes
-----
The position is not zero based, but 1 based index. Returns 0 if the given
Expand All @@ -18477,6 +18480,8 @@ def array_position(col: "ColumnOrName", value: Any) -> Column:

.. versionchanged:: 4.0.0
`value` now also accepts a Column type.
start : :class:`~pyspark.sql.Column` or int, optional
the starting index to search from.

Returns
-------
Expand Down Expand Up @@ -18545,16 +18550,30 @@ def array_position(col: "ColumnOrName", value: Any) -> Column:
Example 6: Finding the position of a column's value in an array of integers

>>> from pyspark.sql import functions as sf
>>> df = spark.createDataFrame([([10, 20, 30], 20)], ['data', 'col'])
>>> df = spark.createDataFrame([([10, 20, 20], 20)], ['data', 'col'])
>>> df.select(sf.array_position(df.data, df.col)).show()
+-------------------------+
|array_position(data, col)|
+-------------------------+
| 2|
+-------------------------+

Example 7: Finding the position of a column's value in an array of integers starting from the index 3

>>> from pyspark.sql import functions as sf
>>> df = spark.createDataFrame([([10, 20, 20], 20, 3)], ['data', 'col'])
>>> df.select(sf.array_position(df.data, df.col)).show()
+-------------------------+
|array_position(data, col)|
+-------------------------+
| 3|
+-------------------------+

"""
return _invoke_function_over_columns("array_position", col, lit(value))
if start is None:
return _invoke_function_over_columns("array_position", col, lit(value))
else:
return _invoke_function_over_columns("array_position", col, lit(value), lit(start))


@_try_remote_functions
Expand Down
14 changes: 14 additions & 0 deletions sql/api/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6343,6 +6343,20 @@ object functions {
def array_position(column: Column, value: Any): Column =
Column.fn("array_position", column, lit(value))

/**
* Locates the position of the first occurrence of the value in the given array after the given
* start position. Returns null if either of the arguments are null.
*
* @note
* The position is not zero based, but 1 based index. Returns 0 if value could not be found in
* array.
*
* @group array_funcs
* @since 2.4.0
*/
def array_position(column: Column, value: Any, pos: Any): Column =
Column.fn("array_position", column, lit(value), lit(pos))

/**
* Returns element of array at given index in value if column is array. Returns value for the
* given key in value if column is map.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2433,17 +2433,18 @@ case class ArrayMax(child: Expression)


/**
* Returns the position of the first occurrence of element in the given array as long.
* Returns 0 if the given value could not be found in the array. Returns null if either of
* the arguments are null
* Returns the position of the first occurrence of element in the given array after the
* given start index as long.
* Returns 0 if the given value could not be found in the array after the given start
* index. Returns null if either of the arguments are null
*
* NOTE: that this is not zero based, but 1-based index. The first element in the array has
* index 1.
* NOTE: that this is not zero based, but 1-based index. The first element in the array
* has index 1.
*/
@ExpressionDescription(
usage = """
_FUNC_(array, element) - Returns the (1-based) index of the first matching element of
the array as long, or 0 if no match is found.
_FUNC_(array, element[, startExpr]) - Returns the (1-based) index of the first matching element
of the array after the start index as long, or 0 if no match is found.
""",
examples = """
Examples:
Expand All @@ -2454,59 +2455,76 @@ case class ArrayMax(child: Expression)
""",
group = "array_funcs",
since = "2.4.0")
case class ArrayPosition(left: Expression, right: Expression)
extends BinaryExpression with ImplicitCastInputTypes with QueryErrorsBase {
case class ArrayPosition(array: Expression, element: Expression, startExpr: Expression)
extends TernaryExpression with ImplicitCastInputTypes with QueryErrorsBase {
override def nullIntolerant: Boolean = true

def this(array: Expression, element: Expression) = {
this(array, element, Literal(1))
}

@transient private lazy val ordering: Ordering[Any] =
TypeUtils.getInterpretedOrdering(right.dataType)
TypeUtils.getInterpretedOrdering(element.dataType)

override def dataType: DataType = LongType

override def inputTypes: Seq[AbstractDataType] = {
(left.dataType, right.dataType) match {
case (ArrayType(e1, hasNull), e2) =>
(array.dataType, element.dataType, startExpr.dataType) match {
case (ArrayType(e1, hasNull), e2, e3: IntegralType) if (e3 != LongType) =>
TypeCoercion.findTightestCommonType(e1, e2) match {
case Some(dt) => Seq(ArrayType(dt, hasNull), dt)
case Some(dt) => Seq(ArrayType(dt, hasNull), dt, IntegerType)
case _ => Seq.empty
}
case _ => Seq.empty
}
}

override def checkInputDataTypes(): TypeCheckResult = {
(left.dataType, right.dataType) match {
case (NullType, _) | (_, NullType) =>
(array.dataType, element.dataType, startExpr.dataType) match {
case (NullType, _, _) | (_, NullType, _) =>
DataTypeMismatch(
errorSubClass = "NULL_TYPE",
Map("functionName" -> toSQLId(prettyName)))
case (t, _) if !ArrayType.acceptsType(t) =>
case (_, _, t) if t != IntegerType =>
DataTypeMismatch(
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
"paramIndex" -> ordinalNumber(2),
"requiredType" -> toSQLType(IntegerType),
"inputSql" -> toSQLExpr(startExpr),
"inputType" -> toSQLType(startExpr.dataType))
)
case (t, _, _) if !ArrayType.acceptsType(t) =>
DataTypeMismatch(
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
"paramIndex" -> ordinalNumber(0),
"requiredType" -> toSQLType(ArrayType),
"inputSql" -> toSQLExpr(left),
"inputType" -> toSQLType(left.dataType))
"inputSql" -> toSQLExpr(array),
"inputType" -> toSQLType(array.dataType))
)
case (ArrayType(e1, _), e2) if DataTypeUtils.sameType(e1, e2) =>
case (ArrayType(e1, _), e2, _) if DataTypeUtils.sameType(e1, e2) =>
TypeUtils.checkForOrderingExpr(e2, prettyName)
case _ =>
DataTypeMismatch(
errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES",
messageParameters = Map(
"functionName" -> toSQLId(prettyName),
"dataType" -> toSQLType(ArrayType),
"leftType" -> toSQLType(left.dataType),
"rightType" -> toSQLType(right.dataType)
"arrayType" -> toSQLType(array.dataType),
"elementType" -> toSQLType(element.dataType)
)
)
}
}

override def nullSafeEval(arr: Any, value: Any): Any = {
arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) =>
if (v != null && ordering.equiv(v, value)) {
override def first: Expression = array
override def second: Expression = element
override def third: Expression = startExpr

override def nullSafeEval(arr: Any, elem: Any, start: Any): Any = {
arr.asInstanceOf[ArrayData].foreach(element.dataType, (i, v) =>
if (i + 1 >= start.asInstanceOf[Int] && v != null && ordering.equiv(v, elem)) {
return (i + 1).toLong
}
)
Expand All @@ -2516,14 +2534,14 @@ case class ArrayPosition(left: Expression, right: Expression)
override def prettyName: String = "array_position"

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (arr, value) => {
nullSafeCodeGen(ctx, ev, (arr, elem, start) => {
val pos = ctx.freshName("arrayPosition")
val i = ctx.freshName("i")
val getValue = CodeGenerator.getValue(arr, right.dataType, i)
val getValue = CodeGenerator.getValue(arr, element.dataType, i)
s"""
|int $pos = 0;
|for (int $i = 0; $i < $arr.numElements(); $i ++) {
| if (!$arr.isNullAt($i) && ${ctx.genEqual(right.dataType, value, getValue)}) {
|for (int $i = $start-1; $i < $arr.numElements(); $i ++) {
| if (!$arr.isNullAt($i) && ${ctx.genEqual(element.dataType, elem, getValue)}) {
| $pos = $i + 1;
| break;
| }
Expand All @@ -2534,8 +2552,8 @@ case class ArrayPosition(left: Expression, right: Expression)
}

override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): ArrayPosition =
copy(left = newLeft, right = newRight)
newFirst: Expression, newSecond: Expression, newThird: Expression): ArrayPosition =
copy(array = newFirst, element = newSecond, startExpr = newThird)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1862,29 +1862,37 @@ class CollectionExpressionsSuite
val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType))
val a2 = Literal.create(Seq(null), ArrayType(LongType))
val a3 = Literal.create(null, ArrayType(StringType))
val a4 = Literal.create(Seq(1, null, 2, 3), ArrayType(IntegerType))

checkEvaluation(ArrayPosition(a0, Literal(3)), 4L)
checkEvaluation(ArrayPosition(a0, Literal(1)), 1L)
checkEvaluation(ArrayPosition(a0, Literal(0)), 0L)
checkEvaluation(ArrayPosition(a0, Literal.create(null, IntegerType)), null)
checkEvaluation(new ArrayPosition(a0, Literal(3)), 4L)
checkEvaluation(new ArrayPosition(a0, Literal(1)), 1L)
checkEvaluation(new ArrayPosition(a0, Literal(0)), 0L)
checkEvaluation(new ArrayPosition(a0, Literal.create(null, IntegerType)), null)

checkEvaluation(ArrayPosition(a1, Literal("")), 2L)
checkEvaluation(ArrayPosition(a1, Literal("a")), 0L)
checkEvaluation(ArrayPosition(a1, Literal.create(null, StringType)), null)
checkEvaluation(new ArrayPosition(a1, Literal("")), 2L)
checkEvaluation(new ArrayPosition(a1, Literal("a")), 0L)
checkEvaluation(new ArrayPosition(a1, Literal.create(null, StringType)), null)

checkEvaluation(ArrayPosition(a2, Literal(1L)), 0L)
checkEvaluation(ArrayPosition(a2, Literal.create(null, LongType)), null)
checkEvaluation(new ArrayPosition(a2, Literal(1L)), 0L)
checkEvaluation(new ArrayPosition(a2, Literal.create(null, LongType)), null)

checkEvaluation(ArrayPosition(a3, Literal("")), null)
checkEvaluation(ArrayPosition(a3, Literal.create(null, StringType)), null)
checkEvaluation(new ArrayPosition(a3, Literal("")), null)
checkEvaluation(new ArrayPosition(a3, Literal.create(null, StringType)), null)

checkEvaluation(new ArrayPosition(a4, Literal(""), Literal(1)), null)
checkEvaluation(new ArrayPosition(a4, Literal.create(null, StringType), Literal(1)), null)
checkEvaluation(new ArrayPosition(a4, Literal.create(1, IntegerType),
Literal.create(null, IntegerType)), null)

val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)),
ArrayType(ArrayType(IntegerType)))
val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)),
ArrayType(ArrayType(IntegerType)))
val aae = Literal.create(Seq[Int](1, 2), ArrayType(IntegerType))
checkEvaluation(ArrayPosition(aa0, aae), 1L)
checkEvaluation(ArrayPosition(aa1, aae), 0L)
checkEvaluation(new ArrayPosition(aa0, aae), 1L)
checkEvaluation(new ArrayPosition(aa1, aae), 0L)
checkEvaluation(new ArrayPosition(aa0, aae, Literal(1)), 1L)
checkEvaluation(new ArrayPosition(aa0, aae, Literal(2)), 0L)
}

test("elementAt") {
Expand Down
Loading