diff --git a/python/pyspark/sql/connect/functions/builtin.py b/python/pyspark/sql/connect/functions/builtin.py index 85d78ccac3015..ed7080134c4fb 100644 --- a/python/pyspark/sql/connect/functions/builtin.py +++ b/python/pyspark/sql/connect/functions/builtin.py @@ -1782,9 +1782,8 @@ def cardinality(col: "ColumnOrName") -> Column: cardinality.__doc__ = pysparkfuncs.cardinality.__doc__ -def array_position(col: "ColumnOrName", value: Any) -> Column: - return _invoke_function("array_position", _to_col(col), lit(value)) - +def array_position(col: "ColumnOrName", value: Any, start: Union[Column, int] = 1) -> Column: + return _invoke_function("array_position", _to_col(col), lit(value), lit(start)) array_position.__doc__ = pysparkfuncs.array_position.__doc__ diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 37b65c3203da8..43a1f2edb6b80 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -18453,7 +18453,7 @@ def concat(*cols: "ColumnOrName") -> Column: @_try_remote_functions -def array_position(col: "ColumnOrName", value: Any) -> Column: +def array_position(col: "ColumnOrName", value: Any, start: Union[Column, int] = None) -> Column: """ Array function: Locates the position of the first occurrence of the given value in the given array. Returns null if either of the arguments are null. @@ -18463,6 +18463,9 @@ def array_position(col: "ColumnOrName", value: Any) -> Column: .. versionchanged:: 3.4.0 Supports Spark Connect. + .. versionchanged:: 4.1.0 + Supports start index. + Notes ----- The position is not zero based, but 1 based index. Returns 0 if the given @@ -18477,6 +18480,8 @@ def array_position(col: "ColumnOrName", value: Any) -> Column: .. versionchanged:: 4.0.0 `value` now also accepts a Column type. + start : :class:`~pyspark.sql.Column` or int, optional + the starting index to search from. Returns ------- @@ -18545,7 +18550,7 @@ def array_position(col: "ColumnOrName", value: Any) -> Column: Example 6: Finding the position of a column's value in an array of integers >>> from pyspark.sql import functions as sf - >>> df = spark.createDataFrame([([10, 20, 30], 20)], ['data', 'col']) + >>> df = spark.createDataFrame([([10, 20, 20], 20)], ['data', 'col']) >>> df.select(sf.array_position(df.data, df.col)).show() +-------------------------+ |array_position(data, col)| @@ -18553,8 +18558,22 @@ def array_position(col: "ColumnOrName", value: Any) -> Column: | 2| +-------------------------+ + Example 7: Finding the position of a column's value in an array of integers starting from the index 3 + + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([([10, 20, 20], 20, 3)], ['data', 'col']) + >>> df.select(sf.array_position(df.data, df.col)).show() + +-------------------------+ + |array_position(data, col)| + +-------------------------+ + | 3| + +-------------------------+ + """ - return _invoke_function_over_columns("array_position", col, lit(value)) + if start is None: + return _invoke_function_over_columns("array_position", col, lit(value)) + else: + return _invoke_function_over_columns("array_position", col, lit(value), lit(start)) @_try_remote_functions diff --git a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala index ce5c76807b5c1..b61e8360a6c62 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala @@ -6343,6 +6343,20 @@ object functions { def array_position(column: Column, value: Any): Column = Column.fn("array_position", column, lit(value)) + /** + * Locates the position of the first occurrence of the value in the given array after the given + * start position. Returns null if either of the arguments are null. + * + * @note + * The position is not zero based, but 1 based index. Returns 0 if value could not be found in + * array. + * + * @group array_funcs + * @since 2.4.0 + */ + def array_position(column: Column, value: Any, pos: Any): Column = + Column.fn("array_position", column, lit(value), lit(pos)) + /** * Returns element of array at given index in value if column is array. Returns value for the * given key in value if column is map. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index b4978fbe1f70a..89fabf14b9c7e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -2433,17 +2433,18 @@ case class ArrayMax(child: Expression) /** - * Returns the position of the first occurrence of element in the given array as long. - * Returns 0 if the given value could not be found in the array. Returns null if either of - * the arguments are null + * Returns the position of the first occurrence of element in the given array after the + * given start index as long. + * Returns 0 if the given value could not be found in the array after the given start + * index. Returns null if either of the arguments are null * - * NOTE: that this is not zero based, but 1-based index. The first element in the array has - * index 1. + * NOTE: that this is not zero based, but 1-based index. The first element in the array + * has index 1. */ @ExpressionDescription( usage = """ - _FUNC_(array, element) - Returns the (1-based) index of the first matching element of - the array as long, or 0 if no match is found. + _FUNC_(array, element[, startExpr]) - Returns the (1-based) index of the first matching element + of the array after the start index as long, or 0 if no match is found. """, examples = """ Examples: @@ -2454,20 +2455,24 @@ case class ArrayMax(child: Expression) """, group = "array_funcs", since = "2.4.0") -case class ArrayPosition(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes with QueryErrorsBase { +case class ArrayPosition(array: Expression, element: Expression, startExpr: Expression) + extends TernaryExpression with ImplicitCastInputTypes with QueryErrorsBase { override def nullIntolerant: Boolean = true + def this(array: Expression, element: Expression) = { + this(array, element, Literal(1)) + } + @transient private lazy val ordering: Ordering[Any] = - TypeUtils.getInterpretedOrdering(right.dataType) + TypeUtils.getInterpretedOrdering(element.dataType) override def dataType: DataType = LongType override def inputTypes: Seq[AbstractDataType] = { - (left.dataType, right.dataType) match { - case (ArrayType(e1, hasNull), e2) => + (array.dataType, element.dataType, startExpr.dataType) match { + case (ArrayType(e1, hasNull), e2, e3: IntegralType) if (e3 != LongType) => TypeCoercion.findTightestCommonType(e1, e2) match { - case Some(dt) => Seq(ArrayType(dt, hasNull), dt) + case Some(dt) => Seq(ArrayType(dt, hasNull), dt, IntegerType) case _ => Seq.empty } case _ => Seq.empty @@ -2475,21 +2480,30 @@ case class ArrayPosition(left: Expression, right: Expression) } override def checkInputDataTypes(): TypeCheckResult = { - (left.dataType, right.dataType) match { - case (NullType, _) | (_, NullType) => + (array.dataType, element.dataType, startExpr.dataType) match { + case (NullType, _, _) | (_, NullType, _) => DataTypeMismatch( errorSubClass = "NULL_TYPE", Map("functionName" -> toSQLId(prettyName))) - case (t, _) if !ArrayType.acceptsType(t) => + case (_, _, t) if t != IntegerType => + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> ordinalNumber(2), + "requiredType" -> toSQLType(IntegerType), + "inputSql" -> toSQLExpr(startExpr), + "inputType" -> toSQLType(startExpr.dataType)) + ) + case (t, _, _) if !ArrayType.acceptsType(t) => DataTypeMismatch( errorSubClass = "UNEXPECTED_INPUT_TYPE", messageParameters = Map( "paramIndex" -> ordinalNumber(0), "requiredType" -> toSQLType(ArrayType), - "inputSql" -> toSQLExpr(left), - "inputType" -> toSQLType(left.dataType)) + "inputSql" -> toSQLExpr(array), + "inputType" -> toSQLType(array.dataType)) ) - case (ArrayType(e1, _), e2) if DataTypeUtils.sameType(e1, e2) => + case (ArrayType(e1, _), e2, _) if DataTypeUtils.sameType(e1, e2) => TypeUtils.checkForOrderingExpr(e2, prettyName) case _ => DataTypeMismatch( @@ -2497,16 +2511,20 @@ case class ArrayPosition(left: Expression, right: Expression) messageParameters = Map( "functionName" -> toSQLId(prettyName), "dataType" -> toSQLType(ArrayType), - "leftType" -> toSQLType(left.dataType), - "rightType" -> toSQLType(right.dataType) + "arrayType" -> toSQLType(array.dataType), + "elementType" -> toSQLType(element.dataType) ) ) } } - override def nullSafeEval(arr: Any, value: Any): Any = { - arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => - if (v != null && ordering.equiv(v, value)) { + override def first: Expression = array + override def second: Expression = element + override def third: Expression = startExpr + + override def nullSafeEval(arr: Any, elem: Any, start: Any): Any = { + arr.asInstanceOf[ArrayData].foreach(element.dataType, (i, v) => + if (i + 1 >= start.asInstanceOf[Int] && v != null && ordering.equiv(v, elem)) { return (i + 1).toLong } ) @@ -2516,14 +2534,14 @@ case class ArrayPosition(left: Expression, right: Expression) override def prettyName: String = "array_position" override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (arr, value) => { + nullSafeCodeGen(ctx, ev, (arr, elem, start) => { val pos = ctx.freshName("arrayPosition") val i = ctx.freshName("i") - val getValue = CodeGenerator.getValue(arr, right.dataType, i) + val getValue = CodeGenerator.getValue(arr, element.dataType, i) s""" |int $pos = 0; - |for (int $i = 0; $i < $arr.numElements(); $i ++) { - | if (!$arr.isNullAt($i) && ${ctx.genEqual(right.dataType, value, getValue)}) { + |for (int $i = $start-1; $i < $arr.numElements(); $i ++) { + | if (!$arr.isNullAt($i) && ${ctx.genEqual(element.dataType, elem, getValue)}) { | $pos = $i + 1; | break; | } @@ -2534,8 +2552,8 @@ case class ArrayPosition(left: Expression, right: Expression) } override protected def withNewChildrenInternal( - newLeft: Expression, newRight: Expression): ArrayPosition = - copy(left = newLeft, right = newRight) + newFirst: Expression, newSecond: Expression, newThird: Expression): ArrayPosition = + copy(array = newFirst, element = newSecond, startExpr = newThird) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 1907ec7c23aa6..5694a6f32ed69 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1862,29 +1862,37 @@ class CollectionExpressionsSuite val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) val a2 = Literal.create(Seq(null), ArrayType(LongType)) val a3 = Literal.create(null, ArrayType(StringType)) + val a4 = Literal.create(Seq(1, null, 2, 3), ArrayType(IntegerType)) - checkEvaluation(ArrayPosition(a0, Literal(3)), 4L) - checkEvaluation(ArrayPosition(a0, Literal(1)), 1L) - checkEvaluation(ArrayPosition(a0, Literal(0)), 0L) - checkEvaluation(ArrayPosition(a0, Literal.create(null, IntegerType)), null) + checkEvaluation(new ArrayPosition(a0, Literal(3)), 4L) + checkEvaluation(new ArrayPosition(a0, Literal(1)), 1L) + checkEvaluation(new ArrayPosition(a0, Literal(0)), 0L) + checkEvaluation(new ArrayPosition(a0, Literal.create(null, IntegerType)), null) - checkEvaluation(ArrayPosition(a1, Literal("")), 2L) - checkEvaluation(ArrayPosition(a1, Literal("a")), 0L) - checkEvaluation(ArrayPosition(a1, Literal.create(null, StringType)), null) + checkEvaluation(new ArrayPosition(a1, Literal("")), 2L) + checkEvaluation(new ArrayPosition(a1, Literal("a")), 0L) + checkEvaluation(new ArrayPosition(a1, Literal.create(null, StringType)), null) - checkEvaluation(ArrayPosition(a2, Literal(1L)), 0L) - checkEvaluation(ArrayPosition(a2, Literal.create(null, LongType)), null) + checkEvaluation(new ArrayPosition(a2, Literal(1L)), 0L) + checkEvaluation(new ArrayPosition(a2, Literal.create(null, LongType)), null) - checkEvaluation(ArrayPosition(a3, Literal("")), null) - checkEvaluation(ArrayPosition(a3, Literal.create(null, StringType)), null) + checkEvaluation(new ArrayPosition(a3, Literal("")), null) + checkEvaluation(new ArrayPosition(a3, Literal.create(null, StringType)), null) + + checkEvaluation(new ArrayPosition(a4, Literal(""), Literal(1)), null) + checkEvaluation(new ArrayPosition(a4, Literal.create(null, StringType), Literal(1)), null) + checkEvaluation(new ArrayPosition(a4, Literal.create(1, IntegerType), + Literal.create(null, IntegerType)), null) val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)), ArrayType(ArrayType(IntegerType))) val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)), ArrayType(ArrayType(IntegerType))) val aae = Literal.create(Seq[Int](1, 2), ArrayType(IntegerType)) - checkEvaluation(ArrayPosition(aa0, aae), 1L) - checkEvaluation(ArrayPosition(aa1, aae), 0L) + checkEvaluation(new ArrayPosition(aa0, aae), 1L) + checkEvaluation(new ArrayPosition(aa1, aae), 0L) + checkEvaluation(new ArrayPosition(aa0, aae, Literal(1)), 1L) + checkEvaluation(new ArrayPosition(aa0, aae, Literal(2)), 0L) } test("elementAt") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index fc6d3023ed072..41944e01cb5a4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2263,7 +2263,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { test("array position function") { val df = Seq( - (Seq[Int](1, 2), "x", 1), + (Seq[Int](1, 2, 1), "x", 1), (Seq[Int](), "x", 1) ).toDF("a", "b", "c") @@ -2271,10 +2271,18 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { df.select(array_position(df("a"), 1)), Seq(Row(1L), Row(0L)) ) + checkAnswer( + df.select(array_position(df("a"), 1, 2)), + Seq(Row(3L), Row(0L)) + ) checkAnswer( df.selectExpr("array_position(a, 1)"), Seq(Row(1L), Row(0L)) ) + checkAnswer( + df.selectExpr("array_position(a, 1, 2)"), + Seq(Row(3L), Row(0L)) + ) checkAnswer( df.selectExpr("array_position(a, c)"), Seq(Row(1L), Row(0L)) @@ -2302,6 +2310,21 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Seq(Row(1L)) ) + checkAnswer( + OneRowRelation().selectExpr("array_position(array(1, 1), 1.0D)"), + Seq(Row(1L)) + ) + + checkAnswer( + OneRowRelation().selectExpr("array_position(array(1), 1.0D, 2)"), + Seq(Row(0L)) + ) + + checkAnswer( + OneRowRelation().selectExpr("array_position(array(1, 1), 1.0D, 2)"), + Seq(Row(2L)) + ) + checkAnswer( OneRowRelation().selectExpr("array_position(array(1.D), 1)"), Seq(Row(1L)) @@ -2326,6 +2349,22 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { OneRowRelation().selectExpr("array_position(array(array(1), null)[0], 1)"), Seq(Row(1L)) ) + + checkAnswer( + OneRowRelation().selectExpr("array_position(array(array(1), null)[0], 1, array(1)[0])"), + Seq(Row(1L)) + ) + + checkAnswer( + OneRowRelation().selectExpr("array_position(array(array(1, 1), null)[0], 1, 2)"), + Seq(Row(2L)) + ) + + checkAnswer( + OneRowRelation().selectExpr("array_position(array(array(1, 1), null)[0], 1, array(2)[0])"), + Seq(Row(2L)) + ) + checkAnswer( OneRowRelation().selectExpr("array_position(array(1, null), array(1, null)[0])"), Seq(Row(1L)) @@ -2333,56 +2372,71 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { checkError( exception = intercept[AnalysisException] { - Seq((null, "a")).toDF().selectExpr("array_position(_1, _2)") + Seq((null, "a")).toDF().selectExpr("array_position(_1, _2, 1)") }, condition = "DATATYPE_MISMATCH.NULL_TYPE", parameters = Map( - "sqlExpr" -> "\"array_position(_1, _2)\"", + "sqlExpr" -> "\"array_position(_1, _2, 1)\"", "functionName" -> "`array_position`" ), - queryContext = Array(ExpectedContext("", "", 0, 21, "array_position(_1, _2)")) + queryContext = Array(ExpectedContext("", "", 0, 24, "array_position(_1, _2, 1)")) ) checkError( exception = intercept[AnalysisException] { - Seq(("a string element", null)).toDF().selectExpr("array_position(_1, _2)") + Seq(("a string element", null)).toDF().selectExpr("array_position(_1, _2, 1)") }, condition = "DATATYPE_MISMATCH.NULL_TYPE", parameters = Map( - "sqlExpr" -> "\"array_position(_1, _2)\"", + "sqlExpr" -> "\"array_position(_1, _2, 1)\"", "functionName" -> "`array_position`" ), - queryContext = Array(ExpectedContext("", "", 0, 21, "array_position(_1, _2)")) + queryContext = Array(ExpectedContext("", "", 0, 24, "array_position(_1, _2, 1)")) ) checkError( exception = intercept[AnalysisException] { - Seq(("a string element", "a")).toDF().selectExpr("array_position(_1, _2)") + Seq(("a string element", "a")).toDF().selectExpr("array_position(_1, _2, 1)") }, condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( - "sqlExpr" -> "\"array_position(_1, _2)\"", + "sqlExpr" -> "\"array_position(_1, _2, 1)\"", "paramIndex" -> "first", "requiredType" -> "\"ARRAY\"", "inputSql" -> "\"_1\"", "inputType" -> "\"STRING\"" ), - queryContext = Array(ExpectedContext("", "", 0, 21, "array_position(_1, _2)")) + queryContext = Array(ExpectedContext("", "", 0, 24, "array_position(_1, _2, 1)")) ) checkError( exception = intercept[AnalysisException] { - OneRowRelation().selectExpr("array_position(array(1), '1')") + OneRowRelation().selectExpr("array_position(array(1), '1', 1)") }, condition = "DATATYPE_MISMATCH.ARRAY_FUNCTION_DIFF_TYPES", parameters = Map( - "sqlExpr" -> "\"array_position(array(1), 1)\"", + "sqlExpr" -> "\"array_position(array(1), 1, 1)\"", "functionName" -> "`array_position`", "dataType" -> "\"ARRAY\"", - "leftType" -> "\"ARRAY\"", - "rightType" -> "\"STRING\"" + "arrayType" -> "\"ARRAY\"", + "elementType" -> "\"STRING\"" + ), + queryContext = Array(ExpectedContext("", "", 0, 31, "array_position(array(1), '1', 1)")) + ) + + checkError( + exception = intercept[AnalysisException] { + Seq(("a string element", "a")).toDF().selectExpr("array_position(array(1), 1, _1)") + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> "\"array_position(array(1), 1, _1)\"", + "paramIndex" -> "third", + "requiredType" -> "\"INT\"", + "inputSql" -> "\"_1\"", + "inputType" -> "\"STRING\"" ), - queryContext = Array(ExpectedContext("", "", 0, 28, "array_position(array(1), '1')")) + queryContext = Array(ExpectedContext("", "", 0, 30, "array_position(array(1), 1, _1)")) ) }