diff --git a/.github/workflows/publish_snapshot.yml b/.github/workflows/publish_snapshot.yml index a5854d96a4d1a..6e2e5709bbd18 100644 --- a/.github/workflows/publish_snapshot.yml +++ b/.github/workflows/publish_snapshot.yml @@ -28,7 +28,7 @@ on: description: 'list of branches to publish (JSON)' required: true # keep in sync with default value of strategy matrix 'branch' - default: '["master", "branch-3.5"]' + default: '["master", "branch-4.0", "branch-3.5"]' jobs: publish-snapshot: @@ -38,7 +38,7 @@ jobs: fail-fast: false matrix: # keep in sync with default value of workflow_dispatch input 'branch' - branch: ${{ fromJSON( inputs.branch || '["master", "branch-3.5"]' ) }} + branch: ${{ fromJSON( inputs.branch || '["master", "branch-4.0", "branch-3.5"]' ) }} steps: - name: Checkout Spark repository uses: actions/checkout@v4 diff --git a/LICENSE-binary b/LICENSE-binary index 2892c4b0ecce8..d2eea83525caf 100644 --- a/LICENSE-binary +++ b/LICENSE-binary @@ -296,7 +296,6 @@ jakarta.inject:jakarta.inject-api jakarta.validation:jakarta.validation-api javax.jdo:jdo-api joda-time:joda-time -net.java.dev.jna:jna net.sf.opencsv:opencsv net.sf.supercsv:super-csv net.sf.jpam:jpam diff --git a/assembly/pom.xml b/assembly/pom.xml index df0740e6c6949..525f3b2569bb3 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -323,6 +323,7 @@ hive-provided provided + provided provided diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 3bac6638f7060..92da13df5ff13 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -117,6 +117,13 @@ }, "sqlState" : "42604" }, + "AVRO_CANNOT_WRITE_NULL_FIELD" : { + "message" : [ + "Cannot write null value for field defined as non-null Avro data type .", + "To allow null value for this field, specify its avro schema as a union type with \"null\" using `avroSchema` option." + ], + "sqlState" : "22004" + }, "AVRO_INCOMPATIBLE_READ_TYPE" : { "message" : [ "Cannot convert Avro to SQL because the original encoded data type is , however you're trying to read the field as , which would lead to an incorrect answer.", diff --git a/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala b/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala index c3a1af68d1c82..b53b89ad68c5d 100644 --- a/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala +++ b/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala @@ -549,6 +549,7 @@ private[spark] object LogKeys { case object NUM_ROWS extends LogKey case object NUM_RULE_OF_RUNS extends LogKey case object NUM_SEQUENCES extends LogKey + case object NUM_SKIPPED extends LogKey case object NUM_SLOTS extends LogKey case object NUM_SPILLS extends LogKey case object NUM_SPILL_WRITERS extends LogKey diff --git a/common/utils/src/main/scala/org/apache/spark/internal/Logging.scala b/common/utils/src/main/scala/org/apache/spark/internal/Logging.scala index 4b60cb20f0732..110c5f0934286 100644 --- a/common/utils/src/main/scala/org/apache/spark/internal/Logging.scala +++ b/common/utils/src/main/scala/org/apache/spark/internal/Logging.scala @@ -17,6 +17,7 @@ package org.apache.spark.internal +import scala.concurrent.duration._ import scala.jdk.CollectionConverters._ import org.apache.logging.log4j.{CloseableThreadContext, Level, LogManager} @@ -27,6 +28,7 @@ import org.apache.logging.log4j.core.filter.AbstractFilter import org.slf4j.{Logger, LoggerFactory} import org.apache.spark.internal.Logging.SparkShellLoggingFilter +import org.apache.spark.internal.LogKeys import org.apache.spark.util.SparkClassUtils /** @@ -531,3 +533,158 @@ private[spark] object Logging { override def isStopped: Boolean = status == LifeCycle.State.STOPPED } } + +/** + * A thread-safe token bucket-based throttler implementation with nanosecond accuracy. + * + * Each instance must be shared across all scopes it should throttle. + * For global throttling that means either by extending this class in an `object` or + * by creating the instance as a field of an `object`. + * + * @param bucketSize This corresponds to the largest possible burst without throttling, + * in number of executions. + * @param tokenRecoveryInterval Time between two tokens being added back to the bucket. + * This is reciprocal of the long-term average unthrottled rate. + * + * Example: With a bucket size of 100 and a recovery interval of 1s, we could log up to 100 events + * in under a second without throttling, but at that point the bucket is exhausted and we only + * regain the ability to log more events at 1 event per second. If we log less than 1 event/s + * the bucket will slowly refill until it's back at 100. + * Either way, we can always log at least 1 event/s. + */ +class LogThrottler( + val bucketSize: Int = 100, + val tokenRecoveryInterval: FiniteDuration = 1.second, + val timeSource: NanoTimeTimeSource = SystemNanoTimeSource) extends Logging { + + private var remainingTokens = bucketSize + private var nextRecovery: DeadlineWithTimeSource = + DeadlineWithTimeSource.now(timeSource) + tokenRecoveryInterval + private var numSkipped: Long = 0 + + /** + * Run `thunk` as long as there are tokens remaining in the bucket, + * otherwise skip and remember number of skips. + * + * The argument to `thunk` is how many previous invocations have been skipped since the last time + * an invocation actually ran. + * + * Note: This method is `synchronized`, so it is concurrency safe. + * However, that also means no heavy-lifting should be done as part of this + * if the throttler is shared between concurrent threads. + * This also means that the synchronized block of the `thunk` that *does* execute will still + * hold up concurrent `thunk`s that will actually get rejected once they hold the lock. + * This is fine at low concurrency/low recovery rates. But if we need this to be more efficient at + * some point, we will need to decouple the check from the `thunk` execution. + */ + def throttled(thunk: Long => Unit): Unit = this.synchronized { + tryRecoverTokens() + if (remainingTokens > 0) { + thunk(numSkipped) + numSkipped = 0 + remainingTokens -= 1 + } else { + numSkipped += 1L + } + } + + /** + * Same as [[throttled]] but turns the number of skipped invocations into a logging message + * that can be appended to item being logged in `thunk`. + */ + def throttledWithSkippedLogMessage(thunk: MessageWithContext => Unit): Unit = { + this.throttled { numSkipped => + val skippedStr = if (numSkipped != 0L) { + log"[${MDC(LogKeys.NUM_SKIPPED, numSkipped)} similar messages were skipped.]" + } else { + log"" + } + thunk(skippedStr) + } + } + + /** + * Try to recover tokens, if the rate allows. + * + * Only call from within a `this.synchronized` block! + */ + private[spark] def tryRecoverTokens(): Unit = { + try { + // Doing it one-by-one is a bit inefficient for long periods, but it's easy to avoid jumps + // and rounding errors this way. The inefficiency shouldn't matter as long as the bucketSize + // isn't huge. + while (remainingTokens < bucketSize && nextRecovery.isOverdue()) { + remainingTokens += 1 + nextRecovery += tokenRecoveryInterval + } + + val currentTime = DeadlineWithTimeSource.now(timeSource) + if (remainingTokens == bucketSize && + (currentTime - nextRecovery) > tokenRecoveryInterval) { + // Reset the recovery time, so we don't accumulate infinite recovery while nothing is + // going on. + nextRecovery = currentTime + tokenRecoveryInterval + } + } catch { + case _: IllegalArgumentException => + // Adding FiniteDuration throws IllegalArgumentException instead of wrapping on overflow. + // Given that this happens every ~300 years, we can afford some non-linearity here, + // rather than taking the effort to properly work around that. + nextRecovery = DeadlineWithTimeSource(Duration(-Long.MaxValue, NANOSECONDS), timeSource) + } + } + + /** + * Resets throttler state to initial state. + * Visible for testing. + */ + def reset(): Unit = this.synchronized { + remainingTokens = bucketSize + nextRecovery = DeadlineWithTimeSource.now(timeSource) + tokenRecoveryInterval + numSkipped = 0 + } +} + +/** + * This is essentially the same as Scala's [[Deadline]], + * just with a custom source of nanoTime so it can actually be tested properly. + */ +case class DeadlineWithTimeSource( + time: FiniteDuration, + timeSource: NanoTimeTimeSource = SystemNanoTimeSource) { + // Only implemented the methods LogThrottler actually needs for now. + + /** + * Return a deadline advanced (i.e., moved into the future) by the given duration. + */ + def +(other: FiniteDuration): DeadlineWithTimeSource = copy(time = time + other) + + /** + * Calculate time difference between this and the other deadline, where the result is directed + * (i.e., may be negative). + */ + def -(other: DeadlineWithTimeSource): FiniteDuration = time - other.time + + /** + * Determine whether the deadline lies in the past at the point where this method is called. + */ + def isOverdue(): Boolean = (time.toNanos - timeSource.nanoTime()) <= 0 +} + +object DeadlineWithTimeSource { + /** + * Construct a deadline due exactly at the point where this method is called. Useful for then + * advancing it to obtain a future deadline, or for sampling the current time exactly once and + * then comparing it to multiple deadlines (using subtraction). + */ + def now(timeSource: NanoTimeTimeSource = SystemNanoTimeSource): DeadlineWithTimeSource = + DeadlineWithTimeSource(Duration(timeSource.nanoTime(), NANOSECONDS), timeSource) +} + +/** Generalisation of [[System.nanoTime()]]. */ +private[spark] trait NanoTimeTimeSource { + def nanoTime(): Long +} +private[spark] object SystemNanoTimeSource extends NanoTimeTimeSource { + override def nanoTime(): Long = System.nanoTime() +} diff --git a/common/utils/src/test/scala/org/apache/spark/util/LogThrottlingSuite.scala b/common/utils/src/test/scala/org/apache/spark/util/LogThrottlingSuite.scala new file mode 100644 index 0000000000000..47d7f85093a38 --- /dev/null +++ b/common/utils/src/test/scala/org/apache/spark/util/LogThrottlingSuite.scala @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import scala.concurrent.duration._ + +import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite + +import org.apache.spark.internal.{DeadlineWithTimeSource, Logging, LogThrottler, NanoTimeTimeSource} + + +class LogThrottlingSuite + extends AnyFunSuite // scalastyle:ignore funsuite + with Logging { + + // Make sure that the helper works right. + test("time control") { + val nanoTimeControl = new MockedNanoTime + assert(nanoTimeControl.nanoTime() === 0L) + assert(nanoTimeControl.nanoTime() === 0L) + + nanoTimeControl.advance(112L.nanos) + assert(nanoTimeControl.nanoTime() === 112L) + assert(nanoTimeControl.nanoTime() === 112L) + } + + test("deadline with time control") { + val nanoTimeControl = new MockedNanoTime + assert(DeadlineWithTimeSource.now(nanoTimeControl).isOverdue()) + val deadline = DeadlineWithTimeSource.now(nanoTimeControl) + 5.nanos + assert(!deadline.isOverdue()) + nanoTimeControl.advance(5.nanos) + assert(deadline.isOverdue()) + nanoTimeControl.advance(5.nanos) + assert(deadline.isOverdue()) + // Check addition. + assert(deadline + 0.nanos === deadline) + val increasedDeadline = deadline + 10.nanos + assert(!increasedDeadline.isOverdue()) + nanoTimeControl.advance(5.nanos) + assert(increasedDeadline.isOverdue()) + // Ensure that wrapping keeps throwing this exact exception, since we rely on it in + // LogThrottler.tryRecoverTokens + assertThrows[IllegalArgumentException] { + deadline + Long.MaxValue.nanos + } + // Check difference and ordering. + assert(deadline - deadline === 0.nanos) + assert(increasedDeadline - deadline === 10.nanos) + assert(increasedDeadline - deadline > 9.nanos) + assert(increasedDeadline - deadline < 11.nanos) + assert(deadline - increasedDeadline === -10.nanos) + assert(deadline - increasedDeadline < -9.nanos) + assert(deadline - increasedDeadline > -11.nanos) + } + + test("unthrottled, no burst") { + val nanotTimeControl = new MockedNanoTime + val throttler = new LogThrottler( + bucketSize = 1, + tokenRecoveryInterval = 5.nanos, + timeSource = nanotTimeControl) + val numInvocations = 100 + var timesExecuted = 0 + for (i <- 0 until numInvocations) { + throttler.throttled { skipped => + assert(skipped === 0L) + timesExecuted += 1 + } + nanotTimeControl.advance(5.nanos) + } + assert(timesExecuted === numInvocations) + } + + test("unthrottled, burst") { + val nanotTimeControl = new MockedNanoTime + val throttler = new LogThrottler( + bucketSize = 100, + tokenRecoveryInterval = 1000000.nanos, // Just to make it obvious that it's a large number. + timeSource = nanotTimeControl) + val numInvocations = 100 + var timesExecuted = 0 + for (_ <- 0 until numInvocations) { + throttler.throttled { skipped => + assert(skipped === 0L) + timesExecuted += 1 + } + nanotTimeControl.advance(5.nanos) + } + assert(timesExecuted === numInvocations) + } + + test("throttled, no burst") { + val nanoTimeControl = new MockedNanoTime + val throttler = new LogThrottler( + bucketSize = 1, + tokenRecoveryInterval = 5.nanos, + timeSource = nanoTimeControl) + val numInvocations = 100 + var timesExecuted = 0 + for (i <- 0 until numInvocations) { + throttler.throttled { skipped => + if (timesExecuted == 0) { + assert(skipped === 0L) + } else { + assert(skipped === 4L) + } + timesExecuted += 1 + } + nanoTimeControl.advance(1.nanos) + } + assert(timesExecuted === numInvocations / 5) + } + + test("throttled, single burst") { + val nanoTimeControl = new MockedNanoTime + val throttler = new LogThrottler( + bucketSize = 5, + tokenRecoveryInterval = 10.nanos, + timeSource = nanoTimeControl) + val numInvocations = 100 + var timesExecuted = 0 + for (i <- 0 until numInvocations) { + throttler.throttled { skipped => + if (i < 5) { + // First burst + assert(skipped === 0L) + } else if (i == 10) { + // First token recovery + assert(skipped === 5L) + } else { + // All other token recoveries + assert(skipped === 9L) + } + timesExecuted += 1 + } + nanoTimeControl.advance(1.nano) + } + // A burst of 5 and then 1 every 10ns/invocations. + assert(timesExecuted === 5 + (numInvocations - 10) / 10) + } + + test("throttled, bursty") { + val nanoTimeControl = new MockedNanoTime + val throttler = new LogThrottler( + bucketSize = 5, + tokenRecoveryInterval = 10.nanos, + timeSource = nanoTimeControl) + val numBursts = 10 + val numInvocationsPerBurst = 10 + var timesExecuted = 0 + for (burst <- 0 until numBursts) { + for (i <- 0 until numInvocationsPerBurst) { + throttler.throttled { skipped => + if (i == 0 && burst != 0) { + // first after recovery + assert(skipped === 5L) + } else { + // either first burst, or post-recovery on every other burst. + assert(skipped === 0L) + } + timesExecuted += 1 + } + nanoTimeControl.advance(1.nano) + } + nanoTimeControl.advance(100.nanos) + } + // Bursts of 5. + assert(timesExecuted === 5 * numBursts) + } + + test("wraparound") { + val nanoTimeControl = new MockedNanoTime + val throttler = new LogThrottler( + bucketSize = 1, + tokenRecoveryInterval = 100.nanos, + timeSource = nanoTimeControl) + def executeThrottled(expectedSkipped: Long = 0L): Boolean = { + var executed = false + throttler.throttled { skipped => + assert(skipped === expectedSkipped) + executed = true + } + executed + } + assert(executeThrottled()) + assert(!executeThrottled()) + + // Move to 2 ns before wrapping. + nanoTimeControl.advance((Long.MaxValue - 1L).nanos) + assert(executeThrottled(expectedSkipped = 1L)) + assert(!executeThrottled()) + + nanoTimeControl.advance(1.nano) + assert(!executeThrottled()) + + // Wrapping + nanoTimeControl.advance(1.nano) + assert(!executeThrottled()) + + // Recover + nanoTimeControl.advance(100.nanos) + assert(executeThrottled(expectedSkipped = 3L)) + } +} + +/** + * Use a mocked object to replace calls to `System.nanoTime()` with a custom value that can be + * controlled by calling `advance(nanos)` on an instance of this class. + */ +class MockedNanoTime extends NanoTimeTimeSource { + private var currentTimeNs: Long = 0L + + override def nanoTime(): Long = currentTimeNs + + def advance(time: FiniteDuration): Unit = { + currentTimeNs += time.toNanos + } +} diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 8342ca4e84275..4ddf6503d99ec 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -25,7 +25,7 @@ import java.util.UUID import scala.jdk.CollectionConverters._ -import org.apache.avro.{AvroTypeException, Schema, SchemaBuilder, SchemaFormatter} +import org.apache.avro.{Schema, SchemaBuilder, SchemaFormatter} import org.apache.avro.Schema.{Field, Type} import org.apache.avro.Schema.Type._ import org.apache.avro.file.{DataFileReader, DataFileWriter} @@ -33,7 +33,7 @@ import org.apache.avro.generic.{GenericData, GenericDatumReader, GenericDatumWri import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed} import org.apache.commons.io.FileUtils -import org.apache.spark.{SPARK_VERSION_SHORT, SparkConf, SparkException, SparkUpgradeException} +import org.apache.spark.{SPARK_VERSION_SHORT, SparkConf, SparkException, SparkRuntimeException, SparkThrowable, SparkUpgradeException} import org.apache.spark.TestUtils.assertExceptionMsg import org.apache.spark.sql._ import org.apache.spark.sql.TestingUDT.IntervalData @@ -45,7 +45,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone import org.apache.spark.sql.execution.{FormattedMode, SparkPlan} import org.apache.spark.sql.execution.datasources.{CommonFileDataSourceSuite, DataSource, FilePartition} import org.apache.spark.sql.execution.datasources.v2.BatchScanExec -import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.LegacyBehaviorPolicy import org.apache.spark.sql.internal.LegacyBehaviorPolicy._ import org.apache.spark.sql.internal.SQLConf @@ -100,6 +100,14 @@ abstract class AvroSuite SchemaFormatter.format(AvroUtils.JSON_INLINE_FORMAT, schema) } + private def getRootCause(ex: Throwable): Throwable = { + var rootCause = ex + while (rootCause.getCause != null) { + rootCause = rootCause.getCause + } + rootCause + } + // Check whether an Avro schema of union type is converted to SQL in an expected way, when the // stable ID option is on. // @@ -1317,7 +1325,16 @@ abstract class AvroSuite dfWithNull.write.format("avro") .option("avroSchema", avroSchema).save(s"$tempDir/${UUID.randomUUID()}") } - assertExceptionMsg[AvroTypeException](e1, "value null is not a SuitEnumType") + + val expectedDatatype = "{\"type\":\"enum\",\"name\":\"SuitEnumType\"," + + "\"symbols\":[\"SPADES\",\"HEARTS\",\"DIAMONDS\",\"CLUBS\"]}" + + checkError( + getRootCause(e1).asInstanceOf[SparkThrowable], + condition = "AVRO_CANNOT_WRITE_NULL_FIELD", + parameters = Map( + "name" -> "`Suit`", + "dataType" -> expectedDatatype)) // Writing df containing data not in the enum will throw an exception val e2 = intercept[SparkException] { @@ -1332,6 +1349,50 @@ abstract class AvroSuite } } + test("to_avro nested struct schema nullability mismatch") { + Seq((true, false), (false, true)).foreach { + case (innerNull, outerNull) => + val innerSchema = StructType(Seq(StructField("field1", IntegerType, innerNull))) + val outerSchema = StructType(Seq(StructField("innerStruct", innerSchema, outerNull))) + val nestedSchema = StructType(Seq(StructField("outerStruct", outerSchema, false))) + + val rowWithNull = if (innerNull) Row(Row(null)) else Row(null) + val data = Seq(Row(Row(Row(1))), Row(rowWithNull), Row(Row(Row(3)))) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), nestedSchema) + + val avroTypeStruct = s"""{ + | "type": "record", + | "name": "outerStruct", + | "fields": [ + | { + | "name": "innerStruct", + | "type": { + | "type": "record", + | "name": "innerStruct", + | "fields": [ + | {"name": "field1", "type": "int"} + | ] + | } + | } + | ] + |} + """.stripMargin // nullability mismatch for innerStruct + + val expectedErrorName = if (outerNull) "`innerStruct`" else "`field1`" + val expectedErrorSchema = if (outerNull) "{\"type\":\"record\",\"name\":\"innerStruct\"" + + ",\"fields\":[{\"name\":\"field1\",\"type\":\"int\"}]}" else "\"int\"" + + checkError( + exception = intercept[SparkRuntimeException] { + df.select(avro.functions.to_avro($"outerStruct", avroTypeStruct)).collect() + }, + condition = "AVRO_CANNOT_WRITE_NULL_FIELD", + parameters = Map( + "name" -> expectedErrorName, + "dataType" -> expectedErrorSchema)) + } + } + test("support user provided avro schema for writing nullable fixed type") { withTempPath { tempDir => val avroSchema = @@ -1517,9 +1578,12 @@ abstract class AvroSuite .save(s"$tempDir/${UUID.randomUUID()}") } assert(ex.getCondition == "TASK_WRITE_FAILED") - assert(ex.getCause.isInstanceOf[java.lang.NullPointerException]) - assert(ex.getCause.getMessage.contains( - "null value for (non-nullable) string at test_schema.Name")) + checkError( + ex.getCause.asInstanceOf[SparkThrowable], + condition = "AVRO_CANNOT_WRITE_NULL_FIELD", + parameters = Map( + "name" -> "`Name`", + "dataType" -> "\"string\"")) } } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala index fd7efb1efb764..04637c1b55631 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala @@ -171,7 +171,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD // scalastyle:off assert(getExternalEngineQuery(df.queryExecution.executedPlan) == - """SELECT "dept","name","salary","bonus" FROM "employee" WHERE (CASE WHEN ("name" = 'Legolas') THEN IIF(("name" = 'Elf'), 1, 0) ELSE IIF(("name" <> 'Wizard'), 1, 0) END = 1) """ + """SELECT "dept","name","salary","bonus" FROM "employee" WHERE (CASE WHEN ("name" = 'Legolas') THEN IIF(("name" = 'Elf'), 1, 0) ELSE IIF(("name" <> 'Wizard'), 1, 0) END = 1) """ ) // scalastyle:on df.collect() @@ -186,7 +186,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD // scalastyle:off assert(getExternalEngineQuery(df.queryExecution.executedPlan) == - """SELECT "dept","name","salary","bonus" FROM "employee" WHERE (CASE WHEN ("name" = 'Legolas') THEN IIF(("name" = 'Elf'), 1, 0) ELSE 1 END = 1) """ + """SELECT "dept","name","salary","bonus" FROM "employee" WHERE (CASE WHEN ("name" = 'Legolas') THEN IIF(("name" = 'Elf'), 1, 0) ELSE 1 END = 1) """ ) // scalastyle:on df.collect() @@ -203,7 +203,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD // scalastyle:off assert(getExternalEngineQuery(df.queryExecution.executedPlan) == - """SELECT "dept","name","salary","bonus" FROM "employee" WHERE (CASE WHEN ("name" = 'Legolas') THEN IIF((CASE WHEN ("name" = 'Elf') THEN IIF(("name" = 'Elrond'), 1, 0) ELSE IIF(("name" = 'Gandalf'), 1, 0) END = 1), 1, 0) ELSE IIF(("name" = 'Sauron'), 1, 0) END = 1) """ + """SELECT "dept","name","salary","bonus" FROM "employee" WHERE (CASE WHEN ("name" = 'Legolas') THEN IIF((CASE WHEN ("name" = 'Elf') THEN IIF(("name" = 'Elrond'), 1, 0) ELSE IIF(("name" = 'Gandalf'), 1, 0) END = 1), 1, 0) ELSE IIF(("name" = 'Sauron'), 1, 0) END = 1) """ ) // scalastyle:on df.collect() @@ -220,7 +220,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD // scalastyle:off assert(getExternalEngineQuery(df.queryExecution.executedPlan) == - """SELECT "dept","name","salary","bonus" FROM "employee" WHERE ("name" IS NOT NULL) AND ((CASE WHEN "name" = 'Legolas' THEN CASE WHEN "name" = 'Elf' THEN 'Elf' ELSE 'Wizard' END ELSE 'Sauron' END) = "name") """ + """SELECT "dept","name","salary","bonus" FROM "employee" WHERE ("name" IS NOT NULL) AND ((CASE WHEN "name" = 'Legolas' THEN CASE WHEN "name" = 'Elf' THEN 'Elf' ELSE 'Wizard' END ELSE 'Sauron' END) = "name") """ ) // scalastyle:on df.collect() diff --git a/core/src/test/scala/org/apache/spark/LocalRootDirsTest.scala b/core/src/test/scala/org/apache/spark/LocalRootDirsTest.scala index 3a813f4d8b53c..a7968b6f2b022 100644 --- a/core/src/test/scala/org/apache/spark/LocalRootDirsTest.scala +++ b/core/src/test/scala/org/apache/spark/LocalRootDirsTest.scala @@ -20,7 +20,7 @@ package org.apache.spark import java.io.File import java.util.UUID -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} trait LocalRootDirsTest extends SparkFunSuite with LocalSparkContext { @@ -46,7 +46,13 @@ trait LocalRootDirsTest extends SparkFunSuite with LocalSparkContext { override def afterEach(): Unit = { try { - Utils.deleteRecursively(tempDir) + // SPARK-51030: Only perform manual cleanup of the `tempDir` if it has + // not been registered for cleanup via a shutdown hook, to avoid potential + // IOException due to race conditions during multithreaded cleanup. + if (!ShutdownHookManager.hasShutdownDeleteDir(tempDir) && + !ShutdownHookManager.hasRootAsShutdownDeleteDir(tempDir)) { + Utils.deleteRecursively(tempDir) + } Utils.clearLocalRootDirs() } finally { super.afterEach() diff --git a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala index 27a53e8205201..e3171116a3e14 100644 --- a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala @@ -112,7 +112,7 @@ class CryptoStreamUtilsSuite extends SparkFunSuite { bytes.toByteArray() }.collect()(0) - assert(content != encrypted) + assert(!content.getBytes(UTF_8).sameElements(encrypted)) val in = CryptoStreamUtils.createCryptoInputStream(new ByteArrayInputStream(encrypted), sc.conf, SparkEnv.get.securityManager.getIOEncryptionKey().get) diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala deleted file mode 100644 index c76c97d071418..0000000000000 --- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala +++ /dev/null @@ -1,162 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ui - -import scala.xml.Node - -import jakarta.servlet.http.HttpServletRequest -import org.mockito.Mockito.{mock, when, RETURNS_SMART_NULLS} - -import org.apache.spark._ -import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} -import org.apache.spark.internal.config.Status._ -import org.apache.spark.resource.ResourceProfile -import org.apache.spark.scheduler._ -import org.apache.spark.status.AppStatusStore -import org.apache.spark.status.api.v1.{AccumulableInfo => UIAccumulableInfo, StageData, StageStatus} -import org.apache.spark.ui.jobs.{StagePage, StagesTab} - -class StagePageSuite extends SparkFunSuite with LocalSparkContext { - - private val peakExecutionMemory = 10 - - test("ApiHelper.COLUMN_TO_INDEX should match headers of the task table") { - val conf = new SparkConf(false).set(LIVE_ENTITY_UPDATE_PERIOD, 0L) - val statusStore = AppStatusStore.createLiveStore(conf) - try { - val stageData = new StageData( - status = StageStatus.ACTIVE, - stageId = 1, - attemptId = 1, - numTasks = 1, - numActiveTasks = 1, - numCompleteTasks = 1, - numFailedTasks = 1, - numKilledTasks = 1, - numCompletedIndices = 1, - - submissionTime = None, - firstTaskLaunchedTime = None, - completionTime = None, - failureReason = None, - - executorDeserializeTime = 1L, - executorDeserializeCpuTime = 1L, - executorRunTime = 1L, - executorCpuTime = 1L, - resultSize = 1L, - jvmGcTime = 1L, - resultSerializationTime = 1L, - memoryBytesSpilled = 1L, - diskBytesSpilled = 1L, - peakExecutionMemory = 1L, - inputBytes = 1L, - inputRecords = 1L, - outputBytes = 1L, - outputRecords = 1L, - shuffleRemoteBlocksFetched = 1L, - shuffleLocalBlocksFetched = 1L, - shuffleFetchWaitTime = 1L, - shuffleRemoteBytesRead = 1L, - shuffleRemoteBytesReadToDisk = 1L, - shuffleLocalBytesRead = 1L, - shuffleReadBytes = 1L, - shuffleReadRecords = 1L, - shuffleCorruptMergedBlockChunks = 1L, - shuffleMergedFetchFallbackCount = 1L, - shuffleMergedRemoteBlocksFetched = 1L, - shuffleMergedLocalBlocksFetched = 1L, - shuffleMergedRemoteChunksFetched = 1L, - shuffleMergedLocalChunksFetched = 1L, - shuffleMergedRemoteBytesRead = 1L, - shuffleMergedLocalBytesRead = 1L, - shuffleRemoteReqsDuration = 1L, - shuffleMergedRemoteReqsDuration = 1L, - shuffleWriteBytes = 1L, - shuffleWriteTime = 1L, - shuffleWriteRecords = 1L, - - name = "stage1", - description = Some("description"), - details = "detail", - schedulingPool = "pool1", - - rddIds = Seq(1), - accumulatorUpdates = Seq(new UIAccumulableInfo(0L, "acc", None, "value")), - tasks = None, - executorSummary = None, - speculationSummary = None, - killedTasksSummary = Map.empty, - ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID, - peakExecutorMetrics = None, - taskMetricsDistributions = None, - executorMetricsDistributions = None, - isShufflePushEnabled = false, - shuffleMergersCount = 0 - ) - } finally { - statusStore.close() - } - } - - /** - * Render a stage page started with the given conf and return the HTML. - * This also runs a dummy stage to populate the page with useful content. - */ - private def renderStagePage(): Seq[Node] = { - val conf = new SparkConf(false).set(LIVE_ENTITY_UPDATE_PERIOD, 0L) - val statusStore = AppStatusStore.createLiveStore(conf) - val listener = statusStore.listener.get - - try { - val tab = mock(classOf[StagesTab], RETURNS_SMART_NULLS) - when(tab.store).thenReturn(statusStore) - - val request = mock(classOf[HttpServletRequest]) - when(tab.conf).thenReturn(conf) - when(tab.appName).thenReturn("testing") - when(tab.headerTabs).thenReturn(Seq.empty) - when(request.getParameter("id")).thenReturn("0") - when(request.getParameter("attempt")).thenReturn("0") - val page = new StagePage(tab, statusStore) - - // Simulate a stage in job progress listener - val stageInfo = new StageInfo(0, 0, "dummy", 1, Seq.empty, Seq.empty, "details", - resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) - // Simulate two tasks to test PEAK_EXECUTION_MEMORY correctness - (1 to 2).foreach { - taskId => - val taskInfo = new TaskInfo(taskId, taskId, 0, taskId, 0, - "0", "localhost", TaskLocality.ANY, false) - listener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo)) - listener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo)) - taskInfo.markFinished(TaskState.FINISHED, System.currentTimeMillis()) - val taskMetrics = TaskMetrics.empty - val executorMetrics = new ExecutorMetrics - taskMetrics.incPeakExecutionMemory(peakExecutionMemory) - listener.onTaskEnd(SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, - executorMetrics, taskMetrics)) - } - listener.onStageCompleted(SparkListenerStageCompleted(stageInfo)) - page.render(request) - } finally { - statusStore.close() - } - } - -} diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index 96d5f9d477143..c4228bc543ebb 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -86,7 +86,6 @@ hive-cli/2.3.10//hive-cli-2.3.10.jar hive-common/2.3.10//hive-common-2.3.10.jar hive-exec/2.3.10/core/hive-exec-2.3.10-core.jar hive-jdbc/2.3.10//hive-jdbc-2.3.10.jar -hive-llap-common/2.3.10//hive-llap-common-2.3.10.jar hive-metastore/2.3.10//hive-metastore-2.3.10.jar hive-serde/2.3.10//hive-serde-2.3.10.jar hive-service-rpc/4.0.0//hive-service-rpc-4.0.0.jar @@ -121,7 +120,7 @@ jakarta.validation-api/3.0.2//jakarta.validation-api-3.0.2.jar jakarta.ws.rs-api/3.0.0//jakarta.ws.rs-api-3.0.0.jar jakarta.xml.bind-api/2.3.2//jakarta.xml.bind-api-2.3.2.jar janino/3.1.9//janino-3.1.9.jar -java-diff-utils/4.12//java-diff-utils-4.12.jar +java-diff-utils/4.15//java-diff-utils-4.15.jar java-xmlbuilder/1.2//java-xmlbuilder-1.2.jar javassist/3.30.2-GA//javassist-3.30.2-GA.jar javax.jdo/3.2.0-m3//javax.jdo-3.2.0-m3.jar @@ -145,8 +144,7 @@ jjwt-api/0.12.6//jjwt-api-0.12.6.jar jjwt-impl/0.12.6//jjwt-impl-0.12.6.jar jjwt-jackson/0.12.6//jjwt-jackson-0.12.6.jar jline/2.14.6//jline-2.14.6.jar -jline/3.26.3//jline-3.26.3.jar -jna/5.14.0//jna-5.14.0.jar +jline/3.27.1/jdk8/jline-3.27.1-jdk8.jar joda-time/2.13.0//joda-time-2.13.0.jar jodd-core/3.5.2//jodd-core-3.5.2.jar jpam/1.1//jpam-1.1.jar @@ -254,11 +252,11 @@ py4j/0.10.9.9//py4j-0.10.9.9.jar remotetea-oncrpc/1.1.2//remotetea-oncrpc-1.1.2.jar rocksdbjni/9.8.4//rocksdbjni-9.8.4.jar scala-collection-compat_2.13/2.7.0//scala-collection-compat_2.13-2.7.0.jar -scala-compiler/2.13.15//scala-compiler-2.13.15.jar -scala-library/2.13.15//scala-library-2.13.15.jar +scala-compiler/2.13.16//scala-compiler-2.13.16.jar +scala-library/2.13.16//scala-library-2.13.16.jar scala-parallel-collections_2.13/1.2.0//scala-parallel-collections_2.13-1.2.0.jar scala-parser-combinators_2.13/2.4.0//scala-parser-combinators_2.13-2.4.0.jar -scala-reflect/2.13.15//scala-reflect-2.13.15.jar +scala-reflect/2.13.16//scala-reflect-2.13.16.jar scala-xml_2.13/2.3.0//scala-xml_2.13-2.3.0.jar slf4j-api/2.0.16//slf4j-api-2.0.16.jar snakeyaml-engine/2.8//snakeyaml-engine-2.8.jar diff --git a/docs/_config.yml b/docs/_config.yml index 67c4237392acd..7a28a7c76099d 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -22,7 +22,7 @@ include: SPARK_VERSION: 4.1.0-SNAPSHOT SPARK_VERSION_SHORT: 4.1.0 SCALA_BINARY_VERSION: "2.13" -SCALA_VERSION: "2.13.15" +SCALA_VERSION: "2.13.16" SPARK_ISSUE_TRACKER_URL: https://issues.apache.org/jira/browse/SPARK SPARK_GITHUB_URL: https://github.com/apache/spark # Before a new release, we should: diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index 254c54a414a7e..f459a88d8e148 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -31,6 +31,7 @@ license: | - Since Spark 4.0, any read of SQL tables takes into consideration the SQL configs `spark.sql.files.ignoreCorruptFiles`/`spark.sql.files.ignoreMissingFiles` instead of the core config `spark.files.ignoreCorruptFiles`/`spark.files.ignoreMissingFiles`. - Since Spark 4.0, when reading SQL tables hits `org.apache.hadoop.security.AccessControlException` and `org.apache.hadoop.hdfs.BlockMissingException`, the exception will be thrown and fail the task, even if `spark.sql.files.ignoreCorruptFiles` is set to `true`. - Since Spark 4.0, `spark.sql.hive.metastore` drops the support of Hive prior to 2.0.0 as they require JDK 8 that Spark does not support anymore. Users should migrate to higher versions. +- Since Spark 4.0, Spark removes `hive-llap-common` dependency. To restore the previous behavior, add `hive-llap-common` jar to the class path. - Since Spark 4.0, `spark.sql.parquet.compression.codec` drops the support of codec name `lz4raw`, please use `lz4_raw` instead. - Since Spark 4.0, when overflowing during casting timestamp to byte/short/int under non-ansi mode, Spark will return null instead a wrapping value. - Since Spark 4.0, the `encode()` and `decode()` functions support only the following charsets 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16', 'UTF-32'. To restore the previous behavior when the function accepts charsets of the current JDK used by Spark, set `spark.sql.legacy.javaCharsets` to `true`. diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md index 45b8e9a4dcea7..7c7e2a6909574 100644 --- a/docs/sql-ref-ansi-compliance.md +++ b/docs/sql-ref-ansi-compliance.md @@ -503,6 +503,7 @@ Below is a list of all the keywords in Spark SQL. |DOUBLE|non-reserved|non-reserved|reserved| |DROP|non-reserved|non-reserved|reserved| |ELSE|reserved|non-reserved|reserved| +|ELSEIF|non-reserved|non-reserved|non-reserved| |END|reserved|non-reserved|reserved| |ESCAPE|reserved|non-reserved|reserved| |ESCAPED|non-reserved|non-reserved|non-reserved| diff --git a/pom.xml b/pom.xml index d83d9ceaaaec5..94e2b8650bac5 100644 --- a/pom.xml +++ b/pom.xml @@ -169,7 +169,7 @@ 3.2.2 4.4 - 2.13.15 + 2.13.16 2.13 2.2.0 4.9.2 @@ -230,7 +230,7 @@ and ./python/packaging/connect/setup.py too. --> 18.1.0 - 3.0.0 + 3.0.1 0.12.6 @@ -274,7 +274,7 @@ compile compile compile - compile + test compile compile compile @@ -335,7 +335,7 @@ 12.8.1.jre11 23.6.0.24.10 2.7.1 - 3.21.0 + 3.22.0 20.00.00.39 ${project.version} @@ -2255,7 +2255,7 @@ ${hive.group} hive-llap-common ${hive.version} - ${hive.deps.scope} + ${hive.llap.scope} ${hive.group} @@ -2657,14 +2657,6 @@ ${java.version} test provided - - - org.jline.terminal.impl.ffm.* - diff --git a/python/pyspark/sql/tests/test_python_datasource.py b/python/pyspark/sql/tests/test_python_datasource.py index 9b132f5d693fd..34299bdb7740c 100644 --- a/python/pyspark/sql/tests/test_python_datasource.py +++ b/python/pyspark/sql/tests/test_python_datasource.py @@ -15,6 +15,7 @@ # limitations under the License. # import os +import platform import tempfile import unittest from typing import Callable, Union @@ -508,6 +509,9 @@ def write(self, iterator): ): df.write.format("test").mode("append").saveAsTable("test_table") + @unittest.skipIf( + "pypy" in platform.python_implementation().lower(), "cannot run in environment pypy" + ) def test_data_source_segfault(self): import ctypes diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index ef029adfb476b..95a81236fbf52 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -15,6 +15,7 @@ # limitations under the License. # import os +import platform import shutil import tempfile import unittest @@ -2761,6 +2762,9 @@ def eval(self, n): res = self.spark.sql("select i, to_json(v['v1']) from test_udtf_struct(8)") assertDataFrameEqual(res, [Row(i=n, s=f'{{"a":"{chr(99 + n)}"}}') for n in range(8)]) + @unittest.skipIf( + "pypy" in platform.python_implementation().lower(), "cannot run in environment pypy" + ) def test_udtf_segfault(self): for enabled, expected in [ (True, "Segmentation fault"), diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 index dafeed48aef11..360854d81e384 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 @@ -218,6 +218,7 @@ DO: 'DO'; DOUBLE: 'DOUBLE'; DROP: 'DROP'; ELSE: 'ELSE'; +ELSEIF: 'ELSEIF'; END: 'END'; ESCAPE: 'ESCAPE'; ESCAPED: 'ESCAPED'; diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 8e6edac2108a6..9b438a667a3e5 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -85,7 +85,7 @@ whileStatement ifElseStatement : IF booleanExpression THEN conditionalBodies+=compoundBody - (ELSE IF booleanExpression THEN conditionalBodies+=compoundBody)* + (ELSEIF booleanExpression THEN conditionalBodies+=compoundBody)* (ELSE elseBody=compoundBody)? END IF ; @@ -1642,6 +1642,7 @@ ansiNonReserved | DO | DOUBLE | DROP + | ELSEIF | ESCAPED | EVOLUTION | EXCHANGE @@ -1987,6 +1988,7 @@ nonReserved | DOUBLE | DROP | ELSE + | ELSEIF | END | ESCAPE | ESCAPED diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala index 9fbfb9e679e58..23de9c222724b 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala @@ -127,11 +127,13 @@ object CurrentOrigin { } } - private val sparkCodePattern = Pattern.compile("(org\\.apache\\.spark\\.sql\\." + - "(?:(classic|connect)\\.)?" + - "(?:functions|Column|ColumnName|SQLImplicits|Dataset|DataFrameStatFunctions|DatasetHolder)" + - "(?:|\\..*|\\$.*))" + - "|(scala\\.collection\\..*)") + private val sparkCodePattern = Pattern.compile( + "(org\\.apache\\.spark\\.sql\\." + + "(?:(classic|connect)\\.)?" + + "(?:functions|Column|ColumnName|SQLImplicits|Dataset|DataFrameStatFunctions|DatasetHolder" + + "|SparkSession|ColumnNodeToProtoConverter)" + + "(?:|\\..*|\\$.*))" + + "|(scala\\.collection\\..*)") private def sparkCode(ste: StackTraceElement): Boolean = { sparkCodePattern.matcher(ste.getClassName).matches() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala index ad00a5216b4c9..b3bd86149ba91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala @@ -81,9 +81,9 @@ case class CompoundBody( /** * Logical operator for IF ELSE statement. * @param conditions Collection of conditions. First condition corresponds to IF clause, - * while others (if any) correspond to following ELSE IF clauses. + * while others (if any) correspond to following ELSEIF clauses. * @param conditionalBodies Collection of bodies that have a corresponding condition, - * in IF or ELSE IF branches. + * in IF or ELSEIF branches. * @param elseBody Body that is executed if none of the conditions are met, * i.e. ELSE branch. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index e129c6dbba052..40ba7809e5cee 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -501,12 +501,12 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { .getText == "SELECT 2") } - test("if else if") { + test("if elseif") { val sqlScriptText = """BEGIN |IF 1 = 1 THEN | SELECT 1; - |ELSE IF 2 = 2 THEN + |ELSEIF 2 = 2 THEN | SELECT 2; |ELSE | SELECT 3; @@ -541,14 +541,14 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { .getText == "SELECT 3") } - test("if multi else if") { + test("if multi elseif") { val sqlScriptText = """BEGIN |IF 1 = 1 THEN | SELECT 1; - |ELSE IF 2 = 2 THEN + |ELSEIF 2 = 2 THEN | SELECT 2; - |ELSE IF 3 = 3 THEN + |ELSEIF 3 = 3 THEN | SELECT 3; |END IF; |END @@ -584,6 +584,87 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { .getText == "SELECT 3") } + test("if - multi elseif - else nested") { + val sqlScriptText = + """BEGIN + |IF 1 = 1 THEN + | SELECT 1; + |ELSEIF 2 = 2 THEN + | SELECT 2; + |ELSE + | IF 3 = 3 THEN + | SELECT 3; + | ELSEIF 4 = 4 THEN + | SELECT 4; + | ELSE + | IF 5 = 5 THEN + | SELECT 5; + | ELSE + | SELECT 6; + | END IF; + | END IF; + |END IF; + |END + """.stripMargin + val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[IfElseStatement]) + + val ifStmt = tree.collection.head.asInstanceOf[IfElseStatement] + assert(ifStmt.conditions.length == 2) + assert(ifStmt.conditionalBodies.length == 2) + assert(ifStmt.elseBody.nonEmpty) + + assert(ifStmt.conditions.head.isInstanceOf[SingleStatement]) + assert(ifStmt.conditions.head.getText == "1 = 1") + + assert(ifStmt.conditionalBodies.head.collection.head.isInstanceOf[SingleStatement]) + assert(ifStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement] + .getText == "SELECT 1") + + assert(ifStmt.conditions(1).isInstanceOf[SingleStatement]) + assert(ifStmt.conditions(1).getText == "2 = 2") + + assert(ifStmt.conditionalBodies(1).collection.head.isInstanceOf[SingleStatement]) + assert(ifStmt.conditionalBodies(1).collection.head.asInstanceOf[SingleStatement] + .getText == "SELECT 2") + + assert(ifStmt.elseBody.get.collection.head.isInstanceOf[IfElseStatement]) + val nestedIf_1 = ifStmt.elseBody.get.collection.head.asInstanceOf[IfElseStatement] + + assert(nestedIf_1.conditions.length == 2) + assert(nestedIf_1.conditionalBodies.length == 2) + assert(nestedIf_1.elseBody.nonEmpty) + + + assert(nestedIf_1.conditions.head.isInstanceOf[SingleStatement]) + assert(nestedIf_1.conditions.head.getText == "3 = 3") + + assert(nestedIf_1.conditionalBodies.head.collection.head.isInstanceOf[SingleStatement]) + assert(nestedIf_1.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement] + .getText == "SELECT 3") + + assert(nestedIf_1.conditions(1).isInstanceOf[SingleStatement]) + assert(nestedIf_1.conditions(1).getText == "4 = 4") + + assert(nestedIf_1.conditionalBodies(1).collection.head.isInstanceOf[SingleStatement]) + assert(nestedIf_1.conditionalBodies(1).collection.head.asInstanceOf[SingleStatement] + .getText == "SELECT 4") + + assert(nestedIf_1.elseBody.get.collection.head.isInstanceOf[IfElseStatement]) + val nestedIf_2 = nestedIf_1.elseBody.get.collection.head.asInstanceOf[IfElseStatement] + + assert(nestedIf_2.conditions.length == 1) + assert(nestedIf_2.conditionalBodies.length == 1) + assert(nestedIf_2.elseBody.nonEmpty) + + assert(nestedIf_2.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement] + .getText == "SELECT 5") + + assert(nestedIf_2.elseBody.get.collection.head.asInstanceOf[SingleStatement] + .getText == "SELECT 6") + } + test("if nested") { val sqlScriptText = """ diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameNaFunctions.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameNaFunctions.scala index 8f6c6ef07b3df..7b79387fbfde9 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameNaFunctions.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameNaFunctions.scala @@ -23,8 +23,8 @@ import org.apache.spark.connect.proto.{NAReplace, Relation} import org.apache.spark.connect.proto.Expression.{Literal => GLiteral} import org.apache.spark.connect.proto.NAReplace.Replacement import org.apache.spark.sql +import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toLiteral import org.apache.spark.sql.connect.ConnectConversions._ -import org.apache.spark.sql.functions /** * Functionality for working with missing data in `DataFrame`s. @@ -33,7 +33,6 @@ import org.apache.spark.sql.functions */ final class DataFrameNaFunctions private[sql] (sparkSession: SparkSession, root: Relation) extends sql.DataFrameNaFunctions { - import sparkSession.RichColumn override protected def drop(minNonNulls: Option[Int]): DataFrame = buildDropDataFrame(None, minNonNulls) @@ -103,7 +102,7 @@ final class DataFrameNaFunctions private[sql] (sparkSession: SparkSession, root: sparkSession.newDataFrame { builder => val fillNaBuilder = builder.getFillNaBuilder.setInput(root) values.map { case (colName, replaceValue) => - fillNaBuilder.addCols(colName).addValues(functions.lit(replaceValue).expr.getLiteral) + fillNaBuilder.addCols(colName).addValues(toLiteral(replaceValue).getLiteral) } } } @@ -143,8 +142,8 @@ final class DataFrameNaFunctions private[sql] (sparkSession: SparkSession, root: replacementMap.map { case (oldValue, newValue) => Replacement .newBuilder() - .setOldValue(functions.lit(oldValue).expr.getLiteral) - .setNewValue(functions.lit(newValue).expr.getLiteral) + .setOldValue(toLiteral(oldValue).getLiteral) + .setNewValue(toLiteral(newValue).getLiteral) .build() } } diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameStatFunctions.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameStatFunctions.scala index f3c3f82a233ae..a510afc716a77 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameStatFunctions.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameStatFunctions.scala @@ -23,9 +23,9 @@ import org.apache.spark.connect.proto.{Relation, StatSampleBy} import org.apache.spark.sql import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, PrimitiveDoubleEncoder} +import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.{toExpr, toLiteral} import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.connect.DataFrameStatFunctions.approxQuantileResultEncoder -import org.apache.spark.sql.functions.lit /** * Statistic functions for `DataFrame`s. @@ -120,20 +120,19 @@ final class DataFrameStatFunctions private[sql] (protected val df: DataFrame) /** @inheritdoc */ def sampleBy[T](col: Column, fractions: Map[T, Double], seed: Long): DataFrame = { - import sparkSession.RichColumn require( fractions.values.forall(p => p >= 0.0 && p <= 1.0), s"Fractions must be in [0, 1], but got $fractions.") sparkSession.newDataFrame { builder => val sampleByBuilder = builder.getSampleByBuilder .setInput(root) - .setCol(col.expr) + .setCol(toExpr(col)) .setSeed(seed) fractions.foreach { case (k, v) => sampleByBuilder.addFractions( StatSampleBy.Fraction .newBuilder() - .setStratum(lit(k).expr.getLiteral) + .setStratum(toLiteral(k).getLiteral) .setFraction(v)) } } diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameWriterV2.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameWriterV2.scala index 42cf2cdfad58a..06d339487bfb8 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameWriterV2.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameWriterV2.scala @@ -23,6 +23,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.connect.proto import org.apache.spark.sql import org.apache.spark.sql.Column +import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toExpr /** * Interface used to write a [[org.apache.spark.sql.Dataset]] to external storage using the v2 @@ -33,7 +34,6 @@ import org.apache.spark.sql.Column @Experimental final class DataFrameWriterV2[T] private[sql] (table: String, ds: Dataset[T]) extends sql.DataFrameWriterV2[T] { - import ds.sparkSession.RichColumn private val builder = proto.WriteOperationV2 .newBuilder() @@ -73,7 +73,7 @@ final class DataFrameWriterV2[T] private[sql] (table: String, ds: Dataset[T]) /** @inheritdoc */ @scala.annotation.varargs override def partitionedBy(column: Column, columns: Column*): this.type = { - builder.addAllPartitioningColumns((column +: columns).map(_.expr).asJava) + builder.addAllPartitioningColumns((column +: columns).map(toExpr).asJava) this } @@ -106,7 +106,7 @@ final class DataFrameWriterV2[T] private[sql] (table: String, ds: Dataset[T]) /** @inheritdoc */ def overwrite(condition: Column): Unit = { - builder.setOverwriteCondition(condition.expr) + builder.setOverwriteCondition(toExpr(condition)) executeWriteOperation(proto.WriteOperationV2.Mode.MODE_OVERWRITE) } diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala index 36003283a3369..419ac3b7f74ae 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ import org.apache.spark.sql.catalyst.expressions.OrderUtils +import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.{toExpr, toLiteral, toTypedExpr} import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.connect.client.SparkResult import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, StorageLevelProtoConverter} @@ -140,7 +141,6 @@ class Dataset[T] private[sql] ( @DeveloperApi val plan: proto.Plan, val encoder: Encoder[T]) extends sql.Dataset[T] { - import sparkSession.RichColumn // Make sure we don't forget to set plan id. assert(plan.getRoot.getCommon.hasPlanId) @@ -336,7 +336,7 @@ class Dataset[T] private[sql] ( buildJoin(right, Seq(joinExprs)) { builder => builder .setJoinType(toJoinType(joinType)) - .setJoinCondition(joinExprs.expr) + .setJoinCondition(toExpr(joinExprs)) } } @@ -375,7 +375,7 @@ class Dataset[T] private[sql] ( .setLeft(plan.getRoot) .setRight(other.plan.getRoot) .setJoinType(joinTypeValue) - .setJoinCondition(condition.expr) + .setJoinCondition(toExpr(condition)) .setJoinDataType(joinBuilder.getJoinDataTypeBuilder .setIsLeftStruct(this.agnosticEncoder.isStruct) .setIsRightStruct(other.agnosticEncoder.isStruct)) @@ -396,7 +396,7 @@ class Dataset[T] private[sql] ( sparkSession.newDataFrame(joinExprs.toSeq) { builder => val lateralJoinBuilder = builder.getLateralJoinBuilder lateralJoinBuilder.setLeft(plan.getRoot).setRight(right.plan.getRoot) - joinExprs.foreach(c => lateralJoinBuilder.setJoinCondition(c.expr)) + joinExprs.foreach(c => lateralJoinBuilder.setJoinCondition(toExpr(c))) lateralJoinBuilder.setJoinType(joinTypeValue) } } @@ -440,7 +440,7 @@ class Dataset[T] private[sql] ( builder.getHintBuilder .setInput(plan.getRoot) .setName(name) - .addAllParameters(parameters.map(p => functions.lit(p).expr).asJava) + .addAllParameters(parameters.map(p => toLiteral(p)).asJava) } private def getPlanId: Option[Long] = @@ -486,7 +486,7 @@ class Dataset[T] private[sql] ( sparkSession.newDataset(encoder) { builder => builder.getProjectBuilder .setInput(plan.getRoot) - .addExpressions(col.typedExpr(this.encoder)) + .addExpressions(toTypedExpr(col, this.encoder)) } } @@ -504,14 +504,14 @@ class Dataset[T] private[sql] ( sparkSession.newDataset(encoder, cols) { builder => builder.getProjectBuilder .setInput(plan.getRoot) - .addAllExpressions(cols.map(_.typedExpr(this.encoder)).asJava) + .addAllExpressions(cols.map(c => toTypedExpr(c, this.encoder)).asJava) } } /** @inheritdoc */ def filter(condition: Column): Dataset[T] = { sparkSession.newDataset(agnosticEncoder, Seq(condition)) { builder => - builder.getFilterBuilder.setInput(plan.getRoot).setCondition(condition.expr) + builder.getFilterBuilder.setInput(plan.getRoot).setCondition(toExpr(condition)) } } @@ -523,12 +523,12 @@ class Dataset[T] private[sql] ( sparkSession.newDataFrame(ids.toSeq ++ valuesOption.toSeq.flatten) { builder => val unpivot = builder.getUnpivotBuilder .setInput(plan.getRoot) - .addAllIds(ids.toImmutableArraySeq.map(_.expr).asJava) + .addAllIds(ids.toImmutableArraySeq.map(toExpr).asJava) .setVariableColumnName(variableColumnName) .setValueColumnName(valueColumnName) valuesOption.foreach { values => unpivot.getValuesBuilder - .addAllValues(values.toImmutableArraySeq.map(_.expr).asJava) + .addAllValues(values.toImmutableArraySeq.map(toExpr).asJava) } } } @@ -537,7 +537,7 @@ class Dataset[T] private[sql] ( sparkSession.newDataFrame(indices) { builder => val transpose = builder.getTransposeBuilder.setInput(plan.getRoot) indices.foreach { indexColumn => - transpose.addIndexColumns(indexColumn.expr) + transpose.addIndexColumns(toExpr(indexColumn)) } } @@ -553,7 +553,7 @@ class Dataset[T] private[sql] ( function = func, inputEncoders = agnosticEncoder :: agnosticEncoder :: Nil, outputEncoder = agnosticEncoder) - val reduceExpr = Column.fn("reduce", udf.apply(col("*"), col("*"))).expr + val reduceExpr = toExpr(Column.fn("reduce", udf.apply(col("*"), col("*")))) val result = sparkSession .newDataset(agnosticEncoder) { builder => @@ -590,7 +590,7 @@ class Dataset[T] private[sql] ( val groupingSetMsgs = groupingSets.map { groupingSet => val groupingSetMsg = proto.Aggregate.GroupingSets.newBuilder() for (groupCol <- groupingSet) { - groupingSetMsg.addGroupingSet(groupCol.expr) + groupingSetMsg.addGroupingSet(toExpr(groupCol)) } groupingSetMsg.build() } @@ -779,7 +779,7 @@ class Dataset[T] private[sql] ( s"The size of column names: ${names.size} isn't equal to " + s"the size of columns: ${values.size}") val aliases = values.zip(names).map { case (value, name) => - value.name(name).expr.getAlias + toExpr(value.name(name)).getAlias } sparkSession.newDataFrame(values) { builder => builder.getWithColumnsBuilder @@ -812,7 +812,7 @@ class Dataset[T] private[sql] ( def withMetadata(columnName: String, metadata: Metadata): DataFrame = { val newAlias = proto.Expression.Alias .newBuilder() - .setExpr(col(columnName).expr) + .setExpr(toExpr(col(columnName))) .addName(columnName) .setMetadata(metadata.json) sparkSession.newDataFrame { builder => @@ -845,7 +845,7 @@ class Dataset[T] private[sql] ( sparkSession.newDataFrame(cols) { builder => builder.getDropBuilder .setInput(plan.getRoot) - .addAllColumns(cols.map(_.expr).asJava) + .addAllColumns(cols.map(toExpr).asJava) } } @@ -915,7 +915,7 @@ class Dataset[T] private[sql] ( sparkSession.newDataset[T](agnosticEncoder) { builder => builder.getFilterBuilder .setInput(plan.getRoot) - .setCondition(udf.apply(col("*")).expr) + .setCondition(toExpr(udf.apply(col("*")))) } } @@ -944,7 +944,7 @@ class Dataset[T] private[sql] ( sparkSession.newDataset(outputEncoder) { builder => builder.getMapPartitionsBuilder .setInput(plan.getRoot) - .setFunc(udf.apply(col("*")).expr.getCommonInlineUserDefinedFunction) + .setFunc(toExpr(udf.apply(col("*"))).getCommonInlineUserDefinedFunction) } } @@ -1020,7 +1020,7 @@ class Dataset[T] private[sql] ( sparkSession.newDataset(agnosticEncoder, partitionExprs) { builder => val repartitionBuilder = builder.getRepartitionByExpressionBuilder .setInput(plan.getRoot) - .addAllPartitionExprs(partitionExprs.map(_.expr).asJava) + .addAllPartitionExprs(partitionExprs.map(toExpr).asJava) numPartitions.foreach(repartitionBuilder.setNumPartitions) } } @@ -1036,7 +1036,7 @@ class Dataset[T] private[sql] ( // The underlying `LogicalPlan` operator special-cases all-`SortOrder` arguments. // However, we don't want to complicate the semantics of this API method. // Instead, let's give users a friendly error message, pointing them to the new method. - val sortOrders = partitionExprs.filter(_.expr.hasSortOrder) + val sortOrders = partitionExprs.filter(e => toExpr(e).hasSortOrder) if (sortOrders.nonEmpty) { throw new IllegalArgumentException( s"Invalid partitionExprs specified: $sortOrders\n" + @@ -1050,7 +1050,7 @@ class Dataset[T] private[sql] ( partitionExprs: Seq[Column]): Dataset[T] = { require(partitionExprs.nonEmpty, "At least one partition-by expression must be specified.") val sortExprs = partitionExprs.map { - case e if e.expr.hasSortOrder => e + case e if toExpr(e).hasSortOrder => e case e => e.asc } buildRepartitionByExpression(numPartitions, sortExprs) @@ -1158,7 +1158,7 @@ class Dataset[T] private[sql] ( builder.getCollectMetricsBuilder .setInput(plan.getRoot) .setName(name) - .addAllMetrics((expr +: exprs).map(_.expr).asJava) + .addAllMetrics((expr +: exprs).map(toExpr).asJava) } } diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/KeyValueGroupedDataset.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/KeyValueGroupedDataset.scala index c984582ed6ae1..dc494649b397c 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/KeyValueGroupedDataset.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/KeyValueGroupedDataset.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql import org.apache.spark.sql.{Column, Encoder, TypedColumn} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor, ProductEncoder} -import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toExpr +import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.{toExpr, toTypedExpr} import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, UdfUtils} import org.apache.spark.sql.expressions.SparkUserDefinedFunction @@ -394,7 +394,6 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( private val valueMapFunc: Option[IV => V], private val keysFunc: () => Dataset[IK]) extends KeyValueGroupedDataset[K, V] { - import sparkSession.RichColumn override def keyAs[L: Encoder]: KeyValueGroupedDataset[L, V] = { new KeyValueGroupedDatasetImpl[L, V, IK, IV]( @@ -436,7 +435,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( sparkSession.newDataset[U](outputEncoder) { builder => builder.getGroupMapBuilder .setInput(plan.getRoot) - .addAllSortingExpressions(sortExprs.map(e => e.expr).asJava) + .addAllSortingExpressions(sortExprs.map(toExpr).asJava) .addAllGroupingExpressions(groupingExprs) .setFunc(getUdf(nf, outputEncoder)(ivEncoder)) } @@ -453,10 +452,10 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( builder.getCoGroupMapBuilder .setInput(plan.getRoot) .addAllInputGroupingExpressions(groupingExprs) - .addAllInputSortingExpressions(thisSortExprs.map(e => e.expr).asJava) + .addAllInputSortingExpressions(thisSortExprs.map(toExpr).asJava) .setOther(otherImpl.plan.getRoot) .addAllOtherGroupingExpressions(otherImpl.groupingExprs) - .addAllOtherSortingExpressions(otherSortExprs.map(e => e.expr).asJava) + .addAllOtherSortingExpressions(otherSortExprs.map(toExpr).asJava) .setFunc(getUdf(nf, outputEncoder)(ivEncoder, otherImpl.ivEncoder)) } } @@ -469,7 +468,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( .setInput(plan.getRoot) .setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY) .addAllGroupingExpressions(groupingExprs) - .addAllAggregateExpressions(columns.map(_.typedExpr(vEncoder)).asJava) + .addAllAggregateExpressions(columns.map(c => toTypedExpr(c, vEncoder)).asJava) } } @@ -534,7 +533,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( function = nf, inputEncoders = inputEncoders, outputEncoder = outputEncoder) - udf.apply(inputEncoders.map(_ => col("*")): _*).expr.getCommonInlineUserDefinedFunction + toExpr(udf.apply(inputEncoders.map(_ => col("*")): _*)).getCommonInlineUserDefinedFunction } private def getUdf[U: Encoder, S: Encoder]( @@ -549,7 +548,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( function = nf, inputEncoders = inputEncoders, outputEncoder = outputEncoder) - udf.apply(inputEncoders.map(_ => col("*")): _*).expr.getCommonInlineUserDefinedFunction + toExpr(udf.apply(inputEncoders.map(_ => col("*")): _*)).getCommonInlineUserDefinedFunction } /** diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/MergeIntoWriter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/MergeIntoWriter.scala index c245a8644a3cb..66354e63ca8af 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/MergeIntoWriter.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/MergeIntoWriter.scala @@ -24,6 +24,7 @@ import org.apache.spark.connect.proto.{Expression, MergeAction, MergeIntoTableCo import org.apache.spark.connect.proto.MergeAction.ActionType._ import org.apache.spark.sql import org.apache.spark.sql.Column +import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toExpr import org.apache.spark.sql.functions.expr /** @@ -44,13 +45,12 @@ import org.apache.spark.sql.functions.expr @Experimental class MergeIntoWriter[T] private[sql] (table: String, ds: Dataset[T], on: Column) extends sql.MergeIntoWriter[T] { - import ds.sparkSession.RichColumn private val builder = MergeIntoTableCommand .newBuilder() .setTargetTableName(table) .setSourceTablePlan(ds.plan.getRoot) - .setMergeCondition(on.expr) + .setMergeCondition(toExpr(on)) /** * Executes the merge operation. @@ -121,12 +121,12 @@ class MergeIntoWriter[T] private[sql] (table: String, ds: Dataset[T], on: Column condition: Option[Column], assignments: Map[String, Column] = Map.empty): Expression = { val builder = proto.MergeAction.newBuilder().setActionType(actionType) - condition.foreach(c => builder.setCondition(c.expr)) + condition.foreach(c => builder.setCondition(toExpr(c))) assignments.foreach { case (k, v) => builder .addAssignmentsBuilder() - .setKey(expr(k).expr) - .setValue(v.expr) + .setKey(toExpr(expr(k))) + .setValue(toExpr(v)) } Expression .newBuilder() diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/RelationalGroupedDataset.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/RelationalGroupedDataset.scala index 00dc1fb6906f7..ac361047bbd08 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/RelationalGroupedDataset.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/RelationalGroupedDataset.scala @@ -23,6 +23,7 @@ import org.apache.spark.connect.proto import org.apache.spark.sql import org.apache.spark.sql.{functions, Column, Encoder} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.agnosticEncoderFor +import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.{toExpr, toTypedExpr} import org.apache.spark.sql.connect.ConnectConversions._ /** @@ -44,14 +45,13 @@ class RelationalGroupedDataset private[sql] ( pivot: Option[proto.Aggregate.Pivot] = None, groupingSets: Option[Seq[proto.Aggregate.GroupingSets]] = None) extends sql.RelationalGroupedDataset { - import df.sparkSession.RichColumn protected def toDF(aggExprs: Seq[Column]): DataFrame = { df.sparkSession.newDataFrame(groupingExprs ++ aggExprs) { builder => val aggBuilder = builder.getAggregateBuilder .setInput(df.plan.getRoot) - groupingExprs.foreach(c => aggBuilder.addGroupingExpressions(c.expr)) - aggExprs.foreach(c => aggBuilder.addAggregateExpressions(c.typedExpr(df.encoder))) + groupingExprs.foreach(c => aggBuilder.addGroupingExpressions(toExpr(c))) + aggExprs.foreach(c => aggBuilder.addAggregateExpressions(toTypedExpr(c, df.encoder))) groupType match { case proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP => @@ -152,10 +152,13 @@ class RelationalGroupedDataset private[sql] ( groupType match { case proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY => val valueExprs = values.map { - case c: Column if c.expr.hasLiteral => c.expr.getLiteral - case c: Column if !c.expr.hasLiteral => - throw new IllegalArgumentException("values only accept literal Column") - case v => functions.lit(v).expr.getLiteral + case c: Column => + val e = toExpr(c) + if (!e.hasLiteral) { + throw new IllegalArgumentException("values only accept literal Column") + } + e.getLiteral + case v => toExpr(functions.lit(v)).getLiteral } new RelationalGroupedDataset( df, @@ -164,7 +167,7 @@ class RelationalGroupedDataset private[sql] ( Some( proto.Aggregate.Pivot .newBuilder() - .setCol(pivotColumn.expr) + .setCol(toExpr(pivotColumn)) .addAllValues(valueExprs.asJava) .build())) case _ => diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala index f7998cf60ecac..032ab670dab0e 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala @@ -43,11 +43,10 @@ import org.apache.spark.sql.{Column, Encoder, ExperimentalMethods, Observation, import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection} import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor, BoxedLongEncoder, UnboundRowEncoder} -import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.{toExpr, toTypedExpr} +import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toLiteral import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, SparkConnectClient, SparkResult} import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration import org.apache.spark.sql.connect.client.arrow.ArrowSerializer -import org.apache.spark.sql.functions.lit import org.apache.spark.sql.internal.{SessionState, SharedState, SqlApiConf} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types.StructType @@ -213,7 +212,7 @@ class SparkSession private[sql] ( val sqlCommand = proto.SqlCommand .newBuilder() .setSql(sqlText) - .addAllPosArguments(args.map(lit(_).expr).toImmutableArraySeq.asJava) + .addAllPosArguments(args.map(a => toLiteral(a)).toImmutableArraySeq.asJava) .build() sql(sqlCommand) } @@ -228,7 +227,7 @@ class SparkSession private[sql] ( val sqlCommand = proto.SqlCommand .newBuilder() .setSql(sqlText) - .putAllNamedArguments(args.asScala.map { case (k, v) => (k, lit(v).expr) }.asJava) + .putAllNamedArguments(args.asScala.map { case (k, v) => (k, toLiteral(v)) }.asJava) .build() sql(sqlCommand) } @@ -653,11 +652,6 @@ class SparkSession private[sql] ( } override private[sql] def isUsable: Boolean = client.isSessionValid - - implicit class RichColumn(c: Column) { - def expr: proto.Expression = toExpr(c) - def typedExpr[T](e: Encoder[T]): proto.Expression = toTypedExpr(c, e) - } } // The minimal builder needed to create a spark session. diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/columnNodeSupport.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/columnNodeSupport.scala index f44ec5b2d5046..f08b6e709f13a 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/columnNodeSupport.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/columnNodeSupport.scala @@ -24,7 +24,7 @@ import org.apache.spark.connect.proto.Expression import org.apache.spark.connect.proto.Expression.SortOrder.NullOrdering.{SORT_NULLS_FIRST, SORT_NULLS_LAST} import org.apache.spark.connect.proto.Expression.SortOrder.SortDirection.{SORT_DIRECTION_ASCENDING, SORT_DIRECTION_DESCENDING} import org.apache.spark.connect.proto.Expression.Window.WindowFrame.{FrameBoundary, FrameType} -import org.apache.spark.sql.{Column, Encoder} +import org.apache.spark.sql.{functions, Column, Encoder} import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} import org.apache.spark.sql.connect.common.DataTypeProtoConverter import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProtoBuilder @@ -37,6 +37,8 @@ import org.apache.spark.sql.internal.{Alias, CaseWhenOtherwise, Cast, ColumnNode object ColumnNodeToProtoConverter extends (ColumnNode => proto.Expression) { def toExpr(column: Column): proto.Expression = apply(column.node, None) + def toLiteral(v: Any): proto.Expression = apply(functions.lit(v).node, None) + def toTypedExpr[I](column: Column, encoder: Encoder[I]): proto.Expression = { apply(column.node, Option(encoder)) } diff --git a/sql/connect/common/src/test/resources/query-tests/queries/hint.json b/sql/connect/common/src/test/resources/query-tests/queries/hint.json index 2ac930c0a3a71..2348d0f847157 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/hint.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/hint.json @@ -22,13 +22,13 @@ "stackTrace": [{ "classLoaderName": "app", "declaringClass": "org.apache.spark.sql.connect.Dataset", - "methodName": "~~trimmed~anonfun~~", + "methodName": "hint", "fileName": "Dataset.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.connect.SparkSession", - "methodName": "newDataset", - "fileName": "SparkSession.scala" + "declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite", + "methodName": "~~trimmed~anonfun~~", + "fileName": "PlanGenerationTestSuite.scala" }] } } diff --git a/sql/connect/common/src/test/resources/query-tests/queries/hint.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/hint.proto.bin index 06459ee5b765c..ce7c63b57c47b 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/hint.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/hint.proto.bin differ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala index 814a28e24f522..1d83a46a278f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala @@ -29,11 +29,13 @@ import org.apache.avro.Schema.Type._ import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed, Record} import org.apache.avro.util.Utf8 +import org.apache.spark.SparkRuntimeException import org.apache.spark.internal.Logging import org.apache.spark.sql.avro.AvroUtils.{nonNullUnionBranches, toFieldStr, AvroMatchedField} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, SpecificInternalRow} import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.errors.DataTypeErrors.toSQLId import org.apache.spark.sql.execution.datasources.DataSourceUtils import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} import org.apache.spark.sql.types._ @@ -282,11 +284,20 @@ private[sql] class AvroSerializer( }.toArray.unzip val numFields = catalystStruct.length + val avroFields = avroStruct.getFields() + val isSchemaNullable = avroFields.asScala.map(_.schema().isNullable) row: InternalRow => val result = new Record(avroStruct) var i = 0 while (i < numFields) { if (row.isNullAt(i)) { + if (!isSchemaNullable(i)) { + throw new SparkRuntimeException( + errorClass = "AVRO_CANNOT_WRITE_NULL_FIELD", + messageParameters = Map( + "name" -> toSQLId(avroFields.get(i).name), + "dataType" -> avroFields.get(i).schema().toString)) + } result.put(avroIndices(i), null) } else { result.put(avroIndices(i), fieldConverters(i).apply(row, i)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala index f33e64c859fb3..e818cc915951b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala @@ -207,7 +207,7 @@ private case class DB2Dialect() extends JdbcDialect with SQLConfHelper with NoLe val offsetClause = dialect.getOffsetClause(offset) options.prepareQuery + - s"SELECT $columnList FROM ${options.tableOrQuery} $tableSampleClause" + + s"SELECT $columnList FROM ${options.tableOrQuery}" + s" $whereClause $groupByClause $orderByClause $offsetClause $limitClause" } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index 77d0891ce338c..33fb93b168f9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -250,7 +250,7 @@ private case class MsSqlServerDialect() extends JdbcDialect with NoLegacyJDBCErr options.prepareQuery + s"SELECT $limitClause $columnList FROM ${options.tableOrQuery}" + - s" $tableSampleClause $whereClause $groupByClause $orderByClause" + s" $whereClause $groupByClause $orderByClause" } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index d11bf14be6546..e153d38a00cfb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -344,9 +344,9 @@ class CompoundBodyExec( /** * Executable node for IfElseStatement. * @param conditions Collection of executable conditions. First condition corresponds to IF clause, - * while others (if any) correspond to following ELSE IF clauses. + * while others (if any) correspond to following ELSEIF clauses. * @param conditionalBodies Collection of executable bodies that have a corresponding condition, -* in IF or ELSE IF branches. +* in IF or ELSEIF branches. * @param elseBody Body that is executed if none of the conditions are met, * i.e. ELSE branch. * @param session Spark session that SQL script is executed within. @@ -380,7 +380,7 @@ class IfElseStatementExec( } else { clauseIdx += 1 if (clauseIdx < conditionsCount) { - // There are ELSE IF clauses remaining. + // There are ELSEIF clauses remaining. state = IfElseState.Condition curr = Some(conditions(clauseIdx)) } else if (elseBody.isDefined) { diff --git a/sql/core/src/test/resources/sql-tests/results/keywords-enforced.sql.out b/sql/core/src/test/resources/sql-tests/results/keywords-enforced.sql.out index 521b0afe19264..84a35c270b698 100644 --- a/sql/core/src/test/resources/sql-tests/results/keywords-enforced.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/keywords-enforced.sql.out @@ -104,6 +104,7 @@ DO false DOUBLE false DROP false ELSE true +ELSEIF false END true ESCAPE true ESCAPED false diff --git a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out index 4d702588ad2b3..49ea8bba3e174 100644 --- a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out @@ -104,6 +104,7 @@ DO false DOUBLE false DROP false ELSE false +ELSEIF false END false ESCAPE false ESCAPED false diff --git a/sql/core/src/test/resources/sql-tests/results/nonansi/keywords.sql.out b/sql/core/src/test/resources/sql-tests/results/nonansi/keywords.sql.out index 4d702588ad2b3..49ea8bba3e174 100644 --- a/sql/core/src/test/resources/sql-tests/results/nonansi/keywords.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/nonansi/keywords.sql.out @@ -104,6 +104,7 @@ DO false DOUBLE false DROP false ELSE false +ELSEIF false END false ESCAPE false ESCAPED false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index db56da80fd4af..a1d83ee665088 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -1107,7 +1107,7 @@ class JDBCSuite extends QueryTest with SharedSparkSession { .withLimit(123) .build() .trim() == - "SELECT a,b FROM test FETCH FIRST 123 ROWS ONLY") + "SELECT a,b FROM test FETCH FIRST 123 ROWS ONLY") } test("table exists query by jdbc dialect") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala index 5b5285ea13275..503d38d61c7ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala @@ -218,14 +218,14 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { verifySqlScriptResult(commands, expected) } - test("if else if going in else if") { + test("if elseif going in elseif") { val commands = """ |BEGIN | IF 1=2 | THEN | SELECT 42; - | ELSE IF 1=1 + | ELSEIF 1=1 | THEN | SELECT 43; | ELSE @@ -253,14 +253,14 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { verifySqlScriptResult(commands, expected) } - test("if else if going in else") { + test("if elseif going in else") { val commands = """ |BEGIN | IF 1=2 | THEN | SELECT 42; - | ELSE IF 1=3 + | ELSEIF 1=3 | THEN | SELECT 43; | ELSE @@ -292,7 +292,7 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { } } - test("if else if with count") { + test("if elseif with count") { withTable("t") { val commands = """ @@ -302,7 +302,7 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { | INSERT INTO t VALUES (1, 'a', 1.0); | IF (SELECT COUNT(*) > 2 FROM t) THEN | SELECT 42; - | ELSE IF (SELECT COUNT(*) > 1 FROM t) THEN + | ELSEIF (SELECT COUNT(*) > 1 FROM t) THEN | SELECT 43; | ELSE | SELECT 44; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 601548a2e6bd6..5f149d15c6e8c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -370,14 +370,14 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { verifySqlScriptResult(commands, expected) } - test("if else if going in else if") { + test("if elseif going in elseif") { val commands = """ |BEGIN | IF 1=2 | THEN | SELECT 42; - | ELSE IF 1=1 + | ELSEIF 1=1 | THEN | SELECT 43; | ELSE @@ -407,14 +407,14 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { verifySqlScriptResult(commands, expected) } - test("if else if going in else") { + test("if elseif going in else") { val commands = """ |BEGIN | IF 1=2 | THEN | SELECT 42; - | ELSE IF 1=3 + | ELSEIF 1=3 | THEN | SELECT 43; | ELSE @@ -448,7 +448,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { } } - test("if else if with count") { + test("if elseif with count") { withTable("t") { val commands = """ @@ -458,7 +458,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { | INSERT INTO t VALUES (1, 'a', 1.0); | IF (SELECT COUNT(*) > 2 FROM t) THEN | SELECT 42; - | ELSE IF (SELECT COUNT(*) > 1 FROM t) THEN + | ELSEIF (SELECT COUNT(*) > 1 FROM t) THEN | SELECT 43; | ELSE | SELECT 44; diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 135b84cd01f85..e57fa5a235420 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -148,6 +148,16 @@ byte-buddy-agent test + + ${hive.group} + hive-llap-common + ${hive.llap.scope} + + + ${hive.group} + hive-llap-client + ${hive.llap.scope} + net.sf.jpam jpam diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala index 254eda69e86e8..ec65886fb2c98 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala @@ -214,7 +214,7 @@ trait ThriftServerWithSparkContextSuite extends SharedThriftServer { val sessionHandle = client.openSession(user, "") val infoValue = client.getInfo(sessionHandle, GetInfoType.CLI_ODBC_KEYWORDS) // scalastyle:off line.size.limit - assert(infoValue.getStringValue == "ADD,AFTER,AGGREGATE,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,AT,AUTHORIZATION,BEGIN,BETWEEN,BIGINT,BINARY,BINDING,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CALL,CALLED,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHAR,CHARACTER,CHECK,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLATION,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPENSATION,COMPUTE,CONCATENATE,CONSTRAINT,CONTAINS,COST,CREATE,CROSS,CUBE,CURRENT,CURRENT_DATE,CURRENT_TIME,CURRENT_TIMESTAMP,CURRENT_USER,DATA,DATABASE,DATABASES,DATE,DATEADD,DATEDIFF,DATE_ADD,DATE_DIFF,DAY,DAYOFYEAR,DAYS,DBPROPERTIES,DEC,DECIMAL,DECLARE,DEFAULT,DEFINED,DEFINER,DELETE,DELIMITED,DESC,DESCRIBE,DETERMINISTIC,DFS,DIRECTORIES,DIRECTORY,DISTINCT,DISTRIBUTE,DIV,DO,DOUBLE,DROP,ELSE,END,ESCAPE,ESCAPED,EVOLUTION,EXCEPT,EXCHANGE,EXCLUDE,EXECUTE,EXISTS,EXPLAIN,EXPORT,EXTEND,EXTENDED,EXTERNAL,EXTRACT,FALSE,FETCH,FIELDS,FILEFORMAT,FILTER,FIRST,FLOAT,FOLLOWING,FOR,FOREIGN,FORMAT,FORMATTED,FROM,FULL,FUNCTION,FUNCTIONS,GENERATED,GLOBAL,GRANT,GROUP,GROUPING,HAVING,HOUR,HOURS,IDENTIFIER,IDENTITY,IF,IGNORE,ILIKE,IMMEDIATE,IMPORT,IN,INCLUDE,INCREMENT,INDEX,INDEXES,INNER,INPATH,INPUT,INPUTFORMAT,INSERT,INT,INTEGER,INTERSECT,INTERVAL,INTO,INVOKER,IS,ITEMS,ITERATE,JOIN,JSON,KEYS,LANGUAGE,LAST,LATERAL,LAZY,LEADING,LEAVE,LEFT,LIKE,LIMIT,LINES,LIST,LOAD,LOCAL,LOCATION,LOCK,LOCKS,LOGICAL,LONG,LOOP,MACRO,MAP,MATCHED,MERGE,MICROSECOND,MICROSECONDS,MILLISECOND,MILLISECONDS,MINUS,MINUTE,MINUTES,MODIFIES,MONTH,MONTHS,MSCK,NAME,NAMESPACE,NAMESPACES,NANOSECOND,NANOSECONDS,NATURAL,NO,NONE,NOT,NULL,NULLS,NUMERIC,OF,OFFSET,ON,ONLY,OPTION,OPTIONS,OR,ORDER,OUT,OUTER,OUTPUTFORMAT,OVER,OVERLAPS,OVERLAY,OVERWRITE,PARTITION,PARTITIONED,PARTITIONS,PERCENT,PIVOT,PLACING,POSITION,PRECEDING,PRIMARY,PRINCIPALS,PROPERTIES,PURGE,QUARTER,QUERY,RANGE,READS,REAL,RECORDREADER,RECORDWRITER,RECOVER,RECURSIVE,REDUCE,REFERENCES,REFRESH,RENAME,REPAIR,REPEAT,REPEATABLE,REPLACE,RESET,RESPECT,RESTRICT,RETURN,RETURNS,REVOKE,RIGHT,ROLE,ROLES,ROLLBACK,ROLLUP,ROW,ROWS,SCHEMA,SCHEMAS,SECOND,SECONDS,SECURITY,SELECT,SEMI,SEPARATED,SERDE,SERDEPROPERTIES,SESSION_USER,SET,SETS,SHORT,SHOW,SINGLE,SKEWED,SMALLINT,SOME,SORT,SORTED,SOURCE,SPECIFIC,SQL,START,STATISTICS,STORED,STRATIFY,STRING,STRUCT,SUBSTR,SUBSTRING,SYNC,SYSTEM_TIME,SYSTEM_VERSION,TABLE,TABLES,TABLESAMPLE,TARGET,TBLPROPERTIES,TERMINATED,THEN,TIME,TIMEDIFF,TIMESTAMP,TIMESTAMPADD,TIMESTAMPDIFF,TIMESTAMP_LTZ,TIMESTAMP_NTZ,TINYINT,TO,TOUCH,TRAILING,TRANSACTION,TRANSACTIONS,TRANSFORM,TRIM,TRUE,TRUNCATE,TRY_CAST,TYPE,UNARCHIVE,UNBOUNDED,UNCACHE,UNION,UNIQUE,UNKNOWN,UNLOCK,UNPIVOT,UNSET,UNTIL,UPDATE,USE,USER,USING,VALUES,VAR,VARCHAR,VARIABLE,VARIANT,VERSION,VIEW,VIEWS,VOID,WEEK,WEEKS,WHEN,WHERE,WHILE,WINDOW,WITH,WITHIN,X,YEAR,YEARS,ZONE") + assert(infoValue.getStringValue == "ADD,AFTER,AGGREGATE,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,AT,AUTHORIZATION,BEGIN,BETWEEN,BIGINT,BINARY,BINDING,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CALL,CALLED,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHAR,CHARACTER,CHECK,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLATION,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPENSATION,COMPUTE,CONCATENATE,CONSTRAINT,CONTAINS,COST,CREATE,CROSS,CUBE,CURRENT,CURRENT_DATE,CURRENT_TIME,CURRENT_TIMESTAMP,CURRENT_USER,DATA,DATABASE,DATABASES,DATE,DATEADD,DATEDIFF,DATE_ADD,DATE_DIFF,DAY,DAYOFYEAR,DAYS,DBPROPERTIES,DEC,DECIMAL,DECLARE,DEFAULT,DEFINED,DEFINER,DELETE,DELIMITED,DESC,DESCRIBE,DETERMINISTIC,DFS,DIRECTORIES,DIRECTORY,DISTINCT,DISTRIBUTE,DIV,DO,DOUBLE,DROP,ELSE,ELSEIF,END,ESCAPE,ESCAPED,EVOLUTION,EXCEPT,EXCHANGE,EXCLUDE,EXECUTE,EXISTS,EXPLAIN,EXPORT,EXTEND,EXTENDED,EXTERNAL,EXTRACT,FALSE,FETCH,FIELDS,FILEFORMAT,FILTER,FIRST,FLOAT,FOLLOWING,FOR,FOREIGN,FORMAT,FORMATTED,FROM,FULL,FUNCTION,FUNCTIONS,GENERATED,GLOBAL,GRANT,GROUP,GROUPING,HAVING,HOUR,HOURS,IDENTIFIER,IDENTITY,IF,IGNORE,ILIKE,IMMEDIATE,IMPORT,IN,INCLUDE,INCREMENT,INDEX,INDEXES,INNER,INPATH,INPUT,INPUTFORMAT,INSERT,INT,INTEGER,INTERSECT,INTERVAL,INTO,INVOKER,IS,ITEMS,ITERATE,JOIN,JSON,KEYS,LANGUAGE,LAST,LATERAL,LAZY,LEADING,LEAVE,LEFT,LIKE,LIMIT,LINES,LIST,LOAD,LOCAL,LOCATION,LOCK,LOCKS,LOGICAL,LONG,LOOP,MACRO,MAP,MATCHED,MERGE,MICROSECOND,MICROSECONDS,MILLISECOND,MILLISECONDS,MINUS,MINUTE,MINUTES,MODIFIES,MONTH,MONTHS,MSCK,NAME,NAMESPACE,NAMESPACES,NANOSECOND,NANOSECONDS,NATURAL,NO,NONE,NOT,NULL,NULLS,NUMERIC,OF,OFFSET,ON,ONLY,OPTION,OPTIONS,OR,ORDER,OUT,OUTER,OUTPUTFORMAT,OVER,OVERLAPS,OVERLAY,OVERWRITE,PARTITION,PARTITIONED,PARTITIONS,PERCENT,PIVOT,PLACING,POSITION,PRECEDING,PRIMARY,PRINCIPALS,PROPERTIES,PURGE,QUARTER,QUERY,RANGE,READS,REAL,RECORDREADER,RECORDWRITER,RECOVER,RECURSIVE,REDUCE,REFERENCES,REFRESH,RENAME,REPAIR,REPEAT,REPEATABLE,REPLACE,RESET,RESPECT,RESTRICT,RETURN,RETURNS,REVOKE,RIGHT,ROLE,ROLES,ROLLBACK,ROLLUP,ROW,ROWS,SCHEMA,SCHEMAS,SECOND,SECONDS,SECURITY,SELECT,SEMI,SEPARATED,SERDE,SERDEPROPERTIES,SESSION_USER,SET,SETS,SHORT,SHOW,SINGLE,SKEWED,SMALLINT,SOME,SORT,SORTED,SOURCE,SPECIFIC,SQL,START,STATISTICS,STORED,STRATIFY,STRING,STRUCT,SUBSTR,SUBSTRING,SYNC,SYSTEM_TIME,SYSTEM_VERSION,TABLE,TABLES,TABLESAMPLE,TARGET,TBLPROPERTIES,TERMINATED,THEN,TIME,TIMEDIFF,TIMESTAMP,TIMESTAMPADD,TIMESTAMPDIFF,TIMESTAMP_LTZ,TIMESTAMP_NTZ,TINYINT,TO,TOUCH,TRAILING,TRANSACTION,TRANSACTIONS,TRANSFORM,TRIM,TRUE,TRUNCATE,TRY_CAST,TYPE,UNARCHIVE,UNBOUNDED,UNCACHE,UNION,UNIQUE,UNKNOWN,UNLOCK,UNPIVOT,UNSET,UNTIL,UPDATE,USE,USER,USING,VALUES,VAR,VARCHAR,VARIABLE,VARIANT,VERSION,VIEW,VIEWS,VOID,WEEK,WEEKS,WHEN,WHERE,WHILE,WINDOW,WITH,WITHIN,X,YEAR,YEARS,ZONE") // scalastyle:on line.size.limit } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala index 3d89f31e1965b..402d3c4ae7d92 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala @@ -55,7 +55,7 @@ private[hive] trait HiveClient { /** * Runs a HiveQL command using Hive, returning the results as a list of strings. Each row will - * result in one string. + * result in one string. This should be used only in testing environment. */ def runSqlHive(sql: String): Seq[String] diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 74f938e181793..3e7e81d25d943 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -49,6 +49,7 @@ import org.apache.spark.{SparkConf, SparkException, SparkThrowable} import org.apache.spark.deploy.SparkHadoopUtil.SOURCE_SPARK import org.apache.spark.internal.{Logging, LogKeys, MDC} import org.apache.spark.internal.LogKeys._ +import org.apache.spark.internal.config.Tests.IS_TESTING import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{DatabaseAlreadyExistsException, NoSuchDatabaseException, NoSuchPartitionException, NoSuchPartitionsException, NoSuchTableException, PartitionsAlreadyExistException} @@ -858,8 +859,10 @@ private[hive] class HiveClientImpl( /** * Runs the specified SQL query using Hive. + * This should be used only in testing environment. */ override def runSqlHive(sql: String): Seq[String] = { + assert(Utils.isTesting, s"${IS_TESTING.key} is not set to true") val maxResults = 100000 val results = runHive(sql, maxResults) // It is very confusing when you only get back some of the results...