diff --git a/.github/workflows/publish_snapshot.yml b/.github/workflows/publish_snapshot.yml
index a5854d96a4d1a..6e2e5709bbd18 100644
--- a/.github/workflows/publish_snapshot.yml
+++ b/.github/workflows/publish_snapshot.yml
@@ -28,7 +28,7 @@ on:
description: 'list of branches to publish (JSON)'
required: true
# keep in sync with default value of strategy matrix 'branch'
- default: '["master", "branch-3.5"]'
+ default: '["master", "branch-4.0", "branch-3.5"]'
jobs:
publish-snapshot:
@@ -38,7 +38,7 @@ jobs:
fail-fast: false
matrix:
# keep in sync with default value of workflow_dispatch input 'branch'
- branch: ${{ fromJSON( inputs.branch || '["master", "branch-3.5"]' ) }}
+ branch: ${{ fromJSON( inputs.branch || '["master", "branch-4.0", "branch-3.5"]' ) }}
steps:
- name: Checkout Spark repository
uses: actions/checkout@v4
diff --git a/LICENSE-binary b/LICENSE-binary
index 2892c4b0ecce8..d2eea83525caf 100644
--- a/LICENSE-binary
+++ b/LICENSE-binary
@@ -296,7 +296,6 @@ jakarta.inject:jakarta.inject-api
jakarta.validation:jakarta.validation-api
javax.jdo:jdo-api
joda-time:joda-time
-net.java.dev.jna:jna
net.sf.opencsv:opencsv
net.sf.supercsv:super-csv
net.sf.jpam:jpam
diff --git a/assembly/pom.xml b/assembly/pom.xml
index df0740e6c6949..525f3b2569bb3 100644
--- a/assembly/pom.xml
+++ b/assembly/pom.xml
@@ -323,6 +323,7 @@
hive-provided
provided
+ provided
provided
diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json
index 3bac6638f7060..92da13df5ff13 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -117,6 +117,13 @@
},
"sqlState" : "42604"
},
+ "AVRO_CANNOT_WRITE_NULL_FIELD" : {
+ "message" : [
+ "Cannot write null value for field defined as non-null Avro data type .",
+ "To allow null value for this field, specify its avro schema as a union type with \"null\" using `avroSchema` option."
+ ],
+ "sqlState" : "22004"
+ },
"AVRO_INCOMPATIBLE_READ_TYPE" : {
"message" : [
"Cannot convert Avro to SQL because the original encoded data type is , however you're trying to read the field as , which would lead to an incorrect answer.",
diff --git a/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala b/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala
index c3a1af68d1c82..b53b89ad68c5d 100644
--- a/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala
+++ b/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala
@@ -549,6 +549,7 @@ private[spark] object LogKeys {
case object NUM_ROWS extends LogKey
case object NUM_RULE_OF_RUNS extends LogKey
case object NUM_SEQUENCES extends LogKey
+ case object NUM_SKIPPED extends LogKey
case object NUM_SLOTS extends LogKey
case object NUM_SPILLS extends LogKey
case object NUM_SPILL_WRITERS extends LogKey
diff --git a/common/utils/src/main/scala/org/apache/spark/internal/Logging.scala b/common/utils/src/main/scala/org/apache/spark/internal/Logging.scala
index 4b60cb20f0732..110c5f0934286 100644
--- a/common/utils/src/main/scala/org/apache/spark/internal/Logging.scala
+++ b/common/utils/src/main/scala/org/apache/spark/internal/Logging.scala
@@ -17,6 +17,7 @@
package org.apache.spark.internal
+import scala.concurrent.duration._
import scala.jdk.CollectionConverters._
import org.apache.logging.log4j.{CloseableThreadContext, Level, LogManager}
@@ -27,6 +28,7 @@ import org.apache.logging.log4j.core.filter.AbstractFilter
import org.slf4j.{Logger, LoggerFactory}
import org.apache.spark.internal.Logging.SparkShellLoggingFilter
+import org.apache.spark.internal.LogKeys
import org.apache.spark.util.SparkClassUtils
/**
@@ -531,3 +533,158 @@ private[spark] object Logging {
override def isStopped: Boolean = status == LifeCycle.State.STOPPED
}
}
+
+/**
+ * A thread-safe token bucket-based throttler implementation with nanosecond accuracy.
+ *
+ * Each instance must be shared across all scopes it should throttle.
+ * For global throttling that means either by extending this class in an `object` or
+ * by creating the instance as a field of an `object`.
+ *
+ * @param bucketSize This corresponds to the largest possible burst without throttling,
+ * in number of executions.
+ * @param tokenRecoveryInterval Time between two tokens being added back to the bucket.
+ * This is reciprocal of the long-term average unthrottled rate.
+ *
+ * Example: With a bucket size of 100 and a recovery interval of 1s, we could log up to 100 events
+ * in under a second without throttling, but at that point the bucket is exhausted and we only
+ * regain the ability to log more events at 1 event per second. If we log less than 1 event/s
+ * the bucket will slowly refill until it's back at 100.
+ * Either way, we can always log at least 1 event/s.
+ */
+class LogThrottler(
+ val bucketSize: Int = 100,
+ val tokenRecoveryInterval: FiniteDuration = 1.second,
+ val timeSource: NanoTimeTimeSource = SystemNanoTimeSource) extends Logging {
+
+ private var remainingTokens = bucketSize
+ private var nextRecovery: DeadlineWithTimeSource =
+ DeadlineWithTimeSource.now(timeSource) + tokenRecoveryInterval
+ private var numSkipped: Long = 0
+
+ /**
+ * Run `thunk` as long as there are tokens remaining in the bucket,
+ * otherwise skip and remember number of skips.
+ *
+ * The argument to `thunk` is how many previous invocations have been skipped since the last time
+ * an invocation actually ran.
+ *
+ * Note: This method is `synchronized`, so it is concurrency safe.
+ * However, that also means no heavy-lifting should be done as part of this
+ * if the throttler is shared between concurrent threads.
+ * This also means that the synchronized block of the `thunk` that *does* execute will still
+ * hold up concurrent `thunk`s that will actually get rejected once they hold the lock.
+ * This is fine at low concurrency/low recovery rates. But if we need this to be more efficient at
+ * some point, we will need to decouple the check from the `thunk` execution.
+ */
+ def throttled(thunk: Long => Unit): Unit = this.synchronized {
+ tryRecoverTokens()
+ if (remainingTokens > 0) {
+ thunk(numSkipped)
+ numSkipped = 0
+ remainingTokens -= 1
+ } else {
+ numSkipped += 1L
+ }
+ }
+
+ /**
+ * Same as [[throttled]] but turns the number of skipped invocations into a logging message
+ * that can be appended to item being logged in `thunk`.
+ */
+ def throttledWithSkippedLogMessage(thunk: MessageWithContext => Unit): Unit = {
+ this.throttled { numSkipped =>
+ val skippedStr = if (numSkipped != 0L) {
+ log"[${MDC(LogKeys.NUM_SKIPPED, numSkipped)} similar messages were skipped.]"
+ } else {
+ log""
+ }
+ thunk(skippedStr)
+ }
+ }
+
+ /**
+ * Try to recover tokens, if the rate allows.
+ *
+ * Only call from within a `this.synchronized` block!
+ */
+ private[spark] def tryRecoverTokens(): Unit = {
+ try {
+ // Doing it one-by-one is a bit inefficient for long periods, but it's easy to avoid jumps
+ // and rounding errors this way. The inefficiency shouldn't matter as long as the bucketSize
+ // isn't huge.
+ while (remainingTokens < bucketSize && nextRecovery.isOverdue()) {
+ remainingTokens += 1
+ nextRecovery += tokenRecoveryInterval
+ }
+
+ val currentTime = DeadlineWithTimeSource.now(timeSource)
+ if (remainingTokens == bucketSize &&
+ (currentTime - nextRecovery) > tokenRecoveryInterval) {
+ // Reset the recovery time, so we don't accumulate infinite recovery while nothing is
+ // going on.
+ nextRecovery = currentTime + tokenRecoveryInterval
+ }
+ } catch {
+ case _: IllegalArgumentException =>
+ // Adding FiniteDuration throws IllegalArgumentException instead of wrapping on overflow.
+ // Given that this happens every ~300 years, we can afford some non-linearity here,
+ // rather than taking the effort to properly work around that.
+ nextRecovery = DeadlineWithTimeSource(Duration(-Long.MaxValue, NANOSECONDS), timeSource)
+ }
+ }
+
+ /**
+ * Resets throttler state to initial state.
+ * Visible for testing.
+ */
+ def reset(): Unit = this.synchronized {
+ remainingTokens = bucketSize
+ nextRecovery = DeadlineWithTimeSource.now(timeSource) + tokenRecoveryInterval
+ numSkipped = 0
+ }
+}
+
+/**
+ * This is essentially the same as Scala's [[Deadline]],
+ * just with a custom source of nanoTime so it can actually be tested properly.
+ */
+case class DeadlineWithTimeSource(
+ time: FiniteDuration,
+ timeSource: NanoTimeTimeSource = SystemNanoTimeSource) {
+ // Only implemented the methods LogThrottler actually needs for now.
+
+ /**
+ * Return a deadline advanced (i.e., moved into the future) by the given duration.
+ */
+ def +(other: FiniteDuration): DeadlineWithTimeSource = copy(time = time + other)
+
+ /**
+ * Calculate time difference between this and the other deadline, where the result is directed
+ * (i.e., may be negative).
+ */
+ def -(other: DeadlineWithTimeSource): FiniteDuration = time - other.time
+
+ /**
+ * Determine whether the deadline lies in the past at the point where this method is called.
+ */
+ def isOverdue(): Boolean = (time.toNanos - timeSource.nanoTime()) <= 0
+}
+
+object DeadlineWithTimeSource {
+ /**
+ * Construct a deadline due exactly at the point where this method is called. Useful for then
+ * advancing it to obtain a future deadline, or for sampling the current time exactly once and
+ * then comparing it to multiple deadlines (using subtraction).
+ */
+ def now(timeSource: NanoTimeTimeSource = SystemNanoTimeSource): DeadlineWithTimeSource =
+ DeadlineWithTimeSource(Duration(timeSource.nanoTime(), NANOSECONDS), timeSource)
+}
+
+/** Generalisation of [[System.nanoTime()]]. */
+private[spark] trait NanoTimeTimeSource {
+ def nanoTime(): Long
+}
+private[spark] object SystemNanoTimeSource extends NanoTimeTimeSource {
+ override def nanoTime(): Long = System.nanoTime()
+}
diff --git a/common/utils/src/test/scala/org/apache/spark/util/LogThrottlingSuite.scala b/common/utils/src/test/scala/org/apache/spark/util/LogThrottlingSuite.scala
new file mode 100644
index 0000000000000..47d7f85093a38
--- /dev/null
+++ b/common/utils/src/test/scala/org/apache/spark/util/LogThrottlingSuite.scala
@@ -0,0 +1,234 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import scala.concurrent.duration._
+
+import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite
+
+import org.apache.spark.internal.{DeadlineWithTimeSource, Logging, LogThrottler, NanoTimeTimeSource}
+
+
+class LogThrottlingSuite
+ extends AnyFunSuite // scalastyle:ignore funsuite
+ with Logging {
+
+ // Make sure that the helper works right.
+ test("time control") {
+ val nanoTimeControl = new MockedNanoTime
+ assert(nanoTimeControl.nanoTime() === 0L)
+ assert(nanoTimeControl.nanoTime() === 0L)
+
+ nanoTimeControl.advance(112L.nanos)
+ assert(nanoTimeControl.nanoTime() === 112L)
+ assert(nanoTimeControl.nanoTime() === 112L)
+ }
+
+ test("deadline with time control") {
+ val nanoTimeControl = new MockedNanoTime
+ assert(DeadlineWithTimeSource.now(nanoTimeControl).isOverdue())
+ val deadline = DeadlineWithTimeSource.now(nanoTimeControl) + 5.nanos
+ assert(!deadline.isOverdue())
+ nanoTimeControl.advance(5.nanos)
+ assert(deadline.isOverdue())
+ nanoTimeControl.advance(5.nanos)
+ assert(deadline.isOverdue())
+ // Check addition.
+ assert(deadline + 0.nanos === deadline)
+ val increasedDeadline = deadline + 10.nanos
+ assert(!increasedDeadline.isOverdue())
+ nanoTimeControl.advance(5.nanos)
+ assert(increasedDeadline.isOverdue())
+ // Ensure that wrapping keeps throwing this exact exception, since we rely on it in
+ // LogThrottler.tryRecoverTokens
+ assertThrows[IllegalArgumentException] {
+ deadline + Long.MaxValue.nanos
+ }
+ // Check difference and ordering.
+ assert(deadline - deadline === 0.nanos)
+ assert(increasedDeadline - deadline === 10.nanos)
+ assert(increasedDeadline - deadline > 9.nanos)
+ assert(increasedDeadline - deadline < 11.nanos)
+ assert(deadline - increasedDeadline === -10.nanos)
+ assert(deadline - increasedDeadline < -9.nanos)
+ assert(deadline - increasedDeadline > -11.nanos)
+ }
+
+ test("unthrottled, no burst") {
+ val nanotTimeControl = new MockedNanoTime
+ val throttler = new LogThrottler(
+ bucketSize = 1,
+ tokenRecoveryInterval = 5.nanos,
+ timeSource = nanotTimeControl)
+ val numInvocations = 100
+ var timesExecuted = 0
+ for (i <- 0 until numInvocations) {
+ throttler.throttled { skipped =>
+ assert(skipped === 0L)
+ timesExecuted += 1
+ }
+ nanotTimeControl.advance(5.nanos)
+ }
+ assert(timesExecuted === numInvocations)
+ }
+
+ test("unthrottled, burst") {
+ val nanotTimeControl = new MockedNanoTime
+ val throttler = new LogThrottler(
+ bucketSize = 100,
+ tokenRecoveryInterval = 1000000.nanos, // Just to make it obvious that it's a large number.
+ timeSource = nanotTimeControl)
+ val numInvocations = 100
+ var timesExecuted = 0
+ for (_ <- 0 until numInvocations) {
+ throttler.throttled { skipped =>
+ assert(skipped === 0L)
+ timesExecuted += 1
+ }
+ nanotTimeControl.advance(5.nanos)
+ }
+ assert(timesExecuted === numInvocations)
+ }
+
+ test("throttled, no burst") {
+ val nanoTimeControl = new MockedNanoTime
+ val throttler = new LogThrottler(
+ bucketSize = 1,
+ tokenRecoveryInterval = 5.nanos,
+ timeSource = nanoTimeControl)
+ val numInvocations = 100
+ var timesExecuted = 0
+ for (i <- 0 until numInvocations) {
+ throttler.throttled { skipped =>
+ if (timesExecuted == 0) {
+ assert(skipped === 0L)
+ } else {
+ assert(skipped === 4L)
+ }
+ timesExecuted += 1
+ }
+ nanoTimeControl.advance(1.nanos)
+ }
+ assert(timesExecuted === numInvocations / 5)
+ }
+
+ test("throttled, single burst") {
+ val nanoTimeControl = new MockedNanoTime
+ val throttler = new LogThrottler(
+ bucketSize = 5,
+ tokenRecoveryInterval = 10.nanos,
+ timeSource = nanoTimeControl)
+ val numInvocations = 100
+ var timesExecuted = 0
+ for (i <- 0 until numInvocations) {
+ throttler.throttled { skipped =>
+ if (i < 5) {
+ // First burst
+ assert(skipped === 0L)
+ } else if (i == 10) {
+ // First token recovery
+ assert(skipped === 5L)
+ } else {
+ // All other token recoveries
+ assert(skipped === 9L)
+ }
+ timesExecuted += 1
+ }
+ nanoTimeControl.advance(1.nano)
+ }
+ // A burst of 5 and then 1 every 10ns/invocations.
+ assert(timesExecuted === 5 + (numInvocations - 10) / 10)
+ }
+
+ test("throttled, bursty") {
+ val nanoTimeControl = new MockedNanoTime
+ val throttler = new LogThrottler(
+ bucketSize = 5,
+ tokenRecoveryInterval = 10.nanos,
+ timeSource = nanoTimeControl)
+ val numBursts = 10
+ val numInvocationsPerBurst = 10
+ var timesExecuted = 0
+ for (burst <- 0 until numBursts) {
+ for (i <- 0 until numInvocationsPerBurst) {
+ throttler.throttled { skipped =>
+ if (i == 0 && burst != 0) {
+ // first after recovery
+ assert(skipped === 5L)
+ } else {
+ // either first burst, or post-recovery on every other burst.
+ assert(skipped === 0L)
+ }
+ timesExecuted += 1
+ }
+ nanoTimeControl.advance(1.nano)
+ }
+ nanoTimeControl.advance(100.nanos)
+ }
+ // Bursts of 5.
+ assert(timesExecuted === 5 * numBursts)
+ }
+
+ test("wraparound") {
+ val nanoTimeControl = new MockedNanoTime
+ val throttler = new LogThrottler(
+ bucketSize = 1,
+ tokenRecoveryInterval = 100.nanos,
+ timeSource = nanoTimeControl)
+ def executeThrottled(expectedSkipped: Long = 0L): Boolean = {
+ var executed = false
+ throttler.throttled { skipped =>
+ assert(skipped === expectedSkipped)
+ executed = true
+ }
+ executed
+ }
+ assert(executeThrottled())
+ assert(!executeThrottled())
+
+ // Move to 2 ns before wrapping.
+ nanoTimeControl.advance((Long.MaxValue - 1L).nanos)
+ assert(executeThrottled(expectedSkipped = 1L))
+ assert(!executeThrottled())
+
+ nanoTimeControl.advance(1.nano)
+ assert(!executeThrottled())
+
+ // Wrapping
+ nanoTimeControl.advance(1.nano)
+ assert(!executeThrottled())
+
+ // Recover
+ nanoTimeControl.advance(100.nanos)
+ assert(executeThrottled(expectedSkipped = 3L))
+ }
+}
+
+/**
+ * Use a mocked object to replace calls to `System.nanoTime()` with a custom value that can be
+ * controlled by calling `advance(nanos)` on an instance of this class.
+ */
+class MockedNanoTime extends NanoTimeTimeSource {
+ private var currentTimeNs: Long = 0L
+
+ override def nanoTime(): Long = currentTimeNs
+
+ def advance(time: FiniteDuration): Unit = {
+ currentTimeNs += time.toNanos
+ }
+}
diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
index 8342ca4e84275..4ddf6503d99ec 100644
--- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
+++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
@@ -25,7 +25,7 @@ import java.util.UUID
import scala.jdk.CollectionConverters._
-import org.apache.avro.{AvroTypeException, Schema, SchemaBuilder, SchemaFormatter}
+import org.apache.avro.{Schema, SchemaBuilder, SchemaFormatter}
import org.apache.avro.Schema.{Field, Type}
import org.apache.avro.Schema.Type._
import org.apache.avro.file.{DataFileReader, DataFileWriter}
@@ -33,7 +33,7 @@ import org.apache.avro.generic.{GenericData, GenericDatumReader, GenericDatumWri
import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed}
import org.apache.commons.io.FileUtils
-import org.apache.spark.{SPARK_VERSION_SHORT, SparkConf, SparkException, SparkUpgradeException}
+import org.apache.spark.{SPARK_VERSION_SHORT, SparkConf, SparkException, SparkRuntimeException, SparkThrowable, SparkUpgradeException}
import org.apache.spark.TestUtils.assertExceptionMsg
import org.apache.spark.sql._
import org.apache.spark.sql.TestingUDT.IntervalData
@@ -45,7 +45,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone
import org.apache.spark.sql.execution.{FormattedMode, SparkPlan}
import org.apache.spark.sql.execution.datasources.{CommonFileDataSourceSuite, DataSource, FilePartition}
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
-import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.LegacyBehaviorPolicy
import org.apache.spark.sql.internal.LegacyBehaviorPolicy._
import org.apache.spark.sql.internal.SQLConf
@@ -100,6 +100,14 @@ abstract class AvroSuite
SchemaFormatter.format(AvroUtils.JSON_INLINE_FORMAT, schema)
}
+ private def getRootCause(ex: Throwable): Throwable = {
+ var rootCause = ex
+ while (rootCause.getCause != null) {
+ rootCause = rootCause.getCause
+ }
+ rootCause
+ }
+
// Check whether an Avro schema of union type is converted to SQL in an expected way, when the
// stable ID option is on.
//
@@ -1317,7 +1325,16 @@ abstract class AvroSuite
dfWithNull.write.format("avro")
.option("avroSchema", avroSchema).save(s"$tempDir/${UUID.randomUUID()}")
}
- assertExceptionMsg[AvroTypeException](e1, "value null is not a SuitEnumType")
+
+ val expectedDatatype = "{\"type\":\"enum\",\"name\":\"SuitEnumType\"," +
+ "\"symbols\":[\"SPADES\",\"HEARTS\",\"DIAMONDS\",\"CLUBS\"]}"
+
+ checkError(
+ getRootCause(e1).asInstanceOf[SparkThrowable],
+ condition = "AVRO_CANNOT_WRITE_NULL_FIELD",
+ parameters = Map(
+ "name" -> "`Suit`",
+ "dataType" -> expectedDatatype))
// Writing df containing data not in the enum will throw an exception
val e2 = intercept[SparkException] {
@@ -1332,6 +1349,50 @@ abstract class AvroSuite
}
}
+ test("to_avro nested struct schema nullability mismatch") {
+ Seq((true, false), (false, true)).foreach {
+ case (innerNull, outerNull) =>
+ val innerSchema = StructType(Seq(StructField("field1", IntegerType, innerNull)))
+ val outerSchema = StructType(Seq(StructField("innerStruct", innerSchema, outerNull)))
+ val nestedSchema = StructType(Seq(StructField("outerStruct", outerSchema, false)))
+
+ val rowWithNull = if (innerNull) Row(Row(null)) else Row(null)
+ val data = Seq(Row(Row(Row(1))), Row(rowWithNull), Row(Row(Row(3))))
+ val df = spark.createDataFrame(spark.sparkContext.parallelize(data), nestedSchema)
+
+ val avroTypeStruct = s"""{
+ | "type": "record",
+ | "name": "outerStruct",
+ | "fields": [
+ | {
+ | "name": "innerStruct",
+ | "type": {
+ | "type": "record",
+ | "name": "innerStruct",
+ | "fields": [
+ | {"name": "field1", "type": "int"}
+ | ]
+ | }
+ | }
+ | ]
+ |}
+ """.stripMargin // nullability mismatch for innerStruct
+
+ val expectedErrorName = if (outerNull) "`innerStruct`" else "`field1`"
+ val expectedErrorSchema = if (outerNull) "{\"type\":\"record\",\"name\":\"innerStruct\"" +
+ ",\"fields\":[{\"name\":\"field1\",\"type\":\"int\"}]}" else "\"int\""
+
+ checkError(
+ exception = intercept[SparkRuntimeException] {
+ df.select(avro.functions.to_avro($"outerStruct", avroTypeStruct)).collect()
+ },
+ condition = "AVRO_CANNOT_WRITE_NULL_FIELD",
+ parameters = Map(
+ "name" -> expectedErrorName,
+ "dataType" -> expectedErrorSchema))
+ }
+ }
+
test("support user provided avro schema for writing nullable fixed type") {
withTempPath { tempDir =>
val avroSchema =
@@ -1517,9 +1578,12 @@ abstract class AvroSuite
.save(s"$tempDir/${UUID.randomUUID()}")
}
assert(ex.getCondition == "TASK_WRITE_FAILED")
- assert(ex.getCause.isInstanceOf[java.lang.NullPointerException])
- assert(ex.getCause.getMessage.contains(
- "null value for (non-nullable) string at test_schema.Name"))
+ checkError(
+ ex.getCause.asInstanceOf[SparkThrowable],
+ condition = "AVRO_CANNOT_WRITE_NULL_FIELD",
+ parameters = Map(
+ "name" -> "`Name`",
+ "dataType" -> "\"string\""))
}
}
diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala
index fd7efb1efb764..04637c1b55631 100644
--- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala
+++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala
@@ -171,7 +171,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD
// scalastyle:off
assert(getExternalEngineQuery(df.queryExecution.executedPlan) ==
- """SELECT "dept","name","salary","bonus" FROM "employee" WHERE (CASE WHEN ("name" = 'Legolas') THEN IIF(("name" = 'Elf'), 1, 0) ELSE IIF(("name" <> 'Wizard'), 1, 0) END = 1) """
+ """SELECT "dept","name","salary","bonus" FROM "employee" WHERE (CASE WHEN ("name" = 'Legolas') THEN IIF(("name" = 'Elf'), 1, 0) ELSE IIF(("name" <> 'Wizard'), 1, 0) END = 1) """
)
// scalastyle:on
df.collect()
@@ -186,7 +186,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD
// scalastyle:off
assert(getExternalEngineQuery(df.queryExecution.executedPlan) ==
- """SELECT "dept","name","salary","bonus" FROM "employee" WHERE (CASE WHEN ("name" = 'Legolas') THEN IIF(("name" = 'Elf'), 1, 0) ELSE 1 END = 1) """
+ """SELECT "dept","name","salary","bonus" FROM "employee" WHERE (CASE WHEN ("name" = 'Legolas') THEN IIF(("name" = 'Elf'), 1, 0) ELSE 1 END = 1) """
)
// scalastyle:on
df.collect()
@@ -203,7 +203,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD
// scalastyle:off
assert(getExternalEngineQuery(df.queryExecution.executedPlan) ==
- """SELECT "dept","name","salary","bonus" FROM "employee" WHERE (CASE WHEN ("name" = 'Legolas') THEN IIF((CASE WHEN ("name" = 'Elf') THEN IIF(("name" = 'Elrond'), 1, 0) ELSE IIF(("name" = 'Gandalf'), 1, 0) END = 1), 1, 0) ELSE IIF(("name" = 'Sauron'), 1, 0) END = 1) """
+ """SELECT "dept","name","salary","bonus" FROM "employee" WHERE (CASE WHEN ("name" = 'Legolas') THEN IIF((CASE WHEN ("name" = 'Elf') THEN IIF(("name" = 'Elrond'), 1, 0) ELSE IIF(("name" = 'Gandalf'), 1, 0) END = 1), 1, 0) ELSE IIF(("name" = 'Sauron'), 1, 0) END = 1) """
)
// scalastyle:on
df.collect()
@@ -220,7 +220,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD
// scalastyle:off
assert(getExternalEngineQuery(df.queryExecution.executedPlan) ==
- """SELECT "dept","name","salary","bonus" FROM "employee" WHERE ("name" IS NOT NULL) AND ((CASE WHEN "name" = 'Legolas' THEN CASE WHEN "name" = 'Elf' THEN 'Elf' ELSE 'Wizard' END ELSE 'Sauron' END) = "name") """
+ """SELECT "dept","name","salary","bonus" FROM "employee" WHERE ("name" IS NOT NULL) AND ((CASE WHEN "name" = 'Legolas' THEN CASE WHEN "name" = 'Elf' THEN 'Elf' ELSE 'Wizard' END ELSE 'Sauron' END) = "name") """
)
// scalastyle:on
df.collect()
diff --git a/core/src/test/scala/org/apache/spark/LocalRootDirsTest.scala b/core/src/test/scala/org/apache/spark/LocalRootDirsTest.scala
index 3a813f4d8b53c..a7968b6f2b022 100644
--- a/core/src/test/scala/org/apache/spark/LocalRootDirsTest.scala
+++ b/core/src/test/scala/org/apache/spark/LocalRootDirsTest.scala
@@ -20,7 +20,7 @@ package org.apache.spark
import java.io.File
import java.util.UUID
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{ShutdownHookManager, Utils}
trait LocalRootDirsTest extends SparkFunSuite with LocalSparkContext {
@@ -46,7 +46,13 @@ trait LocalRootDirsTest extends SparkFunSuite with LocalSparkContext {
override def afterEach(): Unit = {
try {
- Utils.deleteRecursively(tempDir)
+ // SPARK-51030: Only perform manual cleanup of the `tempDir` if it has
+ // not been registered for cleanup via a shutdown hook, to avoid potential
+ // IOException due to race conditions during multithreaded cleanup.
+ if (!ShutdownHookManager.hasShutdownDeleteDir(tempDir) &&
+ !ShutdownHookManager.hasRootAsShutdownDeleteDir(tempDir)) {
+ Utils.deleteRecursively(tempDir)
+ }
Utils.clearLocalRootDirs()
} finally {
super.afterEach()
diff --git a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala
index 27a53e8205201..e3171116a3e14 100644
--- a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala
@@ -112,7 +112,7 @@ class CryptoStreamUtilsSuite extends SparkFunSuite {
bytes.toByteArray()
}.collect()(0)
- assert(content != encrypted)
+ assert(!content.getBytes(UTF_8).sameElements(encrypted))
val in = CryptoStreamUtils.createCryptoInputStream(new ByteArrayInputStream(encrypted),
sc.conf, SparkEnv.get.securityManager.getIOEncryptionKey().get)
diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala
deleted file mode 100644
index c76c97d071418..0000000000000
--- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala
+++ /dev/null
@@ -1,162 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.ui
-
-import scala.xml.Node
-
-import jakarta.servlet.http.HttpServletRequest
-import org.mockito.Mockito.{mock, when, RETURNS_SMART_NULLS}
-
-import org.apache.spark._
-import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics}
-import org.apache.spark.internal.config.Status._
-import org.apache.spark.resource.ResourceProfile
-import org.apache.spark.scheduler._
-import org.apache.spark.status.AppStatusStore
-import org.apache.spark.status.api.v1.{AccumulableInfo => UIAccumulableInfo, StageData, StageStatus}
-import org.apache.spark.ui.jobs.{StagePage, StagesTab}
-
-class StagePageSuite extends SparkFunSuite with LocalSparkContext {
-
- private val peakExecutionMemory = 10
-
- test("ApiHelper.COLUMN_TO_INDEX should match headers of the task table") {
- val conf = new SparkConf(false).set(LIVE_ENTITY_UPDATE_PERIOD, 0L)
- val statusStore = AppStatusStore.createLiveStore(conf)
- try {
- val stageData = new StageData(
- status = StageStatus.ACTIVE,
- stageId = 1,
- attemptId = 1,
- numTasks = 1,
- numActiveTasks = 1,
- numCompleteTasks = 1,
- numFailedTasks = 1,
- numKilledTasks = 1,
- numCompletedIndices = 1,
-
- submissionTime = None,
- firstTaskLaunchedTime = None,
- completionTime = None,
- failureReason = None,
-
- executorDeserializeTime = 1L,
- executorDeserializeCpuTime = 1L,
- executorRunTime = 1L,
- executorCpuTime = 1L,
- resultSize = 1L,
- jvmGcTime = 1L,
- resultSerializationTime = 1L,
- memoryBytesSpilled = 1L,
- diskBytesSpilled = 1L,
- peakExecutionMemory = 1L,
- inputBytes = 1L,
- inputRecords = 1L,
- outputBytes = 1L,
- outputRecords = 1L,
- shuffleRemoteBlocksFetched = 1L,
- shuffleLocalBlocksFetched = 1L,
- shuffleFetchWaitTime = 1L,
- shuffleRemoteBytesRead = 1L,
- shuffleRemoteBytesReadToDisk = 1L,
- shuffleLocalBytesRead = 1L,
- shuffleReadBytes = 1L,
- shuffleReadRecords = 1L,
- shuffleCorruptMergedBlockChunks = 1L,
- shuffleMergedFetchFallbackCount = 1L,
- shuffleMergedRemoteBlocksFetched = 1L,
- shuffleMergedLocalBlocksFetched = 1L,
- shuffleMergedRemoteChunksFetched = 1L,
- shuffleMergedLocalChunksFetched = 1L,
- shuffleMergedRemoteBytesRead = 1L,
- shuffleMergedLocalBytesRead = 1L,
- shuffleRemoteReqsDuration = 1L,
- shuffleMergedRemoteReqsDuration = 1L,
- shuffleWriteBytes = 1L,
- shuffleWriteTime = 1L,
- shuffleWriteRecords = 1L,
-
- name = "stage1",
- description = Some("description"),
- details = "detail",
- schedulingPool = "pool1",
-
- rddIds = Seq(1),
- accumulatorUpdates = Seq(new UIAccumulableInfo(0L, "acc", None, "value")),
- tasks = None,
- executorSummary = None,
- speculationSummary = None,
- killedTasksSummary = Map.empty,
- ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID,
- peakExecutorMetrics = None,
- taskMetricsDistributions = None,
- executorMetricsDistributions = None,
- isShufflePushEnabled = false,
- shuffleMergersCount = 0
- )
- } finally {
- statusStore.close()
- }
- }
-
- /**
- * Render a stage page started with the given conf and return the HTML.
- * This also runs a dummy stage to populate the page with useful content.
- */
- private def renderStagePage(): Seq[Node] = {
- val conf = new SparkConf(false).set(LIVE_ENTITY_UPDATE_PERIOD, 0L)
- val statusStore = AppStatusStore.createLiveStore(conf)
- val listener = statusStore.listener.get
-
- try {
- val tab = mock(classOf[StagesTab], RETURNS_SMART_NULLS)
- when(tab.store).thenReturn(statusStore)
-
- val request = mock(classOf[HttpServletRequest])
- when(tab.conf).thenReturn(conf)
- when(tab.appName).thenReturn("testing")
- when(tab.headerTabs).thenReturn(Seq.empty)
- when(request.getParameter("id")).thenReturn("0")
- when(request.getParameter("attempt")).thenReturn("0")
- val page = new StagePage(tab, statusStore)
-
- // Simulate a stage in job progress listener
- val stageInfo = new StageInfo(0, 0, "dummy", 1, Seq.empty, Seq.empty, "details",
- resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID)
- // Simulate two tasks to test PEAK_EXECUTION_MEMORY correctness
- (1 to 2).foreach {
- taskId =>
- val taskInfo = new TaskInfo(taskId, taskId, 0, taskId, 0,
- "0", "localhost", TaskLocality.ANY, false)
- listener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo))
- listener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo))
- taskInfo.markFinished(TaskState.FINISHED, System.currentTimeMillis())
- val taskMetrics = TaskMetrics.empty
- val executorMetrics = new ExecutorMetrics
- taskMetrics.incPeakExecutionMemory(peakExecutionMemory)
- listener.onTaskEnd(SparkListenerTaskEnd(0, 0, "result", Success, taskInfo,
- executorMetrics, taskMetrics))
- }
- listener.onStageCompleted(SparkListenerStageCompleted(stageInfo))
- page.render(request)
- } finally {
- statusStore.close()
- }
- }
-
-}
diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3
index 96d5f9d477143..c4228bc543ebb 100644
--- a/dev/deps/spark-deps-hadoop-3-hive-2.3
+++ b/dev/deps/spark-deps-hadoop-3-hive-2.3
@@ -86,7 +86,6 @@ hive-cli/2.3.10//hive-cli-2.3.10.jar
hive-common/2.3.10//hive-common-2.3.10.jar
hive-exec/2.3.10/core/hive-exec-2.3.10-core.jar
hive-jdbc/2.3.10//hive-jdbc-2.3.10.jar
-hive-llap-common/2.3.10//hive-llap-common-2.3.10.jar
hive-metastore/2.3.10//hive-metastore-2.3.10.jar
hive-serde/2.3.10//hive-serde-2.3.10.jar
hive-service-rpc/4.0.0//hive-service-rpc-4.0.0.jar
@@ -121,7 +120,7 @@ jakarta.validation-api/3.0.2//jakarta.validation-api-3.0.2.jar
jakarta.ws.rs-api/3.0.0//jakarta.ws.rs-api-3.0.0.jar
jakarta.xml.bind-api/2.3.2//jakarta.xml.bind-api-2.3.2.jar
janino/3.1.9//janino-3.1.9.jar
-java-diff-utils/4.12//java-diff-utils-4.12.jar
+java-diff-utils/4.15//java-diff-utils-4.15.jar
java-xmlbuilder/1.2//java-xmlbuilder-1.2.jar
javassist/3.30.2-GA//javassist-3.30.2-GA.jar
javax.jdo/3.2.0-m3//javax.jdo-3.2.0-m3.jar
@@ -145,8 +144,7 @@ jjwt-api/0.12.6//jjwt-api-0.12.6.jar
jjwt-impl/0.12.6//jjwt-impl-0.12.6.jar
jjwt-jackson/0.12.6//jjwt-jackson-0.12.6.jar
jline/2.14.6//jline-2.14.6.jar
-jline/3.26.3//jline-3.26.3.jar
-jna/5.14.0//jna-5.14.0.jar
+jline/3.27.1/jdk8/jline-3.27.1-jdk8.jar
joda-time/2.13.0//joda-time-2.13.0.jar
jodd-core/3.5.2//jodd-core-3.5.2.jar
jpam/1.1//jpam-1.1.jar
@@ -254,11 +252,11 @@ py4j/0.10.9.9//py4j-0.10.9.9.jar
remotetea-oncrpc/1.1.2//remotetea-oncrpc-1.1.2.jar
rocksdbjni/9.8.4//rocksdbjni-9.8.4.jar
scala-collection-compat_2.13/2.7.0//scala-collection-compat_2.13-2.7.0.jar
-scala-compiler/2.13.15//scala-compiler-2.13.15.jar
-scala-library/2.13.15//scala-library-2.13.15.jar
+scala-compiler/2.13.16//scala-compiler-2.13.16.jar
+scala-library/2.13.16//scala-library-2.13.16.jar
scala-parallel-collections_2.13/1.2.0//scala-parallel-collections_2.13-1.2.0.jar
scala-parser-combinators_2.13/2.4.0//scala-parser-combinators_2.13-2.4.0.jar
-scala-reflect/2.13.15//scala-reflect-2.13.15.jar
+scala-reflect/2.13.16//scala-reflect-2.13.16.jar
scala-xml_2.13/2.3.0//scala-xml_2.13-2.3.0.jar
slf4j-api/2.0.16//slf4j-api-2.0.16.jar
snakeyaml-engine/2.8//snakeyaml-engine-2.8.jar
diff --git a/docs/_config.yml b/docs/_config.yml
index 67c4237392acd..7a28a7c76099d 100644
--- a/docs/_config.yml
+++ b/docs/_config.yml
@@ -22,7 +22,7 @@ include:
SPARK_VERSION: 4.1.0-SNAPSHOT
SPARK_VERSION_SHORT: 4.1.0
SCALA_BINARY_VERSION: "2.13"
-SCALA_VERSION: "2.13.15"
+SCALA_VERSION: "2.13.16"
SPARK_ISSUE_TRACKER_URL: https://issues.apache.org/jira/browse/SPARK
SPARK_GITHUB_URL: https://github.com/apache/spark
# Before a new release, we should:
diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md
index 254c54a414a7e..f459a88d8e148 100644
--- a/docs/sql-migration-guide.md
+++ b/docs/sql-migration-guide.md
@@ -31,6 +31,7 @@ license: |
- Since Spark 4.0, any read of SQL tables takes into consideration the SQL configs `spark.sql.files.ignoreCorruptFiles`/`spark.sql.files.ignoreMissingFiles` instead of the core config `spark.files.ignoreCorruptFiles`/`spark.files.ignoreMissingFiles`.
- Since Spark 4.0, when reading SQL tables hits `org.apache.hadoop.security.AccessControlException` and `org.apache.hadoop.hdfs.BlockMissingException`, the exception will be thrown and fail the task, even if `spark.sql.files.ignoreCorruptFiles` is set to `true`.
- Since Spark 4.0, `spark.sql.hive.metastore` drops the support of Hive prior to 2.0.0 as they require JDK 8 that Spark does not support anymore. Users should migrate to higher versions.
+- Since Spark 4.0, Spark removes `hive-llap-common` dependency. To restore the previous behavior, add `hive-llap-common` jar to the class path.
- Since Spark 4.0, `spark.sql.parquet.compression.codec` drops the support of codec name `lz4raw`, please use `lz4_raw` instead.
- Since Spark 4.0, when overflowing during casting timestamp to byte/short/int under non-ansi mode, Spark will return null instead a wrapping value.
- Since Spark 4.0, the `encode()` and `decode()` functions support only the following charsets 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16', 'UTF-32'. To restore the previous behavior when the function accepts charsets of the current JDK used by Spark, set `spark.sql.legacy.javaCharsets` to `true`.
diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md
index 45b8e9a4dcea7..7c7e2a6909574 100644
--- a/docs/sql-ref-ansi-compliance.md
+++ b/docs/sql-ref-ansi-compliance.md
@@ -503,6 +503,7 @@ Below is a list of all the keywords in Spark SQL.
|DOUBLE|non-reserved|non-reserved|reserved|
|DROP|non-reserved|non-reserved|reserved|
|ELSE|reserved|non-reserved|reserved|
+|ELSEIF|non-reserved|non-reserved|non-reserved|
|END|reserved|non-reserved|reserved|
|ESCAPE|reserved|non-reserved|reserved|
|ESCAPED|non-reserved|non-reserved|non-reserved|
diff --git a/pom.xml b/pom.xml
index d83d9ceaaaec5..94e2b8650bac5 100644
--- a/pom.xml
+++ b/pom.xml
@@ -169,7 +169,7 @@
3.2.2
4.4
- 2.13.15
+ 2.13.16
2.13
2.2.0
4.9.2
@@ -230,7 +230,7 @@
and ./python/packaging/connect/setup.py too.
-->
18.1.0
- 3.0.0
+ 3.0.1
0.12.6
@@ -274,7 +274,7 @@
compile
compile
compile
- compile
+ test
compile
compile
compile
@@ -335,7 +335,7 @@
12.8.1.jre11
23.6.0.24.10
2.7.1
- 3.21.0
+ 3.22.0
20.00.00.39
${project.version}
@@ -2255,7 +2255,7 @@
${hive.group}
hive-llap-common
${hive.version}
- ${hive.deps.scope}
+ ${hive.llap.scope}
${hive.group}
@@ -2657,14 +2657,6 @@
${java.version}
test
provided
-
-
- org.jline.terminal.impl.ffm.*
-
diff --git a/python/pyspark/sql/tests/test_python_datasource.py b/python/pyspark/sql/tests/test_python_datasource.py
index 9b132f5d693fd..34299bdb7740c 100644
--- a/python/pyspark/sql/tests/test_python_datasource.py
+++ b/python/pyspark/sql/tests/test_python_datasource.py
@@ -15,6 +15,7 @@
# limitations under the License.
#
import os
+import platform
import tempfile
import unittest
from typing import Callable, Union
@@ -508,6 +509,9 @@ def write(self, iterator):
):
df.write.format("test").mode("append").saveAsTable("test_table")
+ @unittest.skipIf(
+ "pypy" in platform.python_implementation().lower(), "cannot run in environment pypy"
+ )
def test_data_source_segfault(self):
import ctypes
diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py
index ef029adfb476b..95a81236fbf52 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -15,6 +15,7 @@
# limitations under the License.
#
import os
+import platform
import shutil
import tempfile
import unittest
@@ -2761,6 +2762,9 @@ def eval(self, n):
res = self.spark.sql("select i, to_json(v['v1']) from test_udtf_struct(8)")
assertDataFrameEqual(res, [Row(i=n, s=f'{{"a":"{chr(99 + n)}"}}') for n in range(8)])
+ @unittest.skipIf(
+ "pypy" in platform.python_implementation().lower(), "cannot run in environment pypy"
+ )
def test_udtf_segfault(self):
for enabled, expected in [
(True, "Segmentation fault"),
diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
index dafeed48aef11..360854d81e384 100644
--- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
+++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
@@ -218,6 +218,7 @@ DO: 'DO';
DOUBLE: 'DOUBLE';
DROP: 'DROP';
ELSE: 'ELSE';
+ELSEIF: 'ELSEIF';
END: 'END';
ESCAPE: 'ESCAPE';
ESCAPED: 'ESCAPED';
diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
index 8e6edac2108a6..9b438a667a3e5 100644
--- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
+++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
@@ -85,7 +85,7 @@ whileStatement
ifElseStatement
: IF booleanExpression THEN conditionalBodies+=compoundBody
- (ELSE IF booleanExpression THEN conditionalBodies+=compoundBody)*
+ (ELSEIF booleanExpression THEN conditionalBodies+=compoundBody)*
(ELSE elseBody=compoundBody)? END IF
;
@@ -1642,6 +1642,7 @@ ansiNonReserved
| DO
| DOUBLE
| DROP
+ | ELSEIF
| ESCAPED
| EVOLUTION
| EXCHANGE
@@ -1987,6 +1988,7 @@ nonReserved
| DOUBLE
| DROP
| ELSE
+ | ELSEIF
| END
| ESCAPE
| ESCAPED
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala
index 9fbfb9e679e58..23de9c222724b 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala
@@ -127,11 +127,13 @@ object CurrentOrigin {
}
}
- private val sparkCodePattern = Pattern.compile("(org\\.apache\\.spark\\.sql\\." +
- "(?:(classic|connect)\\.)?" +
- "(?:functions|Column|ColumnName|SQLImplicits|Dataset|DataFrameStatFunctions|DatasetHolder)" +
- "(?:|\\..*|\\$.*))" +
- "|(scala\\.collection\\..*)")
+ private val sparkCodePattern = Pattern.compile(
+ "(org\\.apache\\.spark\\.sql\\." +
+ "(?:(classic|connect)\\.)?" +
+ "(?:functions|Column|ColumnName|SQLImplicits|Dataset|DataFrameStatFunctions|DatasetHolder" +
+ "|SparkSession|ColumnNodeToProtoConverter)" +
+ "(?:|\\..*|\\$.*))" +
+ "|(scala\\.collection\\..*)")
private def sparkCode(ste: StackTraceElement): Boolean = {
sparkCodePattern.matcher(ste.getClassName).matches()
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala
index ad00a5216b4c9..b3bd86149ba91 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala
@@ -81,9 +81,9 @@ case class CompoundBody(
/**
* Logical operator for IF ELSE statement.
* @param conditions Collection of conditions. First condition corresponds to IF clause,
- * while others (if any) correspond to following ELSE IF clauses.
+ * while others (if any) correspond to following ELSEIF clauses.
* @param conditionalBodies Collection of bodies that have a corresponding condition,
- * in IF or ELSE IF branches.
+ * in IF or ELSEIF branches.
* @param elseBody Body that is executed if none of the conditions are met,
* i.e. ELSE branch.
*/
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
index e129c6dbba052..40ba7809e5cee 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
@@ -501,12 +501,12 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
.getText == "SELECT 2")
}
- test("if else if") {
+ test("if elseif") {
val sqlScriptText =
"""BEGIN
|IF 1 = 1 THEN
| SELECT 1;
- |ELSE IF 2 = 2 THEN
+ |ELSEIF 2 = 2 THEN
| SELECT 2;
|ELSE
| SELECT 3;
@@ -541,14 +541,14 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
.getText == "SELECT 3")
}
- test("if multi else if") {
+ test("if multi elseif") {
val sqlScriptText =
"""BEGIN
|IF 1 = 1 THEN
| SELECT 1;
- |ELSE IF 2 = 2 THEN
+ |ELSEIF 2 = 2 THEN
| SELECT 2;
- |ELSE IF 3 = 3 THEN
+ |ELSEIF 3 = 3 THEN
| SELECT 3;
|END IF;
|END
@@ -584,6 +584,87 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
.getText == "SELECT 3")
}
+ test("if - multi elseif - else nested") {
+ val sqlScriptText =
+ """BEGIN
+ |IF 1 = 1 THEN
+ | SELECT 1;
+ |ELSEIF 2 = 2 THEN
+ | SELECT 2;
+ |ELSE
+ | IF 3 = 3 THEN
+ | SELECT 3;
+ | ELSEIF 4 = 4 THEN
+ | SELECT 4;
+ | ELSE
+ | IF 5 = 5 THEN
+ | SELECT 5;
+ | ELSE
+ | SELECT 6;
+ | END IF;
+ | END IF;
+ |END IF;
+ |END
+ """.stripMargin
+ val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody]
+ assert(tree.collection.length == 1)
+ assert(tree.collection.head.isInstanceOf[IfElseStatement])
+
+ val ifStmt = tree.collection.head.asInstanceOf[IfElseStatement]
+ assert(ifStmt.conditions.length == 2)
+ assert(ifStmt.conditionalBodies.length == 2)
+ assert(ifStmt.elseBody.nonEmpty)
+
+ assert(ifStmt.conditions.head.isInstanceOf[SingleStatement])
+ assert(ifStmt.conditions.head.getText == "1 = 1")
+
+ assert(ifStmt.conditionalBodies.head.collection.head.isInstanceOf[SingleStatement])
+ assert(ifStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement]
+ .getText == "SELECT 1")
+
+ assert(ifStmt.conditions(1).isInstanceOf[SingleStatement])
+ assert(ifStmt.conditions(1).getText == "2 = 2")
+
+ assert(ifStmt.conditionalBodies(1).collection.head.isInstanceOf[SingleStatement])
+ assert(ifStmt.conditionalBodies(1).collection.head.asInstanceOf[SingleStatement]
+ .getText == "SELECT 2")
+
+ assert(ifStmt.elseBody.get.collection.head.isInstanceOf[IfElseStatement])
+ val nestedIf_1 = ifStmt.elseBody.get.collection.head.asInstanceOf[IfElseStatement]
+
+ assert(nestedIf_1.conditions.length == 2)
+ assert(nestedIf_1.conditionalBodies.length == 2)
+ assert(nestedIf_1.elseBody.nonEmpty)
+
+
+ assert(nestedIf_1.conditions.head.isInstanceOf[SingleStatement])
+ assert(nestedIf_1.conditions.head.getText == "3 = 3")
+
+ assert(nestedIf_1.conditionalBodies.head.collection.head.isInstanceOf[SingleStatement])
+ assert(nestedIf_1.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement]
+ .getText == "SELECT 3")
+
+ assert(nestedIf_1.conditions(1).isInstanceOf[SingleStatement])
+ assert(nestedIf_1.conditions(1).getText == "4 = 4")
+
+ assert(nestedIf_1.conditionalBodies(1).collection.head.isInstanceOf[SingleStatement])
+ assert(nestedIf_1.conditionalBodies(1).collection.head.asInstanceOf[SingleStatement]
+ .getText == "SELECT 4")
+
+ assert(nestedIf_1.elseBody.get.collection.head.isInstanceOf[IfElseStatement])
+ val nestedIf_2 = nestedIf_1.elseBody.get.collection.head.asInstanceOf[IfElseStatement]
+
+ assert(nestedIf_2.conditions.length == 1)
+ assert(nestedIf_2.conditionalBodies.length == 1)
+ assert(nestedIf_2.elseBody.nonEmpty)
+
+ assert(nestedIf_2.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement]
+ .getText == "SELECT 5")
+
+ assert(nestedIf_2.elseBody.get.collection.head.asInstanceOf[SingleStatement]
+ .getText == "SELECT 6")
+ }
+
test("if nested") {
val sqlScriptText =
"""
diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameNaFunctions.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameNaFunctions.scala
index 8f6c6ef07b3df..7b79387fbfde9 100644
--- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameNaFunctions.scala
+++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameNaFunctions.scala
@@ -23,8 +23,8 @@ import org.apache.spark.connect.proto.{NAReplace, Relation}
import org.apache.spark.connect.proto.Expression.{Literal => GLiteral}
import org.apache.spark.connect.proto.NAReplace.Replacement
import org.apache.spark.sql
+import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toLiteral
import org.apache.spark.sql.connect.ConnectConversions._
-import org.apache.spark.sql.functions
/**
* Functionality for working with missing data in `DataFrame`s.
@@ -33,7 +33,6 @@ import org.apache.spark.sql.functions
*/
final class DataFrameNaFunctions private[sql] (sparkSession: SparkSession, root: Relation)
extends sql.DataFrameNaFunctions {
- import sparkSession.RichColumn
override protected def drop(minNonNulls: Option[Int]): DataFrame =
buildDropDataFrame(None, minNonNulls)
@@ -103,7 +102,7 @@ final class DataFrameNaFunctions private[sql] (sparkSession: SparkSession, root:
sparkSession.newDataFrame { builder =>
val fillNaBuilder = builder.getFillNaBuilder.setInput(root)
values.map { case (colName, replaceValue) =>
- fillNaBuilder.addCols(colName).addValues(functions.lit(replaceValue).expr.getLiteral)
+ fillNaBuilder.addCols(colName).addValues(toLiteral(replaceValue).getLiteral)
}
}
}
@@ -143,8 +142,8 @@ final class DataFrameNaFunctions private[sql] (sparkSession: SparkSession, root:
replacementMap.map { case (oldValue, newValue) =>
Replacement
.newBuilder()
- .setOldValue(functions.lit(oldValue).expr.getLiteral)
- .setNewValue(functions.lit(newValue).expr.getLiteral)
+ .setOldValue(toLiteral(oldValue).getLiteral)
+ .setNewValue(toLiteral(newValue).getLiteral)
.build()
}
}
diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameStatFunctions.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameStatFunctions.scala
index f3c3f82a233ae..a510afc716a77 100644
--- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameStatFunctions.scala
+++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameStatFunctions.scala
@@ -23,9 +23,9 @@ import org.apache.spark.connect.proto.{Relation, StatSampleBy}
import org.apache.spark.sql
import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, PrimitiveDoubleEncoder}
+import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.{toExpr, toLiteral}
import org.apache.spark.sql.connect.ConnectConversions._
import org.apache.spark.sql.connect.DataFrameStatFunctions.approxQuantileResultEncoder
-import org.apache.spark.sql.functions.lit
/**
* Statistic functions for `DataFrame`s.
@@ -120,20 +120,19 @@ final class DataFrameStatFunctions private[sql] (protected val df: DataFrame)
/** @inheritdoc */
def sampleBy[T](col: Column, fractions: Map[T, Double], seed: Long): DataFrame = {
- import sparkSession.RichColumn
require(
fractions.values.forall(p => p >= 0.0 && p <= 1.0),
s"Fractions must be in [0, 1], but got $fractions.")
sparkSession.newDataFrame { builder =>
val sampleByBuilder = builder.getSampleByBuilder
.setInput(root)
- .setCol(col.expr)
+ .setCol(toExpr(col))
.setSeed(seed)
fractions.foreach { case (k, v) =>
sampleByBuilder.addFractions(
StatSampleBy.Fraction
.newBuilder()
- .setStratum(lit(k).expr.getLiteral)
+ .setStratum(toLiteral(k).getLiteral)
.setFraction(v))
}
}
diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameWriterV2.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameWriterV2.scala
index 42cf2cdfad58a..06d339487bfb8 100644
--- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameWriterV2.scala
+++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameWriterV2.scala
@@ -23,6 +23,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.connect.proto
import org.apache.spark.sql
import org.apache.spark.sql.Column
+import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toExpr
/**
* Interface used to write a [[org.apache.spark.sql.Dataset]] to external storage using the v2
@@ -33,7 +34,6 @@ import org.apache.spark.sql.Column
@Experimental
final class DataFrameWriterV2[T] private[sql] (table: String, ds: Dataset[T])
extends sql.DataFrameWriterV2[T] {
- import ds.sparkSession.RichColumn
private val builder = proto.WriteOperationV2
.newBuilder()
@@ -73,7 +73,7 @@ final class DataFrameWriterV2[T] private[sql] (table: String, ds: Dataset[T])
/** @inheritdoc */
@scala.annotation.varargs
override def partitionedBy(column: Column, columns: Column*): this.type = {
- builder.addAllPartitioningColumns((column +: columns).map(_.expr).asJava)
+ builder.addAllPartitioningColumns((column +: columns).map(toExpr).asJava)
this
}
@@ -106,7 +106,7 @@ final class DataFrameWriterV2[T] private[sql] (table: String, ds: Dataset[T])
/** @inheritdoc */
def overwrite(condition: Column): Unit = {
- builder.setOverwriteCondition(condition.expr)
+ builder.setOverwriteCondition(toExpr(condition))
executeWriteOperation(proto.WriteOperationV2.Mode.MODE_OVERWRITE)
}
diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala
index 36003283a3369..419ac3b7f74ae 100644
--- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala
+++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala
@@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
import org.apache.spark.sql.catalyst.expressions.OrderUtils
+import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.{toExpr, toLiteral, toTypedExpr}
import org.apache.spark.sql.connect.ConnectConversions._
import org.apache.spark.sql.connect.client.SparkResult
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, StorageLevelProtoConverter}
@@ -140,7 +141,6 @@ class Dataset[T] private[sql] (
@DeveloperApi val plan: proto.Plan,
val encoder: Encoder[T])
extends sql.Dataset[T] {
- import sparkSession.RichColumn
// Make sure we don't forget to set plan id.
assert(plan.getRoot.getCommon.hasPlanId)
@@ -336,7 +336,7 @@ class Dataset[T] private[sql] (
buildJoin(right, Seq(joinExprs)) { builder =>
builder
.setJoinType(toJoinType(joinType))
- .setJoinCondition(joinExprs.expr)
+ .setJoinCondition(toExpr(joinExprs))
}
}
@@ -375,7 +375,7 @@ class Dataset[T] private[sql] (
.setLeft(plan.getRoot)
.setRight(other.plan.getRoot)
.setJoinType(joinTypeValue)
- .setJoinCondition(condition.expr)
+ .setJoinCondition(toExpr(condition))
.setJoinDataType(joinBuilder.getJoinDataTypeBuilder
.setIsLeftStruct(this.agnosticEncoder.isStruct)
.setIsRightStruct(other.agnosticEncoder.isStruct))
@@ -396,7 +396,7 @@ class Dataset[T] private[sql] (
sparkSession.newDataFrame(joinExprs.toSeq) { builder =>
val lateralJoinBuilder = builder.getLateralJoinBuilder
lateralJoinBuilder.setLeft(plan.getRoot).setRight(right.plan.getRoot)
- joinExprs.foreach(c => lateralJoinBuilder.setJoinCondition(c.expr))
+ joinExprs.foreach(c => lateralJoinBuilder.setJoinCondition(toExpr(c)))
lateralJoinBuilder.setJoinType(joinTypeValue)
}
}
@@ -440,7 +440,7 @@ class Dataset[T] private[sql] (
builder.getHintBuilder
.setInput(plan.getRoot)
.setName(name)
- .addAllParameters(parameters.map(p => functions.lit(p).expr).asJava)
+ .addAllParameters(parameters.map(p => toLiteral(p)).asJava)
}
private def getPlanId: Option[Long] =
@@ -486,7 +486,7 @@ class Dataset[T] private[sql] (
sparkSession.newDataset(encoder) { builder =>
builder.getProjectBuilder
.setInput(plan.getRoot)
- .addExpressions(col.typedExpr(this.encoder))
+ .addExpressions(toTypedExpr(col, this.encoder))
}
}
@@ -504,14 +504,14 @@ class Dataset[T] private[sql] (
sparkSession.newDataset(encoder, cols) { builder =>
builder.getProjectBuilder
.setInput(plan.getRoot)
- .addAllExpressions(cols.map(_.typedExpr(this.encoder)).asJava)
+ .addAllExpressions(cols.map(c => toTypedExpr(c, this.encoder)).asJava)
}
}
/** @inheritdoc */
def filter(condition: Column): Dataset[T] = {
sparkSession.newDataset(agnosticEncoder, Seq(condition)) { builder =>
- builder.getFilterBuilder.setInput(plan.getRoot).setCondition(condition.expr)
+ builder.getFilterBuilder.setInput(plan.getRoot).setCondition(toExpr(condition))
}
}
@@ -523,12 +523,12 @@ class Dataset[T] private[sql] (
sparkSession.newDataFrame(ids.toSeq ++ valuesOption.toSeq.flatten) { builder =>
val unpivot = builder.getUnpivotBuilder
.setInput(plan.getRoot)
- .addAllIds(ids.toImmutableArraySeq.map(_.expr).asJava)
+ .addAllIds(ids.toImmutableArraySeq.map(toExpr).asJava)
.setVariableColumnName(variableColumnName)
.setValueColumnName(valueColumnName)
valuesOption.foreach { values =>
unpivot.getValuesBuilder
- .addAllValues(values.toImmutableArraySeq.map(_.expr).asJava)
+ .addAllValues(values.toImmutableArraySeq.map(toExpr).asJava)
}
}
}
@@ -537,7 +537,7 @@ class Dataset[T] private[sql] (
sparkSession.newDataFrame(indices) { builder =>
val transpose = builder.getTransposeBuilder.setInput(plan.getRoot)
indices.foreach { indexColumn =>
- transpose.addIndexColumns(indexColumn.expr)
+ transpose.addIndexColumns(toExpr(indexColumn))
}
}
@@ -553,7 +553,7 @@ class Dataset[T] private[sql] (
function = func,
inputEncoders = agnosticEncoder :: agnosticEncoder :: Nil,
outputEncoder = agnosticEncoder)
- val reduceExpr = Column.fn("reduce", udf.apply(col("*"), col("*"))).expr
+ val reduceExpr = toExpr(Column.fn("reduce", udf.apply(col("*"), col("*"))))
val result = sparkSession
.newDataset(agnosticEncoder) { builder =>
@@ -590,7 +590,7 @@ class Dataset[T] private[sql] (
val groupingSetMsgs = groupingSets.map { groupingSet =>
val groupingSetMsg = proto.Aggregate.GroupingSets.newBuilder()
for (groupCol <- groupingSet) {
- groupingSetMsg.addGroupingSet(groupCol.expr)
+ groupingSetMsg.addGroupingSet(toExpr(groupCol))
}
groupingSetMsg.build()
}
@@ -779,7 +779,7 @@ class Dataset[T] private[sql] (
s"The size of column names: ${names.size} isn't equal to " +
s"the size of columns: ${values.size}")
val aliases = values.zip(names).map { case (value, name) =>
- value.name(name).expr.getAlias
+ toExpr(value.name(name)).getAlias
}
sparkSession.newDataFrame(values) { builder =>
builder.getWithColumnsBuilder
@@ -812,7 +812,7 @@ class Dataset[T] private[sql] (
def withMetadata(columnName: String, metadata: Metadata): DataFrame = {
val newAlias = proto.Expression.Alias
.newBuilder()
- .setExpr(col(columnName).expr)
+ .setExpr(toExpr(col(columnName)))
.addName(columnName)
.setMetadata(metadata.json)
sparkSession.newDataFrame { builder =>
@@ -845,7 +845,7 @@ class Dataset[T] private[sql] (
sparkSession.newDataFrame(cols) { builder =>
builder.getDropBuilder
.setInput(plan.getRoot)
- .addAllColumns(cols.map(_.expr).asJava)
+ .addAllColumns(cols.map(toExpr).asJava)
}
}
@@ -915,7 +915,7 @@ class Dataset[T] private[sql] (
sparkSession.newDataset[T](agnosticEncoder) { builder =>
builder.getFilterBuilder
.setInput(plan.getRoot)
- .setCondition(udf.apply(col("*")).expr)
+ .setCondition(toExpr(udf.apply(col("*"))))
}
}
@@ -944,7 +944,7 @@ class Dataset[T] private[sql] (
sparkSession.newDataset(outputEncoder) { builder =>
builder.getMapPartitionsBuilder
.setInput(plan.getRoot)
- .setFunc(udf.apply(col("*")).expr.getCommonInlineUserDefinedFunction)
+ .setFunc(toExpr(udf.apply(col("*"))).getCommonInlineUserDefinedFunction)
}
}
@@ -1020,7 +1020,7 @@ class Dataset[T] private[sql] (
sparkSession.newDataset(agnosticEncoder, partitionExprs) { builder =>
val repartitionBuilder = builder.getRepartitionByExpressionBuilder
.setInput(plan.getRoot)
- .addAllPartitionExprs(partitionExprs.map(_.expr).asJava)
+ .addAllPartitionExprs(partitionExprs.map(toExpr).asJava)
numPartitions.foreach(repartitionBuilder.setNumPartitions)
}
}
@@ -1036,7 +1036,7 @@ class Dataset[T] private[sql] (
// The underlying `LogicalPlan` operator special-cases all-`SortOrder` arguments.
// However, we don't want to complicate the semantics of this API method.
// Instead, let's give users a friendly error message, pointing them to the new method.
- val sortOrders = partitionExprs.filter(_.expr.hasSortOrder)
+ val sortOrders = partitionExprs.filter(e => toExpr(e).hasSortOrder)
if (sortOrders.nonEmpty) {
throw new IllegalArgumentException(
s"Invalid partitionExprs specified: $sortOrders\n" +
@@ -1050,7 +1050,7 @@ class Dataset[T] private[sql] (
partitionExprs: Seq[Column]): Dataset[T] = {
require(partitionExprs.nonEmpty, "At least one partition-by expression must be specified.")
val sortExprs = partitionExprs.map {
- case e if e.expr.hasSortOrder => e
+ case e if toExpr(e).hasSortOrder => e
case e => e.asc
}
buildRepartitionByExpression(numPartitions, sortExprs)
@@ -1158,7 +1158,7 @@ class Dataset[T] private[sql] (
builder.getCollectMetricsBuilder
.setInput(plan.getRoot)
.setName(name)
- .addAllMetrics((expr +: exprs).map(_.expr).asJava)
+ .addAllMetrics((expr +: exprs).map(toExpr).asJava)
}
}
diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/KeyValueGroupedDataset.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/KeyValueGroupedDataset.scala
index c984582ed6ae1..dc494649b397c 100644
--- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/KeyValueGroupedDataset.scala
+++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/KeyValueGroupedDataset.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql
import org.apache.spark.sql.{Column, Encoder, TypedColumn}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor, ProductEncoder}
-import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toExpr
+import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.{toExpr, toTypedExpr}
import org.apache.spark.sql.connect.ConnectConversions._
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, UdfUtils}
import org.apache.spark.sql.expressions.SparkUserDefinedFunction
@@ -394,7 +394,6 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
private val valueMapFunc: Option[IV => V],
private val keysFunc: () => Dataset[IK])
extends KeyValueGroupedDataset[K, V] {
- import sparkSession.RichColumn
override def keyAs[L: Encoder]: KeyValueGroupedDataset[L, V] = {
new KeyValueGroupedDatasetImpl[L, V, IK, IV](
@@ -436,7 +435,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
sparkSession.newDataset[U](outputEncoder) { builder =>
builder.getGroupMapBuilder
.setInput(plan.getRoot)
- .addAllSortingExpressions(sortExprs.map(e => e.expr).asJava)
+ .addAllSortingExpressions(sortExprs.map(toExpr).asJava)
.addAllGroupingExpressions(groupingExprs)
.setFunc(getUdf(nf, outputEncoder)(ivEncoder))
}
@@ -453,10 +452,10 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
builder.getCoGroupMapBuilder
.setInput(plan.getRoot)
.addAllInputGroupingExpressions(groupingExprs)
- .addAllInputSortingExpressions(thisSortExprs.map(e => e.expr).asJava)
+ .addAllInputSortingExpressions(thisSortExprs.map(toExpr).asJava)
.setOther(otherImpl.plan.getRoot)
.addAllOtherGroupingExpressions(otherImpl.groupingExprs)
- .addAllOtherSortingExpressions(otherSortExprs.map(e => e.expr).asJava)
+ .addAllOtherSortingExpressions(otherSortExprs.map(toExpr).asJava)
.setFunc(getUdf(nf, outputEncoder)(ivEncoder, otherImpl.ivEncoder))
}
}
@@ -469,7 +468,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
.setInput(plan.getRoot)
.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY)
.addAllGroupingExpressions(groupingExprs)
- .addAllAggregateExpressions(columns.map(_.typedExpr(vEncoder)).asJava)
+ .addAllAggregateExpressions(columns.map(c => toTypedExpr(c, vEncoder)).asJava)
}
}
@@ -534,7 +533,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
function = nf,
inputEncoders = inputEncoders,
outputEncoder = outputEncoder)
- udf.apply(inputEncoders.map(_ => col("*")): _*).expr.getCommonInlineUserDefinedFunction
+ toExpr(udf.apply(inputEncoders.map(_ => col("*")): _*)).getCommonInlineUserDefinedFunction
}
private def getUdf[U: Encoder, S: Encoder](
@@ -549,7 +548,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
function = nf,
inputEncoders = inputEncoders,
outputEncoder = outputEncoder)
- udf.apply(inputEncoders.map(_ => col("*")): _*).expr.getCommonInlineUserDefinedFunction
+ toExpr(udf.apply(inputEncoders.map(_ => col("*")): _*)).getCommonInlineUserDefinedFunction
}
/**
diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/MergeIntoWriter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/MergeIntoWriter.scala
index c245a8644a3cb..66354e63ca8af 100644
--- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/MergeIntoWriter.scala
+++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/MergeIntoWriter.scala
@@ -24,6 +24,7 @@ import org.apache.spark.connect.proto.{Expression, MergeAction, MergeIntoTableCo
import org.apache.spark.connect.proto.MergeAction.ActionType._
import org.apache.spark.sql
import org.apache.spark.sql.Column
+import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toExpr
import org.apache.spark.sql.functions.expr
/**
@@ -44,13 +45,12 @@ import org.apache.spark.sql.functions.expr
@Experimental
class MergeIntoWriter[T] private[sql] (table: String, ds: Dataset[T], on: Column)
extends sql.MergeIntoWriter[T] {
- import ds.sparkSession.RichColumn
private val builder = MergeIntoTableCommand
.newBuilder()
.setTargetTableName(table)
.setSourceTablePlan(ds.plan.getRoot)
- .setMergeCondition(on.expr)
+ .setMergeCondition(toExpr(on))
/**
* Executes the merge operation.
@@ -121,12 +121,12 @@ class MergeIntoWriter[T] private[sql] (table: String, ds: Dataset[T], on: Column
condition: Option[Column],
assignments: Map[String, Column] = Map.empty): Expression = {
val builder = proto.MergeAction.newBuilder().setActionType(actionType)
- condition.foreach(c => builder.setCondition(c.expr))
+ condition.foreach(c => builder.setCondition(toExpr(c)))
assignments.foreach { case (k, v) =>
builder
.addAssignmentsBuilder()
- .setKey(expr(k).expr)
- .setValue(v.expr)
+ .setKey(toExpr(expr(k)))
+ .setValue(toExpr(v))
}
Expression
.newBuilder()
diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/RelationalGroupedDataset.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/RelationalGroupedDataset.scala
index 00dc1fb6906f7..ac361047bbd08 100644
--- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/RelationalGroupedDataset.scala
+++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/RelationalGroupedDataset.scala
@@ -23,6 +23,7 @@ import org.apache.spark.connect.proto
import org.apache.spark.sql
import org.apache.spark.sql.{functions, Column, Encoder}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.agnosticEncoderFor
+import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.{toExpr, toTypedExpr}
import org.apache.spark.sql.connect.ConnectConversions._
/**
@@ -44,14 +45,13 @@ class RelationalGroupedDataset private[sql] (
pivot: Option[proto.Aggregate.Pivot] = None,
groupingSets: Option[Seq[proto.Aggregate.GroupingSets]] = None)
extends sql.RelationalGroupedDataset {
- import df.sparkSession.RichColumn
protected def toDF(aggExprs: Seq[Column]): DataFrame = {
df.sparkSession.newDataFrame(groupingExprs ++ aggExprs) { builder =>
val aggBuilder = builder.getAggregateBuilder
.setInput(df.plan.getRoot)
- groupingExprs.foreach(c => aggBuilder.addGroupingExpressions(c.expr))
- aggExprs.foreach(c => aggBuilder.addAggregateExpressions(c.typedExpr(df.encoder)))
+ groupingExprs.foreach(c => aggBuilder.addGroupingExpressions(toExpr(c)))
+ aggExprs.foreach(c => aggBuilder.addAggregateExpressions(toTypedExpr(c, df.encoder)))
groupType match {
case proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP =>
@@ -152,10 +152,13 @@ class RelationalGroupedDataset private[sql] (
groupType match {
case proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY =>
val valueExprs = values.map {
- case c: Column if c.expr.hasLiteral => c.expr.getLiteral
- case c: Column if !c.expr.hasLiteral =>
- throw new IllegalArgumentException("values only accept literal Column")
- case v => functions.lit(v).expr.getLiteral
+ case c: Column =>
+ val e = toExpr(c)
+ if (!e.hasLiteral) {
+ throw new IllegalArgumentException("values only accept literal Column")
+ }
+ e.getLiteral
+ case v => toExpr(functions.lit(v)).getLiteral
}
new RelationalGroupedDataset(
df,
@@ -164,7 +167,7 @@ class RelationalGroupedDataset private[sql] (
Some(
proto.Aggregate.Pivot
.newBuilder()
- .setCol(pivotColumn.expr)
+ .setCol(toExpr(pivotColumn))
.addAllValues(valueExprs.asJava)
.build()))
case _ =>
diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
index f7998cf60ecac..032ab670dab0e 100644
--- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
+++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
@@ -43,11 +43,10 @@ import org.apache.spark.sql.{Column, Encoder, ExperimentalMethods, Observation,
import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor, BoxedLongEncoder, UnboundRowEncoder}
-import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.{toExpr, toTypedExpr}
+import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toLiteral
import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, SparkConnectClient, SparkResult}
import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration
import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
-import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.internal.{SessionState, SharedState, SqlApiConf}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types.StructType
@@ -213,7 +212,7 @@ class SparkSession private[sql] (
val sqlCommand = proto.SqlCommand
.newBuilder()
.setSql(sqlText)
- .addAllPosArguments(args.map(lit(_).expr).toImmutableArraySeq.asJava)
+ .addAllPosArguments(args.map(a => toLiteral(a)).toImmutableArraySeq.asJava)
.build()
sql(sqlCommand)
}
@@ -228,7 +227,7 @@ class SparkSession private[sql] (
val sqlCommand = proto.SqlCommand
.newBuilder()
.setSql(sqlText)
- .putAllNamedArguments(args.asScala.map { case (k, v) => (k, lit(v).expr) }.asJava)
+ .putAllNamedArguments(args.asScala.map { case (k, v) => (k, toLiteral(v)) }.asJava)
.build()
sql(sqlCommand)
}
@@ -653,11 +652,6 @@ class SparkSession private[sql] (
}
override private[sql] def isUsable: Boolean = client.isSessionValid
-
- implicit class RichColumn(c: Column) {
- def expr: proto.Expression = toExpr(c)
- def typedExpr[T](e: Encoder[T]): proto.Expression = toTypedExpr(c, e)
- }
}
// The minimal builder needed to create a spark session.
diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/columnNodeSupport.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/columnNodeSupport.scala
index f44ec5b2d5046..f08b6e709f13a 100644
--- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/columnNodeSupport.scala
+++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/columnNodeSupport.scala
@@ -24,7 +24,7 @@ import org.apache.spark.connect.proto.Expression
import org.apache.spark.connect.proto.Expression.SortOrder.NullOrdering.{SORT_NULLS_FIRST, SORT_NULLS_LAST}
import org.apache.spark.connect.proto.Expression.SortOrder.SortDirection.{SORT_DIRECTION_ASCENDING, SORT_DIRECTION_DESCENDING}
import org.apache.spark.connect.proto.Expression.Window.WindowFrame.{FrameBoundary, FrameType}
-import org.apache.spark.sql.{Column, Encoder}
+import org.apache.spark.sql.{functions, Column, Encoder}
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
import org.apache.spark.sql.connect.common.DataTypeProtoConverter
import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProtoBuilder
@@ -37,6 +37,8 @@ import org.apache.spark.sql.internal.{Alias, CaseWhenOtherwise, Cast, ColumnNode
object ColumnNodeToProtoConverter extends (ColumnNode => proto.Expression) {
def toExpr(column: Column): proto.Expression = apply(column.node, None)
+ def toLiteral(v: Any): proto.Expression = apply(functions.lit(v).node, None)
+
def toTypedExpr[I](column: Column, encoder: Encoder[I]): proto.Expression = {
apply(column.node, Option(encoder))
}
diff --git a/sql/connect/common/src/test/resources/query-tests/queries/hint.json b/sql/connect/common/src/test/resources/query-tests/queries/hint.json
index 2ac930c0a3a71..2348d0f847157 100644
--- a/sql/connect/common/src/test/resources/query-tests/queries/hint.json
+++ b/sql/connect/common/src/test/resources/query-tests/queries/hint.json
@@ -22,13 +22,13 @@
"stackTrace": [{
"classLoaderName": "app",
"declaringClass": "org.apache.spark.sql.connect.Dataset",
- "methodName": "~~trimmed~anonfun~~",
+ "methodName": "hint",
"fileName": "Dataset.scala"
}, {
"classLoaderName": "app",
- "declaringClass": "org.apache.spark.sql.connect.SparkSession",
- "methodName": "newDataset",
- "fileName": "SparkSession.scala"
+ "declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite",
+ "methodName": "~~trimmed~anonfun~~",
+ "fileName": "PlanGenerationTestSuite.scala"
}]
}
}
diff --git a/sql/connect/common/src/test/resources/query-tests/queries/hint.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/hint.proto.bin
index 06459ee5b765c..ce7c63b57c47b 100644
Binary files a/sql/connect/common/src/test/resources/query-tests/queries/hint.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/hint.proto.bin differ
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
index 814a28e24f522..1d83a46a278f7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
@@ -29,11 +29,13 @@ import org.apache.avro.Schema.Type._
import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed, Record}
import org.apache.avro.util.Utf8
+import org.apache.spark.SparkRuntimeException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.avro.AvroUtils.{nonNullUnionBranches, toFieldStr, AvroMatchedField}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, SpecificInternalRow}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.errors.DataTypeErrors.toSQLId
import org.apache.spark.sql.execution.datasources.DataSourceUtils
import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf}
import org.apache.spark.sql.types._
@@ -282,11 +284,20 @@ private[sql] class AvroSerializer(
}.toArray.unzip
val numFields = catalystStruct.length
+ val avroFields = avroStruct.getFields()
+ val isSchemaNullable = avroFields.asScala.map(_.schema().isNullable)
row: InternalRow =>
val result = new Record(avroStruct)
var i = 0
while (i < numFields) {
if (row.isNullAt(i)) {
+ if (!isSchemaNullable(i)) {
+ throw new SparkRuntimeException(
+ errorClass = "AVRO_CANNOT_WRITE_NULL_FIELD",
+ messageParameters = Map(
+ "name" -> toSQLId(avroFields.get(i).name),
+ "dataType" -> avroFields.get(i).schema().toString))
+ }
result.put(avroIndices(i), null)
} else {
result.put(avroIndices(i), fieldConverters(i).apply(row, i))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala
index f33e64c859fb3..e818cc915951b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala
@@ -207,7 +207,7 @@ private case class DB2Dialect() extends JdbcDialect with SQLConfHelper with NoLe
val offsetClause = dialect.getOffsetClause(offset)
options.prepareQuery +
- s"SELECT $columnList FROM ${options.tableOrQuery} $tableSampleClause" +
+ s"SELECT $columnList FROM ${options.tableOrQuery}" +
s" $whereClause $groupByClause $orderByClause $offsetClause $limitClause"
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala
index 77d0891ce338c..33fb93b168f9e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala
@@ -250,7 +250,7 @@ private case class MsSqlServerDialect() extends JdbcDialect with NoLegacyJDBCErr
options.prepareQuery +
s"SELECT $limitClause $columnList FROM ${options.tableOrQuery}" +
- s" $tableSampleClause $whereClause $groupByClause $orderByClause"
+ s" $whereClause $groupByClause $orderByClause"
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
index d11bf14be6546..e153d38a00cfb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
@@ -344,9 +344,9 @@ class CompoundBodyExec(
/**
* Executable node for IfElseStatement.
* @param conditions Collection of executable conditions. First condition corresponds to IF clause,
- * while others (if any) correspond to following ELSE IF clauses.
+ * while others (if any) correspond to following ELSEIF clauses.
* @param conditionalBodies Collection of executable bodies that have a corresponding condition,
-* in IF or ELSE IF branches.
+* in IF or ELSEIF branches.
* @param elseBody Body that is executed if none of the conditions are met,
* i.e. ELSE branch.
* @param session Spark session that SQL script is executed within.
@@ -380,7 +380,7 @@ class IfElseStatementExec(
} else {
clauseIdx += 1
if (clauseIdx < conditionsCount) {
- // There are ELSE IF clauses remaining.
+ // There are ELSEIF clauses remaining.
state = IfElseState.Condition
curr = Some(conditions(clauseIdx))
} else if (elseBody.isDefined) {
diff --git a/sql/core/src/test/resources/sql-tests/results/keywords-enforced.sql.out b/sql/core/src/test/resources/sql-tests/results/keywords-enforced.sql.out
index 521b0afe19264..84a35c270b698 100644
--- a/sql/core/src/test/resources/sql-tests/results/keywords-enforced.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/keywords-enforced.sql.out
@@ -104,6 +104,7 @@ DO false
DOUBLE false
DROP false
ELSE true
+ELSEIF false
END true
ESCAPE true
ESCAPED false
diff --git a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out
index 4d702588ad2b3..49ea8bba3e174 100644
--- a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out
@@ -104,6 +104,7 @@ DO false
DOUBLE false
DROP false
ELSE false
+ELSEIF false
END false
ESCAPE false
ESCAPED false
diff --git a/sql/core/src/test/resources/sql-tests/results/nonansi/keywords.sql.out b/sql/core/src/test/resources/sql-tests/results/nonansi/keywords.sql.out
index 4d702588ad2b3..49ea8bba3e174 100644
--- a/sql/core/src/test/resources/sql-tests/results/nonansi/keywords.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/nonansi/keywords.sql.out
@@ -104,6 +104,7 @@ DO false
DOUBLE false
DROP false
ELSE false
+ELSEIF false
END false
ESCAPE false
ESCAPED false
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index db56da80fd4af..a1d83ee665088 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -1107,7 +1107,7 @@ class JDBCSuite extends QueryTest with SharedSparkSession {
.withLimit(123)
.build()
.trim() ==
- "SELECT a,b FROM test FETCH FIRST 123 ROWS ONLY")
+ "SELECT a,b FROM test FETCH FIRST 123 ROWS ONLY")
}
test("table exists query by jdbc dialect") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala
index 5b5285ea13275..503d38d61c7ab 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala
@@ -218,14 +218,14 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession {
verifySqlScriptResult(commands, expected)
}
- test("if else if going in else if") {
+ test("if elseif going in elseif") {
val commands =
"""
|BEGIN
| IF 1=2
| THEN
| SELECT 42;
- | ELSE IF 1=1
+ | ELSEIF 1=1
| THEN
| SELECT 43;
| ELSE
@@ -253,14 +253,14 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession {
verifySqlScriptResult(commands, expected)
}
- test("if else if going in else") {
+ test("if elseif going in else") {
val commands =
"""
|BEGIN
| IF 1=2
| THEN
| SELECT 42;
- | ELSE IF 1=3
+ | ELSEIF 1=3
| THEN
| SELECT 43;
| ELSE
@@ -292,7 +292,7 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession {
}
}
- test("if else if with count") {
+ test("if elseif with count") {
withTable("t") {
val commands =
"""
@@ -302,7 +302,7 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession {
| INSERT INTO t VALUES (1, 'a', 1.0);
| IF (SELECT COUNT(*) > 2 FROM t) THEN
| SELECT 42;
- | ELSE IF (SELECT COUNT(*) > 1 FROM t) THEN
+ | ELSEIF (SELECT COUNT(*) > 1 FROM t) THEN
| SELECT 43;
| ELSE
| SELECT 44;
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
index 601548a2e6bd6..5f149d15c6e8c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
@@ -370,14 +370,14 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession {
verifySqlScriptResult(commands, expected)
}
- test("if else if going in else if") {
+ test("if elseif going in elseif") {
val commands =
"""
|BEGIN
| IF 1=2
| THEN
| SELECT 42;
- | ELSE IF 1=1
+ | ELSEIF 1=1
| THEN
| SELECT 43;
| ELSE
@@ -407,14 +407,14 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession {
verifySqlScriptResult(commands, expected)
}
- test("if else if going in else") {
+ test("if elseif going in else") {
val commands =
"""
|BEGIN
| IF 1=2
| THEN
| SELECT 42;
- | ELSE IF 1=3
+ | ELSEIF 1=3
| THEN
| SELECT 43;
| ELSE
@@ -448,7 +448,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession {
}
}
- test("if else if with count") {
+ test("if elseif with count") {
withTable("t") {
val commands =
"""
@@ -458,7 +458,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession {
| INSERT INTO t VALUES (1, 'a', 1.0);
| IF (SELECT COUNT(*) > 2 FROM t) THEN
| SELECT 42;
- | ELSE IF (SELECT COUNT(*) > 1 FROM t) THEN
+ | ELSEIF (SELECT COUNT(*) > 1 FROM t) THEN
| SELECT 43;
| ELSE
| SELECT 44;
diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml
index 135b84cd01f85..e57fa5a235420 100644
--- a/sql/hive-thriftserver/pom.xml
+++ b/sql/hive-thriftserver/pom.xml
@@ -148,6 +148,16 @@
byte-buddy-agent
test
+
+ ${hive.group}
+ hive-llap-common
+ ${hive.llap.scope}
+
+
+ ${hive.group}
+ hive-llap-client
+ ${hive.llap.scope}
+
net.sf.jpam
jpam
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala
index 254eda69e86e8..ec65886fb2c98 100644
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala
@@ -214,7 +214,7 @@ trait ThriftServerWithSparkContextSuite extends SharedThriftServer {
val sessionHandle = client.openSession(user, "")
val infoValue = client.getInfo(sessionHandle, GetInfoType.CLI_ODBC_KEYWORDS)
// scalastyle:off line.size.limit
- assert(infoValue.getStringValue == "ADD,AFTER,AGGREGATE,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,AT,AUTHORIZATION,BEGIN,BETWEEN,BIGINT,BINARY,BINDING,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CALL,CALLED,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHAR,CHARACTER,CHECK,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLATION,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPENSATION,COMPUTE,CONCATENATE,CONSTRAINT,CONTAINS,COST,CREATE,CROSS,CUBE,CURRENT,CURRENT_DATE,CURRENT_TIME,CURRENT_TIMESTAMP,CURRENT_USER,DATA,DATABASE,DATABASES,DATE,DATEADD,DATEDIFF,DATE_ADD,DATE_DIFF,DAY,DAYOFYEAR,DAYS,DBPROPERTIES,DEC,DECIMAL,DECLARE,DEFAULT,DEFINED,DEFINER,DELETE,DELIMITED,DESC,DESCRIBE,DETERMINISTIC,DFS,DIRECTORIES,DIRECTORY,DISTINCT,DISTRIBUTE,DIV,DO,DOUBLE,DROP,ELSE,END,ESCAPE,ESCAPED,EVOLUTION,EXCEPT,EXCHANGE,EXCLUDE,EXECUTE,EXISTS,EXPLAIN,EXPORT,EXTEND,EXTENDED,EXTERNAL,EXTRACT,FALSE,FETCH,FIELDS,FILEFORMAT,FILTER,FIRST,FLOAT,FOLLOWING,FOR,FOREIGN,FORMAT,FORMATTED,FROM,FULL,FUNCTION,FUNCTIONS,GENERATED,GLOBAL,GRANT,GROUP,GROUPING,HAVING,HOUR,HOURS,IDENTIFIER,IDENTITY,IF,IGNORE,ILIKE,IMMEDIATE,IMPORT,IN,INCLUDE,INCREMENT,INDEX,INDEXES,INNER,INPATH,INPUT,INPUTFORMAT,INSERT,INT,INTEGER,INTERSECT,INTERVAL,INTO,INVOKER,IS,ITEMS,ITERATE,JOIN,JSON,KEYS,LANGUAGE,LAST,LATERAL,LAZY,LEADING,LEAVE,LEFT,LIKE,LIMIT,LINES,LIST,LOAD,LOCAL,LOCATION,LOCK,LOCKS,LOGICAL,LONG,LOOP,MACRO,MAP,MATCHED,MERGE,MICROSECOND,MICROSECONDS,MILLISECOND,MILLISECONDS,MINUS,MINUTE,MINUTES,MODIFIES,MONTH,MONTHS,MSCK,NAME,NAMESPACE,NAMESPACES,NANOSECOND,NANOSECONDS,NATURAL,NO,NONE,NOT,NULL,NULLS,NUMERIC,OF,OFFSET,ON,ONLY,OPTION,OPTIONS,OR,ORDER,OUT,OUTER,OUTPUTFORMAT,OVER,OVERLAPS,OVERLAY,OVERWRITE,PARTITION,PARTITIONED,PARTITIONS,PERCENT,PIVOT,PLACING,POSITION,PRECEDING,PRIMARY,PRINCIPALS,PROPERTIES,PURGE,QUARTER,QUERY,RANGE,READS,REAL,RECORDREADER,RECORDWRITER,RECOVER,RECURSIVE,REDUCE,REFERENCES,REFRESH,RENAME,REPAIR,REPEAT,REPEATABLE,REPLACE,RESET,RESPECT,RESTRICT,RETURN,RETURNS,REVOKE,RIGHT,ROLE,ROLES,ROLLBACK,ROLLUP,ROW,ROWS,SCHEMA,SCHEMAS,SECOND,SECONDS,SECURITY,SELECT,SEMI,SEPARATED,SERDE,SERDEPROPERTIES,SESSION_USER,SET,SETS,SHORT,SHOW,SINGLE,SKEWED,SMALLINT,SOME,SORT,SORTED,SOURCE,SPECIFIC,SQL,START,STATISTICS,STORED,STRATIFY,STRING,STRUCT,SUBSTR,SUBSTRING,SYNC,SYSTEM_TIME,SYSTEM_VERSION,TABLE,TABLES,TABLESAMPLE,TARGET,TBLPROPERTIES,TERMINATED,THEN,TIME,TIMEDIFF,TIMESTAMP,TIMESTAMPADD,TIMESTAMPDIFF,TIMESTAMP_LTZ,TIMESTAMP_NTZ,TINYINT,TO,TOUCH,TRAILING,TRANSACTION,TRANSACTIONS,TRANSFORM,TRIM,TRUE,TRUNCATE,TRY_CAST,TYPE,UNARCHIVE,UNBOUNDED,UNCACHE,UNION,UNIQUE,UNKNOWN,UNLOCK,UNPIVOT,UNSET,UNTIL,UPDATE,USE,USER,USING,VALUES,VAR,VARCHAR,VARIABLE,VARIANT,VERSION,VIEW,VIEWS,VOID,WEEK,WEEKS,WHEN,WHERE,WHILE,WINDOW,WITH,WITHIN,X,YEAR,YEARS,ZONE")
+ assert(infoValue.getStringValue == "ADD,AFTER,AGGREGATE,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,AT,AUTHORIZATION,BEGIN,BETWEEN,BIGINT,BINARY,BINDING,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CALL,CALLED,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHAR,CHARACTER,CHECK,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLATION,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPENSATION,COMPUTE,CONCATENATE,CONSTRAINT,CONTAINS,COST,CREATE,CROSS,CUBE,CURRENT,CURRENT_DATE,CURRENT_TIME,CURRENT_TIMESTAMP,CURRENT_USER,DATA,DATABASE,DATABASES,DATE,DATEADD,DATEDIFF,DATE_ADD,DATE_DIFF,DAY,DAYOFYEAR,DAYS,DBPROPERTIES,DEC,DECIMAL,DECLARE,DEFAULT,DEFINED,DEFINER,DELETE,DELIMITED,DESC,DESCRIBE,DETERMINISTIC,DFS,DIRECTORIES,DIRECTORY,DISTINCT,DISTRIBUTE,DIV,DO,DOUBLE,DROP,ELSE,ELSEIF,END,ESCAPE,ESCAPED,EVOLUTION,EXCEPT,EXCHANGE,EXCLUDE,EXECUTE,EXISTS,EXPLAIN,EXPORT,EXTEND,EXTENDED,EXTERNAL,EXTRACT,FALSE,FETCH,FIELDS,FILEFORMAT,FILTER,FIRST,FLOAT,FOLLOWING,FOR,FOREIGN,FORMAT,FORMATTED,FROM,FULL,FUNCTION,FUNCTIONS,GENERATED,GLOBAL,GRANT,GROUP,GROUPING,HAVING,HOUR,HOURS,IDENTIFIER,IDENTITY,IF,IGNORE,ILIKE,IMMEDIATE,IMPORT,IN,INCLUDE,INCREMENT,INDEX,INDEXES,INNER,INPATH,INPUT,INPUTFORMAT,INSERT,INT,INTEGER,INTERSECT,INTERVAL,INTO,INVOKER,IS,ITEMS,ITERATE,JOIN,JSON,KEYS,LANGUAGE,LAST,LATERAL,LAZY,LEADING,LEAVE,LEFT,LIKE,LIMIT,LINES,LIST,LOAD,LOCAL,LOCATION,LOCK,LOCKS,LOGICAL,LONG,LOOP,MACRO,MAP,MATCHED,MERGE,MICROSECOND,MICROSECONDS,MILLISECOND,MILLISECONDS,MINUS,MINUTE,MINUTES,MODIFIES,MONTH,MONTHS,MSCK,NAME,NAMESPACE,NAMESPACES,NANOSECOND,NANOSECONDS,NATURAL,NO,NONE,NOT,NULL,NULLS,NUMERIC,OF,OFFSET,ON,ONLY,OPTION,OPTIONS,OR,ORDER,OUT,OUTER,OUTPUTFORMAT,OVER,OVERLAPS,OVERLAY,OVERWRITE,PARTITION,PARTITIONED,PARTITIONS,PERCENT,PIVOT,PLACING,POSITION,PRECEDING,PRIMARY,PRINCIPALS,PROPERTIES,PURGE,QUARTER,QUERY,RANGE,READS,REAL,RECORDREADER,RECORDWRITER,RECOVER,RECURSIVE,REDUCE,REFERENCES,REFRESH,RENAME,REPAIR,REPEAT,REPEATABLE,REPLACE,RESET,RESPECT,RESTRICT,RETURN,RETURNS,REVOKE,RIGHT,ROLE,ROLES,ROLLBACK,ROLLUP,ROW,ROWS,SCHEMA,SCHEMAS,SECOND,SECONDS,SECURITY,SELECT,SEMI,SEPARATED,SERDE,SERDEPROPERTIES,SESSION_USER,SET,SETS,SHORT,SHOW,SINGLE,SKEWED,SMALLINT,SOME,SORT,SORTED,SOURCE,SPECIFIC,SQL,START,STATISTICS,STORED,STRATIFY,STRING,STRUCT,SUBSTR,SUBSTRING,SYNC,SYSTEM_TIME,SYSTEM_VERSION,TABLE,TABLES,TABLESAMPLE,TARGET,TBLPROPERTIES,TERMINATED,THEN,TIME,TIMEDIFF,TIMESTAMP,TIMESTAMPADD,TIMESTAMPDIFF,TIMESTAMP_LTZ,TIMESTAMP_NTZ,TINYINT,TO,TOUCH,TRAILING,TRANSACTION,TRANSACTIONS,TRANSFORM,TRIM,TRUE,TRUNCATE,TRY_CAST,TYPE,UNARCHIVE,UNBOUNDED,UNCACHE,UNION,UNIQUE,UNKNOWN,UNLOCK,UNPIVOT,UNSET,UNTIL,UPDATE,USE,USER,USING,VALUES,VAR,VARCHAR,VARIABLE,VARIANT,VERSION,VIEW,VIEWS,VOID,WEEK,WEEKS,WHEN,WHERE,WHILE,WINDOW,WITH,WITHIN,X,YEAR,YEARS,ZONE")
// scalastyle:on line.size.limit
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala
index 3d89f31e1965b..402d3c4ae7d92 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala
@@ -55,7 +55,7 @@ private[hive] trait HiveClient {
/**
* Runs a HiveQL command using Hive, returning the results as a list of strings. Each row will
- * result in one string.
+ * result in one string. This should be used only in testing environment.
*/
def runSqlHive(sql: String): Seq[String]
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
index 74f938e181793..3e7e81d25d943 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
@@ -49,6 +49,7 @@ import org.apache.spark.{SparkConf, SparkException, SparkThrowable}
import org.apache.spark.deploy.SparkHadoopUtil.SOURCE_SPARK
import org.apache.spark.internal.{Logging, LogKeys, MDC}
import org.apache.spark.internal.LogKeys._
+import org.apache.spark.internal.config.Tests.IS_TESTING
import org.apache.spark.metrics.source.HiveCatalogMetrics
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{DatabaseAlreadyExistsException, NoSuchDatabaseException, NoSuchPartitionException, NoSuchPartitionsException, NoSuchTableException, PartitionsAlreadyExistException}
@@ -858,8 +859,10 @@ private[hive] class HiveClientImpl(
/**
* Runs the specified SQL query using Hive.
+ * This should be used only in testing environment.
*/
override def runSqlHive(sql: String): Seq[String] = {
+ assert(Utils.isTesting, s"${IS_TESTING.key} is not set to true")
val maxResults = 100000
val results = runHive(sql, maxResults)
// It is very confusing when you only get back some of the results...