From 988c71457354b0a443471f501cef544a85b1a76a Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 23 Sep 2016 12:17:59 -0700 Subject: [PATCH 01/96] [SPARK-17643] Remove comparable requirement from Offset For some sources, it is difficult to provide a global ordering based only on the data in the offset. Since we don't use comparison for correctness, lets remove it. Author: Michael Armbrust Closes #15207 from marmbrus/removeComparable. --- .../execution/streaming/CompositeOffset.scala | 30 -------------- .../sql/execution/streaming/LongOffset.scala | 6 --- .../sql/execution/streaming/Offset.scala | 19 ++------- .../execution/streaming/StreamExecution.scala | 9 +++-- .../spark/sql/streaming/OffsetSuite.scala | 39 ------------------- 5 files changed, 9 insertions(+), 94 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompositeOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompositeOffset.scala index 729c8462fed65..ebc6ee8184902 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompositeOffset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompositeOffset.scala @@ -23,36 +23,6 @@ package org.apache.spark.sql.execution.streaming * vector clock that must progress linearly forward. */ case class CompositeOffset(offsets: Seq[Option[Offset]]) extends Offset { - /** - * Returns a negative integer, zero, or a positive integer as this object is less than, equal to, - * or greater than the specified object. - */ - override def compareTo(other: Offset): Int = other match { - case otherComposite: CompositeOffset if otherComposite.offsets.size == offsets.size => - val comparisons = offsets.zip(otherComposite.offsets).map { - case (Some(a), Some(b)) => a compareTo b - case (None, None) => 0 - case (None, _) => -1 - case (_, None) => 1 - } - val nonZeroSigns = comparisons.map(sign).filter(_ != 0).toSet - nonZeroSigns.size match { - case 0 => 0 // if both empty or only 0s - case 1 => nonZeroSigns.head // if there are only (0s and 1s) or (0s and -1s) - case _ => // there are both 1s and -1s - throw new IllegalArgumentException( - s"Invalid comparison between non-linear histories: $this <=> $other") - } - case _ => - throw new IllegalArgumentException(s"Cannot compare $this <=> $other") - } - - private def sign(num: Int): Int = num match { - case i if i < 0 => -1 - case i if i == 0 => 0 - case i if i > 0 => 1 - } - /** * Unpacks an offset into [[StreamProgress]] by associating each offset with the order list of * sources. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala index bb176408d8f59..c5e8827777792 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala @@ -22,12 +22,6 @@ package org.apache.spark.sql.execution.streaming */ case class LongOffset(offset: Long) extends Offset { - override def compareTo(other: Offset): Int = other match { - case l: LongOffset => offset.compareTo(l.offset) - case _ => - throw new IllegalArgumentException(s"Invalid comparison of $getClass with ${other.getClass}") - } - def +(increment: Long): LongOffset = new LongOffset(offset + increment) def -(decrement: Long): LongOffset = new LongOffset(offset - decrement) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.scala index 2cc012840dcaa..1f52abf277581 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.scala @@ -19,19 +19,8 @@ package org.apache.spark.sql.execution.streaming /** * An offset is a monotonically increasing metric used to track progress in the computation of a - * stream. An [[Offset]] must be comparable, and the result of `compareTo` must be consistent - * with `equals` and `hashcode`. + * stream. Since offsets are retrieved from a [[Source]] by a single thread, we know the global + * ordering of two [[Offset]] instances. We do assume that if two offsets are `equal` then no + * new data has arrived. */ -trait Offset extends Serializable { - - /** - * Returns a negative integer, zero, or a positive integer as this object is less than, equal to, - * or greater than the specified object. - */ - def compareTo(other: Offset): Int - - def >(other: Offset): Boolean = compareTo(other) > 0 - def <(other: Offset): Boolean = compareTo(other) < 0 - def <=(other: Offset): Boolean = compareTo(other) <= 0 - def >=(other: Offset): Boolean = compareTo(other) >= 0 -} +trait Offset extends Serializable {} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 220f77dc24ce0..9825f19b86a55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -259,7 +259,7 @@ class StreamExecution( case (source, available) => committedOffsets .get(source) - .map(committed => committed < available) + .map(committed => committed != available) .getOrElse(true) } } @@ -318,7 +318,8 @@ class StreamExecution( // Request unprocessed data from all sources. val newData = availableOffsets.flatMap { - case (source, available) if committedOffsets.get(source).map(_ < available).getOrElse(true) => + case (source, available) + if committedOffsets.get(source).map(_ != available).getOrElse(true) => val current = committedOffsets.get(source) val batch = source.getBatch(current, available) logDebug(s"Retrieving data from $source: $current -> $available") @@ -404,10 +405,10 @@ class StreamExecution( * Blocks the current thread until processing for data from the given `source` has reached at * least the given `Offset`. This method is indented for use primarily when writing tests. */ - def awaitOffset(source: Source, newOffset: Offset): Unit = { + private[sql] def awaitOffset(source: Source, newOffset: Offset): Unit = { def notDone = { val localCommittedOffsets = committedOffsets - !localCommittedOffsets.contains(source) || localCommittedOffsets(source) < newOffset + !localCommittedOffsets.contains(source) || localCommittedOffsets(source) != newOffset } while (notDone) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/OffsetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/OffsetSuite.scala index 9590af4e7737d..b65a987770304 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/OffsetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/OffsetSuite.scala @@ -24,44 +24,12 @@ trait OffsetSuite extends SparkFunSuite { /** Creates test to check all the comparisons of offsets given a `one` that is less than `two`. */ def compare(one: Offset, two: Offset): Unit = { test(s"comparison $one <=> $two") { - assert(one < two) - assert(one <= two) - assert(one <= one) - assert(two > one) - assert(two >= one) - assert(one >= one) assert(one == one) assert(two == two) assert(one != two) assert(two != one) } } - - /** Creates test to check that non-equality comparisons throw exception. */ - def compareInvalid(one: Offset, two: Offset): Unit = { - test(s"invalid comparison $one <=> $two") { - intercept[IllegalArgumentException] { - assert(one < two) - } - - intercept[IllegalArgumentException] { - assert(one <= two) - } - - intercept[IllegalArgumentException] { - assert(one > two) - } - - intercept[IllegalArgumentException] { - assert(one >= two) - } - - assert(!(one == two)) - assert(!(two == one)) - assert(one != two) - assert(two != one) - } - } } class LongOffsetSuite extends OffsetSuite { @@ -79,10 +47,6 @@ class CompositeOffsetSuite extends OffsetSuite { one = CompositeOffset(None :: Nil), two = CompositeOffset(Some(LongOffset(2)) :: Nil)) - compareInvalid( // sizes must be same - one = CompositeOffset(Nil), - two = CompositeOffset(Some(LongOffset(2)) :: Nil)) - compare( one = CompositeOffset.fill(LongOffset(0), LongOffset(1)), two = CompositeOffset.fill(LongOffset(1), LongOffset(2))) @@ -91,8 +55,5 @@ class CompositeOffsetSuite extends OffsetSuite { one = CompositeOffset.fill(LongOffset(1), LongOffset(1)), two = CompositeOffset.fill(LongOffset(1), LongOffset(2))) - compareInvalid( - one = CompositeOffset.fill(LongOffset(2), LongOffset(1)), // vector time inconsistent - two = CompositeOffset.fill(LongOffset(1), LongOffset(2))) } From 90a30f46349182b6fc9d4123090c4712fdb425be Mon Sep 17 00:00:00 2001 From: jisookim Date: Fri, 23 Sep 2016 13:43:47 -0700 Subject: [PATCH 02/96] [SPARK-12221] add cpu time to metrics Currently task metrics don't support executor CPU time, so there's no way to calculate how much CPU time a stage/task took from History Server metrics. This PR enables reporting CPU time. Author: jisookim Closes #10212 from jisookim0513/add-cpu-time-metric. --- .../apache/spark/InternalAccumulator.scala | 2 + .../org/apache/spark/executor/Executor.scala | 15 +++ .../apache/spark/executor/TaskMetrics.scala | 18 ++++ .../apache/spark/scheduler/ResultTask.scala | 8 ++ .../spark/scheduler/ShuffleMapTask.scala | 8 ++ .../org/apache/spark/scheduler/Task.scala | 2 + .../status/api/v1/AllStagesResource.scala | 5 + .../org/apache/spark/status/api/v1/api.scala | 5 + .../spark/ui/jobs/JobProgressListener.scala | 4 + .../org/apache/spark/ui/jobs/UIData.scala | 5 + .../org/apache/spark/util/JsonProtocol.scala | 10 ++ .../complete_stage_list_json_expectation.json | 3 + .../failed_stage_list_json_expectation.json | 1 + .../one_stage_attempt_json_expectation.json | 17 +++ .../one_stage_json_expectation.json | 17 +++ .../stage_list_json_expectation.json | 4 + ...ist_with_accumulable_json_expectation.json | 1 + .../stage_task_list_expectation.json | 40 +++++++ ...multi_attempt_app_json_1__expectation.json | 16 +++ ...multi_attempt_app_json_2__expectation.json | 16 +++ ...k_list_w__offset___length_expectation.json | 100 ++++++++++++++++++ ...stage_task_list_w__sortBy_expectation.json | 40 +++++++ ...tBy_short_names___runtime_expectation.json | 40 +++++++ ...rtBy_short_names__runtime_expectation.json | 40 +++++++ ...mmary_w__custom_quantiles_expectation.json | 2 + ...sk_summary_w_shuffle_read_expectation.json | 2 + ...k_summary_w_shuffle_write_expectation.json | 2 + ...age_with_accumulable_json_expectation.json | 17 +++ .../apache/spark/util/JsonProtocolSuite.scala | 69 ++++++++---- project/MimaExcludes.scala | 4 + 30 files changed, 492 insertions(+), 21 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/InternalAccumulator.scala b/core/src/main/scala/org/apache/spark/InternalAccumulator.scala index 0b494c146fa1b..82d3098e2e055 100644 --- a/core/src/main/scala/org/apache/spark/InternalAccumulator.scala +++ b/core/src/main/scala/org/apache/spark/InternalAccumulator.scala @@ -31,7 +31,9 @@ private[spark] object InternalAccumulator { // Names of internal task level metrics val EXECUTOR_DESERIALIZE_TIME = METRICS_PREFIX + "executorDeserializeTime" + val EXECUTOR_DESERIALIZE_CPU_TIME = METRICS_PREFIX + "executorDeserializeCpuTime" val EXECUTOR_RUN_TIME = METRICS_PREFIX + "executorRunTime" + val EXECUTOR_CPU_TIME = METRICS_PREFIX + "executorCpuTime" val RESULT_SIZE = METRICS_PREFIX + "resultSize" val JVM_GC_TIME = METRICS_PREFIX + "jvmGCTime" val RESULT_SERIALIZATION_TIME = METRICS_PREFIX + "resultSerializationTime" diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 668ec41153086..9501dd9cd8e93 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -232,13 +232,18 @@ private[spark] class Executor( } override def run(): Unit = { + val threadMXBean = ManagementFactory.getThreadMXBean val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId) val deserializeStartTime = System.currentTimeMillis() + val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) { + threadMXBean.getCurrentThreadCpuTime + } else 0L Thread.currentThread.setContextClassLoader(replClassLoader) val ser = env.closureSerializer.newInstance() logInfo(s"Running $taskName (TID $taskId)") execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER) var taskStart: Long = 0 + var taskStartCpu: Long = 0 startGCTime = computeTotalGcTime() try { @@ -269,6 +274,9 @@ private[spark] class Executor( // Run the actual task and measure its runtime. taskStart = System.currentTimeMillis() + taskStartCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) { + threadMXBean.getCurrentThreadCpuTime + } else 0L var threwException = true val value = try { val res = task.run( @@ -302,6 +310,9 @@ private[spark] class Executor( } } val taskFinish = System.currentTimeMillis() + val taskFinishCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) { + threadMXBean.getCurrentThreadCpuTime + } else 0L // If the task has been killed, let's fail it. if (task.killed) { @@ -317,8 +328,12 @@ private[spark] class Executor( // includes the Partition. Second, Task.run() deserializes the RDD and function to be run. task.metrics.setExecutorDeserializeTime( (taskStart - deserializeStartTime) + task.executorDeserializeTime) + task.metrics.setExecutorDeserializeCpuTime( + (taskStartCpu - deserializeStartCpuTime) + task.executorDeserializeCpuTime) // We need to subtract Task.run()'s deserialization time to avoid double-counting task.metrics.setExecutorRunTime((taskFinish - taskStart) - task.executorDeserializeTime) + task.metrics.setExecutorCpuTime( + (taskFinishCpu - taskStartCpu) - task.executorDeserializeCpuTime) task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime) task.metrics.setResultSerializationTime(afterSerialization - beforeSerialization) diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 52a349919e336..2956768c16417 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -47,7 +47,9 @@ import org.apache.spark.util.{AccumulatorContext, AccumulatorMetadata, Accumulat class TaskMetrics private[spark] () extends Serializable { // Each metric is internally represented as an accumulator private val _executorDeserializeTime = new LongAccumulator + private val _executorDeserializeCpuTime = new LongAccumulator private val _executorRunTime = new LongAccumulator + private val _executorCpuTime = new LongAccumulator private val _resultSize = new LongAccumulator private val _jvmGCTime = new LongAccumulator private val _resultSerializationTime = new LongAccumulator @@ -61,11 +63,22 @@ class TaskMetrics private[spark] () extends Serializable { */ def executorDeserializeTime: Long = _executorDeserializeTime.sum + /** + * CPU Time taken on the executor to deserialize this task in nanoseconds. + */ + def executorDeserializeCpuTime: Long = _executorDeserializeCpuTime.sum + /** * Time the executor spends actually running the task (including fetching shuffle data). */ def executorRunTime: Long = _executorRunTime.sum + /** + * CPU Time the executor spends actually running the task + * (including fetching shuffle data) in nanoseconds. + */ + def executorCpuTime: Long = _executorCpuTime.sum + /** * The number of bytes this task transmitted back to the driver as the TaskResult. */ @@ -111,7 +124,10 @@ class TaskMetrics private[spark] () extends Serializable { // Setters and increment-ers private[spark] def setExecutorDeserializeTime(v: Long): Unit = _executorDeserializeTime.setValue(v) + private[spark] def setExecutorDeserializeCpuTime(v: Long): Unit = + _executorDeserializeCpuTime.setValue(v) private[spark] def setExecutorRunTime(v: Long): Unit = _executorRunTime.setValue(v) + private[spark] def setExecutorCpuTime(v: Long): Unit = _executorCpuTime.setValue(v) private[spark] def setResultSize(v: Long): Unit = _resultSize.setValue(v) private[spark] def setJvmGCTime(v: Long): Unit = _jvmGCTime.setValue(v) private[spark] def setResultSerializationTime(v: Long): Unit = @@ -188,7 +204,9 @@ class TaskMetrics private[spark] () extends Serializable { import InternalAccumulator._ @transient private[spark] lazy val nameToAccums = LinkedHashMap( EXECUTOR_DESERIALIZE_TIME -> _executorDeserializeTime, + EXECUTOR_DESERIALIZE_CPU_TIME -> _executorDeserializeCpuTime, EXECUTOR_RUN_TIME -> _executorRunTime, + EXECUTOR_CPU_TIME -> _executorCpuTime, RESULT_SIZE -> _resultSize, JVM_GC_TIME -> _jvmGCTime, RESULT_SERIALIZATION_TIME -> _resultSerializationTime, diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 75c6018e214d8..609f10aee940d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler import java.io._ +import java.lang.management.ManagementFactory import java.nio.ByteBuffer import java.util.Properties @@ -61,11 +62,18 @@ private[spark] class ResultTask[T, U]( override def runTask(context: TaskContext): U = { // Deserialize the RDD and the func using the broadcast variables. + val threadMXBean = ManagementFactory.getThreadMXBean val deserializeStartTime = System.currentTimeMillis() + val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) { + threadMXBean.getCurrentThreadCpuTime + } else 0L val ser = SparkEnv.get.closureSerializer.newInstance() val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)]( ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader) _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime + _executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) { + threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime + } else 0L func(context, rdd.iterator(partition, context)) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 84b3e5ba6c1f3..448fe02084e0d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -17,6 +17,7 @@ package org.apache.spark.scheduler +import java.lang.management.ManagementFactory import java.nio.ByteBuffer import java.util.Properties @@ -66,11 +67,18 @@ private[spark] class ShuffleMapTask( override def runTask(context: TaskContext): MapStatus = { // Deserialize the RDD using the broadcast variable. + val threadMXBean = ManagementFactory.getThreadMXBean val deserializeStartTime = System.currentTimeMillis() + val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) { + threadMXBean.getCurrentThreadCpuTime + } else 0L val ser = SparkEnv.get.closureSerializer.newInstance() val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])]( ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader) _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime + _executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) { + threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime + } else 0L var writer: ShuffleWriter[Any, Any] = null try { diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index ea9dc3988d934..48daa344f3c88 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -139,6 +139,7 @@ private[spark] abstract class Task[T]( @volatile @transient private var _killed = false protected var _executorDeserializeTime: Long = 0 + protected var _executorDeserializeCpuTime: Long = 0 /** * Whether the task has been killed. @@ -149,6 +150,7 @@ private[spark] abstract class Task[T]( * Returns the amount of time spent deserializing the RDD and function to be run. */ def executorDeserializeTime: Long = _executorDeserializeTime + def executorDeserializeCpuTime: Long = _executorDeserializeCpuTime /** * Collect the latest values of accumulators used in this task. If the task failed, diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala index 7d63a8f734f0e..acb7c23079681 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala @@ -101,6 +101,7 @@ private[v1] object AllStagesResource { numCompleteTasks = stageUiData.numCompleteTasks, numFailedTasks = stageUiData.numFailedTasks, executorRunTime = stageUiData.executorRunTime, + executorCpuTime = stageUiData.executorCpuTime, submissionTime = stageInfo.submissionTime.map(new Date(_)), firstTaskLaunchedTime, completionTime = stageInfo.completionTime.map(new Date(_)), @@ -220,7 +221,9 @@ private[v1] object AllStagesResource { new TaskMetricDistributions( quantiles = quantiles, executorDeserializeTime = metricQuantiles(_.executorDeserializeTime), + executorDeserializeCpuTime = metricQuantiles(_.executorDeserializeCpuTime), executorRunTime = metricQuantiles(_.executorRunTime), + executorCpuTime = metricQuantiles(_.executorCpuTime), resultSize = metricQuantiles(_.resultSize), jvmGcTime = metricQuantiles(_.jvmGCTime), resultSerializationTime = metricQuantiles(_.resultSerializationTime), @@ -241,7 +244,9 @@ private[v1] object AllStagesResource { def convertUiTaskMetrics(internal: InternalTaskMetrics): TaskMetrics = { new TaskMetrics( executorDeserializeTime = internal.executorDeserializeTime, + executorDeserializeCpuTime = internal.executorDeserializeCpuTime, executorRunTime = internal.executorRunTime, + executorCpuTime = internal.executorCpuTime, resultSize = internal.resultSize, jvmGcTime = internal.jvmGCTime, resultSerializationTime = internal.resultSerializationTime, diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 32e332a9adb9d..44a929b310384 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -128,6 +128,7 @@ class StageData private[spark]( val numFailedTasks: Int, val executorRunTime: Long, + val executorCpuTime: Long, val submissionTime: Option[Date], val firstTaskLaunchedTime: Option[Date], val completionTime: Option[Date], @@ -166,7 +167,9 @@ class TaskData private[spark]( class TaskMetrics private[spark]( val executorDeserializeTime: Long, + val executorDeserializeCpuTime: Long, val executorRunTime: Long, + val executorCpuTime: Long, val resultSize: Long, val jvmGcTime: Long, val resultSerializationTime: Long, @@ -202,7 +205,9 @@ class TaskMetricDistributions private[spark]( val quantiles: IndexedSeq[Double], val executorDeserializeTime: IndexedSeq[Double], + val executorDeserializeCpuTime: IndexedSeq[Double], val executorRunTime: IndexedSeq[Double], + val executorCpuTime: IndexedSeq[Double], val resultSize: IndexedSeq[Double], val jvmGcTime: IndexedSeq[Double], val resultSerializationTime: IndexedSeq[Double], diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index d3a4f9d3223a7..83dc5d874589e 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -503,6 +503,10 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { val timeDelta = taskMetrics.executorRunTime - oldMetrics.map(_.executorRunTime).getOrElse(0L) stageData.executorRunTime += timeDelta + + val cpuTimeDelta = + taskMetrics.executorCpuTime - oldMetrics.map(_.executorCpuTime).getOrElse(0L) + stageData.executorCpuTime += cpuTimeDelta } override def onExecutorMetricsUpdate(executorMetricsUpdate: SparkListenerExecutorMetricsUpdate) { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index c729f03b3c383..f4a04609c4c69 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -80,6 +80,7 @@ private[spark] object UIData { var numKilledTasks: Int = _ var executorRunTime: Long = _ + var executorCpuTime: Long = _ var inputBytes: Long = _ var inputRecords: Long = _ @@ -137,7 +138,9 @@ private[spark] object UIData { metrics.map { m => TaskMetricsUIData( executorDeserializeTime = m.executorDeserializeTime, + executorDeserializeCpuTime = m.executorDeserializeCpuTime, executorRunTime = m.executorRunTime, + executorCpuTime = m.executorCpuTime, resultSize = m.resultSize, jvmGCTime = m.jvmGCTime, resultSerializationTime = m.resultSerializationTime, @@ -179,7 +182,9 @@ private[spark] object UIData { case class TaskMetricsUIData( executorDeserializeTime: Long, + executorDeserializeCpuTime: Long, executorRunTime: Long, + executorCpuTime: Long, resultSize: Long, jvmGCTime: Long, resultSerializationTime: Long, diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 41d947c4428ad..f4fa7b4061640 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -348,7 +348,9 @@ private[spark] object JsonProtocol { ("Status" -> blockStatusToJson(status)) }) ("Executor Deserialize Time" -> taskMetrics.executorDeserializeTime) ~ + ("Executor Deserialize CPU Time" -> taskMetrics.executorDeserializeCpuTime) ~ ("Executor Run Time" -> taskMetrics.executorRunTime) ~ + ("Executor CPU Time" -> taskMetrics.executorCpuTime) ~ ("Result Size" -> taskMetrics.resultSize) ~ ("JVM GC Time" -> taskMetrics.jvmGCTime) ~ ("Result Serialization Time" -> taskMetrics.resultSerializationTime) ~ @@ -759,7 +761,15 @@ private[spark] object JsonProtocol { return metrics } metrics.setExecutorDeserializeTime((json \ "Executor Deserialize Time").extract[Long]) + metrics.setExecutorDeserializeCpuTime((json \ "Executor Deserialize CPU Time") match { + case JNothing => 0 + case x => x.extract[Long] + }) metrics.setExecutorRunTime((json \ "Executor Run Time").extract[Long]) + metrics.setExecutorCpuTime((json \ "Executor CPU Time") match { + case JNothing => 0 + case x => x.extract[Long] + }) metrics.setResultSize((json \ "Result Size").extract[Long]) metrics.setJvmGCTime((json \ "JVM GC Time").extract[Long]) metrics.setResultSerializationTime((json \ "Result Serialization Time").extract[Long]) diff --git a/core/src/test/resources/HistoryServerExpectations/complete_stage_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/complete_stage_list_json_expectation.json index 8f8067f86d57f..25c4fff77e0ad 100644 --- a/core/src/test/resources/HistoryServerExpectations/complete_stage_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/complete_stage_list_json_expectation.json @@ -6,6 +6,7 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 162, + "executorCpuTime" : 0, "submissionTime" : "2015-02-03T16:43:07.191GMT", "firstTaskLaunchedTime" : "2015-02-03T16:43:07.191GMT", "completionTime" : "2015-02-03T16:43:07.226GMT", @@ -31,6 +32,7 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 3476, + "executorCpuTime" : 0, "submissionTime" : "2015-02-03T16:43:05.829GMT", "firstTaskLaunchedTime" : "2015-02-03T16:43:05.829GMT", "completionTime" : "2015-02-03T16:43:06.286GMT", @@ -56,6 +58,7 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 4338, + "executorCpuTime" : 0, "submissionTime" : "2015-02-03T16:43:04.228GMT", "firstTaskLaunchedTime" : "2015-02-03T16:43:04.234GMT", "completionTime" : "2015-02-03T16:43:04.819GMT", diff --git a/core/src/test/resources/HistoryServerExpectations/failed_stage_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/failed_stage_list_json_expectation.json index 08b692eda8028..b86ba1e65de12 100644 --- a/core/src/test/resources/HistoryServerExpectations/failed_stage_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/failed_stage_list_json_expectation.json @@ -6,6 +6,7 @@ "numCompleteTasks" : 7, "numFailedTasks" : 1, "executorRunTime" : 278, + "executorCpuTime" : 0, "submissionTime" : "2015-02-03T16:43:06.296GMT", "firstTaskLaunchedTime" : "2015-02-03T16:43:06.296GMT", "completionTime" : "2015-02-03T16:43:06.347GMT", diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json index 477a2fec8b69b..0084339d24642 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json @@ -6,6 +6,7 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 3476, + "executorCpuTime" : 0, "submissionTime" : "2015-02-03T16:43:05.829GMT", "firstTaskLaunchedTime" : "2015-02-03T16:43:05.829GMT", "completionTime" : "2015-02-03T16:43:06.286GMT", @@ -36,7 +37,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 1, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 435, + "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, "resultSerializationTime" : 2, @@ -77,7 +80,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 434, + "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, "resultSerializationTime" : 1, @@ -118,7 +123,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 434, + "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, "resultSerializationTime" : 1, @@ -159,7 +166,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 434, + "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, "resultSerializationTime" : 2, @@ -200,7 +209,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 434, + "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, "resultSerializationTime" : 1, @@ -241,7 +252,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 1, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 436, + "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, "resultSerializationTime" : 0, @@ -282,7 +295,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 434, + "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, "resultSerializationTime" : 1, @@ -323,7 +338,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 1, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 435, + "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, "resultSerializationTime" : 1, diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json index 388e51f77a24d..63fe3b2f958e5 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json @@ -6,6 +6,7 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 3476, + "executorCpuTime" : 0, "submissionTime" : "2015-02-03T16:43:05.829GMT", "firstTaskLaunchedTime" : "2015-02-03T16:43:05.829GMT", "completionTime" : "2015-02-03T16:43:06.286GMT", @@ -36,7 +37,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 1, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 435, + "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, "resultSerializationTime" : 2, @@ -77,7 +80,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 434, + "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, "resultSerializationTime" : 1, @@ -118,7 +123,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 434, + "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, "resultSerializationTime" : 1, @@ -159,7 +166,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 434, + "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, "resultSerializationTime" : 2, @@ -200,7 +209,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 434, + "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, "resultSerializationTime" : 1, @@ -241,7 +252,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 1, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 436, + "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, "resultSerializationTime" : 0, @@ -282,7 +295,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 434, + "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, "resultSerializationTime" : 1, @@ -323,7 +338,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 1, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 435, + "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, "resultSerializationTime" : 1, diff --git a/core/src/test/resources/HistoryServerExpectations/stage_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_list_json_expectation.json index 5b957ed549556..6509df1508b30 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_list_json_expectation.json @@ -6,6 +6,7 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 162, + "executorCpuTime" : 0, "submissionTime" : "2015-02-03T16:43:07.191GMT", "firstTaskLaunchedTime" : "2015-02-03T16:43:07.191GMT", "completionTime" : "2015-02-03T16:43:07.226GMT", @@ -31,6 +32,7 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 3476, + "executorCpuTime" : 0, "submissionTime" : "2015-02-03T16:43:05.829GMT", "firstTaskLaunchedTime" : "2015-02-03T16:43:05.829GMT", "completionTime" : "2015-02-03T16:43:06.286GMT", @@ -56,6 +58,7 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 4338, + "executorCpuTime" : 0, "submissionTime" : "2015-02-03T16:43:04.228GMT", "firstTaskLaunchedTime" : "2015-02-03T16:43:04.234GMT", "completionTime" : "2015-02-03T16:43:04.819GMT", @@ -81,6 +84,7 @@ "numCompleteTasks" : 7, "numFailedTasks" : 1, "executorRunTime" : 278, + "executorCpuTime" : 0, "submissionTime" : "2015-02-03T16:43:06.296GMT", "firstTaskLaunchedTime" : "2015-02-03T16:43:06.296GMT", "completionTime" : "2015-02-03T16:43:06.347GMT", diff --git a/core/src/test/resources/HistoryServerExpectations/stage_list_with_accumulable_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_list_with_accumulable_json_expectation.json index afa425f8c27bb..8496863a93469 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_list_with_accumulable_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_list_with_accumulable_json_expectation.json @@ -6,6 +6,7 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 120, + "executorCpuTime" : 0, "submissionTime" : "2015-03-16T19:25:36.103GMT", "firstTaskLaunchedTime" : "2015-03-16T19:25:36.515GMT", "completionTime" : "2015-03-16T19:25:36.579GMT", diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json index 8e09aabbad7c9..e0661c464179d 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json @@ -10,7 +10,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 32, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 349, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 1, @@ -50,7 +52,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 31, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 350, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 0, @@ -90,7 +94,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 32, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 348, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 2, @@ -130,7 +136,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 31, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 349, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 2, @@ -170,7 +178,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 31, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 349, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 1, @@ -210,7 +220,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 30, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 350, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 1, @@ -250,7 +262,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 29, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 351, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 1, @@ -290,7 +304,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 31, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 349, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 0, @@ -330,7 +346,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 80, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -370,7 +388,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 9, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 84, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -410,7 +430,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 8, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 73, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -450,7 +472,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 6, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 75, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -490,7 +514,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 9, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 77, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -530,7 +556,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 9, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 76, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -570,7 +598,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 6, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 83, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -610,7 +640,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 76, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -650,7 +682,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 10, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 84, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -690,7 +724,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 11, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 91, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 5, "resultSerializationTime" : 1, @@ -730,7 +766,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 10, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 92, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -770,7 +808,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 5, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 84, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 5, "resultSerializationTime" : 0, diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json index 1dbf72b42a926..8492f19ab7a5f 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json @@ -15,7 +15,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 14, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 2, @@ -60,7 +62,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 14, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 2, @@ -105,7 +109,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 13, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 2, @@ -150,7 +156,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 13, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 2, @@ -195,7 +203,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 12, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 1, @@ -240,7 +250,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 12, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 2, @@ -285,7 +297,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 12, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 2, @@ -330,7 +344,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 12, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 2, diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json index 483492282dd64..4de4c501a43ad 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json @@ -15,7 +15,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 14, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 2, @@ -60,7 +62,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 14, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 2, @@ -105,7 +109,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 13, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 2, @@ -150,7 +156,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 13, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 2, @@ -195,7 +203,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 12, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 1, @@ -240,7 +250,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 12, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 2, @@ -285,7 +297,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 12, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 2, @@ -330,7 +344,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 12, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 2, diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json index 624f2bb16df48..d2eceeb3f97a9 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json @@ -10,7 +10,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 8, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 73, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -50,7 +52,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 6, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 75, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -90,7 +94,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 9, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 77, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -130,7 +136,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 9, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 76, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -170,7 +178,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 6, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 83, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -210,7 +220,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 76, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -250,7 +262,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 10, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 84, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -290,7 +304,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 11, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 91, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 5, "resultSerializationTime" : 1, @@ -330,7 +346,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 10, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 92, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -370,7 +388,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 5, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 84, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -410,7 +430,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 83, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -450,7 +472,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 5, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 88, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -490,7 +514,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 5, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 93, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -530,7 +556,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 65, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -570,7 +598,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 43, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 1, @@ -610,7 +640,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 49, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -650,7 +682,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 6, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 38, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -690,7 +724,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 32, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -730,7 +766,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 29, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -770,7 +808,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 39, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -810,7 +850,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 4, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 34, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -850,7 +892,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 36, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 24, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -890,7 +934,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 17, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -930,7 +976,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 5, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 43, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -970,7 +1018,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 6, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 27, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -1010,7 +1060,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 5, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 35, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -1050,7 +1102,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 5, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 29, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -1090,7 +1144,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 4, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 32, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -1130,7 +1186,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 31, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -1170,7 +1228,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 17, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -1210,7 +1270,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 4, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 14, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -1250,7 +1312,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 16, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -1290,7 +1354,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 10, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 17, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -1330,7 +1396,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 16, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -1370,7 +1438,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 18, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -1410,7 +1480,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 19, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -1450,7 +1522,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 1, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 31, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 6, "resultSerializationTime" : 0, @@ -1490,7 +1564,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 18, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -1530,7 +1606,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 24, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 6, "resultSerializationTime" : 0, @@ -1570,7 +1648,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 7, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 23, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 6, "resultSerializationTime" : 0, @@ -1610,7 +1690,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 4, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 18, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -1650,7 +1732,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 17, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -1690,7 +1774,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 5, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 18, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -1730,7 +1816,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 6, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 18, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -1770,7 +1858,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 18, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -1810,7 +1900,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 4, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 21, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -1850,7 +1942,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 5, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 20, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -1890,7 +1984,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 16, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -1930,7 +2026,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 16, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -1970,7 +2068,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 17, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json index 96d86b7278ff1..f42c3a4ee5c38 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json @@ -10,7 +10,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 29, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 351, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 1, @@ -50,7 +52,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 31, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 350, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 0, @@ -90,7 +94,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 30, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 350, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 1, @@ -130,7 +136,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 32, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 349, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 1, @@ -170,7 +178,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 31, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 349, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 2, @@ -210,7 +220,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 31, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 349, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 1, @@ -250,7 +262,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 31, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 349, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 0, @@ -290,7 +304,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 32, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 348, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 2, @@ -330,7 +346,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 5, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 93, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -370,7 +388,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 10, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 92, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -410,7 +430,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 11, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 91, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 5, "resultSerializationTime" : 1, @@ -450,7 +472,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 5, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 88, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -490,7 +514,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 9, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 84, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -530,7 +556,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 10, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 84, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -570,7 +598,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 5, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 84, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -610,7 +640,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 6, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 83, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -650,7 +682,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 83, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -690,7 +724,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 80, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -730,7 +766,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 9, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 77, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -770,7 +808,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 9, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 76, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json index 96d86b7278ff1..f42c3a4ee5c38 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json @@ -10,7 +10,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 29, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 351, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 1, @@ -50,7 +52,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 31, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 350, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 0, @@ -90,7 +94,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 30, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 350, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 1, @@ -130,7 +136,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 32, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 349, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 1, @@ -170,7 +178,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 31, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 349, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 2, @@ -210,7 +220,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 31, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 349, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 1, @@ -250,7 +262,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 31, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 349, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 0, @@ -290,7 +304,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 32, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 348, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 2, @@ -330,7 +346,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 5, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 93, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -370,7 +388,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 10, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 92, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -410,7 +430,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 11, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 91, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 5, "resultSerializationTime" : 1, @@ -450,7 +472,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 5, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 88, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -490,7 +514,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 9, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 84, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -530,7 +556,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 10, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 84, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -570,7 +598,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 5, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 84, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -610,7 +640,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 6, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 83, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -650,7 +682,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 83, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -690,7 +724,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 80, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -730,7 +766,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 9, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 77, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -770,7 +808,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 9, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 76, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json index e0e9e8140c717..db60ccccbf8c8 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json @@ -10,7 +10,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 4, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 14, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -50,7 +52,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 16, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -90,7 +94,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 16, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -130,7 +136,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 16, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -170,7 +178,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 16, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -210,7 +220,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 16, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -250,7 +262,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 16, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 1, @@ -290,7 +304,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 17, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -330,7 +346,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 17, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -370,7 +388,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 10, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 17, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -410,7 +430,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 17, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -450,7 +472,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 17, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -490,7 +514,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 20, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 17, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -530,7 +556,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 12, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 17, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -570,7 +598,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 17, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -610,7 +640,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 17, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -650,7 +682,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 18, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -690,7 +724,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 18, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -730,7 +766,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 4, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 18, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -770,7 +808,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 5, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 18, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json index 76d1553bc8f77..5dcbc890438b2 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json @@ -1,7 +1,9 @@ { "quantiles" : [ 0.01, 0.5, 0.99 ], "executorDeserializeTime" : [ 1.0, 3.0, 36.0 ], + "executorDeserializeCpuTime" : [ 0.0, 0.0, 0.0 ], "executorRunTime" : [ 16.0, 28.0, 351.0 ], + "executorCpuTime" : [ 0.0, 0.0, 0.0], "resultSize" : [ 2010.0, 2065.0, 2065.0 ], "jvmGcTime" : [ 0.0, 0.0, 7.0 ], "resultSerializationTime" : [ 0.0, 0.0, 2.0 ], diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json index 7baffc5df0b0f..6d230ac653776 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json @@ -1,7 +1,9 @@ { "quantiles" : [ 0.05, 0.25, 0.5, 0.75, 0.95 ], "executorDeserializeTime" : [ 1.0, 2.0, 2.0, 2.0, 3.0 ], + "executorDeserializeCpuTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "executorRunTime" : [ 30.0, 74.0, 75.0, 76.0, 79.0 ], + "executorCpuTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "resultSize" : [ 1034.0, 1034.0, 1034.0, 1034.0, 1034.0 ], "jvmGcTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "resultSerializationTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json index f8c4b7c128733..aea0f5413d8b9 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json @@ -1,7 +1,9 @@ { "quantiles" : [ 0.05, 0.25, 0.5, 0.75, 0.95 ], "executorDeserializeTime" : [ 2.0, 2.0, 3.0, 7.0, 31.0 ], + "executorDeserializeCpuTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "executorRunTime" : [ 16.0, 18.0, 28.0, 49.0, 349.0 ], + "executorCpuTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "resultSize" : [ 2010.0, 2065.0, 2065.0, 2065.0, 2065.0 ], "jvmGcTime" : [ 0.0, 0.0, 0.0, 5.0, 7.0 ], "resultSerializationTime" : [ 0.0, 0.0, 0.0, 0.0, 1.0 ], diff --git a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json index ce008bf40967d..aaeef1f2f582c 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json @@ -6,6 +6,7 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 120, + "executorCpuTime" : 0, "submissionTime" : "2015-03-16T19:25:36.103GMT", "firstTaskLaunchedTime" : "2015-03-16T19:25:36.515GMT", "completionTime" : "2015-03-16T19:25:36.579GMT", @@ -45,7 +46,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 13, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 2, @@ -91,7 +94,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 12, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 2, @@ -137,7 +142,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 12, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 1, @@ -183,7 +190,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 12, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 2, @@ -229,7 +238,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 14, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 2, @@ -275,7 +286,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 13, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 2, @@ -321,7 +334,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 12, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 2, @@ -367,7 +382,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 14, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 2, diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 00314abf49fd4..d5146d70ebaa3 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -606,6 +606,9 @@ private[spark] object JsonProtocolSuite extends Assertions { private def assertEquals(metrics1: TaskMetrics, metrics2: TaskMetrics) { assert(metrics1.executorDeserializeTime === metrics2.executorDeserializeTime) + assert(metrics1.executorDeserializeCpuTime === metrics2.executorDeserializeCpuTime) + assert(metrics1.executorRunTime === metrics2.executorRunTime) + assert(metrics1.executorCpuTime === metrics2.executorCpuTime) assert(metrics1.resultSize === metrics2.resultSize) assert(metrics1.jvmGCTime === metrics2.jvmGCTime) assert(metrics1.resultSerializationTime === metrics2.resultSerializationTime) @@ -816,8 +819,11 @@ private[spark] object JsonProtocolSuite extends Assertions { hasOutput: Boolean, hasRecords: Boolean = true) = { val t = TaskMetrics.empty + // Set CPU times same as wall times for testing purpose t.setExecutorDeserializeTime(a) + t.setExecutorDeserializeCpuTime(a) t.setExecutorRunTime(b) + t.setExecutorCpuTime(b) t.setResultSize(c) t.setJvmGCTime(d) t.setResultSerializationTime(a + b) @@ -1097,7 +1103,9 @@ private[spark] object JsonProtocolSuite extends Assertions { | }, | "Task Metrics": { | "Executor Deserialize Time": 300, + | "Executor Deserialize CPU Time": 300, | "Executor Run Time": 400, + | "Executor CPU Time": 400, | "Result Size": 500, | "JVM GC Time": 600, | "Result Serialization Time": 700, @@ -1195,7 +1203,9 @@ private[spark] object JsonProtocolSuite extends Assertions { | }, | "Task Metrics": { | "Executor Deserialize Time": 300, + | "Executor Deserialize CPU Time": 300, | "Executor Run Time": 400, + | "Executor CPU Time": 400, | "Result Size": 500, | "JVM GC Time": 600, | "Result Serialization Time": 700, @@ -1293,7 +1303,9 @@ private[spark] object JsonProtocolSuite extends Assertions { | }, | "Task Metrics": { | "Executor Deserialize Time": 300, + | "Executor Deserialize CPU Time": 300, | "Executor Run Time": 400, + | "Executor CPU Time": 400, | "Result Size": 500, | "JVM GC Time": 600, | "Result Serialization Time": 700, @@ -1785,55 +1797,70 @@ private[spark] object JsonProtocolSuite extends Assertions { | }, | { | "ID": 1, + | "Name": "$EXECUTOR_DESERIALIZE_CPU_TIME", + | "Update": 300, + | "Internal": true, + | "Count Failed Values": true + | }, + | + | { + | "ID": 2, | "Name": "$EXECUTOR_RUN_TIME", | "Update": 400, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 2, + | "ID": 3, + | "Name": "$EXECUTOR_CPU_TIME", + | "Update": 400, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 4, | "Name": "$RESULT_SIZE", | "Update": 500, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 3, + | "ID": 5, | "Name": "$JVM_GC_TIME", | "Update": 600, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 4, + | "ID": 6, | "Name": "$RESULT_SERIALIZATION_TIME", | "Update": 700, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 5, + | "ID": 7, | "Name": "$MEMORY_BYTES_SPILLED", | "Update": 800, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 6, + | "ID": 8, | "Name": "$DISK_BYTES_SPILLED", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 7, + | "ID": 9, | "Name": "$PEAK_EXECUTION_MEMORY", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 8, + | "ID": 10, | "Name": "$UPDATED_BLOCK_STATUSES", | "Update": [ | { @@ -1854,98 +1881,98 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Count Failed Values": true | }, | { - | "ID": 9, + | "ID": 11, | "Name": "${shuffleRead.REMOTE_BLOCKS_FETCHED}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 10, + | "ID": 12, | "Name": "${shuffleRead.LOCAL_BLOCKS_FETCHED}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 11, + | "ID": 13, | "Name": "${shuffleRead.REMOTE_BYTES_READ}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 12, + | "ID": 14, | "Name": "${shuffleRead.LOCAL_BYTES_READ}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 13, + | "ID": 15, | "Name": "${shuffleRead.FETCH_WAIT_TIME}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 14, + | "ID": 16, | "Name": "${shuffleRead.RECORDS_READ}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 15, + | "ID": 17, | "Name": "${shuffleWrite.BYTES_WRITTEN}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 16, + | "ID": 18, | "Name": "${shuffleWrite.RECORDS_WRITTEN}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 17, + | "ID": 19, | "Name": "${shuffleWrite.WRITE_TIME}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 18, + | "ID": 20, | "Name": "${input.BYTES_READ}", | "Update": 2100, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 19, + | "ID": 21, | "Name": "${input.RECORDS_READ}", | "Update": 21, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 20, + | "ID": 22, | "Name": "${output.BYTES_WRITTEN}", | "Update": 1200, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 21, + | "ID": 23, | "Name": "${output.RECORDS_WRITTEN}", | "Update": 12, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 22, + | "ID": 24, | "Name": "$TEST_ACCUM", | "Update": 0, | "Internal": true, diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index b6f64e5a703ca..8024fbd21bbfc 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -823,6 +823,10 @@ object MimaExcludes { ) ++ Seq( // [SPARK-17365][Core] Remove/Kill multiple executors together to reduce RPC call time ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.SparkContext") + ) ++ Seq( + // [SPARK-12221] Add CPU time to metrics + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskMetrics.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.this") ) } From 7c382524a959a2bc9b3d2fca44f6f0b41aba4e3c Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Fri, 23 Sep 2016 14:35:18 -0700 Subject: [PATCH 03/96] [SPARK-17651][SPARKR] Set R package version number along with mvn ## What changes were proposed in this pull request? This PR sets the R package version while tagging releases. Note that since R doesn't accept `-SNAPSHOT` in version number field, we remove that while setting the next version ## How was this patch tested? Tested manually by running locally Author: Shivaram Venkataraman Closes #15223 from shivaram/sparkr-version-change. --- dev/create-release/release-tag.sh | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/dev/create-release/release-tag.sh b/dev/create-release/release-tag.sh index d404939d1caee..b7e5100ca7408 100755 --- a/dev/create-release/release-tag.sh +++ b/dev/create-release/release-tag.sh @@ -60,12 +60,27 @@ git config user.email $GIT_EMAIL # Create release version $MVN versions:set -DnewVersion=$RELEASE_VERSION | grep -v "no value" # silence logs +# Set the release version in R/pkg/DESCRIPTION +sed -i".tmp1" 's/Version.*$/Version: '"$RELEASE_VERSION"'/g' R/pkg/DESCRIPTION +# Set the release version in docs +sed -i".tmp1" 's/SPARK_VERSION:.*$/SPARK_VERSION: '"$RELEASE_VERSION"'/g' docs/_config.yml +sed -i".tmp2" 's/SPARK_VERSION_SHORT:.*$/SPARK_VERSION_SHORT: '"$RELEASE_VERSION"'/g' docs/_config.yml + git commit -a -m "Preparing Spark release $RELEASE_TAG" echo "Creating tag $RELEASE_TAG at the head of $GIT_BRANCH" git tag $RELEASE_TAG # Create next version $MVN versions:set -DnewVersion=$NEXT_VERSION | grep -v "no value" # silence logs +# Remove -SNAPSHOT before setting the R version as R expects version strings to only have numbers +R_NEXT_VERSION=`echo $NEXT_VERSION | sed 's/-SNAPSHOT//g'` +sed -i".tmp2" 's/Version.*$/Version: '"$R_NEXT_VERSION"'/g' R/pkg/DESCRIPTION + +# Update docs with next version +sed -i".tmp3" 's/SPARK_VERSION:.*$/SPARK_VERSION: '"$NEXT_VERSION"'/g' docs/_config.yml +# Use R version for short version +sed -i".tmp4" 's/SPARK_VERSION_SHORT:.*$/SPARK_VERSION_SHORT: '"$R_NEXT_VERSION"'/g' docs/_config.yml + git commit -a -m "Preparing development version $NEXT_VERSION" # Push changes From f3fe55439e4c865c26502487a1bccf255da33f4a Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 24 Sep 2016 08:06:41 +0100 Subject: [PATCH 04/96] [SPARK-10835][ML] Word2Vec should accept non-null string array, in addition to existing null string array ## What changes were proposed in this pull request? To match Tokenizer and for compatibility with Word2Vec, output a nullable string array type in NGram ## How was this patch tested? Jenkins tests. Author: Sean Owen Closes #15179 from srowen/SPARK-10835. --- .../apache/spark/ml/feature/Word2Vec.scala | 3 ++- .../spark/ml/feature/Word2VecSuite.scala | 21 +++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 14c05123c62ed..d53f3df514dff 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -108,7 +108,8 @@ private[feature] trait Word2VecBase extends Params * Validate and transform the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true)) + val typeCandidates = List(new ArrayType(StringType, true), new ArrayType(StringType, false)) + SchemaUtils.checkColumnTypes(schema, $(inputCol), typeCandidates) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index 0b441f8b80810..613cc3d60b227 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -207,5 +207,26 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val newInstance = testDefaultReadWrite(instance) assert(newInstance.getVectors.collect() === instance.getVectors.collect()) } + + test("Word2Vec works with input that is non-nullable (NGram)") { + val spark = this.spark + import spark.implicits._ + + val sentence = "a q s t q s t b b b s t m s t m q " + val docDF = sc.parallelize(Seq(sentence, sentence)).map(_.split(" ")).toDF("text") + + val ngram = new NGram().setN(2).setInputCol("text").setOutputCol("ngrams") + val ngramDF = ngram.transform(docDF) + + val model = new Word2Vec() + .setVectorSize(2) + .setInputCol("ngrams") + .setOutputCol("result") + .fit(ngramDF) + + // Just test that this transformation succeeds + model.transform(ngramDF).collect() + } + } From 248916f5589155c0c3e93c3874781f17b08d598d Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 24 Sep 2016 08:15:55 +0100 Subject: [PATCH 05/96] [SPARK-17057][ML] ProbabilisticClassifierModels' thresholds should have at most one 0 ## What changes were proposed in this pull request? Match ProbabilisticClassifer.thresholds requirements to R randomForest cutoff, requiring all > 0 ## How was this patch tested? Jenkins tests plus new test cases Author: Sean Owen Closes #15149 from srowen/SPARK-17057. --- .../classification/LogisticRegression.scala | 5 +-- .../ProbabilisticClassifier.scala | 20 +++++------ .../ml/param/shared/SharedParamsCodeGen.scala | 8 +++-- .../spark/ml/param/shared/sharedParams.scala | 4 +-- .../ProbabilisticClassifierSuite.scala | 35 +++++++++++++++---- .../ml/param/_shared_params_code_gen.py | 5 +-- python/pyspark/ml/param/shared.py | 4 +-- 7 files changed, 52 insertions(+), 29 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 343d50c790e85..5ab63d1de95d3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -123,9 +123,10 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas /** * Set thresholds in multiclass (or binary) classification to adjust the probability of - * predicting each class. Array must have length equal to the number of classes, with values >= 0. + * predicting each class. Array must have length equal to the number of classes, with values > 0, + * excepting that at most one value may be 0. * The class with largest value p/t is predicted, where p is the original probability of that - * class and t is the class' threshold. + * class and t is the class's threshold. * * Note: When [[setThresholds()]] is called, any user-set value for [[threshold]] will be cleared. * If both [[threshold]] and [[thresholds]] are set in a ParamMap, then they must be diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index 1b6e77542cc80..e89da6ff8bdd7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors, VectorUDT} +import org.apache.spark.ml.linalg.{DenseVector, Vector, VectorUDT} import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.sql.{DataFrame, Dataset} @@ -200,22 +200,20 @@ abstract class ProbabilisticClassificationModel[ if (!isDefined(thresholds)) { probability.argmax } else { - val thresholds: Array[Double] = getThresholds - val probabilities = probability.toArray + val thresholds = getThresholds var argMax = 0 var max = Double.NegativeInfinity var i = 0 val probabilitySize = probability.size while (i < probabilitySize) { - if (thresholds(i) == 0.0) { - max = Double.PositiveInfinity + // Thresholds are all > 0, excepting that at most one may be 0. + // The single class whose threshold is 0, if any, will always be predicted + // ('scaled' = +Infinity). However in the case that this class also has + // 0 probability, the class will not be selected ('scaled' is NaN). + val scaled = probability(i) / thresholds(i) + if (scaled > max) { + max = scaled argMax = i - } else { - val scaled = probabilities(i) / thresholds(i) - if (scaled > max) { - max = scaled - argMax = i - } } i += 1 } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 480b03d0f35c4..c94b8b4e9dfda 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -50,10 +50,12 @@ private[shared] object SharedParamsCodeGen { isValid = "ParamValidators.inRange(0, 1)", finalMethods = false), ParamDesc[Array[Double]]("thresholds", "Thresholds in multi-class classification" + " to adjust the probability of predicting each class." + - " Array must have length equal to the number of classes, with values >= 0." + + " Array must have length equal to the number of classes, with values > 0" + + " excepting that at most one value may be 0." + " The class with largest value p/t is predicted, where p is the original probability" + - " of that class and t is the class' threshold", - isValid = "(t: Array[Double]) => t.forall(_ >= 0)", finalMethods = false), + " of that class and t is the class's threshold", + isValid = "(t: Array[Double]) => t.forall(_ >= 0) && t.count(_ == 0) <= 1", + finalMethods = false), ParamDesc[String]("inputCol", "input column name"), ParamDesc[Array[String]]("inputCols", "input column names"), ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")), diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 9125d9e19bf09..fa4530927e8b0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -176,10 +176,10 @@ private[ml] trait HasThreshold extends Params { private[ml] trait HasThresholds extends Params { /** - * Param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold. + * Param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0 excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold. * @group param */ - final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold", (t: Array[Double]) => t.forall(_ >= 0)) + final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0 excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold", (t: Array[Double]) => t.forall(_ >= 0) && t.count(_ == 0) <= 1) /** @group getParam */ def getThresholds: Array[Double] = $(thresholds) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala index b3bd2b3e57b36..172c64aab9d3d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala @@ -36,8 +36,8 @@ final class TestProbabilisticClassificationModel( rawPrediction } - def friendlyPredict(input: Vector): Double = { - predict(input) + def friendlyPredict(values: Double*): Double = { + predict(Vectors.dense(values.toArray)) } } @@ -45,16 +45,37 @@ final class TestProbabilisticClassificationModel( class ProbabilisticClassifierSuite extends SparkFunSuite { test("test thresholding") { - val thresholds = Array(0.5, 0.2) val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2) - .setThresholds(thresholds) - assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 1.0))) === 1.0) - assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 0.2))) === 0.0) + .setThresholds(Array(0.5, 0.2)) + assert(testModel.friendlyPredict(1.0, 1.0) === 1.0) + assert(testModel.friendlyPredict(1.0, 0.2) === 0.0) } test("test thresholding not required") { val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2) - assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 2.0))) === 1.0) + assert(testModel.friendlyPredict(1.0, 2.0) === 1.0) + } + + test("test tiebreak") { + val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2) + .setThresholds(Array(0.4, 0.4)) + assert(testModel.friendlyPredict(0.6, 0.6) === 0.0) + } + + test("test one zero threshold") { + val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2) + .setThresholds(Array(0.0, 0.1)) + assert(testModel.friendlyPredict(1.0, 10.0) === 0.0) + assert(testModel.friendlyPredict(0.0, 10.0) === 1.0) + } + + test("bad thresholds") { + intercept[IllegalArgumentException] { + new TestProbabilisticClassificationModel("myuid", 2, 2).setThresholds(Array(0.0, 0.0)) + } + intercept[IllegalArgumentException] { + new TestProbabilisticClassificationModel("myuid", 2, 2).setThresholds(Array(-0.1, 0.1)) + } } } diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index 4f4328bcadc6f..929591236d688 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -139,8 +139,9 @@ def get$Name(self): "model.", "True", "TypeConverters.toBoolean"), ("thresholds", "Thresholds in multi-class classification to adjust the probability of " + "predicting each class. Array must have length equal to the number of classes, with " + - "values >= 0. The class with largest value p/t is predicted, where p is the original " + - "probability of that class and t is the class' threshold.", None, + "values > 0, excepting that at most one value may be 0. " + + "The class with largest value p/t is predicted, where p is the original " + + "probability of that class and t is the class's threshold.", None, "TypeConverters.toListFloat"), ("weightCol", "weight column name. If this is not set or empty, we treat " + "all instance weights as 1.0.", None, "TypeConverters.toString"), diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 24af07afc7d5c..cc596936d82f6 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -469,10 +469,10 @@ def getStandardization(self): class HasThresholds(Params): """ - Mixin for param thresholds: Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold. + Mixin for param thresholds: Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold. """ - thresholds = Param(Params._dummy(), "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.", typeConverter=TypeConverters.toListFloat) + thresholds = Param(Params._dummy(), "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.", typeConverter=TypeConverters.toListFloat) def __init__(self): super(HasThresholds, self).__init__() From 7945daed12542587d51ece8f07e5c828b40db14a Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sat, 24 Sep 2016 01:03:11 -0700 Subject: [PATCH 06/96] [MINOR][SPARKR] Add sparkr-vignettes.html to gitignore. ## What changes were proposed in this pull request? Add ```sparkr-vignettes.html``` to ```.gitignore```. ## How was this patch tested? No need test. Author: Yanbo Liang Closes #15215 from yanboliang/ignore. --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index cfa8ad05f7da1..39d17e1793f77 100644 --- a/.gitignore +++ b/.gitignore @@ -24,6 +24,7 @@ R-unit-tests.log R/unit-tests.out R/cran-check.out +R/pkg/vignettes/sparkr-vignettes.html build/*.jar build/apache-maven* build/scala* From de333d121da4cb80d45819cbcf8b4246e48ec4d0 Mon Sep 17 00:00:00 2001 From: xin wu Date: Sun, 25 Sep 2016 16:46:12 -0700 Subject: [PATCH 07/96] [SPARK-17551][SQL] Add DataFrame API for null ordering ## What changes were proposed in this pull request? This pull request adds Scala/Java DataFrame API for null ordering (NULLS FIRST | LAST). Also did some minor clean up for related code (e.g. incorrect indentation), and renamed "orderby-nulls-ordering.sql" to be consistent with existing test files. ## How was this patch tested? Added a new test case in DataFrameSuite. Author: petermaxlee Author: Xin Wu Closes #15123 from petermaxlee/SPARK-17551. --- .../sql/catalyst/expressions/SortOrder.scala | 28 ++------ .../codegen/GenerateOrdering.scala | 16 ++--- .../scala/org/apache/spark/sql/Column.scala | 64 ++++++++++++++++++- .../org/apache/spark/sql/functions.scala | 51 ++++++++++++++- ...dering.sql => order-by-nulls-ordering.sql} | 0 ...ql.out => order-by-nulls-ordering.sql.out} | 0 .../org/apache/spark/sql/DataFrameSuite.scala | 18 ++++++ 7 files changed, 144 insertions(+), 33 deletions(-) rename sql/core/src/test/resources/sql-tests/inputs/{orderby-nulls-ordering.sql => order-by-nulls-ordering.sql} (100%) rename sql/core/src/test/resources/sql-tests/results/{orderby-nulls-ordering.sql.out => order-by-nulls-ordering.sql.out} (100%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index d015125baccaf..3bebd552ef51a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -54,10 +54,7 @@ case object NullsLast extends NullOrdering{ * An expression that can be used to sort a tuple. This class extends expression primarily so that * transformations over expression will descend into its child. */ -case class SortOrder( - child: Expression, - direction: SortDirection, - nullOrdering: NullOrdering) +case class SortOrder(child: Expression, direction: SortDirection, nullOrdering: NullOrdering) extends UnaryExpression with Unevaluable { /** Sort order is not foldable because we don't have an eval for it. */ @@ -94,17 +91,9 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { val nullValue = child.child.dataType match { case BooleanType | DateType | TimestampType | _: IntegralType => - if (nullAsSmallest) { - Long.MinValue - } else { - Long.MaxValue - } + if (nullAsSmallest) Long.MinValue else Long.MaxValue case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS => - if (nullAsSmallest) { - Long.MinValue - } else { - Long.MaxValue - } + if (nullAsSmallest) Long.MinValue else Long.MaxValue case _: DecimalType => if (nullAsSmallest) { DoublePrefixComparator.computePrefix(Double.NegativeInfinity) @@ -112,16 +101,13 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { DoublePrefixComparator.computePrefix(Double.NaN) } case _ => - if (nullAsSmallest) { - 0L - } else { - -1L - } + if (nullAsSmallest) 0L else -1L } - private def nullAsSmallest: Boolean = (child.isAscending && child.nullOrdering == NullsFirst) || + private def nullAsSmallest: Boolean = { + (child.isAscending && child.nullOrdering == NullsFirst) || (!child.isAscending && child.nullOrdering == NullsLast) - + } override def eval(input: InternalRow): Any = throw new UnsupportedOperationException diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index e7df95e1142ca..f1c30ef6c7fb8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -100,16 +100,16 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR // Nothing } else if ($isNullA) { return ${ - order.nullOrdering match { - case NullsFirst => "-1" - case NullsLast => "1" - }}; + order.nullOrdering match { + case NullsFirst => "-1" + case NullsLast => "1" + }}; } else if ($isNullB) { return ${ - order.nullOrdering match { - case NullsFirst => "1" - case NullsLast => "-1" - }}; + order.nullOrdering match { + case NullsFirst => "1" + case NullsLast => "-1" + }}; } else { int comp = ${ctx.genComp(order.child.dataType, primitiveA, primitiveB)}; if (comp != 0) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 844ca7a8e99ca..63da501f18cca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -1007,7 +1007,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { /** * Returns an ordering used in sorting. * {{{ - * // Scala: sort a DataFrame by age column in descending order. + * // Scala * df.sort(df("age").desc) * * // Java @@ -1020,7 +1020,37 @@ class Column(protected[sql] val expr: Expression) extends Logging { def desc: Column = withExpr { SortOrder(expr, Descending) } /** - * Returns an ordering used in sorting. + * Returns a descending ordering used in sorting, where null values appear before non-null values. + * {{{ + * // Scala: sort a DataFrame by age column in descending order and null values appearing first. + * df.sort(df("age").desc_nulls_first) + * + * // Java + * df.sort(df.col("age").desc_nulls_first()); + * }}} + * + * @group expr_ops + * @since 2.1.0 + */ + def desc_nulls_first: Column = withExpr { SortOrder(expr, Descending, NullsFirst) } + + /** + * Returns a descending ordering used in sorting, where null values appear after non-null values. + * {{{ + * // Scala: sort a DataFrame by age column in descending order and null values appearing last. + * df.sort(df("age").desc_nulls_last) + * + * // Java + * df.sort(df.col("age").desc_nulls_last()); + * }}} + * + * @group expr_ops + * @since 2.1.0 + */ + def desc_nulls_last: Column = withExpr { SortOrder(expr, Descending, NullsLast) } + + /** + * Returns an ascending ordering used in sorting. * {{{ * // Scala: sort a DataFrame by age column in ascending order. * df.sort(df("age").asc) @@ -1034,6 +1064,36 @@ class Column(protected[sql] val expr: Expression) extends Logging { */ def asc: Column = withExpr { SortOrder(expr, Ascending) } + /** + * Returns an ascending ordering used in sorting, where null values appear before non-null values. + * {{{ + * // Scala: sort a DataFrame by age column in ascending order and null values appearing first. + * df.sort(df("age").asc_nulls_last) + * + * // Java + * df.sort(df.col("age").asc_nulls_last()); + * }}} + * + * @group expr_ops + * @since 2.1.0 + */ + def asc_nulls_first: Column = withExpr { SortOrder(expr, Ascending, NullsFirst) } + + /** + * Returns an ordering used in sorting, where null values appear after non-null values. + * {{{ + * // Scala: sort a DataFrame by age column in ascending order and null values appearing last. + * df.sort(df("age").asc_nulls_last) + * + * // Java + * df.sort(df.col("age").asc_nulls_last()); + * }}} + * + * @group expr_ops + * @since 2.1.0 + */ + def asc_nulls_last: Column = withExpr { SortOrder(expr, Ascending, NullsLast) } + /** * Prints the expression to the console for debugging purpose. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 960c87f60e624..47bf41a2da813 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -109,7 +109,6 @@ object functions { /** * Returns a sort expression based on ascending order of the column. * {{{ - * // Sort by dept in ascending order, and then age in descending order. * df.sort(asc("dept"), desc("age")) * }}} * @@ -118,10 +117,33 @@ object functions { */ def asc(columnName: String): Column = Column(columnName).asc + /** + * Returns a sort expression based on ascending order of the column, + * and null values return before non-null values. + * {{{ + * df.sort(asc_nulls_last("dept"), desc("age")) + * }}} + * + * @group sort_funcs + * @since 2.1.0 + */ + def asc_nulls_first(columnName: String): Column = Column(columnName).asc_nulls_first + + /** + * Returns a sort expression based on ascending order of the column, + * and null values appear after non-null values. + * {{{ + * df.sort(asc_nulls_last("dept"), desc("age")) + * }}} + * + * @group sort_funcs + * @since 2.1.0 + */ + def asc_nulls_last(columnName: String): Column = Column(columnName).asc_nulls_last + /** * Returns a sort expression based on the descending order of the column. * {{{ - * // Sort by dept in ascending order, and then age in descending order. * df.sort(asc("dept"), desc("age")) * }}} * @@ -130,6 +152,31 @@ object functions { */ def desc(columnName: String): Column = Column(columnName).desc + /** + * Returns a sort expression based on the descending order of the column, + * and null values appear before non-null values. + * {{{ + * df.sort(asc("dept"), desc_nulls_first("age")) + * }}} + * + * @group sort_funcs + * @since 2.1.0 + */ + def desc_nulls_first(columnName: String): Column = Column(columnName).desc_nulls_first + + /** + * Returns a sort expression based on the descending order of the column, + * and null values appear after non-null values. + * {{{ + * df.sort(asc("dept"), desc_nulls_last("age")) + * }}} + * + * @group sort_funcs + * @since 2.1.0 + */ + def desc_nulls_last(columnName: String): Column = Column(columnName).desc_nulls_last + + ////////////////////////////////////////////////////////////////////////////////////////////// // Aggregate functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/resources/sql-tests/inputs/orderby-nulls-ordering.sql b/sql/core/src/test/resources/sql-tests/inputs/order-by-nulls-ordering.sql similarity index 100% rename from sql/core/src/test/resources/sql-tests/inputs/orderby-nulls-ordering.sql rename to sql/core/src/test/resources/sql-tests/inputs/order-by-nulls-ordering.sql diff --git a/sql/core/src/test/resources/sql-tests/results/orderby-nulls-ordering.sql.out b/sql/core/src/test/resources/sql-tests/results/order-by-nulls-ordering.sql.out similarity index 100% rename from sql/core/src/test/resources/sql-tests/results/orderby-nulls-ordering.sql.out rename to sql/core/src/test/resources/sql-tests/results/order-by-nulls-ordering.sql.out diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 2c60a7dd9209b..16cc368208485 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -326,6 +326,24 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Row(6)) } + test("sorting with null ordering") { + val data = Seq[java.lang.Integer](2, 1, null).toDF("key") + + checkAnswer(data.orderBy('key.asc), Row(null) :: Row(1) :: Row(2) :: Nil) + checkAnswer(data.orderBy(asc("key")), Row(null) :: Row(1) :: Row(2) :: Nil) + checkAnswer(data.orderBy('key.asc_nulls_first), Row(null) :: Row(1) :: Row(2) :: Nil) + checkAnswer(data.orderBy(asc_nulls_first("key")), Row(null) :: Row(1) :: Row(2) :: Nil) + checkAnswer(data.orderBy('key.asc_nulls_last), Row(1) :: Row(2) :: Row(null) :: Nil) + checkAnswer(data.orderBy(asc_nulls_last("key")), Row(1) :: Row(2) :: Row(null) :: Nil) + + checkAnswer(data.orderBy('key.desc), Row(2) :: Row(1) :: Row(null) :: Nil) + checkAnswer(data.orderBy(desc("key")), Row(2) :: Row(1) :: Row(null) :: Nil) + checkAnswer(data.orderBy('key.desc_nulls_first), Row(null) :: Row(2) :: Row(1) :: Nil) + checkAnswer(data.orderBy(desc_nulls_first("key")), Row(null) :: Row(2) :: Row(1) :: Nil) + checkAnswer(data.orderBy('key.desc_nulls_last), Row(2) :: Row(1) :: Row(null) :: Nil) + checkAnswer(data.orderBy(desc_nulls_last("key")), Row(2) :: Row(1) :: Row(null) :: Nil) + } + test("global sorting") { checkAnswer( testData2.orderBy('a.asc, 'b.asc), From 59d87d24079bc633e63ce032f0a5ddd18a3b02cb Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Sun, 25 Sep 2016 22:57:31 -0700 Subject: [PATCH 08/96] [SPARK-17650] malformed url's throw exceptions before bricking Executors ## What changes were proposed in this pull request? When a malformed URL was sent to Executors through `sc.addJar` and `sc.addFile`, the executors become unusable, because they constantly throw `MalformedURLException`s and can never acknowledge that the file or jar is just bad input. This PR tries to fix that problem by making sure MalformedURLs can never be submitted through `sc.addJar` and `sc.addFile`. Another solution would be to blacklist bad files and jars on Executors. Maybe fail the first time, and then ignore the second time (but print a warning message). ## How was this patch tested? Unit tests in SparkContextSuite Author: Burak Yavuz Closes #15224 from brkyvz/SPARK-17650. --- .../scala/org/apache/spark/SparkContext.scala | 16 ++++++++------ .../scala/org/apache/spark/util/Utils.scala | 20 +++++++++++++++++ .../org/apache/spark/SparkContextSuite.scala | 22 +++++++++++++++++++ 3 files changed, 51 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index f58037e100989..4694790c72cd8 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -19,7 +19,7 @@ package org.apache.spark import java.io._ import java.lang.reflect.Constructor -import java.net.URI +import java.net.{MalformedURLException, URI} import java.util.{Arrays, Locale, Properties, ServiceLoader, UUID} import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap} import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger, AtomicReference} @@ -36,18 +36,15 @@ import com.google.common.collect.MapMaker import org.apache.commons.lang3.SerializationUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable, - FloatWritable, IntWritable, LongWritable, NullWritable, Text, Writable} -import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf, SequenceFileInputFormat, - TextInputFormat} +import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable, FloatWritable, IntWritable, LongWritable, NullWritable, Text, Writable} +import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf, SequenceFileInputFormat, TextInputFormat} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHadoopJob} import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} -import org.apache.spark.input.{FixedLengthBinaryInputFormat, PortableDataStream, StreamInputFormat, - WholeTextFileInputFormat} +import org.apache.spark.input.{FixedLengthBinaryInputFormat, PortableDataStream, StreamInputFormat, WholeTextFileInputFormat} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.io.CompressionCodec @@ -1452,6 +1449,9 @@ class SparkContext(config: SparkConf) extends Logging { throw new SparkException(s"Added file $hadoopPath is a directory and recursive is not " + "turned on.") } + } else { + // SPARK-17650: Make sure this is a valid URL before adding it to the list of dependencies + Utils.validateURL(uri) } val key = if (!isLocal && scheme == "file") { @@ -1711,6 +1711,8 @@ class SparkContext(config: SparkConf) extends Logging { key = env.rpcEnv.fileServer.addJar(new File(path)) } else { val uri = new URI(path) + // SPARK-17650: Make sure this is a valid URL before adding it to the list of dependencies + Utils.validateURL(uri) key = uri.getScheme match { // A JAR file which exists only on the driver node case null | "file" => diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 09896c4e2f502..e09666c6103c6 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -697,6 +697,26 @@ private[spark] object Utils extends Logging { } } + /** + * Validate that a given URI is actually a valid URL as well. + * @param uri The URI to validate + */ + @throws[MalformedURLException]("when the URI is an invalid URL") + def validateURL(uri: URI): Unit = { + Option(uri.getScheme).getOrElse("file") match { + case "http" | "https" | "ftp" => + try { + uri.toURL + } catch { + case e: MalformedURLException => + val ex = new MalformedURLException(s"URI (${uri.toString}) is not a valid URL.") + ex.initCause(e) + throw ex + } + case _ => // will not be turned into a URL anyway + } + } + /** * Get the path of a temporary directory. Spark's local directories can be configured through * multiple settings, which are used with the following precedence: diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index f8d143dc610cb..c451c596b069a 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark import java.io.File +import java.net.MalformedURLException import java.nio.charset.StandardCharsets import java.util.concurrent.TimeUnit @@ -173,6 +174,27 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext { } } + test("SPARK-17650: malformed url's throw exceptions before bricking Executors") { + try { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + Seq("http", "https", "ftp").foreach { scheme => + val badURL = s"$scheme://user:pwd/path" + val e1 = intercept[MalformedURLException] { + sc.addFile(badURL) + } + assert(e1.getMessage.contains(badURL)) + val e2 = intercept[MalformedURLException] { + sc.addJar(badURL) + } + assert(e2.getMessage.contains(badURL)) + assert(sc.addedFiles.isEmpty) + assert(sc.addedJars.isEmpty) + } + } finally { + sc.stop() + } + } + test("addFile recursive works") { val pluto = Utils.createTempDir() val neptune = Utils.createTempDir(pluto.getAbsolutePath) From ac65139be96dbf87402b9a85729a93afd3c6ff17 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 26 Sep 2016 09:45:33 +0100 Subject: [PATCH 09/96] [SPARK-17017][FOLLOW-UP][ML] Refactor of ChiSqSelector and add ML Python API. ## What changes were proposed in this pull request? #14597 modified ```ChiSqSelector``` to support ```fpr``` type selector, however, it left some issue need to be addressed: * We should allow users to set selector type explicitly rather than switching them by using different setting function, since the setting order will involves some unexpected issue. For example, if users both set ```numTopFeatures``` and ```percentile```, it will train ```kbest``` or ```percentile``` model based on the order of setting (the latter setting one will be trained). This make users confused, and we should allow users to set selector type explicitly. We handle similar issues at other place of ML code base such as ```GeneralizedLinearRegression``` and ```LogisticRegression```. * Meanwhile, if there are more than one parameter except ```alpha``` can be set for ```fpr``` model, we can not handle it elegantly in the existing framework. And similar issues for ```kbest``` and ```percentile``` model. Setting selector type explicitly can solve this issue also. * If setting selector type explicitly by users is allowed, we should handle param interaction such as if users set ```selectorType = percentile``` and ```alpha = 0.1```, we should notify users the parameter ```alpha``` will take no effect. We should handle complex parameter interaction checks at ```transformSchema```. (FYI #11620) * We should use lower case of the selector type names to follow MLlib convention. * Add ML Python API. ## How was this patch tested? Unit test. Author: Yanbo Liang Closes #15214 from yanboliang/spark-17017. --- .../spark/ml/feature/ChiSqSelector.scala | 86 ++++++++++--------- .../mllib/api/python/PythonMLLibAPI.scala | 38 +++----- .../spark/mllib/feature/ChiSqSelector.scala | 51 ++++++----- .../spark/ml/feature/ChiSqSelectorSuite.scala | 27 ++++-- .../mllib/feature/ChiSqSelectorSuite.scala | 2 +- python/pyspark/ml/feature.py | 71 +++++++++++++-- python/pyspark/mllib/feature.py | 59 ++++++------- 7 files changed, 206 insertions(+), 128 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index 0c6a37bab0aad..9c131a41850cc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -27,7 +27,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature -import org.apache.spark.mllib.feature.ChiSqSelectorType +import org.apache.spark.mllib.feature.{ChiSqSelector => OldChiSqSelector} import org.apache.spark.mllib.linalg.{Vectors => OldVectors} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.rdd.RDD @@ -44,7 +44,9 @@ private[feature] trait ChiSqSelectorParams extends Params /** * Number of features that selector will select (ordered by statistic value descending). If the * number of features is less than numTopFeatures, then this will select all features. + * Only applicable when selectorType = "kbest". * The default value of numTopFeatures is 50. + * * @group param */ final val numTopFeatures = new IntParam(this, "numTopFeatures", @@ -56,6 +58,11 @@ private[feature] trait ChiSqSelectorParams extends Params /** @group getParam */ def getNumTopFeatures: Int = $(numTopFeatures) + /** + * Percentile of features that selector will select, ordered by statistics value descending. + * Only applicable when selectorType = "percentile". + * Default value is 0.1. + */ final val percentile = new DoubleParam(this, "percentile", "Percentile of features that selector will select, ordered by statistics value descending.", ParamValidators.inRange(0, 1)) @@ -64,8 +71,12 @@ private[feature] trait ChiSqSelectorParams extends Params /** @group getParam */ def getPercentile: Double = $(percentile) - final val alpha = new DoubleParam(this, "alpha", - "The highest p-value for features to be kept.", + /** + * The highest p-value for features to be kept. + * Only applicable when selectorType = "fpr". + * Default value is 0.05. + */ + final val alpha = new DoubleParam(this, "alpha", "The highest p-value for features to be kept.", ParamValidators.inRange(0, 1)) setDefault(alpha -> 0.05) @@ -73,29 +84,27 @@ private[feature] trait ChiSqSelectorParams extends Params def getAlpha: Double = $(alpha) /** - * The ChiSqSelector supports KBest, Percentile, FPR selection, - * which is the same as ChiSqSelectorType defined in MLLIB. - * when call setNumTopFeatures, the selectorType is set to KBest - * when call setPercentile, the selectorType is set to Percentile - * when call setAlpha, the selectorType is set to FPR + * The selector type of the ChisqSelector. + * Supported options: "kbest" (default), "percentile" and "fpr". */ final val selectorType = new Param[String](this, "selectorType", - "ChiSqSelector Type: KBest, Percentile, FPR") - setDefault(selectorType -> ChiSqSelectorType.KBest.toString) + "The selector type of the ChisqSelector. " + + "Supported options: kbest (default), percentile and fpr.", + ParamValidators.inArray[String](OldChiSqSelector.supportedSelectorTypes.toArray)) + setDefault(selectorType -> OldChiSqSelector.KBest) /** @group getParam */ - def getChiSqSelectorType: String = $(selectorType) + def getSelectorType: String = $(selectorType) } /** * Chi-Squared feature selection, which selects categorical features to use for predicting a * categorical label. - * The selector supports three selection methods: `KBest`, `Percentile` and `FPR`. - * `KBest` chooses the `k` top features according to a chi-squared test. - * `Percentile` is similar but chooses a fraction of all features instead of a fixed number. - * `FPR` chooses all features whose false positive rate meets some threshold. - * By default, the selection method is `KBest`, the default number of top features is 50. - * User can use setNumTopFeatures, setPercentile and setAlpha to set different selection methods. + * The selector supports three selection methods: `kbest`, `percentile` and `fpr`. + * `kbest` chooses the `k` top features according to a chi-squared test. + * `percentile` is similar but chooses a fraction of all features instead of a fixed number. + * `fpr` chooses all features whose false positive rate meets some threshold. + * By default, the selection method is `kbest`, the default number of top features is 50. */ @Since("1.6.0") final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: String) @@ -104,24 +113,21 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str @Since("1.6.0") def this() = this(Identifiable.randomUID("chiSqSelector")) + /** @group setParam */ + @Since("2.1.0") + def setSelectorType(value: String): this.type = set(selectorType, value) + /** @group setParam */ @Since("1.6.0") - def setNumTopFeatures(value: Int): this.type = { - set(selectorType, ChiSqSelectorType.KBest.toString) - set(numTopFeatures, value) - } + def setNumTopFeatures(value: Int): this.type = set(numTopFeatures, value) + /** @group setParam */ @Since("2.1.0") - def setPercentile(value: Double): this.type = { - set(selectorType, ChiSqSelectorType.Percentile.toString) - set(percentile, value) - } + def setPercentile(value: Double): this.type = set(percentile, value) + /** @group setParam */ @Since("2.1.0") - def setAlpha(value: Double): this.type = { - set(selectorType, ChiSqSelectorType.FPR.toString) - set(alpha, value) - } + def setAlpha(value: Double): this.type = set(alpha, value) /** @group setParam */ @Since("1.6.0") @@ -143,23 +149,23 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str case Row(label: Double, features: Vector) => OldLabeledPoint(label, OldVectors.fromML(features)) } - var selector = new feature.ChiSqSelector() - ChiSqSelectorType.withName($(selectorType)) match { - case ChiSqSelectorType.KBest => - selector.setNumTopFeatures($(numTopFeatures)) - case ChiSqSelectorType.Percentile => - selector.setPercentile($(percentile)) - case ChiSqSelectorType.FPR => - selector.setAlpha($(alpha)) - case errorType => - throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType") - } + val selector = new feature.ChiSqSelector() + .setSelectorType($(selectorType)) + .setNumTopFeatures($(numTopFeatures)) + .setPercentile($(percentile)) + .setAlpha($(alpha)) val model = selector.fit(input) copyValues(new ChiSqSelectorModel(uid, model).setParent(this)) } @Since("1.6.0") override def transformSchema(schema: StructType): StructType = { + val otherPairs = OldChiSqSelector.supportedTypeAndParamPairs.filter(_._1 != $(selectorType)) + otherPairs.foreach { case (_, paramName: String) => + if (isSet(getParam(paramName))) { + logWarning(s"Param $paramName will take no effect when selector type = ${$(selectorType)}.") + } + } SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) SchemaUtils.checkNumericType(schema, $(labelCol)) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 5cffbf0892888..904000f50d0a2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -629,35 +629,23 @@ private[python] class PythonMLLibAPI extends Serializable { } /** - * Java stub for ChiSqSelector.fit() when the seletion type is KBest. This stub returns a + * Java stub for ChiSqSelector.fit(). This stub returns a * handle to the Java object instead of the content of the Java object. * Extra care needs to be taken in the Python code to ensure it gets freed on * exit; see the Py4J documentation. */ - def fitChiSqSelectorKBest(numTopFeatures: Int, - data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = { - new ChiSqSelector().setNumTopFeatures(numTopFeatures).fit(data.rdd) - } - - /** - * Java stub for ChiSqSelector.fit() when the selection type is Percentile. This stub returns a - * handle to the Java object instead of the content of the Java object. - * Extra care needs to be taken in the Python code to ensure it gets freed on - * exit; see the Py4J documentation. - */ - def fitChiSqSelectorPercentile(percentile: Double, - data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = { - new ChiSqSelector().setPercentile(percentile).fit(data.rdd) - } - - /** - * Java stub for ChiSqSelector.fit() when the selection type is FPR. This stub returns a - * handle to the Java object instead of the content of the Java object. - * Extra care needs to be taken in the Python code to ensure it gets freed on - * exit; see the Py4J documentation. - */ - def fitChiSqSelectorFPR(alpha: Double, data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = { - new ChiSqSelector().setAlpha(alpha).fit(data.rdd) + def fitChiSqSelector( + selectorType: String, + numTopFeatures: Int, + percentile: Double, + alpha: Double, + data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = { + new ChiSqSelector() + .setSelectorType(selectorType) + .setNumTopFeatures(numTopFeatures) + .setPercentile(percentile) + .setAlpha(alpha) + .fit(data.rdd) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index f68a017184b21..0f7c6e8bc04bb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -32,12 +32,6 @@ import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext import org.apache.spark.sql.{Row, SparkSession} -@Since("2.1.0") -private[spark] object ChiSqSelectorType extends Enumeration { - type SelectorType = Value - val KBest, Percentile, FPR = Value -} - /** * Chi Squared selector model. * @@ -166,19 +160,18 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { /** * Creates a ChiSquared feature selector. - * The selector supports three selection methods: `KBest`, `Percentile` and `FPR`. - * `KBest` chooses the `k` top features according to a chi-squared test. - * `Percentile` is similar but chooses a fraction of all features instead of a fixed number. - * `FPR` chooses all features whose false positive rate meets some threshold. - * By default, the selection method is `KBest`, the default number of top features is 50. - * User can use setNumTopFeatures, setPercentile and setAlpha to set different selection methods. + * The selector supports three selection methods: `kbest`, `percentile` and `fpr`. + * `kbest` chooses the `k` top features according to a chi-squared test. + * `percentile` is similar but chooses a fraction of all features instead of a fixed number. + * `fpr` chooses all features whose false positive rate meets some threshold. + * By default, the selection method is `kbest`, the default number of top features is 50. */ @Since("1.3.0") class ChiSqSelector @Since("2.1.0") () extends Serializable { var numTopFeatures: Int = 50 var percentile: Double = 0.1 var alpha: Double = 0.05 - var selectorType = ChiSqSelectorType.KBest + var selectorType = ChiSqSelector.KBest /** * The is the same to call this() and setNumTopFeatures(numTopFeatures) @@ -192,7 +185,6 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { @Since("1.6.0") def setNumTopFeatures(value: Int): this.type = { numTopFeatures = value - selectorType = ChiSqSelectorType.KBest this } @@ -200,7 +192,6 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { def setPercentile(value: Double): this.type = { require(0.0 <= value && value <= 1.0, "Percentile must be in [0,1]") percentile = value - selectorType = ChiSqSelectorType.Percentile this } @@ -208,12 +199,13 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { def setAlpha(value: Double): this.type = { require(0.0 <= value && value <= 1.0, "Alpha must be in [0,1]") alpha = value - selectorType = ChiSqSelectorType.FPR this } @Since("2.1.0") - def setChiSqSelectorType(value: ChiSqSelectorType.Value): this.type = { + def setSelectorType(value: String): this.type = { + require(ChiSqSelector.supportedSelectorTypes.toSeq.contains(value), + s"ChiSqSelector Type: $value was not supported.") selectorType = value this } @@ -230,11 +222,11 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { val chiSqTestResult = Statistics.chiSqTest(data) .zipWithIndex.sortBy { case (res, _) => -res.statistic } val features = selectorType match { - case ChiSqSelectorType.KBest => chiSqTestResult + case ChiSqSelector.KBest => chiSqTestResult .take(numTopFeatures) - case ChiSqSelectorType.Percentile => chiSqTestResult + case ChiSqSelector.Percentile => chiSqTestResult .take((chiSqTestResult.length * percentile).toInt) - case ChiSqSelectorType.FPR => chiSqTestResult + case ChiSqSelector.FPR => chiSqTestResult .filter{ case (res, _) => res.pValue < alpha } case errorType => throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType") @@ -244,3 +236,22 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { } } +@Since("2.1.0") +object ChiSqSelector { + + /** String name for `kbest` selector type. */ + private[spark] val KBest: String = "kbest" + + /** String name for `percentile` selector type. */ + private[spark] val Percentile: String = "percentile" + + /** String name for `fpr` selector type. */ + private[spark] val FPR: String = "fpr" + + /** Set of selector type and param pairs that ChiSqSelector supports. */ + private[spark] val supportedTypeAndParamPairs = Set(KBest -> "numTopFeatures", + Percentile -> "percentile", FPR -> "alpha") + + /** Set of selector types that ChiSqSelector supports. */ + private[spark] val supportedSelectorTypes = supportedTypeAndParamPairs.map(_._1) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala index e0293dbc4b0b2..6b56e4200250c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -50,6 +50,7 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext .toDF("label", "data", "preFilteredData") val selector = new ChiSqSelector() + .setSelectorType("kbest") .setNumTopFeatures(1) .setFeaturesCol("data") .setLabelCol("label") @@ -60,12 +61,28 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext assert(vec1 ~== vec2 absTol 1e-1) } - selector.setPercentile(0.34).fit(df).transform(df) - .select("filtered", "preFilteredData").collect().foreach { - case Row(vec1: Vector, vec2: Vector) => - assert(vec1 ~== vec2 absTol 1e-1) - } + selector.setSelectorType("percentile").setPercentile(0.34).fit(df).transform(df) + .select("filtered", "preFilteredData").collect().foreach { + case Row(vec1: Vector, vec2: Vector) => + assert(vec1 ~== vec2 absTol 1e-1) + } + + val preFilteredData2 = Seq( + Vectors.dense(8.0, 7.0), + Vectors.dense(0.0, 9.0), + Vectors.dense(0.0, 9.0), + Vectors.dense(8.0, 9.0) + ) + val df2 = sc.parallelize(data.zip(preFilteredData2)) + .map(x => (x._1.label, x._1.features, x._2)) + .toDF("label", "data", "preFilteredData") + + selector.setSelectorType("fpr").setAlpha(0.2).fit(df2).transform(df2) + .select("filtered", "preFilteredData").collect().foreach { + case Row(vec1: Vector, vec2: Vector) => + assert(vec1 ~== vec2 absTol 1e-1) + } } test("ChiSqSelector read/write") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala index e181a544f7159..ec23a4aa7364d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala @@ -76,7 +76,7 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext { LabeledPoint(1.0, Vectors.dense(Array(4.0))), LabeledPoint(1.0, Vectors.dense(Array(4.0))), LabeledPoint(2.0, Vectors.dense(Array(9.0)))) - val model = new ChiSqSelector().setAlpha(0.1).fit(labeledDiscreteData) + val model = new ChiSqSelector().setSelectorType("fpr").setAlpha(0.1).fit(labeledDiscreteData) val filteredData = labeledDiscreteData.map { lp => LabeledPoint(lp.label, model.transform(lp.features)) }.collect().toSet diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index c45434f1a57ca..12a13849dc9bc 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2586,39 +2586,68 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, Ja .. versionadded:: 2.0.0 """ + selectorType = Param(Params._dummy(), "selectorType", + "The selector type of the ChisqSelector. " + + "Supported options: kbest (default), percentile and fpr.", + typeConverter=TypeConverters.toString) + numTopFeatures = \ Param(Params._dummy(), "numTopFeatures", "Number of features that selector will select, ordered by statistics value " + "descending. If the number of features is < numTopFeatures, then this will select " + "all features.", typeConverter=TypeConverters.toInt) + percentile = Param(Params._dummy(), "percentile", "Percentile of features that selector " + + "will select, ordered by statistics value descending.", + typeConverter=TypeConverters.toFloat) + + alpha = Param(Params._dummy(), "alpha", "The highest p-value for features to be kept.", + typeConverter=TypeConverters.toFloat) + @keyword_only - def __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, labelCol="label"): + def __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, + labelCol="label", selectorType="kbest", percentile=0.1, alpha=0.05): """ - __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, labelCol="label") + __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, \ + labelCol="label", selectorType="kbest", percentile=0.1, alpha=0.05) """ super(ChiSqSelector, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.ChiSqSelector", self.uid) - self._setDefault(numTopFeatures=50) + self._setDefault(numTopFeatures=50, selectorType="kbest", percentile=0.1, alpha=0.05) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only @since("2.0.0") def setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None, - labelCol="labels"): + labelCol="labels", selectorType="kbest", percentile=0.1, alpha=0.05): """ - setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None,\ - labelCol="labels") + setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None, \ + labelCol="labels", selectorType="kbest", percentile=0.1, alpha=0.05) Sets params for this ChiSqSelector. """ kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("2.1.0") + def setSelectorType(self, value): + """ + Sets the value of :py:attr:`selectorType`. + """ + return self._set(selectorType=value) + + @since("2.1.0") + def getSelectorType(self): + """ + Gets the value of selectorType or its default value. + """ + return self.getOrDefault(self.selectorType) + @since("2.0.0") def setNumTopFeatures(self, value): """ Sets the value of :py:attr:`numTopFeatures`. + Only applicable when selectorType = "kbest". """ return self._set(numTopFeatures=value) @@ -2629,6 +2658,36 @@ def getNumTopFeatures(self): """ return self.getOrDefault(self.numTopFeatures) + @since("2.1.0") + def setPercentile(self, value): + """ + Sets the value of :py:attr:`percentile`. + Only applicable when selectorType = "percentile". + """ + return self._set(percentile=value) + + @since("2.1.0") + def getPercentile(self): + """ + Gets the value of percentile or its default value. + """ + return self.getOrDefault(self.percentile) + + @since("2.1.0") + def setAlpha(self, value): + """ + Sets the value of :py:attr:`alpha`. + Only applicable when selectorType = "fpr". + """ + return self._set(alpha=value) + + @since("2.1.0") + def getAlpha(self): + """ + Gets the value of alpha or its default value. + """ + return self.getOrDefault(self.alpha) + def _create_model(self, java_model): return ChiSqSelectorModel(java_model) diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 077c11370eb3f..4aea81840a162 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -271,22 +271,14 @@ def transform(self, vector): return JavaVectorTransformer.transform(self, vector) -class ChiSqSelectorType: - """ - This class defines the selector types of Chi Square Selector. - """ - KBest, Percentile, FPR = range(3) - - class ChiSqSelector(object): """ Creates a ChiSquared feature selector. The selector supports three selection methods: `KBest`, `Percentile` and `FPR`. - `KBest` chooses the `k` top features according to a chi-squared test. - `Percentile` is similar but chooses a fraction of all features instead of a fixed number. - `FPR` chooses all features whose false positive rate meets some threshold. - By default, the selection method is `KBest`, the default number of top features is 50. - User can use setNumTopFeatures, setPercentile and setAlpha to set different selection methods. + `kbest` chooses the `k` top features according to a chi-squared test. + `percentile` is similar but chooses a fraction of all features instead of a fixed number. + `fpr` chooses all features whose false positive rate meets some threshold. + By default, the selection method is `kbest`, the default number of top features is 50. >>> data = [ ... LabeledPoint(0.0, SparseVector(3, {0: 8.0, 1: 7.0})), @@ -299,7 +291,8 @@ class ChiSqSelector(object): SparseVector(1, {0: 6.0}) >>> model.transform(DenseVector([8.0, 9.0, 5.0])) DenseVector([5.0]) - >>> model = ChiSqSelector().setPercentile(0.34).fit(sc.parallelize(data)) + >>> model = ChiSqSelector().setSelectorType("percentile").setPercentile(0.34).fit( + ... sc.parallelize(data)) >>> model.transform(SparseVector(3, {1: 9.0, 2: 6.0})) SparseVector(1, {0: 6.0}) >>> model.transform(DenseVector([8.0, 9.0, 5.0])) @@ -310,41 +303,52 @@ class ChiSqSelector(object): ... LabeledPoint(1.0, [0.0, 9.0, 8.0, 4.0]), ... LabeledPoint(2.0, [8.0, 9.0, 5.0, 9.0]) ... ] - >>> model = ChiSqSelector().setAlpha(0.1).fit(sc.parallelize(data)) + >>> model = ChiSqSelector().setSelectorType("fpr").setAlpha(0.1).fit(sc.parallelize(data)) >>> model.transform(DenseVector([1.0,2.0,3.0,4.0])) DenseVector([4.0]) .. versionadded:: 1.4.0 """ - def __init__(self, numTopFeatures=50): + def __init__(self, numTopFeatures=50, selectorType="kbest", percentile=0.1, alpha=0.05): self.numTopFeatures = numTopFeatures - self.selectorType = ChiSqSelectorType.KBest + self.selectorType = selectorType + self.percentile = percentile + self.alpha = alpha @since('2.1.0') def setNumTopFeatures(self, numTopFeatures): """ - set numTopFeature for feature selection by number of top features + set numTopFeature for feature selection by number of top features. + Only applicable when selectorType = "kbest". """ self.numTopFeatures = int(numTopFeatures) - self.selectorType = ChiSqSelectorType.KBest return self @since('2.1.0') def setPercentile(self, percentile): """ - set percentile [0.0, 1.0] for feature selection by percentile + set percentile [0.0, 1.0] for feature selection by percentile. + Only applicable when selectorType = "percentile". """ self.percentile = float(percentile) - self.selectorType = ChiSqSelectorType.Percentile return self @since('2.1.0') def setAlpha(self, alpha): """ - set alpha [0.0, 1.0] for feature selection by FPR + set alpha [0.0, 1.0] for feature selection by FPR. + Only applicable when selectorType = "fpr". """ self.alpha = float(alpha) - self.selectorType = ChiSqSelectorType.FPR + return self + + @since('2.1.0') + def setSelectorType(self, selectorType): + """ + set the selector type of the ChisqSelector. + Supported options: "kbest" (default), "percentile" and "fpr". + """ + self.selectorType = str(selectorType) return self @since('1.4.0') @@ -357,15 +361,8 @@ def fit(self, data): treated as categorical for each distinct value. Apply feature discretizer before using this function. """ - if self.selectorType == ChiSqSelectorType.KBest: - jmodel = callMLlibFunc("fitChiSqSelectorKBest", self.numTopFeatures, data) - elif self.selectorType == ChiSqSelectorType.Percentile: - jmodel = callMLlibFunc("fitChiSqSelectorPercentile", self.percentile, data) - elif self.selectorType == ChiSqSelectorType.FPR: - jmodel = callMLlibFunc("fitChiSqSelectorFPR", self.alpha, data) - else: - raise ValueError("ChiSqSelector type supports KBest(0), Percentile(1) and" - " FPR(2), the current value is: %s" % self.selectorType) + jmodel = callMLlibFunc("fitChiSqSelector", self.selectorType, self.numTopFeatures, + self.percentile, self.alpha, data) return ChiSqSelectorModel(jmodel) From 50b89d05b7bffc212cc9b9ae6e0bca7cb90b9c77 Mon Sep 17 00:00:00 2001 From: Justin Pihony Date: Mon, 26 Sep 2016 09:54:22 +0100 Subject: [PATCH 10/96] [SPARK-14525][SQL] Make DataFrameWrite.save work for jdbc ## What changes were proposed in this pull request? This change modifies the implementation of DataFrameWriter.save such that it works with jdbc, and the call to jdbc merely delegates to save. ## How was this patch tested? This was tested via unit tests in the JDBCWriteSuite, of which I added one new test to cover this scenario. ## Additional details rxin This seems to have been most recently touched by you and was also commented on in the JIRA. This contribution is my original work and I license the work to the project under the project's open source license. Author: Justin Pihony Author: Justin Pihony Closes #12601 from JustinPihony/jdbc_reconciliation. --- docs/sql-programming-guide.md | 6 +- .../sql/JavaSQLDataSourceExample.java | 21 ++++ examples/src/main/python/sql/datasource.py | 19 ++++ examples/src/main/r/RSparkSQLExample.R | 4 + .../examples/sql/SQLDataSourceExample.scala | 22 +++++ .../apache/spark/sql/DataFrameWriter.scala | 59 +----------- .../datasources/jdbc/JDBCOptions.scala | 11 ++- .../jdbc/JdbcRelationProvider.scala | 95 ++++++++++++++++--- .../spark/sql/jdbc/JDBCWriteSuite.scala | 82 ++++++++++++++++ 9 files changed, 246 insertions(+), 73 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 4ac5fae566abe..71bdd19c16dbb 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1100,9 +1100,13 @@ CREATE TEMPORARY VIEW jdbcTable USING org.apache.spark.sql.jdbc OPTIONS ( url "jdbc:postgresql:dbserver", - dbtable "schema.tablename" + dbtable "schema.tablename", + user 'username', + password 'password' ) +INSERT INTO TABLE jdbcTable +SELECT * FROM resultTable {% endhighlight %} diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java index f9087e059385e..1860594e8e547 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java @@ -22,6 +22,7 @@ import java.util.Arrays; import java.util.List; // $example off:schema_merging$ +import java.util.Properties; // $example on:basic_parquet_example$ import org.apache.spark.api.java.JavaRDD; @@ -235,6 +236,8 @@ private static void runJsonDatasetExample(SparkSession spark) { private static void runJdbcDatasetExample(SparkSession spark) { // $example on:jdbc_dataset$ + // Note: JDBC loading and saving can be achieved via either the load/save or jdbc methods + // Loading data from a JDBC source Dataset jdbcDF = spark.read() .format("jdbc") .option("url", "jdbc:postgresql:dbserver") @@ -242,6 +245,24 @@ private static void runJdbcDatasetExample(SparkSession spark) { .option("user", "username") .option("password", "password") .load(); + + Properties connectionProperties = new Properties(); + connectionProperties.put("user", "username"); + connectionProperties.put("password", "password"); + Dataset jdbcDF2 = spark.read() + .jdbc("jdbc:postgresql:dbserver", "schema.tablename", connectionProperties); + + // Saving data to a JDBC source + jdbcDF.write() + .format("jdbc") + .option("url", "jdbc:postgresql:dbserver") + .option("dbtable", "schema.tablename") + .option("user", "username") + .option("password", "password") + .save(); + + jdbcDF2.write() + .jdbc("jdbc:postgresql:dbserver", "schema.tablename", connectionProperties); // $example off:jdbc_dataset$ } } diff --git a/examples/src/main/python/sql/datasource.py b/examples/src/main/python/sql/datasource.py index b36c901d2b403..e9aa9d9ac2583 100644 --- a/examples/src/main/python/sql/datasource.py +++ b/examples/src/main/python/sql/datasource.py @@ -143,6 +143,8 @@ def json_dataset_example(spark): def jdbc_dataset_example(spark): # $example on:jdbc_dataset$ + # Note: JDBC loading and saving can be achieved via either the load/save or jdbc methods + # Loading data from a JDBC source jdbcDF = spark.read \ .format("jdbc") \ .option("url", "jdbc:postgresql:dbserver") \ @@ -150,6 +152,23 @@ def jdbc_dataset_example(spark): .option("user", "username") \ .option("password", "password") \ .load() + + jdbcDF2 = spark.read \ + .jdbc("jdbc:postgresql:dbserver", "schema.tablename", + properties={"user": "username", "password": "password"}) + + # Saving data to a JDBC source + jdbcDF.write \ + .format("jdbc") \ + .option("url", "jdbc:postgresql:dbserver") \ + .option("dbtable", "schema.tablename") \ + .option("user", "username") \ + .option("password", "password") \ + .save() + + jdbcDF2.write \ + .jdbc("jdbc:postgresql:dbserver", "schema.tablename", + properties={"user": "username", "password": "password"}) # $example off:jdbc_dataset$ diff --git a/examples/src/main/r/RSparkSQLExample.R b/examples/src/main/r/RSparkSQLExample.R index 4e0267a03851b..373a36dba14f0 100644 --- a/examples/src/main/r/RSparkSQLExample.R +++ b/examples/src/main/r/RSparkSQLExample.R @@ -204,7 +204,11 @@ results <- collect(sql("FROM src SELECT key, value")) # $example on:jdbc_dataset$ +# Loading data from a JDBC source df <- read.jdbc("jdbc:postgresql:dbserver", "schema.tablename", user = "username", password = "password") + +# Saving data to a JDBC source +write.jdbc(df, "jdbc:postgresql:dbserver", "schema.tablename", user = "username", password = "password") # $example off:jdbc_dataset$ # Stop the SparkSession now diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala index dc3915a4882b0..66f7cb1b53f48 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.examples.sql +import java.util.Properties + import org.apache.spark.sql.SparkSession object SQLDataSourceExample { @@ -148,6 +150,8 @@ object SQLDataSourceExample { private def runJdbcDatasetExample(spark: SparkSession): Unit = { // $example on:jdbc_dataset$ + // Note: JDBC loading and saving can be achieved via either the load/save or jdbc methods + // Loading data from a JDBC source val jdbcDF = spark.read .format("jdbc") .option("url", "jdbc:postgresql:dbserver") @@ -155,6 +159,24 @@ object SQLDataSourceExample { .option("user", "username") .option("password", "password") .load() + + val connectionProperties = new Properties() + connectionProperties.put("user", "username") + connectionProperties.put("password", "password") + val jdbcDF2 = spark.read + .jdbc("jdbc:postgresql:dbserver", "schema.tablename", connectionProperties) + + // Saving data to a JDBC source + jdbcDF.write + .format("jdbc") + .option("url", "jdbc:postgresql:dbserver") + .option("dbtable", "schema.tablename") + .option("user", "username") + .option("password", "password") + .save() + + jdbcDF2.write + .jdbc("jdbc:postgresql:dbserver", "schema.tablename", connectionProperties) // $example off:jdbc_dataset$ } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 64d3422cb4b54..7374a8e045035 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -425,62 +425,11 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { def jdbc(url: String, table: String, connectionProperties: Properties): Unit = { assertNotPartitioned("jdbc") assertNotBucketed("jdbc") - - // to add required options like URL and dbtable - val params = extraOptions.toMap ++ Map("url" -> url, "dbtable" -> table) - val jdbcOptions = new JDBCOptions(params) - val jdbcUrl = jdbcOptions.url - val jdbcTable = jdbcOptions.table - - val props = new Properties() - extraOptions.foreach { case (key, value) => - props.put(key, value) - } // connectionProperties should override settings in extraOptions - props.putAll(connectionProperties) - val conn = JdbcUtils.createConnectionFactory(jdbcUrl, props)() - - try { - var tableExists = JdbcUtils.tableExists(conn, jdbcUrl, jdbcTable) - - if (mode == SaveMode.Ignore && tableExists) { - return - } - - if (mode == SaveMode.ErrorIfExists && tableExists) { - sys.error(s"Table $jdbcTable already exists.") - } - - if (mode == SaveMode.Overwrite && tableExists) { - if (jdbcOptions.isTruncate && - JdbcUtils.isCascadingTruncateTable(jdbcUrl) == Some(false)) { - JdbcUtils.truncateTable(conn, jdbcTable) - } else { - JdbcUtils.dropTable(conn, jdbcTable) - tableExists = false - } - } - - // Create the table if the table didn't exist. - if (!tableExists) { - val schema = JdbcUtils.schemaString(df, jdbcUrl) - // To allow certain options to append when create a new table, which can be - // table_options or partition_options. - // E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8" - val createtblOptions = jdbcOptions.createTableOptions - val sql = s"CREATE TABLE $jdbcTable ($schema) $createtblOptions" - val statement = conn.createStatement - try { - statement.executeUpdate(sql) - } finally { - statement.close() - } - } - } finally { - conn.close() - } - - JdbcUtils.saveTable(df, jdbcUrl, jdbcTable, props) + this.extraOptions = this.extraOptions ++ (connectionProperties.asScala) + // explicit url and dbtable should override all + this.extraOptions += ("url" -> url, "dbtable" -> table) + format("jdbc").save() } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index 1db090eaf9c9e..bcf65e53afa73 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -27,10 +27,12 @@ class JDBCOptions( // ------------------------------------------------------------ // Required parameters // ------------------------------------------------------------ + require(parameters.isDefinedAt("url"), "Option 'url' is required.") + require(parameters.isDefinedAt("dbtable"), "Option 'dbtable' is required.") // a JDBC URL - val url = parameters.getOrElse("url", sys.error("Option 'url' not specified")) + val url = parameters("url") // name of table - val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' not specified")) + val table = parameters("dbtable") // ------------------------------------------------------------ // Optional parameter list @@ -44,6 +46,11 @@ class JDBCOptions( // the number of partitions val numPartitions = parameters.getOrElse("numPartitions", null) + require(partitionColumn == null || + (lowerBound != null && upperBound != null && numPartitions != null), + "If 'partitionColumn' is specified then 'lowerBound', 'upperBound'," + + " and 'numPartitions' are required.") + // ------------------------------------------------------------ // The options for DataFrameWriter // ------------------------------------------------------------ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index 106ed1d440102..ae04af2479c8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -19,37 +19,102 @@ package org.apache.spark.sql.execution.datasources.jdbc import java.util.Properties -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister, RelationProvider} +import scala.collection.JavaConverters.mapAsJavaMapConverter -class JdbcRelationProvider extends RelationProvider with DataSourceRegister { +import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext} +import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, RelationProvider} + +class JdbcRelationProvider extends CreatableRelationProvider + with RelationProvider with DataSourceRegister { override def shortName(): String = "jdbc" - /** Returns a new base relation with the given parameters. */ override def createRelation( sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { val jdbcOptions = new JDBCOptions(parameters) - if (jdbcOptions.partitionColumn != null - && (jdbcOptions.lowerBound == null - || jdbcOptions.upperBound == null - || jdbcOptions.numPartitions == null)) { - sys.error("Partitioning incompletely specified") - } + val partitionColumn = jdbcOptions.partitionColumn + val lowerBound = jdbcOptions.lowerBound + val upperBound = jdbcOptions.upperBound + val numPartitions = jdbcOptions.numPartitions - val partitionInfo = if (jdbcOptions.partitionColumn == null) { + val partitionInfo = if (partitionColumn == null) { null } else { JDBCPartitioningInfo( - jdbcOptions.partitionColumn, - jdbcOptions.lowerBound.toLong, - jdbcOptions.upperBound.toLong, - jdbcOptions.numPartitions.toInt) + partitionColumn, lowerBound.toLong, upperBound.toLong, numPartitions.toInt) } val parts = JDBCRelation.columnPartition(partitionInfo) val properties = new Properties() // Additional properties that we will pass to getConnection parameters.foreach(kv => properties.setProperty(kv._1, kv._2)) JDBCRelation(jdbcOptions.url, jdbcOptions.table, parts, properties)(sqlContext.sparkSession) } + + /* + * The following structure applies to this code: + * | tableExists | !tableExists + *------------------------------------------------------------------------------------ + * Ignore | BaseRelation | CreateTable, saveTable, BaseRelation + * ErrorIfExists | ERROR | CreateTable, saveTable, BaseRelation + * Overwrite* | (DropTable, CreateTable,) | CreateTable, saveTable, BaseRelation + * | saveTable, BaseRelation | + * Append | saveTable, BaseRelation | CreateTable, saveTable, BaseRelation + * + * *Overwrite & tableExists with truncate, will not drop & create, but instead truncate + */ + override def createRelation( + sqlContext: SQLContext, + mode: SaveMode, + parameters: Map[String, String], + data: DataFrame): BaseRelation = { + val jdbcOptions = new JDBCOptions(parameters) + val url = jdbcOptions.url + val table = jdbcOptions.table + + val props = new Properties() + props.putAll(parameters.asJava) + val conn = JdbcUtils.createConnectionFactory(url, props)() + + try { + val tableExists = JdbcUtils.tableExists(conn, url, table) + + val (doCreate, doSave) = (mode, tableExists) match { + case (SaveMode.Ignore, true) => (false, false) + case (SaveMode.ErrorIfExists, true) => throw new AnalysisException( + s"Table or view '$table' already exists, and SaveMode is set to ErrorIfExists.") + case (SaveMode.Overwrite, true) => + if (jdbcOptions.isTruncate && JdbcUtils.isCascadingTruncateTable(url) == Some(false)) { + JdbcUtils.truncateTable(conn, table) + (false, true) + } else { + JdbcUtils.dropTable(conn, table) + (true, true) + } + case (SaveMode.Append, true) => (false, true) + case (_, true) => throw new IllegalArgumentException(s"Unexpected SaveMode, '$mode'," + + " for handling existing tables.") + case (_, false) => (true, true) + } + + if (doCreate) { + val schema = JdbcUtils.schemaString(data, url) + // To allow certain options to append when create a new table, which can be + // table_options or partition_options. + // E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8" + val createtblOptions = jdbcOptions.createTableOptions + val sql = s"CREATE TABLE $table ($schema) $createtblOptions" + val statement = conn.createStatement + try { + statement.executeUpdate(sql) + } finally { + statement.close() + } + } + if (doSave) JdbcUtils.saveTable(data, url, table, props) + } finally { + conn.close() + } + + createRelation(sqlContext, parameters) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index ff3309874f2e1..506971362f867 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.jdbc import java.sql.DriverManager import java.util.Properties +import scala.collection.JavaConverters.propertiesAsScalaMapConverter + import org.scalatest.BeforeAndAfter import org.apache.spark.SparkException @@ -208,4 +210,84 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).count()) assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } + + test("save works for format(\"jdbc\") if url and dbtable are set") { + val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + + df.write.format("jdbc") + .options(Map("url" -> url, "dbtable" -> "TEST.SAVETEST")) + .save() + + assert(2 === sqlContext.read.jdbc(url, "TEST.SAVETEST", new Properties).count) + assert( + 2 === sqlContext.read.jdbc(url, "TEST.SAVETEST", new Properties).collect()(0).length) + } + + test("save API with SaveMode.Overwrite") { + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2) + + df.write.format("jdbc") + .option("url", url1) + .option("dbtable", "TEST.SAVETEST") + .options(properties.asScala) + .save() + df2.write.mode(SaveMode.Overwrite).format("jdbc") + .option("url", url1) + .option("dbtable", "TEST.SAVETEST") + .options(properties.asScala) + .save() + assert(1 === spark.read.jdbc(url1, "TEST.SAVETEST", properties).count()) + assert(2 === spark.read.jdbc(url1, "TEST.SAVETEST", properties).collect()(0).length) + } + + test("save errors if url is not specified") { + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + + val e = intercept[RuntimeException] { + df.write.format("jdbc") + .option("dbtable", "TEST.SAVETEST") + .options(properties.asScala) + .save() + }.getMessage + assert(e.contains("Option 'url' is required")) + } + + test("save errors if dbtable is not specified") { + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + + val e = intercept[RuntimeException] { + df.write.format("jdbc") + .option("url", url1) + .options(properties.asScala) + .save() + }.getMessage + assert(e.contains("Option 'dbtable' is required")) + } + + test("save errors if wrong user/password combination") { + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + + val e = intercept[org.h2.jdbc.JdbcSQLException] { + df.write.format("jdbc") + .option("dbtable", "TEST.SAVETEST") + .option("url", url1) + .save() + }.getMessage + assert(e.contains("Wrong user name or password")) + } + + test("save errors if partitionColumn and numPartitions and bounds not set") { + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + + val e = intercept[java.lang.IllegalArgumentException] { + df.write.format("jdbc") + .option("dbtable", "TEST.SAVETEST") + .option("url", url1) + .option("partitionColumn", "foo") + .save() + }.getMessage + assert(e.contains("If 'partitionColumn' is specified then 'lowerBound', 'upperBound'," + + " and 'numPartitions' are required.")) + } } From f234b7cd795dd9baa3feff541c211b4daf39ccc6 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 26 Sep 2016 04:19:39 -0700 Subject: [PATCH 11/96] [SPARK-16356][ML] Add testImplicits for ML unit tests and promote toDF() ## What changes were proposed in this pull request? This was suggested in https://github.com/apache/spark/commit/101663f1ae222a919fc40510aa4f2bad22d1be6f#commitcomment-17114968. This PR adds `testImplicits` to `MLlibTestSparkContext` so that some implicits such as `toDF()` can be sued across ml tests. This PR also changes all the usages of `spark.createDataFrame( ... )` to `toDF()` where applicable in ml tests in Scala. ## How was this patch tested? Existing tests should work. Author: hyukjinkwon Closes #14035 from HyukjinKwon/minor-ml-test. --- .../org/apache/spark/ml/PipelineSuite.scala | 13 +- .../ml/classification/ClassifierSuite.scala | 16 +-- .../DecisionTreeClassifierSuite.scala | 3 +- .../classification/GBTClassifierSuite.scala | 6 +- .../LogisticRegressionSuite.scala | 43 +++--- .../MultilayerPerceptronClassifierSuite.scala | 26 ++-- .../ml/classification/NaiveBayesSuite.scala | 20 +-- .../ml/classification/OneVsRestSuite.scala | 4 +- .../RandomForestClassifierSuite.scala | 3 +- .../apache/spark/ml/clustering/LDASuite.scala | 6 +- .../BinaryClassificationEvaluatorSuite.scala | 14 +- .../evaluation/RegressionEvaluatorSuite.scala | 8 +- .../spark/ml/feature/BinarizerSuite.scala | 16 +-- .../spark/ml/feature/BucketizerSuite.scala | 15 +-- .../spark/ml/feature/ChiSqSelectorSuite.scala | 3 +- .../ml/feature/CountVectorizerSuite.scala | 30 +++-- .../apache/spark/ml/feature/DCTSuite.scala | 10 +- .../spark/ml/feature/HashingTFSuite.scala | 10 +- .../apache/spark/ml/feature/IDFSuite.scala | 6 +- .../spark/ml/feature/InteractionSuite.scala | 53 ++++---- .../spark/ml/feature/MaxAbsScalerSuite.scala | 5 +- .../spark/ml/feature/MinMaxScalerSuite.scala | 13 +- .../apache/spark/ml/feature/NGramSuite.scala | 35 +++-- .../spark/ml/feature/NormalizerSuite.scala | 4 +- .../spark/ml/feature/OneHotEncoderSuite.scala | 10 +- .../apache/spark/ml/feature/PCASuite.scala | 4 +- .../ml/feature/PolynomialExpansionSuite.scala | 11 +- .../spark/ml/feature/RFormulaSuite.scala | 126 ++++++++---------- .../ml/feature/SQLTransformerSuite.scala | 8 +- .../ml/feature/StandardScalerSuite.scala | 12 +- .../ml/feature/StopWordsRemoverSuite.scala | 29 ++-- .../spark/ml/feature/StringIndexerSuite.scala | 32 ++--- .../spark/ml/feature/TokenizerSuite.scala | 17 +-- .../ml/feature/VectorAssemblerSuite.scala | 10 +- .../spark/ml/feature/VectorIndexerSuite.scala | 15 ++- .../AFTSurvivalRegressionSuite.scala | 26 ++-- .../ml/regression/GBTRegressorSuite.scala | 7 +- .../GeneralizedLinearRegressionSuite.scala | 115 ++++++++-------- .../regression/IsotonicRegressionSuite.scala | 14 +- .../ml/regression/LinearRegressionSuite.scala | 62 ++++----- .../tree/impl/GradientBoostedTreesSuite.scala | 6 +- .../spark/ml/tuning/CrossValidatorSuite.scala | 12 +- .../ml/tuning/TrainValidationSplitSuite.scala | 13 +- .../spark/mllib/util/MLUtilsSuite.scala | 18 +-- .../mllib/util/MLlibTestSparkContext.scala | 13 +- 45 files changed, 462 insertions(+), 460 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 3b490cdf56018..6413ca1f8b19e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -36,6 +36,8 @@ import org.apache.spark.sql.types.StructType class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + abstract class MyModel extends Model[MyModel] test("pipeline") { @@ -183,12 +185,11 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } test("pipeline validateParams") { - val df = spark.createDataFrame( - Seq( - (1, Vectors.dense(0.0, 1.0, 4.0), 1.0), - (2, Vectors.dense(1.0, 0.0, 4.0), 2.0), - (3, Vectors.dense(1.0, 0.0, 5.0), 3.0), - (4, Vectors.dense(0.0, 0.0, 5.0), 4.0)) + val df = Seq( + (1, Vectors.dense(0.0, 1.0, 4.0), 1.0), + (2, Vectors.dense(1.0, 0.0, 4.0), 2.0), + (3, Vectors.dense(1.0, 0.0, 5.0), 3.0), + (4, Vectors.dense(0.0, 0.0, 5.0), 4.0) ).toDF("id", "features", "label") intercept[IllegalArgumentException] { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala index 4db5f03fb00b4..de712079329da 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala @@ -29,12 +29,13 @@ import org.apache.spark.sql.{DataFrame, Dataset} class ClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { - test("extractLabeledPoints") { - def getTestData(labels: Seq[Double]): DataFrame = { - val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) } - spark.createDataFrame(data) - } + import testImplicits._ + + private def getTestData(labels: Seq[Double]): DataFrame = { + labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) }.toDF() + } + test("extractLabeledPoints") { val c = new MockClassifier // Valid dataset val df0 = getTestData(Seq(0.0, 2.0, 1.0, 5.0)) @@ -70,11 +71,6 @@ class ClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { } test("getNumClasses") { - def getTestData(labels: Seq[Double]): DataFrame = { - val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) } - spark.createDataFrame(data) - } - val c = new MockClassifier // Valid dataset val df0 = getTestData(Seq(0.0, 2.0, 1.0, 5.0)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 089d30abb5ef9..c711e7fa9dc67 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -34,6 +34,7 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { import DecisionTreeClassifierSuite.compareAPIs + import testImplicits._ private var categoricalDataPointsRDD: RDD[LabeledPoint] = _ private var orderedLabeledPointsWithLabel0RDD: RDD[LabeledPoint] = _ @@ -345,7 +346,7 @@ class DecisionTreeClassifierSuite } test("Fitting without numClasses in metadata") { - val df: DataFrame = spark.createDataFrame(TreeTests.featureImportanceData(sc)) + val df: DataFrame = TreeTests.featureImportanceData(sc).toDF() val dt = new DecisionTreeClassifier().setMaxDepth(1) dt.fit(df) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 8d588ccfd3545..3492709677d4f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -39,6 +39,7 @@ import org.apache.spark.util.Utils class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ import GBTClassifierSuite.compareAPIs // Combinations for estimators, learning rates and subsamplingRate @@ -134,15 +135,14 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext */ test("Fitting without numClasses in metadata") { - val df: DataFrame = spark.createDataFrame(TreeTests.featureImportanceData(sc)) + val df: DataFrame = TreeTests.featureImportanceData(sc).toDF() val gbt = new GBTClassifier().setMaxDepth(1).setMaxIter(1) gbt.fit(df) } test("extractLabeledPoints with bad data") { def getTestData(labels: Seq[Double]): DataFrame = { - val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) } - spark.createDataFrame(data) + labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) }.toDF() } val gbt = new GBTClassifier().setMaxDepth(1).setMaxIter(1) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 2623759f24d91..8451e60144981 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -37,6 +37,8 @@ import org.apache.spark.sql.functions.lit class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + @transient var smallBinaryDataset: Dataset[_] = _ @transient var smallMultinomialDataset: Dataset[_] = _ @transient var binaryDataset: Dataset[_] = _ @@ -46,8 +48,7 @@ class LogisticRegressionSuite override def beforeAll(): Unit = { super.beforeAll() - smallBinaryDataset = - spark.createDataFrame(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42)) + smallBinaryDataset = generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42).toDF() smallMultinomialDataset = { val nPoints = 100 @@ -61,7 +62,7 @@ class LogisticRegressionSuite val testData = generateMultinomialLogisticInput( coefficients, xMean, xVariance, addIntercept = true, nPoints, 42) - val df = spark.createDataFrame(sc.parallelize(testData, 4)) + val df = sc.parallelize(testData, 4).toDF() df.cache() df } @@ -76,7 +77,7 @@ class LogisticRegressionSuite generateMultinomialLogisticInput(coefficients, xMean, xVariance, addIntercept = true, nPoints, 42) - spark.createDataFrame(sc.parallelize(testData, 4)) + sc.parallelize(testData, 4).toDF() } multinomialDataset = { @@ -91,7 +92,7 @@ class LogisticRegressionSuite val testData = generateMultinomialLogisticInput( coefficients, xMean, xVariance, addIntercept = true, nPoints, 42) - val df = spark.createDataFrame(sc.parallelize(testData, 4)) + val df = sc.parallelize(testData, 4).toDF() df.cache() df } @@ -430,10 +431,10 @@ class LogisticRegressionSuite val model = new LogisticRegressionModel("mLogReg", Matrices.dense(3, 2, Array(0.0, 0.0, 0.0, 1.0, 2.0, 3.0)), Vectors.dense(0.0, 0.0, 0.0), 3, true) - val overFlowData = spark.createDataFrame(Seq( + val overFlowData = Seq( LabeledPoint(1.0, Vectors.dense(0.0, 1000.0)), LabeledPoint(1.0, Vectors.dense(0.0, -1.0)) - )) + ).toDF() val results = model.transform(overFlowData).select("rawPrediction", "probability").collect() // probabilities are correct when margins have to be adjusted @@ -1795,9 +1796,9 @@ class LogisticRegressionSuite val numPoints = 40 val outlierData = MLTestingUtils.genClassificationInstancesWithWeightedOutliers(spark, numClasses, numPoints) - val testData = spark.createDataFrame(Array.tabulate[LabeledPoint](numClasses) { i => + val testData = Array.tabulate[LabeledPoint](numClasses) { i => LabeledPoint(i.toDouble, Vectors.dense(i.toDouble)) - }) + }.toSeq.toDF() val lr = new LogisticRegression().setFamily("binomial").setWeightCol("weight") val model = lr.fit(outlierData) val results = model.transform(testData).select("label", "prediction").collect() @@ -1819,9 +1820,9 @@ class LogisticRegressionSuite val numPoints = 40 val outlierData = MLTestingUtils.genClassificationInstancesWithWeightedOutliers(spark, numClasses, numPoints) - val testData = spark.createDataFrame(Array.tabulate[LabeledPoint](numClasses) { i => + val testData = Array.tabulate[LabeledPoint](numClasses) { i => LabeledPoint(i.toDouble, Vectors.dense(i.toDouble)) - }) + }.toSeq.toDF() val mlr = new LogisticRegression().setFamily("multinomial").setWeightCol("weight") val model = mlr.fit(outlierData) val results = model.transform(testData).select("label", "prediction").collect() @@ -1945,11 +1946,10 @@ class LogisticRegressionSuite } test("multiclass logistic regression with all labels the same") { - val constantData = spark.createDataFrame(Seq( + val constantData = Seq( LabeledPoint(4.0, Vectors.dense(0.0)), LabeledPoint(4.0, Vectors.dense(1.0)), - LabeledPoint(4.0, Vectors.dense(2.0))) - ) + LabeledPoint(4.0, Vectors.dense(2.0))).toDF() val mlr = new LogisticRegression().setFamily("multinomial") val model = mlr.fit(constantData) val results = model.transform(constantData) @@ -1961,11 +1961,10 @@ class LogisticRegressionSuite } // force the model to be trained with only one class - val constantZeroData = spark.createDataFrame(Seq( + val constantZeroData = Seq( LabeledPoint(0.0, Vectors.dense(0.0)), LabeledPoint(0.0, Vectors.dense(1.0)), - LabeledPoint(0.0, Vectors.dense(2.0))) - ) + LabeledPoint(0.0, Vectors.dense(2.0))).toDF() val modelZeroLabel = mlr.setFitIntercept(false).fit(constantZeroData) val resultsZero = modelZeroLabel.transform(constantZeroData) resultsZero.select("rawPrediction", "probability", "prediction").collect().foreach { @@ -1990,20 +1989,18 @@ class LogisticRegressionSuite } test("compressed storage") { - val moreClassesThanFeatures = spark.createDataFrame(Seq( + val moreClassesThanFeatures = Seq( LabeledPoint(4.0, Vectors.dense(0.0, 0.0, 0.0)), LabeledPoint(4.0, Vectors.dense(1.0, 1.0, 1.0)), - LabeledPoint(4.0, Vectors.dense(2.0, 2.0, 2.0))) - ) + LabeledPoint(4.0, Vectors.dense(2.0, 2.0, 2.0))).toDF() val mlr = new LogisticRegression().setFamily("multinomial") val model = mlr.fit(moreClassesThanFeatures) assert(model.coefficientMatrix.isInstanceOf[SparseMatrix]) assert(model.coefficientMatrix.asInstanceOf[SparseMatrix].colPtrs.length === 4) - val moreFeaturesThanClasses = spark.createDataFrame(Seq( + val moreFeaturesThanClasses = Seq( LabeledPoint(1.0, Vectors.dense(0.0, 0.0, 0.0)), LabeledPoint(1.0, Vectors.dense(1.0, 1.0, 1.0)), - LabeledPoint(1.0, Vectors.dense(2.0, 2.0, 2.0))) - ) + LabeledPoint(1.0, Vectors.dense(2.0, 2.0, 2.0))).toDF() val model2 = mlr.fit(moreFeaturesThanClasses) assert(model2.coefficientMatrix.isInstanceOf[SparseMatrix]) assert(model2.coefficientMatrix.asInstanceOf[SparseMatrix].colPtrs.length === 3) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index e809dd4092afa..c08cb695806d0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -33,16 +33,18 @@ import org.apache.spark.sql.{Dataset, Row} class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + @transient var dataset: Dataset[_] = _ override def beforeAll(): Unit = { super.beforeAll() - dataset = spark.createDataFrame(Seq( - (Vectors.dense(0.0, 0.0), 0.0), - (Vectors.dense(0.0, 1.0), 1.0), - (Vectors.dense(1.0, 0.0), 1.0), - (Vectors.dense(1.0, 1.0), 0.0)) + dataset = Seq( + (Vectors.dense(0.0, 0.0), 0.0), + (Vectors.dense(0.0, 1.0), 1.0), + (Vectors.dense(1.0, 0.0), 1.0), + (Vectors.dense(1.0, 1.0), 0.0) ).toDF("features", "label") } @@ -80,11 +82,11 @@ class MultilayerPerceptronClassifierSuite } test("Test setWeights by training restart") { - val dataFrame = spark.createDataFrame(Seq( + val dataFrame = Seq( (Vectors.dense(0.0, 0.0), 0.0), (Vectors.dense(0.0, 1.0), 1.0), (Vectors.dense(1.0, 0.0), 1.0), - (Vectors.dense(1.0, 1.0), 0.0)) + (Vectors.dense(1.0, 1.0), 0.0) ).toDF("features", "label") val layers = Array[Int](2, 5, 2) val trainer = new MultilayerPerceptronClassifier() @@ -114,9 +116,9 @@ class MultilayerPerceptronClassifierSuite val xMean = Array(5.843, 3.057, 3.758, 1.199) val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) // the input seed is somewhat magic, to make this test pass - val rdd = sc.parallelize(generateMultinomialLogisticInput( - coefficients, xMean, xVariance, true, nPoints, 1), 2) - val dataFrame = spark.createDataFrame(rdd).toDF("label", "features") + val data = generateMultinomialLogisticInput( + coefficients, xMean, xVariance, true, nPoints, 1).toDS() + val dataFrame = data.toDF("label", "features") val numClasses = 3 val numIterations = 100 val layers = Array[Int](4, 5, 4, numClasses) @@ -137,9 +139,9 @@ class MultilayerPerceptronClassifierSuite .setNumClasses(numClasses) lr.optimizer.setRegParam(0.0) .setNumIterations(numIterations) - val lrModel = lr.run(rdd.map(OldLabeledPoint.fromML)) + val lrModel = lr.run(data.rdd.map(OldLabeledPoint.fromML)) val lrPredictionAndLabels = - lrModel.predict(rdd.map(p => OldVectors.fromML(p.features))).zip(rdd.map(_.label)) + lrModel.predict(data.rdd.map(p => OldVectors.fromML(p.features))).zip(data.rdd.map(_.label)) // MLP's predictions should not differ a lot from LR's. val lrMetrics = new MulticlassMetrics(lrPredictionAndLabels) val mlpMetrics = new MulticlassMetrics(mlpPredictionAndLabels) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 04c010bd13e1e..99099324284dc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -35,6 +35,8 @@ import org.apache.spark.sql.{DataFrame, Dataset, Row} class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + @transient var dataset: Dataset[_] = _ override def beforeAll(): Unit = { @@ -47,7 +49,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa Array(0.10, 0.10, 0.70, 0.10) // label 2 ).map(_.map(math.log)) - dataset = spark.createDataFrame(generateNaiveBayesInput(pi, theta, 100, 42)) + dataset = generateNaiveBayesInput(pi, theta, 100, 42).toDF() } def validatePrediction(predictionAndLabels: DataFrame): Unit = { @@ -131,16 +133,16 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val pi = Vectors.dense(piArray) val theta = new DenseMatrix(3, 4, thetaArray.flatten, true) - val testDataset = spark.createDataFrame(generateNaiveBayesInput( - piArray, thetaArray, nPoints, 42, "multinomial")) + val testDataset = + generateNaiveBayesInput(piArray, thetaArray, nPoints, 42, "multinomial").toDF() val nb = new NaiveBayes().setSmoothing(1.0).setModelType("multinomial") val model = nb.fit(testDataset) validateModelFit(pi, theta, model) assert(model.hasParent) - val validationDataset = spark.createDataFrame(generateNaiveBayesInput( - piArray, thetaArray, nPoints, 17, "multinomial")) + val validationDataset = + generateNaiveBayesInput(piArray, thetaArray, nPoints, 17, "multinomial").toDF() val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") validatePrediction(predictionAndLabels) @@ -161,16 +163,16 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val pi = Vectors.dense(piArray) val theta = new DenseMatrix(3, 12, thetaArray.flatten, true) - val testDataset = spark.createDataFrame(generateNaiveBayesInput( - piArray, thetaArray, nPoints, 45, "bernoulli")) + val testDataset = + generateNaiveBayesInput(piArray, thetaArray, nPoints, 45, "bernoulli").toDF() val nb = new NaiveBayes().setSmoothing(1.0).setModelType("bernoulli") val model = nb.fit(testDataset) validateModelFit(pi, theta, model) assert(model.hasParent) - val validationDataset = spark.createDataFrame(generateNaiveBayesInput( - piArray, thetaArray, nPoints, 20, "bernoulli")) + val validationDataset = + generateNaiveBayesInput(piArray, thetaArray, nPoints, 20, "bernoulli").toDF() val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") validatePrediction(predictionAndLabels) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 99dd5854ff649..3f9bcec427399 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -37,6 +37,8 @@ import org.apache.spark.sql.types.Metadata class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + @transient var dataset: Dataset[_] = _ @transient var rdd: RDD[LabeledPoint] = _ @@ -55,7 +57,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) rdd = sc.parallelize(generateMultinomialLogisticInput( coefficients, xMean, xVariance, true, nPoints, 42), 2) - dataset = spark.createDataFrame(rdd) + dataset = rdd.toDF() } test("params") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 2e99ee157ae95..44e1585ee514b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -39,6 +39,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { import RandomForestClassifierSuite.compareAPIs + import testImplicits._ private var orderedLabeledPoints50_1000: RDD[LabeledPoint] = _ private var orderedLabeledPoints5_20: RDD[LabeledPoint] = _ @@ -158,7 +159,7 @@ class RandomForestClassifierSuite } test("Fitting without numClasses in metadata") { - val df: DataFrame = spark.createDataFrame(TreeTests.featureImportanceData(sc)) + val df: DataFrame = TreeTests.featureImportanceData(sc).toDF() val rf = new RandomForestClassifier().setMaxDepth(1).setNumTrees(1) rf.fit(df) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index ddfa87555427b..3f39deddf20b4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -62,6 +62,8 @@ object LDASuite { class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + val k: Int = 5 val vocabSize: Int = 30 @transient var dataset: Dataset[_] = _ @@ -140,8 +142,8 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead new LDA().setTopicConcentration(-1.1) } - val dummyDF = spark.createDataFrame(Seq( - (1, Vectors.dense(1.0, 2.0)))).toDF("id", "features") + val dummyDF = Seq((1, Vectors.dense(1.0, 2.0))).toDF("id", "features") + // validate parameters lda.transformSchema(dummyDF.schema) lda.setDocConcentration(1.1) diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala index 9ee3df5eb5e33..ede284712b1c0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala @@ -26,6 +26,8 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext class BinaryClassificationEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + test("params") { ParamsSuite.checkParams(new BinaryClassificationEvaluator) } @@ -42,25 +44,25 @@ class BinaryClassificationEvaluatorSuite val evaluator = new BinaryClassificationEvaluator() .setMetricName("areaUnderPR") - val vectorDF = spark.createDataFrame(Seq( + val vectorDF = Seq( (0d, Vectors.dense(12, 2.5)), (1d, Vectors.dense(1, 3)), (0d, Vectors.dense(10, 2)) - )).toDF("label", "rawPrediction") + ).toDF("label", "rawPrediction") assert(evaluator.evaluate(vectorDF) === 1.0) - val doubleDF = spark.createDataFrame(Seq( + val doubleDF = Seq( (0d, 0d), (1d, 1d), (0d, 0d) - )).toDF("label", "rawPrediction") + ).toDF("label", "rawPrediction") assert(evaluator.evaluate(doubleDF) === 1.0) - val stringDF = spark.createDataFrame(Seq( + val stringDF = Seq( (0d, "0d"), (1d, "1d"), (0d, "0d") - )).toDF("label", "rawPrediction") + ).toDF("label", "rawPrediction") val thrown = intercept[IllegalArgumentException] { evaluator.evaluate(stringDF) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala index 42ff8adf6bd65..c1a156959618e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala @@ -27,6 +27,8 @@ import org.apache.spark.mllib.util.TestingUtils._ class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + test("params") { ParamsSuite.checkParams(new RegressionEvaluator) } @@ -42,9 +44,9 @@ class RegressionEvaluatorSuite * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)) * .saveAsTextFile("path") */ - val dataset = spark.createDataFrame( - sc.parallelize(LinearDataGenerator.generateLinearInput( - 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2).map(_.asML)) + val dataset = LinearDataGenerator.generateLinearInput( + 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1) + .map(_.asML).toDF() /** * Using the following R code to load the data, train the model and evaluate metrics. diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala index 9cb84a6ee9b87..4455d35210878 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala @@ -26,6 +26,8 @@ import org.apache.spark.sql.{DataFrame, Row} class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + @transient var data: Array[Double] = _ override def beforeAll(): Unit = { @@ -39,8 +41,7 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau test("Binarize continuous features with default parameter") { val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0) - val dataFrame: DataFrame = spark.createDataFrame( - data.zip(defaultBinarized)).toDF("feature", "expected") + val dataFrame: DataFrame = data.zip(defaultBinarized).toSeq.toDF("feature", "expected") val binarizer: Binarizer = new Binarizer() .setInputCol("feature") @@ -55,8 +56,7 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau test("Binarize continuous features with setter") { val threshold: Double = 0.2 val thresholdBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0) - val dataFrame: DataFrame = spark.createDataFrame( - data.zip(thresholdBinarized)).toDF("feature", "expected") + val dataFrame: DataFrame = data.zip(thresholdBinarized).toSeq.toDF("feature", "expected") val binarizer: Binarizer = new Binarizer() .setInputCol("feature") @@ -71,9 +71,9 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau test("Binarize vector of continuous features with default parameter") { val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0) - val dataFrame: DataFrame = spark.createDataFrame(Seq( + val dataFrame: DataFrame = Seq( (Vectors.dense(data), Vectors.dense(defaultBinarized)) - )).toDF("feature", "expected") + ).toDF("feature", "expected") val binarizer: Binarizer = new Binarizer() .setInputCol("feature") @@ -88,9 +88,9 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau test("Binarize vector of continuous features with setter") { val threshold: Double = 0.2 val defaultBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0) - val dataFrame: DataFrame = spark.createDataFrame(Seq( + val dataFrame: DataFrame = Seq( (Vectors.dense(data), Vectors.dense(defaultBinarized)) - )).toDF("feature", "expected") + ).toDF("feature", "expected") val binarizer: Binarizer = new Binarizer() .setInputCol("feature") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index c7f5093e74740..87cdceb267387 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -29,6 +29,8 @@ import org.apache.spark.sql.{DataFrame, Row} class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + test("params") { ParamsSuite.checkParams(new Bucketizer) } @@ -38,8 +40,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val splits = Array(-0.5, 0.0, 0.5) val validData = Array(-0.5, -0.3, 0.0, 0.2) val expectedBuckets = Array(0.0, 0.0, 1.0, 1.0) - val dataFrame: DataFrame = - spark.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected") + val dataFrame: DataFrame = validData.zip(expectedBuckets).toSeq.toDF("feature", "expected") val bucketizer: Bucketizer = new Bucketizer() .setInputCol("feature") @@ -55,13 +56,13 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa // Check for exceptions when using a set of invalid feature values. val invalidData1: Array[Double] = Array(-0.9) ++ validData val invalidData2 = Array(0.51) ++ validData - val badDF1 = spark.createDataFrame(invalidData1.zipWithIndex).toDF("feature", "idx") + val badDF1 = invalidData1.zipWithIndex.toSeq.toDF("feature", "idx") withClue("Invalid feature value -0.9 was not caught as an invalid feature!") { intercept[SparkException] { bucketizer.transform(badDF1).collect() } } - val badDF2 = spark.createDataFrame(invalidData2.zipWithIndex).toDF("feature", "idx") + val badDF2 = invalidData2.zipWithIndex.toSeq.toDF("feature", "idx") withClue("Invalid feature value 0.51 was not caught as an invalid feature!") { intercept[SparkException] { bucketizer.transform(badDF2).collect() @@ -73,8 +74,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity) val validData = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9) val expectedBuckets = Array(0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0) - val dataFrame: DataFrame = - spark.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected") + val dataFrame: DataFrame = validData.zip(expectedBuckets).toSeq.toDF("feature", "expected") val bucketizer: Bucketizer = new Bucketizer() .setInputCol("feature") @@ -92,8 +92,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity) val validData = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9, Double.NaN, Double.NaN, Double.NaN) val expectedBuckets = Array(0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 4.0) - val dataFrame: DataFrame = - spark.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected") + val dataFrame: DataFrame = validData.zip(expectedBuckets).toSeq.toDF("feature", "expected") val bucketizer: Bucketizer = new Bucketizer() .setInputCol("feature") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala index 6b56e4200250c..dfebfc87ea1d3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -29,8 +29,7 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("Test Chi-Square selector") { - val spark = this.spark - import spark.implicits._ + import testImplicits._ val data = Seq( LabeledPoint(0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0)))), LabeledPoint(1.0, Vectors.sparse(3, Array((1, 9.0), (2, 6.0)))), diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala index 863b66bf497fe..69d3033bb2189 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala @@ -27,6 +27,8 @@ import org.apache.spark.sql.Row class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + test("params") { ParamsSuite.checkParams(new CountVectorizer) ParamsSuite.checkParams(new CountVectorizerModel(Array("empty"))) @@ -35,7 +37,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext private def split(s: String): Seq[String] = s.split("\\s+") test("CountVectorizerModel common cases") { - val df = spark.createDataFrame(Seq( + val df = Seq( (0, split("a b c d"), Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)))), (1, split("a b b c d a"), @@ -44,7 +46,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext (3, split(""), Vectors.sparse(4, Seq())), // empty string (4, split("a notInDict d"), Vectors.sparse(4, Seq((0, 1.0), (3, 1.0)))) // with words not in vocabulary - )).toDF("id", "words", "expected") + ).toDF("id", "words", "expected") val cv = new CountVectorizerModel(Array("a", "b", "c", "d")) .setInputCol("words") .setOutputCol("features") @@ -55,13 +57,13 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext } test("CountVectorizer common cases") { - val df = spark.createDataFrame(Seq( + val df = Seq( (0, split("a b c d e"), Vectors.sparse(5, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0), (4, 1.0)))), (1, split("a a a a a a"), Vectors.sparse(5, Seq((0, 6.0)))), (2, split("c c"), Vectors.sparse(5, Seq((2, 2.0)))), (3, split("d"), Vectors.sparse(5, Seq((3, 1.0)))), - (4, split("b b b b b"), Vectors.sparse(5, Seq((1, 5.0))))) + (4, split("b b b b b"), Vectors.sparse(5, Seq((1, 5.0)))) ).toDF("id", "words", "expected") val cv = new CountVectorizer() .setInputCol("words") @@ -76,11 +78,11 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext } test("CountVectorizer vocabSize and minDF") { - val df = spark.createDataFrame(Seq( + val df = Seq( (0, split("a b c d"), Vectors.sparse(2, Seq((0, 1.0), (1, 1.0)))), (1, split("a b c"), Vectors.sparse(2, Seq((0, 1.0), (1, 1.0)))), (2, split("a b"), Vectors.sparse(2, Seq((0, 1.0), (1, 1.0)))), - (3, split("a"), Vectors.sparse(2, Seq((0, 1.0))))) + (3, split("a"), Vectors.sparse(2, Seq((0, 1.0)))) ).toDF("id", "words", "expected") val cvModel = new CountVectorizer() .setInputCol("words") @@ -118,9 +120,9 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext test("CountVectorizer throws exception when vocab is empty") { intercept[IllegalArgumentException] { - val df = spark.createDataFrame(Seq( + val df = Seq( (0, split("a a b b c c")), - (1, split("aa bb cc"))) + (1, split("aa bb cc")) ).toDF("id", "words") val cvModel = new CountVectorizer() .setInputCol("words") @@ -132,11 +134,11 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext } test("CountVectorizerModel with minTF count") { - val df = spark.createDataFrame(Seq( + val df = Seq( (0, split("a a a b b c c c d "), Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))), (1, split("c c c c c c"), Vectors.sparse(4, Seq((2, 6.0)))), (2, split("a"), Vectors.sparse(4, Seq())), - (3, split("e e e e e"), Vectors.sparse(4, Seq()))) + (3, split("e e e e e"), Vectors.sparse(4, Seq())) ).toDF("id", "words", "expected") // minTF: count @@ -151,11 +153,11 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext } test("CountVectorizerModel with minTF freq") { - val df = spark.createDataFrame(Seq( + val df = Seq( (0, split("a a a b b c c c d "), Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))), (1, split("c c c c c c"), Vectors.sparse(4, Seq((2, 6.0)))), (2, split("a"), Vectors.sparse(4, Seq((0, 1.0)))), - (3, split("e e e e e"), Vectors.sparse(4, Seq()))) + (3, split("e e e e e"), Vectors.sparse(4, Seq())) ).toDF("id", "words", "expected") // minTF: set frequency @@ -170,12 +172,12 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext } test("CountVectorizerModel and CountVectorizer with binary") { - val df = spark.createDataFrame(Seq( + val df = Seq( (0, split("a a a a b b b b c d"), Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)))), (1, split("c c c"), Vectors.sparse(4, Seq((2, 1.0)))), (2, split("a"), Vectors.sparse(4, Seq((0, 1.0)))) - )).toDF("id", "words", "expected") + ).toDF("id", "words", "expected") // CountVectorizer test val cv = new CountVectorizer() diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala index c02e9610418bf..8dd3dd75e1be5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala @@ -32,6 +32,8 @@ case class DCTTestData(vec: Vector, wantedVec: Vector) class DCTSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + test("forward transform of discrete cosine matches jTransforms result") { val data = Vectors.dense((0 until 128).map(_ => 2D * math.random - 1D).toArray) val inverse = false @@ -57,15 +59,13 @@ class DCTSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead private def testDCT(data: Vector, inverse: Boolean): Unit = { val expectedResultBuffer = data.toArray.clone() if (inverse) { - (new DoubleDCT_1D(data.size)).inverse(expectedResultBuffer, true) + new DoubleDCT_1D(data.size).inverse(expectedResultBuffer, true) } else { - (new DoubleDCT_1D(data.size)).forward(expectedResultBuffer, true) + new DoubleDCT_1D(data.size).forward(expectedResultBuffer, true) } val expectedResult = Vectors.dense(expectedResultBuffer) - val dataset = spark.createDataFrame(Seq( - DCTTestData(data, expectedResult) - )) + val dataset = Seq(DCTTestData(data, expectedResult)).toDF() val transformer = new DCT() .setInputCol("vec") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala index 99b800776bb64..1d14866cc933b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala @@ -29,14 +29,14 @@ import org.apache.spark.util.Utils class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + test("params") { ParamsSuite.checkParams(new HashingTF) } test("hashingTF") { - val df = spark.createDataFrame(Seq( - (0, "a a b b c d".split(" ").toSeq) - )).toDF("id", "words") + val df = Seq((0, "a a b b c d".split(" ").toSeq)).toDF("id", "words") val n = 100 val hashingTF = new HashingTF() .setInputCol("words") @@ -54,9 +54,7 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with Defau } test("applying binary term freqs") { - val df = spark.createDataFrame(Seq( - (0, "a a b c c c".split(" ").toSeq) - )).toDF("id", "words") + val df = Seq((0, "a a b c c c".split(" ").toSeq)).toDF("id", "words") val n = 100 val hashingTF = new HashingTF() .setInputCol("words") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala index 09dc8b9b932fd..5325d95526a50 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala @@ -29,6 +29,8 @@ import org.apache.spark.sql.Row class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + def scaleDataWithIDF(dataSet: Array[Vector], model: Vector): Array[Vector] = { dataSet.map { case data: DenseVector => @@ -61,7 +63,7 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead }) val expected = scaleDataWithIDF(data, idf) - val df = spark.createDataFrame(data.zip(expected)).toDF("features", "expected") + val df = data.zip(expected).toSeq.toDF("features", "expected") val idfModel = new IDF() .setInputCol("features") @@ -87,7 +89,7 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead }) val expected = scaleDataWithIDF(data, idf) - val df = spark.createDataFrame(data.zip(expected)).toDF("features", "expected") + val df = data.zip(expected).toSeq.toDF("features", "expected") val idfModel = new IDF() .setInputCol("features") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala index 3429172a8c903..54f059e5f143e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala @@ -28,6 +28,9 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.functions.col class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + import testImplicits._ + test("params") { ParamsSuite.checkParams(new Interaction()) } @@ -59,11 +62,10 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def } test("numeric interaction") { - val data = spark.createDataFrame( - Seq( - (2, Vectors.dense(3.0, 4.0)), - (1, Vectors.dense(1.0, 5.0))) - ).toDF("a", "b") + val data = Seq( + (2, Vectors.dense(3.0, 4.0)), + (1, Vectors.dense(1.0, 5.0)) + ).toDF("a", "b") val groupAttr = new AttributeGroup( "b", Array[Attribute]( @@ -74,11 +76,10 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def col("b").as("b", groupAttr.toMetadata())) val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features") val res = trans.transform(df) - val expected = spark.createDataFrame( - Seq( - (2, Vectors.dense(3.0, 4.0), Vectors.dense(6.0, 8.0)), - (1, Vectors.dense(1.0, 5.0), Vectors.dense(1.0, 5.0))) - ).toDF("a", "b", "features") + val expected = Seq( + (2, Vectors.dense(3.0, 4.0), Vectors.dense(6.0, 8.0)), + (1, Vectors.dense(1.0, 5.0), Vectors.dense(1.0, 5.0)) + ).toDF("a", "b", "features") assert(res.collect() === expected.collect()) val attrs = AttributeGroup.fromStructField(res.schema("features")) val expectedAttrs = new AttributeGroup( @@ -90,11 +91,10 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def } test("nominal interaction") { - val data = spark.createDataFrame( - Seq( - (2, Vectors.dense(3.0, 4.0)), - (1, Vectors.dense(1.0, 5.0))) - ).toDF("a", "b") + val data = Seq( + (2, Vectors.dense(3.0, 4.0)), + (1, Vectors.dense(1.0, 5.0)) + ).toDF("a", "b") val groupAttr = new AttributeGroup( "b", Array[Attribute]( @@ -106,11 +106,10 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def col("b").as("b", groupAttr.toMetadata())) val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features") val res = trans.transform(df) - val expected = spark.createDataFrame( - Seq( - (2, Vectors.dense(3.0, 4.0), Vectors.dense(0, 0, 0, 0, 3, 4)), - (1, Vectors.dense(1.0, 5.0), Vectors.dense(0, 0, 1, 5, 0, 0))) - ).toDF("a", "b", "features") + val expected = Seq( + (2, Vectors.dense(3.0, 4.0), Vectors.dense(0, 0, 0, 0, 3, 4)), + (1, Vectors.dense(1.0, 5.0), Vectors.dense(0, 0, 1, 5, 0, 0)) + ).toDF("a", "b", "features") assert(res.collect() === expected.collect()) val attrs = AttributeGroup.fromStructField(res.schema("features")) val expectedAttrs = new AttributeGroup( @@ -126,10 +125,9 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def } test("default attr names") { - val data = spark.createDataFrame( - Seq( + val data = Seq( (2, Vectors.dense(0.0, 4.0), 1.0), - (1, Vectors.dense(1.0, 5.0), 10.0)) + (1, Vectors.dense(1.0, 5.0), 10.0) ).toDF("a", "b", "c") val groupAttr = new AttributeGroup( "b", @@ -142,11 +140,10 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def col("c").as("c", NumericAttribute.defaultAttr.toMetadata())) val trans = new Interaction().setInputCols(Array("a", "b", "c")).setOutputCol("features") val res = trans.transform(df) - val expected = spark.createDataFrame( - Seq( - (2, Vectors.dense(0.0, 4.0), 1.0, Vectors.dense(0, 0, 0, 0, 0, 0, 1, 0, 4)), - (1, Vectors.dense(1.0, 5.0), 10.0, Vectors.dense(0, 0, 0, 0, 10, 50, 0, 0, 0))) - ).toDF("a", "b", "c", "features") + val expected = Seq( + (2, Vectors.dense(0.0, 4.0), 1.0, Vectors.dense(0, 0, 0, 0, 0, 0, 1, 0, 4)), + (1, Vectors.dense(1.0, 5.0), 10.0, Vectors.dense(0, 0, 0, 0, 10, 50, 0, 0, 0)) + ).toDF("a", "b", "c", "features") assert(res.collect() === expected.collect()) val attrs = AttributeGroup.fromStructField(res.schema("features")) val expectedAttrs = new AttributeGroup( diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala index d6400ee02f951..a12174493b867 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala @@ -23,6 +23,9 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row class MaxAbsScalerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + import testImplicits._ + test("MaxAbsScaler fit basic case") { val data = Array( Vectors.dense(1, 0, 100), @@ -36,7 +39,7 @@ class MaxAbsScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De Vectors.sparse(3, Array(0, 2), Array(-1, -1)), Vectors.sparse(3, Array(0), Array(-0.75))) - val df = spark.createDataFrame(data.zip(expected)).toDF("features", "expected") + val df = data.zip(expected).toSeq.toDF("features", "expected") val scaler = new MaxAbsScaler() .setInputCol("features") .setOutputCol("scaled") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala index 9f376b70035c5..b79eeb2d75ef0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala @@ -25,6 +25,8 @@ import org.apache.spark.sql.Row class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + test("MinMaxScaler fit basic case") { val data = Array( Vectors.dense(1, 0, Long.MinValue), @@ -38,7 +40,7 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De Vectors.sparse(3, Array(0, 2), Array(5, 5)), Vectors.sparse(3, Array(0), Array(-2.5))) - val df = spark.createDataFrame(data.zip(expected)).toDF("features", "expected") + val df = data.zip(expected).toSeq.toDF("features", "expected") val scaler = new MinMaxScaler() .setInputCol("features") .setOutputCol("scaled") @@ -57,14 +59,13 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De test("MinMaxScaler arguments max must be larger than min") { withClue("arguments max must be larger than min") { - val dummyDF = spark.createDataFrame(Seq( - (1, Vectors.dense(1.0, 2.0)))).toDF("id", "feature") + val dummyDF = Seq((1, Vectors.dense(1.0, 2.0))).toDF("id", "features") intercept[IllegalArgumentException] { - val scaler = new MinMaxScaler().setMin(10).setMax(0).setInputCol("feature") + val scaler = new MinMaxScaler().setMin(10).setMax(0).setInputCol("features") scaler.transformSchema(dummyDF.schema) } intercept[IllegalArgumentException] { - val scaler = new MinMaxScaler().setMin(0).setMax(0).setInputCol("feature") + val scaler = new MinMaxScaler().setMin(0).setMax(0).setInputCol("features") scaler.transformSchema(dummyDF.schema) } } @@ -104,7 +105,7 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De Vectors.dense(-1.0, Double.NaN, -5.0, -5.0), Vectors.dense(5.0, 0.0, 5.0, Double.NaN)) - val df = spark.createDataFrame(data.zip(expected)).toDF("features", "expected") + val df = data.zip(expected).toSeq.toDF("features", "expected") val scaler = new MinMaxScaler() .setInputCol("features") .setOutputCol("scaled") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala index e5288d9259d3c..d4975c0b4e20e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala @@ -28,17 +28,18 @@ import org.apache.spark.sql.{Dataset, Row} case class NGramTestData(inputTokens: Array[String], wantedNGrams: Array[String]) class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import org.apache.spark.ml.feature.NGramSuite._ + import testImplicits._ test("default behavior yields bigram features") { val nGram = new NGram() .setInputCol("inputTokens") .setOutputCol("nGrams") - val dataset = spark.createDataFrame(Seq( - NGramTestData( - Array("Test", "for", "ngram", "."), - Array("Test for", "for ngram", "ngram .") - ))) + val dataset = Seq(NGramTestData( + Array("Test", "for", "ngram", "."), + Array("Test for", "for ngram", "ngram .") + )).toDF() testNGram(nGram, dataset) } @@ -47,11 +48,10 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe .setInputCol("inputTokens") .setOutputCol("nGrams") .setN(4) - val dataset = spark.createDataFrame(Seq( - NGramTestData( - Array("a", "b", "c", "d", "e"), - Array("a b c d", "b c d e") - ))) + val dataset = Seq(NGramTestData( + Array("a", "b", "c", "d", "e"), + Array("a b c d", "b c d e") + )).toDF() testNGram(nGram, dataset) } @@ -60,11 +60,7 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe .setInputCol("inputTokens") .setOutputCol("nGrams") .setN(4) - val dataset = spark.createDataFrame(Seq( - NGramTestData( - Array(), - Array() - ))) + val dataset = Seq(NGramTestData(Array(), Array())).toDF() testNGram(nGram, dataset) } @@ -73,11 +69,10 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe .setInputCol("inputTokens") .setOutputCol("nGrams") .setN(6) - val dataset = spark.createDataFrame(Seq( - NGramTestData( - Array("a", "b", "c", "d", "e"), - Array() - ))) + val dataset = Seq(NGramTestData( + Array("a", "b", "c", "d", "e"), + Array() + )).toDF() testNGram(nGram, dataset) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala index b692831714466..c75027fb4553d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala @@ -27,6 +27,8 @@ import org.apache.spark.sql.{DataFrame, Row} class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + @transient var data: Array[Vector] = _ @transient var dataFrame: DataFrame = _ @transient var normalizer: Normalizer = _ @@ -61,7 +63,7 @@ class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa Vectors.sparse(3, Seq()) ) - dataFrame = spark.createDataFrame(sc.parallelize(data, 2).map(NormalizerSuite.FeatureData)) + dataFrame = data.map(NormalizerSuite.FeatureData).toSeq.toDF() normalizer = new Normalizer() .setInputCol("features") .setOutputCol("normalized_features") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala index d41eeec1329c5..c44c6813a94be 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala @@ -30,9 +30,11 @@ import org.apache.spark.sql.types._ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + def stringIndexed(): DataFrame = { - val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) - val df = spark.createDataFrame(data).toDF("id", "label") + val data = Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")) + val df = data.toDF("id", "label") val indexer = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex") @@ -83,7 +85,7 @@ class OneHotEncoderSuite test("input column with ML attribute") { val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large") - val df = spark.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("size") + val df = Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply).toDF("size") .select(col("size").as("size", attr.toMetadata())) val encoder = new OneHotEncoder() .setInputCol("size") @@ -96,7 +98,7 @@ class OneHotEncoderSuite } test("input column without ML attribute") { - val df = spark.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("index") + val df = Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply).toDF("index") val encoder = new OneHotEncoder() .setInputCol("index") .setOutputCol("encoded") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala index ddb51fb1706a7..a60e87590f060 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala @@ -29,6 +29,8 @@ import org.apache.spark.sql.Row class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + test("params") { ParamsSuite.checkParams(new PCA) val mat = Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)).asInstanceOf[DenseMatrix] @@ -50,7 +52,7 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead val pc = mat.computePrincipalComponents(3) val expected = mat.multiply(pc).rows.map(_.asML) - val df = spark.createDataFrame(dataRDD.zip(expected)).toDF("features", "expected") + val df = dataRDD.zip(expected).toDF("features", "expected") val pca = new PCA() .setInputCol("features") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala index 9ecd321b128f6..e4b0ddf98bfad 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala @@ -30,6 +30,8 @@ import org.apache.spark.sql.Row class PolynomialExpansionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + test("params") { ParamsSuite.checkParams(new PolynomialExpansion) } @@ -59,7 +61,7 @@ class PolynomialExpansionSuite Vectors.sparse(19, Array.empty, Array.empty)) test("Polynomial expansion with default parameter") { - val df = spark.createDataFrame(data.zip(twoDegreeExpansion)).toDF("features", "expected") + val df = data.zip(twoDegreeExpansion).toSeq.toDF("features", "expected") val polynomialExpansion = new PolynomialExpansion() .setInputCol("features") @@ -76,7 +78,7 @@ class PolynomialExpansionSuite } test("Polynomial expansion with setter") { - val df = spark.createDataFrame(data.zip(threeDegreeExpansion)).toDF("features", "expected") + val df = data.zip(threeDegreeExpansion).toSeq.toDF("features", "expected") val polynomialExpansion = new PolynomialExpansion() .setInputCol("features") @@ -94,7 +96,7 @@ class PolynomialExpansionSuite } test("Polynomial expansion with degree 1 is identity on vectors") { - val df = spark.createDataFrame(data.zip(data)).toDF("features", "expected") + val df = data.zip(data).toSeq.toDF("features", "expected") val polynomialExpansion = new PolynomialExpansion() .setInputCol("features") @@ -124,8 +126,7 @@ class PolynomialExpansionSuite (Vectors.dense(1.0, 2.0, 3.0, 4.0, 5.0, 6.0), 8007, 12375) ) - val df = spark.createDataFrame(data) - .toDF("features", "expectedPoly10size", "expectedPoly11size") + val df = data.toSeq.toDF("features", "expectedPoly10size", "expectedPoly11size") val t = new PolynomialExpansion() .setInputCol("features") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 0794a049d9cd8..97c268f3d5c97 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -26,22 +26,23 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.types.DoubleType class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + import testImplicits._ + test("params") { ParamsSuite.checkParams(new RFormula()) } test("transform numeric data") { val formula = new RFormula().setFormula("id ~ v1 + v2") - val original = spark.createDataFrame( - Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2") + val original = Seq((0, 1.0, 3.0), (2, 2.0, 5.0)).toDF("id", "v1", "v2") val model = formula.fit(original) val result = model.transform(original) val resultSchema = model.transformSchema(original.schema) - val expected = spark.createDataFrame( - Seq( - (0, 1.0, 3.0, Vectors.dense(1.0, 3.0), 0.0), - (2, 2.0, 5.0, Vectors.dense(2.0, 5.0), 2.0)) - ).toDF("id", "v1", "v2", "features", "label") + val expected = Seq( + (0, 1.0, 3.0, Vectors.dense(1.0, 3.0), 0.0), + (2, 2.0, 5.0, Vectors.dense(2.0, 5.0), 2.0) + ).toDF("id", "v1", "v2", "features", "label") // TODO(ekl) make schema comparisons ignore metadata, to avoid .toString assert(result.schema.toString == resultSchema.toString) assert(resultSchema == expected.schema) @@ -50,7 +51,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("features column already exists") { val formula = new RFormula().setFormula("y ~ x").setFeaturesCol("x") - val original = spark.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y") + val original = Seq((0, 1.0), (2, 2.0)).toDF("x", "y") intercept[IllegalArgumentException] { formula.fit(original) } @@ -58,7 +59,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("label column already exists") { val formula = new RFormula().setFormula("y ~ x").setLabelCol("y") - val original = spark.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y") + val original = Seq((0, 1.0), (2, 2.0)).toDF("x", "y") val model = formula.fit(original) val resultSchema = model.transformSchema(original.schema) assert(resultSchema.length == 3) @@ -67,7 +68,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("label column already exists but is not numeric type") { val formula = new RFormula().setFormula("y ~ x").setLabelCol("y") - val original = spark.createDataFrame(Seq((0, true), (2, false))).toDF("x", "y") + val original = Seq((0, true), (2, false)).toDF("x", "y") val model = formula.fit(original) intercept[IllegalArgumentException] { model.transformSchema(original.schema) @@ -79,7 +80,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("allow missing label column for test datasets") { val formula = new RFormula().setFormula("y ~ x").setLabelCol("label") - val original = spark.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "_not_y") + val original = Seq((0, 1.0), (2, 2.0)).toDF("x", "_not_y") val model = formula.fit(original) val resultSchema = model.transformSchema(original.schema) assert(resultSchema.length == 3) @@ -88,37 +89,32 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } test("allow empty label") { - val original = spark.createDataFrame( - Seq((1, 2.0, 3.0), (4, 5.0, 6.0), (7, 8.0, 9.0)) - ).toDF("id", "a", "b") + val original = Seq((1, 2.0, 3.0), (4, 5.0, 6.0), (7, 8.0, 9.0)).toDF("id", "a", "b") val formula = new RFormula().setFormula("~ a + b") val model = formula.fit(original) val result = model.transform(original) val resultSchema = model.transformSchema(original.schema) - val expected = spark.createDataFrame( - Seq( - (1, 2.0, 3.0, Vectors.dense(2.0, 3.0)), - (4, 5.0, 6.0, Vectors.dense(5.0, 6.0)), - (7, 8.0, 9.0, Vectors.dense(8.0, 9.0))) - ).toDF("id", "a", "b", "features") + val expected = Seq( + (1, 2.0, 3.0, Vectors.dense(2.0, 3.0)), + (4, 5.0, 6.0, Vectors.dense(5.0, 6.0)), + (7, 8.0, 9.0, Vectors.dense(8.0, 9.0)) + ).toDF("id", "a", "b", "features") assert(result.schema.toString == resultSchema.toString) assert(result.collect() === expected.collect()) } test("encodes string terms") { val formula = new RFormula().setFormula("id ~ a + b") - val original = spark.createDataFrame( - Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5)) - ).toDF("id", "a", "b") + val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5)) + .toDF("id", "a", "b") val model = formula.fit(original) val result = model.transform(original) val resultSchema = model.transformSchema(original.schema) - val expected = spark.createDataFrame( - Seq( + val expected = Seq( (1, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), (2, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 2.0), (3, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 3.0), - (4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0)) + (4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0) ).toDF("id", "a", "b", "features", "label") assert(result.schema.toString == resultSchema.toString) assert(result.collect() === expected.collect()) @@ -126,17 +122,16 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("index string label") { val formula = new RFormula().setFormula("id ~ a + b") - val original = spark.createDataFrame( + val original = Seq(("male", "foo", 4), ("female", "bar", 4), ("female", "bar", 5), ("male", "baz", 5)) - ).toDF("id", "a", "b") + .toDF("id", "a", "b") val model = formula.fit(original) val result = model.transform(original) - val expected = spark.createDataFrame( - Seq( + val expected = Seq( ("male", "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), ("female", "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0), ("female", "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 0.0), - ("male", "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 1.0)) + ("male", "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 1.0) ).toDF("id", "a", "b", "features", "label") // assert(result.schema.toString == resultSchema.toString) assert(result.collect() === expected.collect()) @@ -144,9 +139,8 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("attribute generation") { val formula = new RFormula().setFormula("id ~ a + b") - val original = spark.createDataFrame( - Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5)) - ).toDF("id", "a", "b") + val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5)) + .toDF("id", "a", "b") val model = formula.fit(original) val result = model.transform(original) val attrs = AttributeGroup.fromStructField(result.schema("features")) @@ -161,9 +155,8 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("vector attribute generation") { val formula = new RFormula().setFormula("id ~ vec") - val original = spark.createDataFrame( - Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0))) - ).toDF("id", "vec") + val original = Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0))) + .toDF("id", "vec") val model = formula.fit(original) val result = model.transform(original) val attrs = AttributeGroup.fromStructField(result.schema("features")) @@ -177,9 +170,8 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("vector attribute generation with unnamed input attrs") { val formula = new RFormula().setFormula("id ~ vec2") - val base = spark.createDataFrame( - Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0))) - ).toDF("id", "vec") + val base = Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0))) + .toDF("id", "vec") val metadata = new AttributeGroup( "vec2", Array[Attribute]( @@ -199,16 +191,13 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("numeric interaction") { val formula = new RFormula().setFormula("a ~ b:c:d") - val original = spark.createDataFrame( - Seq((1, 2, 4, 2), (2, 3, 4, 1)) - ).toDF("a", "b", "c", "d") + val original = Seq((1, 2, 4, 2), (2, 3, 4, 1)).toDF("a", "b", "c", "d") val model = formula.fit(original) val result = model.transform(original) - val expected = spark.createDataFrame( - Seq( - (1, 2, 4, 2, Vectors.dense(16.0), 1.0), - (2, 3, 4, 1, Vectors.dense(12.0), 2.0)) - ).toDF("a", "b", "c", "d", "features", "label") + val expected = Seq( + (1, 2, 4, 2, Vectors.dense(16.0), 1.0), + (2, 3, 4, 1, Vectors.dense(12.0), 2.0) + ).toDF("a", "b", "c", "d", "features", "label") assert(result.collect() === expected.collect()) val attrs = AttributeGroup.fromStructField(result.schema("features")) val expectedAttrs = new AttributeGroup( @@ -219,20 +208,19 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("factor numeric interaction") { val formula = new RFormula().setFormula("id ~ a:b") - val original = spark.createDataFrame( + val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5), (4, "baz", 5), (4, "baz", 5)) - ).toDF("id", "a", "b") + .toDF("id", "a", "b") val model = formula.fit(original) val result = model.transform(original) - val expected = spark.createDataFrame( - Seq( - (1, "foo", 4, Vectors.dense(0.0, 0.0, 4.0), 1.0), - (2, "bar", 4, Vectors.dense(0.0, 4.0, 0.0), 2.0), - (3, "bar", 5, Vectors.dense(0.0, 5.0, 0.0), 3.0), - (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0), - (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0), - (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0)) - ).toDF("id", "a", "b", "features", "label") + val expected = Seq( + (1, "foo", 4, Vectors.dense(0.0, 0.0, 4.0), 1.0), + (2, "bar", 4, Vectors.dense(0.0, 4.0, 0.0), 2.0), + (3, "bar", 5, Vectors.dense(0.0, 5.0, 0.0), 3.0), + (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0), + (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0), + (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0) + ).toDF("id", "a", "b", "features", "label") assert(result.collect() === expected.collect()) val attrs = AttributeGroup.fromStructField(result.schema("features")) val expectedAttrs = new AttributeGroup( @@ -246,17 +234,15 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("factor factor interaction") { val formula = new RFormula().setFormula("id ~ a:b") - val original = spark.createDataFrame( - Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")) - ).toDF("id", "a", "b") + val original = + Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")).toDF("id", "a", "b") val model = formula.fit(original) val result = model.transform(original) - val expected = spark.createDataFrame( - Seq( - (1, "foo", "zq", Vectors.dense(0.0, 0.0, 1.0, 0.0), 1.0), - (2, "bar", "zq", Vectors.dense(1.0, 0.0, 0.0, 0.0), 2.0), - (3, "bar", "zz", Vectors.dense(0.0, 1.0, 0.0, 0.0), 3.0)) - ).toDF("id", "a", "b", "features", "label") + val expected = Seq( + (1, "foo", "zq", Vectors.dense(0.0, 0.0, 1.0, 0.0), 1.0), + (2, "bar", "zq", Vectors.dense(1.0, 0.0, 0.0, 0.0), 2.0), + (3, "bar", "zz", Vectors.dense(0.0, 1.0, 0.0, 0.0), 3.0) + ).toDF("id", "a", "b", "features", "label") assert(result.collect() === expected.collect()) val attrs = AttributeGroup.fromStructField(result.schema("features")) val expectedAttrs = new AttributeGroup( @@ -295,9 +281,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } } - val dataset = spark.createDataFrame( - Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")) - ).toDF("id", "a", "b") + val dataset = Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")).toDF("id", "a", "b") val rFormula = new RFormula().setFormula("id ~ a:b") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala index 1401ea9c4b431..23464073e6edb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala @@ -26,19 +26,19 @@ import org.apache.spark.sql.types.{LongType, StructField, StructType} class SQLTransformerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + test("params") { ParamsSuite.checkParams(new SQLTransformer()) } test("transform numeric data") { - val original = spark.createDataFrame( - Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2") + val original = Seq((0, 1.0, 3.0), (2, 2.0, 5.0)).toDF("id", "v1", "v2") val sqlTrans = new SQLTransformer().setStatement( "SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__") val result = sqlTrans.transform(original) val resultSchema = sqlTrans.transformSchema(original.schema) - val expected = spark.createDataFrame( - Seq((0, 1.0, 3.0, 4.0, 3.0), (2, 2.0, 5.0, 7.0, 10.0))) + val expected = Seq((0, 1.0, 3.0, 4.0, 3.0), (2, 2.0, 5.0, 7.0, 10.0)) .toDF("id", "v1", "v2", "v3", "v4") assert(result.schema.toString == resultSchema.toString) assert(resultSchema == expected.schema) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala index 827ecb0fadbee..a928f93633011 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala @@ -28,6 +28,8 @@ import org.apache.spark.sql.{DataFrame, Row} class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + @transient var data: Array[Vector] = _ @transient var resWithStd: Array[Vector] = _ @transient var resWithMean: Array[Vector] = _ @@ -73,7 +75,7 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext } test("Standardization with default parameter") { - val df0 = spark.createDataFrame(data.zip(resWithStd)).toDF("features", "expected") + val df0 = data.zip(resWithStd).toSeq.toDF("features", "expected") val standardScaler0 = new StandardScaler() .setInputCol("features") @@ -84,9 +86,9 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext } test("Standardization with setter") { - val df1 = spark.createDataFrame(data.zip(resWithBoth)).toDF("features", "expected") - val df2 = spark.createDataFrame(data.zip(resWithMean)).toDF("features", "expected") - val df3 = spark.createDataFrame(data.zip(data)).toDF("features", "expected") + val df1 = data.zip(resWithBoth).toSeq.toDF("features", "expected") + val df2 = data.zip(resWithMean).toSeq.toDF("features", "expected") + val df3 = data.zip(data).toSeq.toDF("features", "expected") val standardScaler1 = new StandardScaler() .setInputCol("features") @@ -120,7 +122,7 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext Vectors.sparse(3, Array(1, 2), Array(-5.1, 1.0)), Vectors.dense(1.7, -0.6, 3.3) ) - val df = spark.createDataFrame(someSparseData.zip(resWithMean)).toDF("features", "expected") + val df = someSparseData.zip(resWithMean).toSeq.toDF("features", "expected") val standardScaler = new StandardScaler() .setInputCol("features") .setOutputCol("standardized_features") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala index 125ad02ebcc02..957cf58a68f85 100755 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala @@ -37,19 +37,20 @@ class StopWordsRemoverSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { import StopWordsRemoverSuite._ + import testImplicits._ test("StopWordsRemover default") { val remover = new StopWordsRemover() .setInputCol("raw") .setOutputCol("filtered") - val dataSet = spark.createDataFrame(Seq( + val dataSet = Seq( (Seq("test", "test"), Seq("test", "test")), (Seq("a", "b", "c", "d"), Seq("b", "c")), (Seq("a", "the", "an"), Seq()), (Seq("A", "The", "AN"), Seq()), (Seq(null), Seq(null)), (Seq(), Seq()) - )).toDF("raw", "expected") + ).toDF("raw", "expected") testStopWordsRemover(remover, dataSet) } @@ -60,14 +61,14 @@ class StopWordsRemoverSuite .setInputCol("raw") .setOutputCol("filtered") .setStopWords(stopWords) - val dataSet = spark.createDataFrame(Seq( + val dataSet = Seq( (Seq("test", "test"), Seq()), (Seq("a", "b", "c", "d"), Seq("b", "c", "d")), (Seq("a", "the", "an"), Seq()), (Seq("A", "The", "AN"), Seq()), (Seq(null), Seq(null)), (Seq(), Seq()) - )).toDF("raw", "expected") + ).toDF("raw", "expected") testStopWordsRemover(remover, dataSet) } @@ -77,10 +78,10 @@ class StopWordsRemoverSuite .setInputCol("raw") .setOutputCol("filtered") .setCaseSensitive(true) - val dataSet = spark.createDataFrame(Seq( + val dataSet = Seq( (Seq("A"), Seq("A")), (Seq("The", "the"), Seq("The")) - )).toDF("raw", "expected") + ).toDF("raw", "expected") testStopWordsRemover(remover, dataSet) } @@ -98,10 +99,10 @@ class StopWordsRemoverSuite .setInputCol("raw") .setOutputCol("filtered") .setStopWords(stopWords) - val dataSet = spark.createDataFrame(Seq( + val dataSet = Seq( (Seq("acaba", "ama", "biri"), Seq()), (Seq("hep", "her", "scala"), Seq("scala")) - )).toDF("raw", "expected") + ).toDF("raw", "expected") testStopWordsRemover(remover, dataSet) } @@ -112,10 +113,10 @@ class StopWordsRemoverSuite .setInputCol("raw") .setOutputCol("filtered") .setStopWords(stopWords.toArray) - val dataSet = spark.createDataFrame(Seq( + val dataSet = Seq( (Seq("python", "scala", "a"), Seq("python", "scala", "a")), (Seq("Python", "Scala", "swift"), Seq("Python", "Scala", "swift")) - )).toDF("raw", "expected") + ).toDF("raw", "expected") testStopWordsRemover(remover, dataSet) } @@ -126,10 +127,10 @@ class StopWordsRemoverSuite .setInputCol("raw") .setOutputCol("filtered") .setStopWords(stopWords.toArray) - val dataSet = spark.createDataFrame(Seq( + val dataSet = Seq( (Seq("python", "scala", "a"), Seq()), (Seq("Python", "Scala", "swift"), Seq("swift")) - )).toDF("raw", "expected") + ).toDF("raw", "expected") testStopWordsRemover(remover, dataSet) } @@ -148,9 +149,7 @@ class StopWordsRemoverSuite val remover = new StopWordsRemover() .setInputCol("raw") .setOutputCol(outputCol) - val dataSet = spark.createDataFrame(Seq( - (Seq("The", "the", "swift"), Seq("swift")) - )).toDF("raw", outputCol) + val dataSet = Seq((Seq("The", "the", "swift"), Seq("swift"))).toDF("raw", outputCol) val thrown = intercept[IllegalArgumentException] { testStopWordsRemover(remover, dataSet) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index b478fea5e74ec..a6bbb944a1bd7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -29,6 +29,8 @@ import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructTy class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + test("params") { ParamsSuite.checkParams(new StringIndexer) val model = new StringIndexerModel("indexer", Array("a", "b")) @@ -38,8 +40,8 @@ class StringIndexerSuite } test("StringIndexer") { - val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) - val df = spark.createDataFrame(data).toDF("id", "label") + val data = Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")) + val df = data.toDF("id", "label") val indexer = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex") @@ -61,10 +63,10 @@ class StringIndexerSuite } test("StringIndexerUnseen") { - val data = sc.parallelize(Seq((0, "a"), (1, "b"), (4, "b")), 2) - val data2 = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c")), 2) - val df = spark.createDataFrame(data).toDF("id", "label") - val df2 = spark.createDataFrame(data2).toDF("id", "label") + val data = Seq((0, "a"), (1, "b"), (4, "b")) + val data2 = Seq((0, "a"), (1, "b"), (2, "c")) + val df = data.toDF("id", "label") + val df2 = data2.toDF("id", "label") val indexer = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex") @@ -92,8 +94,8 @@ class StringIndexerSuite } test("StringIndexer with a numeric input column") { - val data = sc.parallelize(Seq((0, 100), (1, 200), (2, 300), (3, 100), (4, 100), (5, 300)), 2) - val df = spark.createDataFrame(data).toDF("id", "label") + val data = Seq((0, 100), (1, 200), (2, 300), (3, 100), (4, 100), (5, 300)) + val df = data.toDF("id", "label") val indexer = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex") @@ -119,7 +121,7 @@ class StringIndexerSuite } test("StringIndexerModel can't overwrite output column") { - val df = spark.createDataFrame(Seq((1, 2), (3, 4))).toDF("input", "output") + val df = Seq((1, 2), (3, 4)).toDF("input", "output") intercept[IllegalArgumentException] { new StringIndexer() .setInputCol("input") @@ -161,9 +163,7 @@ class StringIndexerSuite test("IndexToString.transform") { val labels = Array("a", "b", "c") - val df0 = spark.createDataFrame(Seq( - (0, "a"), (1, "b"), (2, "c"), (0, "a") - )).toDF("index", "expected") + val df0 = Seq((0, "a"), (1, "b"), (2, "c"), (0, "a")).toDF("index", "expected") val idxToStr0 = new IndexToString() .setInputCol("index") @@ -187,8 +187,8 @@ class StringIndexerSuite } test("StringIndexer, IndexToString are inverses") { - val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) - val df = spark.createDataFrame(data).toDF("id", "label") + val data = Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")) + val df = data.toDF("id", "label") val indexer = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex") @@ -220,8 +220,8 @@ class StringIndexerSuite } test("StringIndexer metadata") { - val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) - val df = spark.createDataFrame(data).toDF("id", "label") + val data = Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")) + val df = data.toDF("id", "label") val indexer = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala index f30bdc3ddc0d7..c895659a2d8be 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala @@ -46,6 +46,7 @@ class RegexTokenizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { import org.apache.spark.ml.feature.RegexTokenizerSuite._ + import testImplicits._ test("params") { ParamsSuite.checkParams(new RegexTokenizer) @@ -57,26 +58,26 @@ class RegexTokenizerSuite .setPattern("\\w+|\\p{Punct}") .setInputCol("rawText") .setOutputCol("tokens") - val dataset0 = spark.createDataFrame(Seq( + val dataset0 = Seq( TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization", ".")), TokenizerTestData("Te,st. punct", Array("te", ",", "st", ".", "punct")) - )) + ).toDF() testRegexTokenizer(tokenizer0, dataset0) - val dataset1 = spark.createDataFrame(Seq( + val dataset1 = Seq( TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization")), TokenizerTestData("Te,st. punct", Array("punct")) - )) + ).toDF() tokenizer0.setMinTokenLength(3) testRegexTokenizer(tokenizer0, dataset1) val tokenizer2 = new RegexTokenizer() .setInputCol("rawText") .setOutputCol("tokens") - val dataset2 = spark.createDataFrame(Seq( + val dataset2 = Seq( TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization.")), TokenizerTestData("Te,st. punct", Array("te,st.", "punct")) - )) + ).toDF() testRegexTokenizer(tokenizer2, dataset2) } @@ -85,10 +86,10 @@ class RegexTokenizerSuite .setInputCol("rawText") .setOutputCol("tokens") .setToLowercase(false) - val dataset = spark.createDataFrame(Seq( + val dataset = Seq( TokenizerTestData("JAVA SCALA", Array("JAVA", "SCALA")), TokenizerTestData("java scala", Array("java", "scala")) - )) + ).toDF() testRegexTokenizer(tokenizer, dataset) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index 561493fbafd6c..46cced3a9a6e5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -29,6 +29,8 @@ import org.apache.spark.sql.functions.col class VectorAssemblerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + test("params") { ParamsSuite.checkParams(new VectorAssembler) } @@ -57,9 +59,9 @@ class VectorAssemblerSuite } test("VectorAssembler") { - val df = spark.createDataFrame(Seq( + val df = Seq( (0, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, Array(1), Array(3.0)), 10L) - )).toDF("id", "x", "y", "name", "z", "n") + ).toDF("id", "x", "y", "name", "z", "n") val assembler = new VectorAssembler() .setInputCols(Array("x", "y", "z", "n")) .setOutputCol("features") @@ -70,7 +72,7 @@ class VectorAssemblerSuite } test("transform should throw an exception in case of unsupported type") { - val df = spark.createDataFrame(Seq(("a", "b", "c"))).toDF("a", "b", "c") + val df = Seq(("a", "b", "c")).toDF("a", "b", "c") val assembler = new VectorAssembler() .setInputCols(Array("a", "b", "c")) .setOutputCol("features") @@ -87,7 +89,7 @@ class VectorAssemblerSuite NominalAttribute.defaultAttr.withName("gender").withValues("male", "female"), NumericAttribute.defaultAttr.withName("salary"))) val row = (1.0, 0.5, 1, Vectors.dense(1.0, 1000.0), Vectors.sparse(2, Array(1), Array(2.0))) - val df = spark.createDataFrame(Seq(row)).toDF("browser", "hour", "count", "user", "ad") + val df = Seq(row).toDF("browser", "hour", "count", "user", "ad") .select( col("browser").as("browser", browser.toMetadata()), col("hour").as("hour", hour.toMetadata()), diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index 707142332349c..4da1b133e8cd5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.DataFrame class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest with Logging { + import testImplicits._ import VectorIndexerSuite.FeatureData // identical, of length 3 @@ -85,11 +86,13 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext checkPair(densePoints1Seq, sparsePoints1Seq) checkPair(densePoints2Seq, sparsePoints2Seq) - densePoints1 = spark.createDataFrame(sc.parallelize(densePoints1Seq, 2).map(FeatureData)) - sparsePoints1 = spark.createDataFrame(sc.parallelize(sparsePoints1Seq, 2).map(FeatureData)) - densePoints2 = spark.createDataFrame(sc.parallelize(densePoints2Seq, 2).map(FeatureData)) - sparsePoints2 = spark.createDataFrame(sc.parallelize(sparsePoints2Seq, 2).map(FeatureData)) - badPoints = spark.createDataFrame(sc.parallelize(badPointsSeq, 2).map(FeatureData)) + densePoints1 = densePoints1Seq.map(FeatureData).toDF() + sparsePoints1 = sparsePoints1Seq.map(FeatureData).toDF() + // TODO: If we directly use `toDF` without parallelize, the test in + // "Throws error when given RDDs with different size vectors" is failed for an unknown reason. + densePoints2 = sc.parallelize(densePoints2Seq, 2).map(FeatureData).toDF() + sparsePoints2 = sparsePoints2Seq.map(FeatureData).toDF() + badPoints = badPointsSeq.map(FeatureData).toDF() } private def getIndexer: VectorIndexer = @@ -102,7 +105,7 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext } test("Cannot fit an empty DataFrame") { - val rdd = spark.createDataFrame(sc.parallelize(Array.empty[Vector], 2).map(FeatureData)) + val rdd = Array.empty[Vector].map(FeatureData).toSeq.toDF() val vectorIndexer = getIndexer intercept[IllegalArgumentException] { vectorIndexer.fit(rdd) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index 1c70b702de063..0fdfdf37cf38d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -31,23 +31,22 @@ import org.apache.spark.sql.{DataFrame, Row} class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + @transient var datasetUnivariate: DataFrame = _ @transient var datasetMultivariate: DataFrame = _ @transient var datasetUnivariateScaled: DataFrame = _ override def beforeAll(): Unit = { super.beforeAll() - datasetUnivariate = spark.createDataFrame( - sc.parallelize(generateAFTInput( - 1, Array(5.5), Array(0.8), 1000, 42, 1.0, 2.0, 2.0))) - datasetMultivariate = spark.createDataFrame( - sc.parallelize(generateAFTInput( - 2, Array(0.9, -1.3), Array(0.7, 1.2), 1000, 42, 1.5, 2.5, 2.0))) - datasetUnivariateScaled = spark.createDataFrame( - sc.parallelize(generateAFTInput( - 1, Array(5.5), Array(0.8), 1000, 42, 1.0, 2.0, 2.0)).map { x => - AFTPoint(Vectors.dense(x.features(0) * 1.0E3), x.label, x.censor) - }) + datasetUnivariate = generateAFTInput( + 1, Array(5.5), Array(0.8), 1000, 42, 1.0, 2.0, 2.0).toDF() + datasetMultivariate = generateAFTInput( + 2, Array(0.9, -1.3), Array(0.7, 1.2), 1000, 42, 1.5, 2.5, 2.0).toDF() + datasetUnivariateScaled = sc.parallelize( + generateAFTInput(1, Array(5.5), Array(0.8), 1000, 42, 1.0, 2.0, 2.0)).map { x => + AFTPoint(Vectors.dense(x.features(0) * 1.0E3), x.label, x.censor) + }.toDF() } /** @@ -396,9 +395,8 @@ class AFTSurvivalRegressionSuite // the parallelism is bigger than that. Because the issue was about `AFTAggregator`s // being merged incorrectly when it has an empty partition, running the codes below // should not throw an exception. - val dataset = spark.createDataFrame( - sc.parallelize(generateAFTInput( - 1, Array(5.5), Array(0.8), 2, 42, 1.0, 2.0, 2.0), numSlices = 3)) + val dataset = sc.parallelize(generateAFTInput( + 1, Array(5.5), Array(0.8), 2, 42, 1.0, 2.0, 2.0), numSlices = 3).toDF() val trainer = new AFTSurvivalRegression() trainer.fit(dataset) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 7b5df8f31bb38..dcf3f9a1ea9b2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -37,6 +37,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { import GBTRegressorSuite.compareAPIs + import testImplicits._ // Combinations for estimators, learning rates and subsamplingRate private val testCombinations = @@ -76,14 +77,14 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext } test("GBTRegressor behaves reasonably on toy data") { - val df = spark.createDataFrame(Seq( + val df = Seq( LabeledPoint(10, Vectors.dense(1, 2, 3, 4)), LabeledPoint(-5, Vectors.dense(6, 3, 2, 1)), LabeledPoint(11, Vectors.dense(2, 2, 3, 4)), LabeledPoint(-6, Vectors.dense(6, 4, 2, 1)), LabeledPoint(9, Vectors.dense(1, 2, 6, 4)), LabeledPoint(-4, Vectors.dense(6, 3, 2, 2)) - )) + ).toDF() val gbt = new GBTRegressor() .setMaxDepth(2) .setMaxIter(2) @@ -103,7 +104,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext val path = tempDir.toURI.toString sc.setCheckpointDir(path) - val df = spark.createDataFrame(data) + val df = data.toDF() val gbt = new GBTRegressor() .setMaxDepth(2) .setMaxIter(5) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index d8032c4e1705b..937aa7d3c2045 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -35,6 +35,8 @@ import org.apache.spark.sql.functions._ class GeneralizedLinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + private val seed: Int = 42 @transient var datasetGaussianIdentity: DataFrame = _ @transient var datasetGaussianLog: DataFrame = _ @@ -52,23 +54,20 @@ class GeneralizedLinearRegressionSuite import GeneralizedLinearRegressionSuite._ - datasetGaussianIdentity = spark.createDataFrame( - sc.parallelize(generateGeneralizedLinearRegressionInput( - intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), - xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, - family = "gaussian", link = "identity"), 2)) + datasetGaussianIdentity = generateGeneralizedLinearRegressionInput( + intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "gaussian", link = "identity").toDF() - datasetGaussianLog = spark.createDataFrame( - sc.parallelize(generateGeneralizedLinearRegressionInput( - intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5), - xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, - family = "gaussian", link = "log"), 2)) + datasetGaussianLog = generateGeneralizedLinearRegressionInput( + intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "gaussian", link = "log").toDF() - datasetGaussianInverse = spark.createDataFrame( - sc.parallelize(generateGeneralizedLinearRegressionInput( - intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), - xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, - family = "gaussian", link = "inverse"), 2)) + datasetGaussianInverse = generateGeneralizedLinearRegressionInput( + intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "gaussian", link = "inverse").toDF() datasetBinomial = { val nPoints = 10000 @@ -80,44 +79,38 @@ class GeneralizedLinearRegressionSuite generateMultinomialLogisticInput(coefficients, xMean, xVariance, addIntercept = true, nPoints, seed) - spark.createDataFrame(sc.parallelize(testData, 2)) + testData.toDF() } - datasetPoissonLog = spark.createDataFrame( - sc.parallelize(generateGeneralizedLinearRegressionInput( - intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5), - xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, - family = "poisson", link = "log"), 2)) - - datasetPoissonIdentity = spark.createDataFrame( - sc.parallelize(generateGeneralizedLinearRegressionInput( - intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), - xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, - family = "poisson", link = "identity"), 2)) - - datasetPoissonSqrt = spark.createDataFrame( - sc.parallelize(generateGeneralizedLinearRegressionInput( - intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), - xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, - family = "poisson", link = "sqrt"), 2)) - - datasetGammaInverse = spark.createDataFrame( - sc.parallelize(generateGeneralizedLinearRegressionInput( - intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), - xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, - family = "gamma", link = "inverse"), 2)) - - datasetGammaIdentity = spark.createDataFrame( - sc.parallelize(generateGeneralizedLinearRegressionInput( - intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), - xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, - family = "gamma", link = "identity"), 2)) - - datasetGammaLog = spark.createDataFrame( - sc.parallelize(generateGeneralizedLinearRegressionInput( - intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5), - xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, - family = "gamma", link = "log"), 2)) + datasetPoissonLog = generateGeneralizedLinearRegressionInput( + intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "poisson", link = "log").toDF() + + datasetPoissonIdentity = generateGeneralizedLinearRegressionInput( + intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "poisson", link = "identity").toDF() + + datasetPoissonSqrt = generateGeneralizedLinearRegressionInput( + intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "poisson", link = "sqrt").toDF() + + datasetGammaInverse = generateGeneralizedLinearRegressionInput( + intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "gamma", link = "inverse").toDF() + + datasetGammaIdentity = generateGeneralizedLinearRegressionInput( + intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "gamma", link = "identity").toDF() + + datasetGammaLog = generateGeneralizedLinearRegressionInput( + intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "gamma", link = "log").toDF() } /** @@ -540,12 +533,12 @@ class GeneralizedLinearRegressionSuite w <- c(1, 2, 3, 4) df <- as.data.frame(cbind(A, b)) */ - val datasetWithWeight = spark.createDataFrame(sc.parallelize(Seq( + val datasetWithWeight = Seq( Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), Instance(19.0, 2.0, Vectors.dense(1.0, 7.0)), Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)), Instance(29.0, 4.0, Vectors.dense(3.0, 13.0)) - ), 2)) + ).toDF() /* R code: @@ -668,12 +661,12 @@ class GeneralizedLinearRegressionSuite w <- c(1, 2, 3, 4) df <- as.data.frame(cbind(A, b)) */ - val datasetWithWeight = spark.createDataFrame(sc.parallelize(Seq( + val datasetWithWeight = Seq( Instance(1.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), Instance(0.0, 2.0, Vectors.dense(1.0, 2.0)), Instance(1.0, 3.0, Vectors.dense(2.0, 1.0)), Instance(0.0, 4.0, Vectors.dense(3.0, 3.0)) - ), 2)) + ).toDF() /* R code: @@ -782,12 +775,12 @@ class GeneralizedLinearRegressionSuite w <- c(1, 2, 3, 4) df <- as.data.frame(cbind(A, b)) */ - val datasetWithWeight = spark.createDataFrame(sc.parallelize(Seq( + val datasetWithWeight = Seq( Instance(2.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), Instance(8.0, 2.0, Vectors.dense(1.0, 7.0)), Instance(3.0, 3.0, Vectors.dense(2.0, 11.0)), Instance(9.0, 4.0, Vectors.dense(3.0, 13.0)) - ), 2)) + ).toDF() /* R code: @@ -899,12 +892,12 @@ class GeneralizedLinearRegressionSuite w <- c(1, 2, 3, 4) df <- as.data.frame(cbind(A, b)) */ - val datasetWithWeight = spark.createDataFrame(sc.parallelize(Seq( + val datasetWithWeight = Seq( Instance(2.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), Instance(8.0, 2.0, Vectors.dense(1.0, 7.0)), Instance(3.0, 3.0, Vectors.dense(2.0, 11.0)), Instance(9.0, 4.0, Vectors.dense(3.0, 13.0)) - ), 2)) + ).toDF() /* R code: @@ -1054,12 +1047,12 @@ class GeneralizedLinearRegressionSuite [1] 12.92681 [1] 13.32836 */ - val dataset = spark.createDataFrame(Seq( + val dataset = Seq( LabeledPoint(1, Vectors.dense(5, 0)), LabeledPoint(0, Vectors.dense(2, 1)), LabeledPoint(1, Vectors.dense(1, 2)), LabeledPoint(0, Vectors.dense(3, 3)) - )) + ).toDF() val expected = Seq(12.88188, 12.92681, 13.32836) var idx = 0 diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala index 14d8a4e4e3345..c2c79476e8b2b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala @@ -27,15 +27,15 @@ import org.apache.spark.sql.{DataFrame, Row} class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + private def generateIsotonicInput(labels: Seq[Double]): DataFrame = { - spark.createDataFrame( - labels.zipWithIndex.map { case (label, i) => (label, i.toDouble, 1.0) } - ).toDF("label", "features", "weight") + labels.zipWithIndex.map { case (label, i) => (label, i.toDouble, 1.0) } + .toDF("label", "features", "weight") } private def generatePredictionInput(features: Seq[Double]): DataFrame = { - spark.createDataFrame(features.map(Tuple1.apply)) - .toDF("features") + features.map(Tuple1.apply).toDF("features") } test("isotonic regression predictions") { @@ -145,10 +145,10 @@ class IsotonicRegressionSuite } test("vector features column with feature index") { - val dataset = spark.createDataFrame(Seq( + val dataset = Seq( (4.0, Vectors.dense(0.0, 1.0)), (3.0, Vectors.dense(0.0, 2.0)), - (5.0, Vectors.sparse(2, Array(1), Array(3.0)))) + (5.0, Vectors.sparse(2, Array(1), Array(3.0))) ).toDF("label", "features") val ir = new IsotonicRegression() diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 265f2f45c45fe..5ae371b489aa5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -32,6 +32,8 @@ import org.apache.spark.sql.{DataFrame, Row} class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + private val seed: Int = 42 @transient var datasetWithDenseFeature: DataFrame = _ @transient var datasetWithDenseFeatureWithoutIntercept: DataFrame = _ @@ -42,29 +44,27 @@ class LinearRegressionSuite override def beforeAll(): Unit = { super.beforeAll() - datasetWithDenseFeature = spark.createDataFrame( - sc.parallelize(LinearDataGenerator.generateLinearInput( - intercept = 6.3, weights = Array(4.7, 7.2), xMean = Array(0.9, -1.3), - xVariance = Array(0.7, 1.2), nPoints = 10000, seed, eps = 0.1), 2).map(_.asML)) + datasetWithDenseFeature = sc.parallelize(LinearDataGenerator.generateLinearInput( + intercept = 6.3, weights = Array(4.7, 7.2), xMean = Array(0.9, -1.3), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, eps = 0.1), 2).map(_.asML).toDF() /* datasetWithDenseFeatureWithoutIntercept is not needed for correctness testing but is useful for illustrating training model without intercept */ - datasetWithDenseFeatureWithoutIntercept = spark.createDataFrame( - sc.parallelize(LinearDataGenerator.generateLinearInput( + datasetWithDenseFeatureWithoutIntercept = sc.parallelize( + LinearDataGenerator.generateLinearInput( intercept = 0.0, weights = Array(4.7, 7.2), xMean = Array(0.9, -1.3), - xVariance = Array(0.7, 1.2), nPoints = 10000, seed, eps = 0.1), 2).map(_.asML)) + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, eps = 0.1), 2).map(_.asML).toDF() val r = new Random(seed) // When feature size is larger than 4096, normal optimizer is choosed // as the solver of linear regression in the case of "auto" mode. val featureSize = 4100 - datasetWithSparseFeature = spark.createDataFrame( - sc.parallelize(LinearDataGenerator.generateLinearInput( + datasetWithSparseFeature = sc.parallelize(LinearDataGenerator.generateLinearInput( intercept = 0.0, weights = Seq.fill(featureSize)(r.nextDouble()).toArray, xMean = Seq.fill(featureSize)(r.nextDouble()).toArray, xVariance = Seq.fill(featureSize)(r.nextDouble()).toArray, nPoints = 200, - seed, eps = 0.1, sparsity = 0.7), 2).map(_.asML)) + seed, eps = 0.1, sparsity = 0.7), 2).map(_.asML).toDF() /* R code: @@ -74,13 +74,12 @@ class LinearRegressionSuite w <- c(1, 2, 3, 4) df <- as.data.frame(cbind(A, b)) */ - datasetWithWeight = spark.createDataFrame( - sc.parallelize(Seq( - Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), - Instance(19.0, 2.0, Vectors.dense(1.0, 7.0)), - Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)), - Instance(29.0, 4.0, Vectors.dense(3.0, 13.0)) - ), 2)) + datasetWithWeight = sc.parallelize(Seq( + Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(19.0, 2.0, Vectors.dense(1.0, 7.0)), + Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)), + Instance(29.0, 4.0, Vectors.dense(3.0, 13.0)) + ), 2).toDF() /* R code: @@ -90,20 +89,18 @@ class LinearRegressionSuite w <- c(1, 2, 3, 4) df.const.label <- as.data.frame(cbind(A, b.const)) */ - datasetWithWeightConstantLabel = spark.createDataFrame( - sc.parallelize(Seq( - Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), - Instance(17.0, 2.0, Vectors.dense(1.0, 7.0)), - Instance(17.0, 3.0, Vectors.dense(2.0, 11.0)), - Instance(17.0, 4.0, Vectors.dense(3.0, 13.0)) - ), 2)) - datasetWithWeightZeroLabel = spark.createDataFrame( - sc.parallelize(Seq( - Instance(0.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), - Instance(0.0, 2.0, Vectors.dense(1.0, 7.0)), - Instance(0.0, 3.0, Vectors.dense(2.0, 11.0)), - Instance(0.0, 4.0, Vectors.dense(3.0, 13.0)) - ), 2)) + datasetWithWeightConstantLabel = sc.parallelize(Seq( + Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(17.0, 2.0, Vectors.dense(1.0, 7.0)), + Instance(17.0, 3.0, Vectors.dense(2.0, 11.0)), + Instance(17.0, 4.0, Vectors.dense(3.0, 13.0)) + ), 2).toDF() + datasetWithWeightZeroLabel = sc.parallelize(Seq( + Instance(0.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(0.0, 2.0, Vectors.dense(1.0, 7.0)), + Instance(0.0, 3.0, Vectors.dense(2.0, 11.0)), + Instance(0.0, 4.0, Vectors.dense(3.0, 13.0)) + ), 2).toDF() } /** @@ -839,8 +836,7 @@ class LinearRegressionSuite } val data2 = weightedSignedData ++ weightedNoiseData - (spark.createDataFrame(sc.parallelize(data1, 4)), - spark.createDataFrame(sc.parallelize(data2, 4))) + (sc.parallelize(data1, 4).toDF(), sc.parallelize(data2, 4).toDF()) } val trainer1a = (new LinearRegression).setFitIntercept(true) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala index 5c50a88c8314a..4109a299091dc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala @@ -32,13 +32,15 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext */ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { + import testImplicits._ + test("runWithValidation stops early and performs better on a validation dataset") { // Set numIterations large enough so that it stops early. val numIterations = 20 val trainRdd = sc.parallelize(OldGBTSuite.trainData, 2).map(_.asML) val validateRdd = sc.parallelize(OldGBTSuite.validateData, 2).map(_.asML) - val trainDF = spark.createDataFrame(trainRdd) - val validateDF = spark.createDataFrame(validateRdd) + val trainDF = trainRdd.toDF() + val validateDF = validateRdd.toDF() val algos = Array(Regression, Regression, Classification) val losses = Array(SquaredError, AbsoluteError, LogLoss) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 750dc5bf01e6a..7116265474f22 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -35,12 +35,13 @@ import org.apache.spark.sql.types.StructType class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + @transient var dataset: Dataset[_] = _ override def beforeAll(): Unit = { super.beforeAll() - dataset = spark.createDataFrame( - sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)) + dataset = sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2).toDF() } test("cross validation with logistic regression") { @@ -67,9 +68,10 @@ class CrossValidatorSuite } test("cross validation with linear regression") { - val dataset = spark.createDataFrame( - sc.parallelize(LinearDataGenerator.generateLinearInput( - 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2).map(_.asML)) + val dataset = sc.parallelize( + LinearDataGenerator.generateLinearInput( + 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2) + .map(_.asML).toDF() val trainer = new LinearRegression().setSolver("l-bfgs") val lrParamMaps = new ParamGridBuilder() diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index 9971371e47288..87100ae2e342f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -33,9 +33,11 @@ import org.apache.spark.sql.types.StructType class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + import testImplicits._ + test("train validation with logistic regression") { - val dataset = spark.createDataFrame( - sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)) + val dataset = sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2).toDF() val lr = new LogisticRegression val lrParamMaps = new ParamGridBuilder() @@ -58,9 +60,10 @@ class TrainValidationSplitSuite } test("train validation with linear regression") { - val dataset = spark.createDataFrame( - sc.parallelize(LinearDataGenerator.generateLinearInput( - 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2).map(_.asML)) + val dataset = sc.parallelize( + LinearDataGenerator.generateLinearInput( + 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2) + .map(_.asML).toDF() val trainer = new LinearRegression().setSolver("l-bfgs") val lrParamMaps = new ParamGridBuilder() diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index 6aa93c9076007..e4e9be39ff6f9 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -37,6 +37,8 @@ import org.apache.spark.util.Utils class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { + import testImplicits._ + test("epsilon computation") { assert(1.0 + EPSILON > 1.0, s"EPSILON is too small: $EPSILON.") assert(1.0 + EPSILON / 2.0 === 1.0, s"EPSILON is too big: $EPSILON.") @@ -255,9 +257,7 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { val z = Vectors.dense(4.0) val p = (5.0, z) val w = Vectors.dense(6.0).asML - val df = spark.createDataFrame(Seq( - (0, x, y, p, w) - )).toDF("id", "x", "y", "p", "w") + val df = Seq((0, x, y, p, w)).toDF("id", "x", "y", "p", "w") .withColumn("x", col("x"), metadata) val newDF1 = convertVectorColumnsToML(df) assert(newDF1.schema("x").metadata === metadata, "Metadata should be preserved.") @@ -282,9 +282,7 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { val z = Vectors.dense(4.0).asML val p = (5.0, z) val w = Vectors.dense(6.0) - val df = spark.createDataFrame(Seq( - (0, x, y, p, w) - )).toDF("id", "x", "y", "p", "w") + val df = Seq((0, x, y, p, w)).toDF("id", "x", "y", "p", "w") .withColumn("x", col("x"), metadata) val newDF1 = convertVectorColumnsFromML(df) assert(newDF1.schema("x").metadata === metadata, "Metadata should be preserved.") @@ -309,9 +307,7 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { val z = Matrices.ones(1, 1) val p = (5.0, z) val w = Matrices.dense(1, 1, Array(4.5)).asML - val df = spark.createDataFrame(Seq( - (0, x, y, p, w) - )).toDF("id", "x", "y", "p", "w") + val df = Seq((0, x, y, p, w)).toDF("id", "x", "y", "p", "w") .withColumn("x", col("x"), metadata) val newDF1 = convertMatrixColumnsToML(df) assert(newDF1.schema("x").metadata === metadata, "Metadata should be preserved.") @@ -336,9 +332,7 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { val z = Matrices.ones(1, 1).asML val p = (5.0, z) val w = Matrices.dense(1, 1, Array(4.5)) - val df = spark.createDataFrame(Seq( - (0, x, y, p, w) - )).toDF("id", "x", "y", "p", "w") + val df = Seq((0, x, y, p, w)).toDF("id", "x", "y", "p", "w") .withColumn("x", col("x"), metadata) val newDF1 = convertMatrixColumnsFromML(df) assert(newDF1.schema("x").metadata === metadata, "Metadata should be preserved.") diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala index db56aff63102c..6bb7ed9c9513c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala @@ -23,7 +23,7 @@ import org.scalatest.Suite import org.apache.spark.SparkContext import org.apache.spark.ml.util.TempDirectory -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{SparkSession, SQLContext, SQLImplicits} import org.apache.spark.util.Utils trait MLlibTestSparkContext extends TempDirectory { self: Suite => @@ -55,4 +55,15 @@ trait MLlibTestSparkContext extends TempDirectory { self: Suite => super.afterAll() } } + + /** + * A helper object for importing SQL implicits. + * + * Note that the alternative of importing `spark.implicits._` is not possible here. + * This is because we create the [[SQLContext]] immediately before the first test is run, + * but the implicits import is needed in the constructor. + */ + protected object testImplicits extends SQLImplicits { + protected override def _sqlContext: SQLContext = self.spark.sqlContext + } } From bde85f8b70138a51052b613664facbc981378c38 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 26 Sep 2016 10:44:35 -0700 Subject: [PATCH 12/96] [SPARK-17649][CORE] Log how many Spark events got dropped in LiveListenerBus ## What changes were proposed in this pull request? Log how many Spark events got dropped in LiveListenerBus so that the user can get insights on how to set a correct event queue size. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #15220 from zsxwing/SPARK-17649. --- .../spark/scheduler/LiveListenerBus.scala | 26 ++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala index bfa3c408f2284..5533f7b1f2363 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala @@ -18,7 +18,7 @@ package org.apache.spark.scheduler import java.util.concurrent._ -import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} import scala.util.DynamicVariable @@ -57,6 +57,12 @@ private[spark] class LiveListenerBus(val sparkContext: SparkContext) extends Spa // Indicate if `stop()` is called private val stopped = new AtomicBoolean(false) + /** A counter for dropped events. It will be reset every time we log it. */ + private val droppedEventsCounter = new AtomicLong(0L) + + /** When `droppedEventsCounter` was logged last time in milliseconds. */ + @volatile private var lastReportTimestamp = 0L + // Indicate if we are processing some event // Guarded by `self` private var processingEvent = false @@ -123,6 +129,24 @@ private[spark] class LiveListenerBus(val sparkContext: SparkContext) extends Spa eventLock.release() } else { onDropEvent(event) + droppedEventsCounter.incrementAndGet() + } + + val droppedEvents = droppedEventsCounter.get + if (droppedEvents > 0) { + // Don't log too frequently + if (System.currentTimeMillis() - lastReportTimestamp >= 60 * 1000) { + // There may be multiple threads trying to decrease droppedEventsCounter. + // Use "compareAndSet" to make sure only one thread can win. + // And if another thread is increasing droppedEventsCounter, "compareAndSet" will fail and + // then that thread will update it. + if (droppedEventsCounter.compareAndSet(droppedEvents, 0)) { + val prevLastReportTimestamp = lastReportTimestamp + lastReportTimestamp = System.currentTimeMillis() + logWarning(s"Dropped $droppedEvents SparkListenerEvents since " + + new java.util.Date(prevLastReportTimestamp)) + } + } } } From 8135e0e5ebdb9c7f5ac41c675dc8979a5127a31a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 26 Sep 2016 13:07:11 -0700 Subject: [PATCH 13/96] [SPARK-17153][SQL] Should read partition data when reading new files in filestream without globbing ## What changes were proposed in this pull request? When reading file stream with non-globbing path, the results return data with all `null`s for the partitioned columns. E.g., case class A(id: Int, value: Int) val data = spark.createDataset(Seq( A(1, 1), A(2, 2), A(2, 3)) ) val url = "/tmp/test" data.write.partitionBy("id").parquet(url) spark.read.parquet(url).show +-----+---+ |value| id| +-----+---+ | 2| 2| | 3| 2| | 1| 1| +-----+---+ val s = spark.readStream.schema(spark.read.load(url).schema).parquet(url) s.writeStream.queryName("test").format("memory").start() sql("SELECT * FROM test").show +-----+----+ |value| id| +-----+----+ | 2|null| | 3|null| | 1|null| +-----+----+ ## How was this patch tested? Jenkins tests. Author: Liang-Chi Hsieh Author: Liang-Chi Hsieh Closes #14803 from viirya/filestreamsource-option. --- .../structured-streaming-programming-guide.md | 6 ++ .../execution/datasources/DataSource.scala | 7 +- .../streaming/FileStreamSource.scala | 9 +- .../sql/streaming/FileStreamSourceSuite.scala | 83 ++++++++++++++++++- .../spark/sql/streaming/StreamTest.scala | 8 ++ 5 files changed, 110 insertions(+), 3 deletions(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index c7ed3b04bced1..2e6df94823d38 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -512,6 +512,12 @@ csvDF = spark \ These examples generate streaming DataFrames that are untyped, meaning that the schema of the DataFrame is not checked at compile time, only checked at runtime when the query is submitted. Some operations like `map`, `flatMap`, etc. need the type to be known at compile time. To do those, you can convert these untyped streaming DataFrames to typed streaming Datasets using the same methods as static DataFrame. See the [SQL Programming Guide](sql-programming-guide.html) for more details. Additionally, more details on the supported streaming sources are discussed later in the document. +### Schema inference and partition of streaming DataFrames/Datasets + +By default, Structured Streaming from file based sources requires you to specify the schema, rather than rely on Spark to infer it automatically. This restriction ensures a consistent schema will be used for the streaming query, even in the case of failures. For ad-hoc use cases, you can reenable schema inference by setting `spark.sql.streaming.schemaInference` to `true`. + +Partition discovery does occur when subdirectories that are named `/key=value/` are present and listing will automatically recurse into these directories. If these columns appear in the user provided schema, they will be filled in by Spark based on the path of the file being read. The directories that make up the partitioning scheme must be present when the query starts and must remain static. For example, it is okay to add `/data/year=2016/` when `/data/year=2015/` was present, but it is invalid to change the partitioning column (i.e. by creating the directory `/data/date=2016-04-17/`). + ## Operations on streaming DataFrames/Datasets You can apply all kinds of operations on streaming DataFrames/Datasets – ranging from untyped, SQL-like operations (e.g. `select`, `where`, `groupBy`), to typed RDD-like operations (e.g. `map`, `filter`, `flatMap`). See the [SQL programming guide](sql-programming-guide.html) for more details. Let’s take a look at a few example operations that you can use. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 32067011c3dff..e75e7d2770b4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -197,10 +197,15 @@ case class DataSource( SparkHadoopUtil.get.globPathIfNecessary(qualified) }.toArray val fileCatalog = new ListingFileCatalog(sparkSession, globbedPaths, options, None) - format.inferSchema( + val partitionCols = fileCatalog.partitionSpec().partitionColumns.fields + val inferred = format.inferSchema( sparkSession, caseInsensitiveOptions, fileCatalog.allFiles()) + + inferred.map { inferredSchema => + StructType(inferredSchema ++ partitionCols) + } }.getOrElse { throw new AnalysisException("Unable to infer schema. It must be specified manually.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index be023273db2f2..614a6261e7c28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -47,6 +47,13 @@ class FileStreamSource( fs.makeQualified(new Path(path)) // can contains glob patterns } + private val optionsWithPartitionBasePath = sourceOptions.optionMapWithoutPath ++ { + if (!SparkHadoopUtil.get.isGlobPath(new Path(path)) && options.contains("path")) { + Map("basePath" -> path) + } else { + Map() + }} + private val metadataLog = new FileStreamSourceLog(FileStreamSourceLog.VERSION, sparkSession, metadataPath) private var maxBatchId = metadataLog.getLatest().map(_._1).getOrElse(-1L) @@ -136,7 +143,7 @@ class FileStreamSource( paths = files.map(_.path), userSpecifiedSchema = Some(schema), className = fileFormatClassName, - options = sourceOptions.optionMapWithoutPath) + options = optionsWithPartitionBasePath) Dataset.ofRows(sparkSession, LogicalRelation(newDataSource.resolveRelation( checkFilesExist = false))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 55c95ae285c1b..3157afe5a56c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -102,6 +102,12 @@ class FileStreamSourceTest extends StreamTest with SharedSQLContext with Private } } + case class DeleteFile(file: File) extends ExternalAction { + def runAction(): Unit = { + Utils.deleteRecursively(file) + } + } + /** Use `format` and `path` to create FileStreamSource via DataFrameReader */ def createFileStream( format: String, @@ -608,6 +614,81 @@ class FileStreamSourceSuite extends FileStreamSourceTest { // =============== other tests ================ + test("read new files in partitioned table without globbing, should read partition data") { + withTempDirs { case (dir, tmp) => + val partitionFooSubDir = new File(dir, "partition=foo") + val partitionBarSubDir = new File(dir, "partition=bar") + + val schema = new StructType().add("value", StringType).add("partition", StringType) + val fileStream = createFileStream("json", s"${dir.getCanonicalPath}", Some(schema)) + val filtered = fileStream.filter($"value" contains "keep") + testStream(filtered)( + // Create new partition=foo sub dir and write to it + AddTextFileData("{'value': 'drop1'}\n{'value': 'keep2'}", partitionFooSubDir, tmp), + CheckAnswer(("keep2", "foo")), + + // Append to same partition=foo sub dir + AddTextFileData("{'value': 'keep3'}", partitionFooSubDir, tmp), + CheckAnswer(("keep2", "foo"), ("keep3", "foo")), + + // Create new partition sub dir and write to it + AddTextFileData("{'value': 'keep4'}", partitionBarSubDir, tmp), + CheckAnswer(("keep2", "foo"), ("keep3", "foo"), ("keep4", "bar")), + + // Append to same partition=bar sub dir + AddTextFileData("{'value': 'keep5'}", partitionBarSubDir, tmp), + CheckAnswer(("keep2", "foo"), ("keep3", "foo"), ("keep4", "bar"), ("keep5", "bar")) + ) + } + } + + test("when schema inference is turned on, should read partition data") { + def createFile(content: String, src: File, tmp: File): Unit = { + val tempFile = Utils.tempFileWith(new File(tmp, "text")) + val finalFile = new File(src, tempFile.getName) + src.mkdirs() + require(stringToFile(tempFile, content).renameTo(finalFile)) + } + + withSQLConf(SQLConf.STREAMING_SCHEMA_INFERENCE.key -> "true") { + withTempDirs { case (dir, tmp) => + val partitionFooSubDir = new File(dir, "partition=foo") + val partitionBarSubDir = new File(dir, "partition=bar") + + // Create file in partition, so we can infer the schema. + createFile("{'value': 'drop0'}", partitionFooSubDir, tmp) + + val fileStream = createFileStream("json", s"${dir.getCanonicalPath}") + val filtered = fileStream.filter($"value" contains "keep") + testStream(filtered)( + // Append to same partition=foo sub dir + AddTextFileData("{'value': 'drop1'}\n{'value': 'keep2'}", partitionFooSubDir, tmp), + CheckAnswer(("keep2", "foo")), + + // Append to same partition=foo sub dir + AddTextFileData("{'value': 'keep3'}", partitionFooSubDir, tmp), + CheckAnswer(("keep2", "foo"), ("keep3", "foo")), + + // Create new partition sub dir and write to it + AddTextFileData("{'value': 'keep4'}", partitionBarSubDir, tmp), + CheckAnswer(("keep2", "foo"), ("keep3", "foo"), ("keep4", "bar")), + + // Append to same partition=bar sub dir + AddTextFileData("{'value': 'keep5'}", partitionBarSubDir, tmp), + CheckAnswer(("keep2", "foo"), ("keep3", "foo"), ("keep4", "bar"), ("keep5", "bar")), + + // Delete the two partition dirs + DeleteFile(partitionFooSubDir), + DeleteFile(partitionBarSubDir), + + AddTextFileData("{'value': 'keep6'}", partitionBarSubDir, tmp), + CheckAnswer(("keep2", "foo"), ("keep3", "foo"), ("keep4", "bar"), ("keep5", "bar"), + ("keep6", "bar")) + ) + } + } + } + test("fault tolerance") { withTempDirs { case (src, tmp) => val fileStream = createFileStream("text", src.getCanonicalPath) @@ -792,7 +873,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest { } assert(src.listFiles().size === numFiles) - val files = spark.readStream.text(root.getCanonicalPath).as[String] + val files = spark.readStream.text(root.getCanonicalPath).as[(String, Int)] // Note this query will use constant folding to eliminate the file scan. // This is to avoid actually running a Spark job with 10000 tasks diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 6c5b170d9c7c3..aa6515bc7a909 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -95,6 +95,11 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { def addData(query: Option[StreamExecution]): (Source, Offset) } + /** A trait that can be extended when testing a source. */ + trait ExternalAction extends StreamAction { + def runAction(): Unit + } + case class AddDataMemory[A](source: MemoryStream[A], data: Seq[A]) extends AddData { override def toString: String = s"AddData to $source: ${data.mkString(",")}" @@ -429,6 +434,9 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { failTest("Error adding data", e) } + case e: ExternalAction => + e.runAction() + case CheckAnswerRows(expectedAnswer, lastOnly, isSorted) => verify(currentStream != null, "stream not running") // Get the map of source index to the current source objects From 7c7586aef9243081d02ea5065435234b5950ab66 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Mon, 26 Sep 2016 13:21:08 -0700 Subject: [PATCH 14/96] [SPARK-17652] Fix confusing exception message while reserving capacity ## What changes were proposed in this pull request? This minor patch fixes a confusing exception message while reserving additional capacity in the vectorized parquet reader. ## How was this patch tested? Exisiting Unit Tests Author: Sameer Agarwal Closes #15225 from sameeragarwal/error-msg. --- .../sql/execution/vectorized/ColumnVector.java | 14 +++++++------- .../execution/vectorized/ColumnarBatchSuite.scala | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index a7cb3b11f687a..ff07940422a0b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -285,19 +285,19 @@ public void reserve(int requiredCapacity) { try { reserveInternal(newCapacity); } catch (OutOfMemoryError outOfMemoryError) { - throwUnsupportedException(newCapacity, requiredCapacity, outOfMemoryError); + throwUnsupportedException(requiredCapacity, outOfMemoryError); } } else { - throwUnsupportedException(newCapacity, requiredCapacity, null); + throwUnsupportedException(requiredCapacity, null); } } } - private void throwUnsupportedException(int newCapacity, int requiredCapacity, Throwable cause) { - String message = "Cannot reserve more than " + newCapacity + - " bytes in the vectorized reader (requested = " + requiredCapacity + " bytes). As a" + - " workaround, you can disable the vectorized reader by setting " - + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() + " to false."; + private void throwUnsupportedException(int requiredCapacity, Throwable cause) { + String message = "Cannot reserve additional contiguous bytes in the vectorized reader " + + "(requested = " + requiredCapacity + " bytes). As a workaround, you can disable the " + + "vectorized reader by setting " + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() + + " to false."; if (cause != null) { throw new RuntimeException(message, cause); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 100cc4daca875..e3943f31a48ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -802,8 +802,8 @@ class ColumnarBatchSuite extends SparkFunSuite { // Over-allocating beyond MAX_CAPACITY throws an exception column.appendBytes(10, 0.toByte) } - assert(ex.getMessage.contains(s"Cannot reserve more than ${column.MAX_CAPACITY} bytes in " + - s"the vectorized reader")) + assert(ex.getMessage.contains(s"Cannot reserve additional contiguous bytes in the " + + s"vectorized reader")) } } } From 00be16df642317137f17d2d7d2887c41edac3680 Mon Sep 17 00:00:00 2001 From: Andrew Mills Date: Mon, 26 Sep 2016 16:41:10 -0400 Subject: [PATCH 15/96] [Docs] Update spark-standalone.md to fix link Corrected a link to the configuration.html page, it was pointing to a page that does not exist (configurations.html). Documentation change, verified in preview. Author: Andrew Mills Closes #15244 from ammills01/master. --- docs/spark-standalone.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 1097f1fabef6c..7b82b957d5299 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -348,7 +348,7 @@ Learn more about getting started with ZooKeeper [here](http://zookeeper.apache.o **Configuration** In order to enable this recovery mode, you can set SPARK_DAEMON_JAVA_OPTS in spark-env by configuring `spark.deploy.recoveryMode` and related spark.deploy.zookeeper.* configurations. -For more information about these configurations please refer to the configurations (doc)[configurations.html#deploy] +For more information about these configurations please refer to the [configuration doc](configuration.html#deploy) Possible gotcha: If you have multiple Masters in your cluster but fail to correctly configure the Masters to use ZooKeeper, the Masters will fail to discover each other and think they're all leaders. This will not lead to a healthy cluster state (as all Masters will schedule independently). From 93c743f1aca433144611b11d4e1b169d66e0f57b Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 26 Sep 2016 16:47:57 -0700 Subject: [PATCH 16/96] [SPARK-17577][FOLLOW-UP][SPARKR] SparkR spark.addFile supports adding directory recursively ## What changes were proposed in this pull request? #15140 exposed ```JavaSparkContext.addFile(path: String, recursive: Boolean)``` to Python/R, then we can update SparkR ```spark.addFile``` to support adding directory recursively. ## How was this patch tested? Added unit test. Author: Yanbo Liang Closes #15216 from yanboliang/spark-17577-2. --- R/pkg/R/context.R | 9 +++++++-- R/pkg/inst/tests/testthat/test_context.R | 22 ++++++++++++++++++++++ 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 4793578ad684e..fe2f3e3d10a9b 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -231,17 +231,22 @@ setCheckpointDir <- function(sc, dirName) { #' filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, #' use spark.getSparkFiles(fileName) to find its download location. #' +#' A directory can be given if the recursive option is set to true. +#' Currently directories are only supported for Hadoop-supported filesystems. +#' Refer Hadoop-supported filesystems at \url{https://wiki.apache.org/hadoop/HCFS}. +#' #' @rdname spark.addFile #' @param path The path of the file to be added +#' @param recursive Whether to add files recursively from the path. Default is FALSE. #' @export #' @examples #'\dontrun{ #' spark.addFile("~/myfile") #'} #' @note spark.addFile since 2.1.0 -spark.addFile <- function(path) { +spark.addFile <- function(path, recursive = FALSE) { sc <- getSparkContext() - invisible(callJMethod(sc, "addFile", suppressWarnings(normalizePath(path)))) + invisible(callJMethod(sc, "addFile", suppressWarnings(normalizePath(path)), recursive)) } #' Get the root directory that contains files added through spark.addFile. diff --git a/R/pkg/inst/tests/testthat/test_context.R b/R/pkg/inst/tests/testthat/test_context.R index 0495418bb7779..caca06933952b 100644 --- a/R/pkg/inst/tests/testthat/test_context.R +++ b/R/pkg/inst/tests/testthat/test_context.R @@ -169,6 +169,7 @@ test_that("spark.lapply should perform simple transforms", { test_that("add and get file to be downloaded with Spark job on every node", { sparkR.sparkContext() + # Test add file. path <- tempfile(pattern = "hello", fileext = ".txt") filename <- basename(path) words <- "Hello World!" @@ -177,5 +178,26 @@ test_that("add and get file to be downloaded with Spark job on every node", { download_path <- spark.getSparkFiles(filename) expect_equal(readLines(download_path), words) unlink(path) + + # Test add directory recursively. + path <- paste0(tempdir(), "/", "recursive_dir") + dir.create(path) + dir_name <- basename(path) + path1 <- paste0(path, "/", "hello.txt") + file.create(path1) + sub_path <- paste0(path, "/", "sub_hello") + dir.create(sub_path) + path2 <- paste0(sub_path, "/", "sub_hello.txt") + file.create(path2) + words <- "Hello World!" + sub_words <- "Sub Hello World!" + writeLines(words, path1) + writeLines(sub_words, path2) + spark.addFile(path, recursive = TRUE) + download_path1 <- spark.getSparkFiles(paste0(dir_name, "/", "hello.txt")) + expect_equal(readLines(download_path1), words) + download_path2 <- spark.getSparkFiles(paste0(dir_name, "/", "sub_hello/sub_hello.txt")) + expect_equal(readLines(download_path2), sub_words) + unlink(path, recursive = TRUE) sparkR.session.stop() }) From 6ee28423ad1b2e6089b82af64a31d77d3552bb38 Mon Sep 17 00:00:00 2001 From: Ding Fei Date: Mon, 26 Sep 2016 23:09:51 -0700 Subject: [PATCH 17/96] Fix two comments since Actor is not used anymore. ## What changes were proposed in this pull request? Fix two comments since Actor is not used anymore. Author: Ding Fei Closes #15251 from danix800/comment-fixing. --- .../scala/org/apache/spark/deploy/worker/WorkerWatcher.scala | 3 ++- .../test/scala/org/apache/spark/MapOutputTrackerSuite.scala | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index af29de3b0896e..23efcab6caad1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -21,7 +21,8 @@ import org.apache.spark.internal.Logging import org.apache.spark.rpc._ /** - * Actor which connects to a worker process and terminates the JVM if the connection is severed. + * Endpoint which connects to a worker process and terminates the JVM if the + * connection is severed. * Provides fate sharing between a worker and its associated child processes. */ private[spark] class WorkerWatcher( diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index c6aebc19fd12d..bb24c6ce4d33c 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -253,7 +253,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { rpcEnv.stop(masterTracker.trackerEndpoint) rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint) - // Frame size should be ~1.1MB, and MapOutputTrackerMasterActor should throw exception. + // Frame size should be ~1.1MB, and MapOutputTrackerMasterEndpoint should throw exception. // Note that the size is hand-selected here because map output statuses are compressed before // being sent. masterTracker.registerShuffle(20, 100) From 85b0a157543201895557d66306b38b3ca52f2151 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 27 Sep 2016 14:18:32 +0800 Subject: [PATCH 18/96] [SPARK-15962][SQL] Introduce implementation with a dense format for UnsafeArrayData ## What changes were proposed in this pull request? This PR introduces more compact representation for ```UnsafeArrayData```. ```UnsafeArrayData``` needs to accept ```null``` value in each entry of an array. In the current version, it has three parts ``` [numElements] [offsets] [values] ``` `Offsets` has the number of `numElements`, and represents `null` if its value is negative. It may increase memory footprint, and introduces an indirection for accessing each of `values`. This PR uses bitvectors to represent nullability for each element like `UnsafeRow`, and eliminates an indirection for accessing each element. The new ```UnsafeArrayData``` has four parts. ``` [numElements][null bits][values or offset&length][variable length portion] ``` In the `null bits` region, we store 1 bit per element, represents whether an element is null. Its total size is ceil(numElements / 8) bytes, and it is aligned to 8-byte boundaries. In the `values or offset&length` region, we store the content of elements. For fields that hold fixed-length primitive types, such as long, double, or int, we store the value directly in the field. For fields with non-primitive or variable-length values, we store a relative offset (w.r.t. the base address of the array) that points to the beginning of the variable-length field and length (they are combined into a long). Each is word-aligned. For `variable length portion`, each is aligned to 8-byte boundaries. The new format can reduce memory footprint and improve performance of accessing each element. An example of memory foot comparison: 1024x1024 elements integer array Size of ```baseObject``` for ```UnsafeArrayData```: 8 + 1024x1024 + 1024x1024 = 2M bytes Size of ```baseObject``` for ```UnsafeArrayData```: 8 + 1024x1024/8 + 1024x1024 = 1.25M bytes In summary, we got 1.0-2.6x performance improvements over the code before applying this PR. Here are performance results of [benchmark programs](https://github.com/kiszk/spark/blob/04d2e4b6dbdc4eff43ce18b3c9b776e0129257c7/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnsafeArrayDataBenchmark.scala): **Read UnsafeArrayData**: 1.7x and 1.6x performance improvements over the code before applying this PR ```` OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64 Intel Xeon E3-12xx v2 (Ivy Bridge) Without SPARK-15962 Read UnsafeArrayData: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Int 430 / 436 390.0 2.6 1.0X Double 456 / 485 367.8 2.7 0.9X With SPARK-15962 Read UnsafeArrayData: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Int 252 / 260 666.1 1.5 1.0X Double 281 / 292 597.7 1.7 0.9X ```` **Write UnsafeArrayData**: 1.0x and 1.1x performance improvements over the code before applying this PR ```` OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.0.4-301.fc22.x86_64 Intel Xeon E3-12xx v2 (Ivy Bridge) Without SPARK-15962 Write UnsafeArrayData: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Int 203 / 273 103.4 9.7 1.0X Double 239 / 356 87.9 11.4 0.8X With SPARK-15962 Write UnsafeArrayData: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Int 196 / 249 107.0 9.3 1.0X Double 227 / 367 92.3 10.8 0.9X ```` **Get primitive array from UnsafeArrayData**: 2.6x and 1.6x performance improvements over the code before applying this PR ```` OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.0.4-301.fc22.x86_64 Intel Xeon E3-12xx v2 (Ivy Bridge) Without SPARK-15962 Get primitive array from UnsafeArrayData: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Int 207 / 217 304.2 3.3 1.0X Double 257 / 363 245.2 4.1 0.8X With SPARK-15962 Get primitive array from UnsafeArrayData: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Int 151 / 198 415.8 2.4 1.0X Double 214 / 394 293.6 3.4 0.7X ```` **Create UnsafeArrayData from primitive array**: 1.7x and 2.1x performance improvements over the code before applying this PR ```` OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.0.4-301.fc22.x86_64 Intel Xeon E3-12xx v2 (Ivy Bridge) Without SPARK-15962 Create UnsafeArrayData from primitive array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Int 340 / 385 185.1 5.4 1.0X Double 479 / 705 131.3 7.6 0.7X With SPARK-15962 Create UnsafeArrayData from primitive array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Int 206 / 211 306.0 3.3 1.0X Double 232 / 406 271.6 3.7 0.9X ```` 1.7x and 1.4x performance improvements in [```UDTSerializationBenchmark```](https://github.com/apache/spark/blob/master/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala) over the code before applying this PR ```` OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64 Intel Xeon E3-12xx v2 (Ivy Bridge) Without SPARK-15962 VectorUDT de/serialization: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ serialize 442 / 533 0.0 441927.1 1.0X deserialize 217 / 274 0.0 217087.6 2.0X With SPARK-15962 VectorUDT de/serialization: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ serialize 265 / 318 0.0 265138.5 1.0X deserialize 155 / 197 0.0 154611.4 1.7X ```` ## How was this patch tested? Added unit tests into ```UnsafeArraySuite``` Author: Kazuaki Ishizaki Closes #13680 from kiszk/SPARK-15962. --- .../org/apache/spark/unsafe/Platform.java | 4 + .../linalg/UDTSerializationBenchmark.scala | 13 +- .../catalyst/expressions/UnsafeArrayData.java | 269 ++++++++++-------- .../catalyst/expressions/UnsafeMapData.java | 13 +- .../codegen/UnsafeArrayWriter.java | 193 +++++++++---- .../codegen/GenerateUnsafeProjection.scala | 31 +- .../expressions/UnsafeRowConverterSuite.scala | 23 +- .../sql/catalyst/util/UnsafeArraySuite.scala | 195 +++++++++++-- .../sql/execution/columnar/ColumnType.scala | 4 +- .../benchmark/UnsafeArrayDataBenchmark.scala | 232 +++++++++++++++ .../execution/columnar/ColumnTypeSuite.scala | 4 +- 11 files changed, 750 insertions(+), 231 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnsafeArrayDataBenchmark.scala diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index c892b9cdaf49c..671b8c7475943 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -29,6 +29,8 @@ public final class Platform { private static final Unsafe _UNSAFE; + public static final int BOOLEAN_ARRAY_OFFSET; + public static final int BYTE_ARRAY_OFFSET; public static final int SHORT_ARRAY_OFFSET; @@ -235,6 +237,7 @@ public static void throwException(Throwable t) { _UNSAFE = unsafe; if (_UNSAFE != null) { + BOOLEAN_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(boolean[].class); BYTE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(byte[].class); SHORT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(short[].class); INT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(int[].class); @@ -242,6 +245,7 @@ public static void throwException(Throwable t) { FLOAT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(float[].class); DOUBLE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(double[].class); } else { + BOOLEAN_ARRAY_OFFSET = 0; BYTE_ARRAY_OFFSET = 0; SHORT_ARRAY_OFFSET = 0; INT_ARRAY_OFFSET = 0; diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala index 8b439e6b7a017..5973479dfb5ed 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala @@ -57,13 +57,12 @@ object UDTSerializationBenchmark { } /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - - VectorUDT de/serialization: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - serialize 380 / 392 0.0 379730.0 1.0X - deserialize 138 / 142 0.0 137816.6 2.8X + OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64 + Intel Xeon E3-12xx v2 (Ivy Bridge) + VectorUDT de/serialization: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + serialize 265 / 318 0.0 265138.5 1.0X + deserialize 155 / 197 0.0 154611.4 1.7X */ benchmark.run() } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 6302660548ec1..86523c1474015 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -25,6 +25,7 @@ import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.bitset.BitSetMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -32,23 +33,31 @@ /** * An Unsafe implementation of Array which is backed by raw memory instead of Java objects. * - * Each tuple has three parts: [numElements] [offsets] [values] + * Each array has four parts: + * [numElements][null bits][values or offset&length][variable length portion] * - * The `numElements` is 4 bytes storing the number of elements of this array. + * The `numElements` is 8 bytes storing the number of elements of this array. * - * In the `offsets` region, we store 4 bytes per element, represents the relative offset (w.r.t. the - * base address of the array) of this element in `values` region. We can get the length of this - * element by subtracting next offset. - * Note that offset can by negative which means this element is null. + * In the `null bits` region, we store 1 bit per element, represents whether an element is null + * Its total size is ceil(numElements / 8) bytes, and it is aligned to 8-byte boundaries. * - * In the `values` region, we store the content of elements. As we can get length info, so elements - * can be variable-length. + * In the `values or offset&length` region, we store the content of elements. For fields that hold + * fixed-length primitive types, such as long, double, or int, we store the value directly + * in the field. The whole fixed-length portion (even for byte) is aligned to 8-byte boundaries. + * For fields with non-primitive or variable-length values, we store a relative offset + * (w.r.t. the base address of the array) that points to the beginning of the variable-length field + * and length (they are combined into a long). For variable length portion, each is aligned + * to 8-byte boundaries. * * Instances of `UnsafeArrayData` act as pointers to row data stored in this format. */ -// todo: there is a lof of duplicated code between UnsafeRow and UnsafeArrayData. + public final class UnsafeArrayData extends ArrayData { + public static int calculateHeaderPortionInBytes(int numFields) { + return 8 + ((numFields + 63)/ 64) * 8; + } + private Object baseObject; private long baseOffset; @@ -56,24 +65,19 @@ public final class UnsafeArrayData extends ArrayData { private int numElements; // The size of this array's backing data, in bytes. - // The 4-bytes header of `numElements` is also included. + // The 8-bytes header of `numElements` is also included. private int sizeInBytes; - public Object getBaseObject() { return baseObject; } - public long getBaseOffset() { return baseOffset; } - public int getSizeInBytes() { return sizeInBytes; } + /** The position to start storing array elements, */ + private long elementOffset; - private int getElementOffset(int ordinal) { - return Platform.getInt(baseObject, baseOffset + 4 + ordinal * 4L); + private long getElementOffset(int ordinal, int elementSize) { + return elementOffset + ordinal * elementSize; } - private int getElementSize(int offset, int ordinal) { - if (ordinal == numElements - 1) { - return sizeInBytes - offset; - } else { - return Math.abs(getElementOffset(ordinal + 1)) - offset; - } - } + public Object getBaseObject() { return baseObject; } + public long getBaseOffset() { return baseOffset; } + public int getSizeInBytes() { return sizeInBytes; } private void assertIndexIsValid(int ordinal) { assert ordinal >= 0 : "ordinal (" + ordinal + ") should >= 0"; @@ -102,20 +106,22 @@ public UnsafeArrayData() { } * @param sizeInBytes the size of this array's backing data, in bytes */ public void pointTo(Object baseObject, long baseOffset, int sizeInBytes) { - // Read the number of elements from the first 4 bytes. - final int numElements = Platform.getInt(baseObject, baseOffset); + // Read the number of elements from the first 8 bytes. + final long numElements = Platform.getLong(baseObject, baseOffset); assert numElements >= 0 : "numElements (" + numElements + ") should >= 0"; + assert numElements <= Integer.MAX_VALUE : "numElements (" + numElements + ") should <= Integer.MAX_VALUE"; - this.numElements = numElements; + this.numElements = (int)numElements; this.baseObject = baseObject; this.baseOffset = baseOffset; this.sizeInBytes = sizeInBytes; + this.elementOffset = baseOffset + calculateHeaderPortionInBytes(this.numElements); } @Override public boolean isNullAt(int ordinal) { assertIndexIsValid(ordinal); - return getElementOffset(ordinal) < 0; + return BitSetMethods.isSet(baseObject, baseOffset + 8, ordinal); } @Override @@ -165,68 +171,50 @@ public Object get(int ordinal, DataType dataType) { @Override public boolean getBoolean(int ordinal) { assertIndexIsValid(ordinal); - final int offset = getElementOffset(ordinal); - if (offset < 0) return false; - return Platform.getBoolean(baseObject, baseOffset + offset); + return Platform.getBoolean(baseObject, getElementOffset(ordinal, 1)); } @Override public byte getByte(int ordinal) { assertIndexIsValid(ordinal); - final int offset = getElementOffset(ordinal); - if (offset < 0) return 0; - return Platform.getByte(baseObject, baseOffset + offset); + return Platform.getByte(baseObject, getElementOffset(ordinal, 1)); } @Override public short getShort(int ordinal) { assertIndexIsValid(ordinal); - final int offset = getElementOffset(ordinal); - if (offset < 0) return 0; - return Platform.getShort(baseObject, baseOffset + offset); + return Platform.getShort(baseObject, getElementOffset(ordinal, 2)); } @Override public int getInt(int ordinal) { assertIndexIsValid(ordinal); - final int offset = getElementOffset(ordinal); - if (offset < 0) return 0; - return Platform.getInt(baseObject, baseOffset + offset); + return Platform.getInt(baseObject, getElementOffset(ordinal, 4)); } @Override public long getLong(int ordinal) { assertIndexIsValid(ordinal); - final int offset = getElementOffset(ordinal); - if (offset < 0) return 0; - return Platform.getLong(baseObject, baseOffset + offset); + return Platform.getLong(baseObject, getElementOffset(ordinal, 8)); } @Override public float getFloat(int ordinal) { assertIndexIsValid(ordinal); - final int offset = getElementOffset(ordinal); - if (offset < 0) return 0; - return Platform.getFloat(baseObject, baseOffset + offset); + return Platform.getFloat(baseObject, getElementOffset(ordinal, 4)); } @Override public double getDouble(int ordinal) { assertIndexIsValid(ordinal); - final int offset = getElementOffset(ordinal); - if (offset < 0) return 0; - return Platform.getDouble(baseObject, baseOffset + offset); + return Platform.getDouble(baseObject, getElementOffset(ordinal, 8)); } @Override public Decimal getDecimal(int ordinal, int precision, int scale) { - assertIndexIsValid(ordinal); - final int offset = getElementOffset(ordinal); - if (offset < 0) return null; - + if (isNullAt(ordinal)) return null; if (precision <= Decimal.MAX_LONG_DIGITS()) { - final long value = Platform.getLong(baseObject, baseOffset + offset); - return Decimal.apply(value, precision, scale); + return Decimal.apply(getLong(ordinal), precision, scale); } else { final byte[] bytes = getBinary(ordinal); final BigInteger bigInteger = new BigInteger(bytes); @@ -237,19 +225,19 @@ public Decimal getDecimal(int ordinal, int precision, int scale) { @Override public UTF8String getUTF8String(int ordinal) { - assertIndexIsValid(ordinal); - final int offset = getElementOffset(ordinal); - if (offset < 0) return null; - final int size = getElementSize(offset, ordinal); + if (isNullAt(ordinal)) return null; + final long offsetAndSize = getLong(ordinal); + final int offset = (int) (offsetAndSize >> 32); + final int size = (int) offsetAndSize; return UTF8String.fromAddress(baseObject, baseOffset + offset, size); } @Override public byte[] getBinary(int ordinal) { - assertIndexIsValid(ordinal); - final int offset = getElementOffset(ordinal); - if (offset < 0) return null; - final int size = getElementSize(offset, ordinal); + if (isNullAt(ordinal)) return null; + final long offsetAndSize = getLong(ordinal); + final int offset = (int) (offsetAndSize >> 32); + final int size = (int) offsetAndSize; final byte[] bytes = new byte[size]; Platform.copyMemory(baseObject, baseOffset + offset, bytes, Platform.BYTE_ARRAY_OFFSET, size); return bytes; @@ -257,9 +245,9 @@ public byte[] getBinary(int ordinal) { @Override public CalendarInterval getInterval(int ordinal) { - assertIndexIsValid(ordinal); - final int offset = getElementOffset(ordinal); - if (offset < 0) return null; + if (isNullAt(ordinal)) return null; + final long offsetAndSize = getLong(ordinal); + final int offset = (int) (offsetAndSize >> 32); final int months = (int) Platform.getLong(baseObject, baseOffset + offset); final long microseconds = Platform.getLong(baseObject, baseOffset + offset + 8); return new CalendarInterval(months, microseconds); @@ -267,10 +255,10 @@ public CalendarInterval getInterval(int ordinal) { @Override public UnsafeRow getStruct(int ordinal, int numFields) { - assertIndexIsValid(ordinal); - final int offset = getElementOffset(ordinal); - if (offset < 0) return null; - final int size = getElementSize(offset, ordinal); + if (isNullAt(ordinal)) return null; + final long offsetAndSize = getLong(ordinal); + final int offset = (int) (offsetAndSize >> 32); + final int size = (int) offsetAndSize; final UnsafeRow row = new UnsafeRow(numFields); row.pointTo(baseObject, baseOffset + offset, size); return row; @@ -278,10 +266,10 @@ public UnsafeRow getStruct(int ordinal, int numFields) { @Override public UnsafeArrayData getArray(int ordinal) { - assertIndexIsValid(ordinal); - final int offset = getElementOffset(ordinal); - if (offset < 0) return null; - final int size = getElementSize(offset, ordinal); + if (isNullAt(ordinal)) return null; + final long offsetAndSize = getLong(ordinal); + final int offset = (int) (offsetAndSize >> 32); + final int size = (int) offsetAndSize; final UnsafeArrayData array = new UnsafeArrayData(); array.pointTo(baseObject, baseOffset + offset, size); return array; @@ -289,10 +277,10 @@ public UnsafeArrayData getArray(int ordinal) { @Override public UnsafeMapData getMap(int ordinal) { - assertIndexIsValid(ordinal); - final int offset = getElementOffset(ordinal); - if (offset < 0) return null; - final int size = getElementSize(offset, ordinal); + if (isNullAt(ordinal)) return null; + final long offsetAndSize = getLong(ordinal); + final int offset = (int) (offsetAndSize >> 32); + final int size = (int) offsetAndSize; final UnsafeMapData map = new UnsafeMapData(); map.pointTo(baseObject, baseOffset + offset, size); return map; @@ -341,63 +329,108 @@ public UnsafeArrayData copy() { return arrayCopy; } - public static UnsafeArrayData fromPrimitiveArray(int[] arr) { - if (arr.length > (Integer.MAX_VALUE - 4) / 8) { - throw new UnsupportedOperationException("Cannot convert this array to unsafe format as " + - "it's too big."); - } + @Override + public boolean[] toBooleanArray() { + boolean[] values = new boolean[numElements]; + Platform.copyMemory( + baseObject, elementOffset, values, Platform.BOOLEAN_ARRAY_OFFSET, numElements); + return values; + } - final int offsetRegionSize = 4 * arr.length; - final int valueRegionSize = 4 * arr.length; - final int totalSize = 4 + offsetRegionSize + valueRegionSize; - final byte[] data = new byte[totalSize]; + @Override + public byte[] toByteArray() { + byte[] values = new byte[numElements]; + Platform.copyMemory( + baseObject, elementOffset, values, Platform.BYTE_ARRAY_OFFSET, numElements); + return values; + } - Platform.putInt(data, Platform.BYTE_ARRAY_OFFSET, arr.length); + @Override + public short[] toShortArray() { + short[] values = new short[numElements]; + Platform.copyMemory( + baseObject, elementOffset, values, Platform.SHORT_ARRAY_OFFSET, numElements * 2); + return values; + } - int offsetPosition = Platform.BYTE_ARRAY_OFFSET + 4; - int valueOffset = 4 + offsetRegionSize; - for (int i = 0; i < arr.length; i++) { - Platform.putInt(data, offsetPosition, valueOffset); - offsetPosition += 4; - valueOffset += 4; - } + @Override + public int[] toIntArray() { + int[] values = new int[numElements]; + Platform.copyMemory( + baseObject, elementOffset, values, Platform.INT_ARRAY_OFFSET, numElements * 4); + return values; + } - Platform.copyMemory(arr, Platform.INT_ARRAY_OFFSET, data, - Platform.BYTE_ARRAY_OFFSET + 4 + offsetRegionSize, valueRegionSize); + @Override + public long[] toLongArray() { + long[] values = new long[numElements]; + Platform.copyMemory( + baseObject, elementOffset, values, Platform.LONG_ARRAY_OFFSET, numElements * 8); + return values; + } - UnsafeArrayData result = new UnsafeArrayData(); - result.pointTo(data, Platform.BYTE_ARRAY_OFFSET, totalSize); - return result; + @Override + public float[] toFloatArray() { + float[] values = new float[numElements]; + Platform.copyMemory( + baseObject, elementOffset, values, Platform.FLOAT_ARRAY_OFFSET, numElements * 4); + return values; } - public static UnsafeArrayData fromPrimitiveArray(double[] arr) { - if (arr.length > (Integer.MAX_VALUE - 4) / 12) { + @Override + public double[] toDoubleArray() { + double[] values = new double[numElements]; + Platform.copyMemory( + baseObject, elementOffset, values, Platform.DOUBLE_ARRAY_OFFSET, numElements * 8); + return values; + } + + private static UnsafeArrayData fromPrimitiveArray( + Object arr, int offset, int length, int elementSize) { + final long headerInBytes = calculateHeaderPortionInBytes(length); + final long valueRegionInBytes = elementSize * length; + final long totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8; + if (totalSizeInLongs > Integer.MAX_VALUE / 8) { throw new UnsupportedOperationException("Cannot convert this array to unsafe format as " + "it's too big."); } - final int offsetRegionSize = 4 * arr.length; - final int valueRegionSize = 8 * arr.length; - final int totalSize = 4 + offsetRegionSize + valueRegionSize; - final byte[] data = new byte[totalSize]; + final long[] data = new long[(int)totalSizeInLongs]; - Platform.putInt(data, Platform.BYTE_ARRAY_OFFSET, arr.length); - - int offsetPosition = Platform.BYTE_ARRAY_OFFSET + 4; - int valueOffset = 4 + offsetRegionSize; - for (int i = 0; i < arr.length; i++) { - Platform.putInt(data, offsetPosition, valueOffset); - offsetPosition += 4; - valueOffset += 8; - } - - Platform.copyMemory(arr, Platform.DOUBLE_ARRAY_OFFSET, data, - Platform.BYTE_ARRAY_OFFSET + 4 + offsetRegionSize, valueRegionSize); + Platform.putLong(data, Platform.LONG_ARRAY_OFFSET, length); + Platform.copyMemory(arr, offset, data, + Platform.LONG_ARRAY_OFFSET + headerInBytes, valueRegionInBytes); UnsafeArrayData result = new UnsafeArrayData(); - result.pointTo(data, Platform.BYTE_ARRAY_OFFSET, totalSize); + result.pointTo(data, Platform.LONG_ARRAY_OFFSET, (int)totalSizeInLongs * 8); return result; } - // TODO: add more specialized methods. + public static UnsafeArrayData fromPrimitiveArray(boolean[] arr) { + return fromPrimitiveArray(arr, Platform.BOOLEAN_ARRAY_OFFSET, arr.length, 1); + } + + public static UnsafeArrayData fromPrimitiveArray(byte[] arr) { + return fromPrimitiveArray(arr, Platform.BYTE_ARRAY_OFFSET, arr.length, 1); + } + + public static UnsafeArrayData fromPrimitiveArray(short[] arr) { + return fromPrimitiveArray(arr, Platform.SHORT_ARRAY_OFFSET, arr.length, 2); + } + + public static UnsafeArrayData fromPrimitiveArray(int[] arr) { + return fromPrimitiveArray(arr, Platform.INT_ARRAY_OFFSET, arr.length, 4); + } + + public static UnsafeArrayData fromPrimitiveArray(long[] arr) { + return fromPrimitiveArray(arr, Platform.LONG_ARRAY_OFFSET, arr.length, 8); + } + + public static UnsafeArrayData fromPrimitiveArray(float[] arr) { + return fromPrimitiveArray(arr, Platform.FLOAT_ARRAY_OFFSET, arr.length, 4); + } + + public static UnsafeArrayData fromPrimitiveArray(double[] arr) { + return fromPrimitiveArray(arr, Platform.DOUBLE_ARRAY_OFFSET, arr.length, 8); + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java index 0700148becaba..35029f5a50e3e 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java @@ -25,7 +25,7 @@ /** * An Unsafe implementation of Map which is backed by raw memory instead of Java objects. * - * Currently we just use 2 UnsafeArrayData to represent UnsafeMapData, with extra 4 bytes at head + * Currently we just use 2 UnsafeArrayData to represent UnsafeMapData, with extra 8 bytes at head * to indicate the number of bytes of the unsafe key array. * [unsafe key array numBytes] [unsafe key array] [unsafe value array] */ @@ -65,14 +65,15 @@ public UnsafeMapData() { * @param sizeInBytes the size of this map's backing data, in bytes */ public void pointTo(Object baseObject, long baseOffset, int sizeInBytes) { - // Read the numBytes of key array from the first 4 bytes. - final int keyArraySize = Platform.getInt(baseObject, baseOffset); - final int valueArraySize = sizeInBytes - keyArraySize - 4; + // Read the numBytes of key array from the first 8 bytes. + final long keyArraySize = Platform.getLong(baseObject, baseOffset); assert keyArraySize >= 0 : "keyArraySize (" + keyArraySize + ") should >= 0"; + assert keyArraySize <= Integer.MAX_VALUE : "keyArraySize (" + keyArraySize + ") should <= Integer.MAX_VALUE"; + final int valueArraySize = sizeInBytes - (int)keyArraySize - 8; assert valueArraySize >= 0 : "valueArraySize (" + valueArraySize + ") should >= 0"; - keys.pointTo(baseObject, baseOffset + 4, keyArraySize); - values.pointTo(baseObject, baseOffset + 4 + keyArraySize, valueArraySize); + keys.pointTo(baseObject, baseOffset + 8, (int)keyArraySize); + values.pointTo(baseObject, baseOffset + 8 + keyArraySize, valueArraySize); assert keys.numElements() == values.numElements(); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java index 7dd932d1981b7..afea4676893ed 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java @@ -19,9 +19,13 @@ import org.apache.spark.sql.types.Decimal; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.bitset.BitSetMethods; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; +import static org.apache.spark.sql.catalyst.expressions.UnsafeArrayData.calculateHeaderPortionInBytes; + /** * A helper class to write data into global row buffer using `UnsafeArrayData` format, * used by {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection}. @@ -33,134 +37,213 @@ public class UnsafeArrayWriter { // The offset of the global buffer where we start to write this array. private int startingOffset; - public void initialize(BufferHolder holder, int numElements, int fixedElementSize) { - // We need 4 bytes to store numElements and 4 bytes each element to store offset. - final int fixedSize = 4 + 4 * numElements; + // The number of elements in this array + private int numElements; + + private int headerInBytes; + + private void assertIndexIsValid(int index) { + assert index >= 0 : "index (" + index + ") should >= 0"; + assert index < numElements : "index (" + index + ") should < " + numElements; + } + + public void initialize(BufferHolder holder, int numElements, int elementSize) { + // We need 8 bytes to store numElements in header + this.numElements = numElements; + this.headerInBytes = calculateHeaderPortionInBytes(numElements); this.holder = holder; this.startingOffset = holder.cursor; - holder.grow(fixedSize); - Platform.putInt(holder.buffer, holder.cursor, numElements); - holder.cursor += fixedSize; + // Grows the global buffer ahead for header and fixed size data. + int fixedPartInBytes = + ByteArrayMethods.roundNumberOfBytesToNearestWord(elementSize * numElements); + holder.grow(headerInBytes + fixedPartInBytes); + + // Write numElements and clear out null bits to header + Platform.putLong(holder.buffer, startingOffset, numElements); + for (int i = 8; i < headerInBytes; i += 8) { + Platform.putLong(holder.buffer, startingOffset + i, 0L); + } + + // fill 0 into reminder part of 8-bytes alignment in unsafe array + for (int i = elementSize * numElements; i < fixedPartInBytes; i++) { + Platform.putByte(holder.buffer, startingOffset + headerInBytes + i, (byte) 0); + } + holder.cursor += (headerInBytes + fixedPartInBytes); + } + + private void zeroOutPaddingBytes(int numBytes) { + if ((numBytes & 0x07) > 0) { + Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L); + } + } + + private long getElementOffset(int ordinal, int elementSize) { + return startingOffset + headerInBytes + ordinal * elementSize; + } + + public void setOffsetAndSize(int ordinal, long currentCursor, int size) { + assertIndexIsValid(ordinal); + final long relativeOffset = currentCursor - startingOffset; + final long offsetAndSize = (relativeOffset << 32) | (long)size; - // Grows the global buffer ahead for fixed size data. - holder.grow(fixedElementSize * numElements); + write(ordinal, offsetAndSize); } - private long getElementOffset(int ordinal) { - return startingOffset + 4 + 4 * ordinal; + private void setNullBit(int ordinal) { + assertIndexIsValid(ordinal); + BitSetMethods.set(holder.buffer, startingOffset + 8, ordinal); } - public void setNullAt(int ordinal) { - final int relativeOffset = holder.cursor - startingOffset; - // Writes negative offset value to represent null element. - Platform.putInt(holder.buffer, getElementOffset(ordinal), -relativeOffset); + public void setNullBoolean(int ordinal) { + setNullBit(ordinal); + // put zero into the corresponding field when set null + Platform.putBoolean(holder.buffer, getElementOffset(ordinal, 1), false); } - public void setOffset(int ordinal) { - final int relativeOffset = holder.cursor - startingOffset; - Platform.putInt(holder.buffer, getElementOffset(ordinal), relativeOffset); + public void setNullByte(int ordinal) { + setNullBit(ordinal); + // put zero into the corresponding field when set null + Platform.putByte(holder.buffer, getElementOffset(ordinal, 1), (byte)0); } + public void setNullShort(int ordinal) { + setNullBit(ordinal); + // put zero into the corresponding field when set null + Platform.putShort(holder.buffer, getElementOffset(ordinal, 2), (short)0); + } + + public void setNullInt(int ordinal) { + setNullBit(ordinal); + // put zero into the corresponding field when set null + Platform.putInt(holder.buffer, getElementOffset(ordinal, 4), (int)0); + } + + public void setNullLong(int ordinal) { + setNullBit(ordinal); + // put zero into the corresponding field when set null + Platform.putLong(holder.buffer, getElementOffset(ordinal, 8), (long)0); + } + + public void setNullFloat(int ordinal) { + setNullBit(ordinal); + // put zero into the corresponding field when set null + Platform.putFloat(holder.buffer, getElementOffset(ordinal, 4), (float)0); + } + + public void setNullDouble(int ordinal) { + setNullBit(ordinal); + // put zero into the corresponding field when set null + Platform.putDouble(holder.buffer, getElementOffset(ordinal, 8), (double)0); + } + + public void setNull(int ordinal) { setNullLong(ordinal); } + public void write(int ordinal, boolean value) { - Platform.putBoolean(holder.buffer, holder.cursor, value); - setOffset(ordinal); - holder.cursor += 1; + assertIndexIsValid(ordinal); + Platform.putBoolean(holder.buffer, getElementOffset(ordinal, 1), value); } public void write(int ordinal, byte value) { - Platform.putByte(holder.buffer, holder.cursor, value); - setOffset(ordinal); - holder.cursor += 1; + assertIndexIsValid(ordinal); + Platform.putByte(holder.buffer, getElementOffset(ordinal, 1), value); } public void write(int ordinal, short value) { - Platform.putShort(holder.buffer, holder.cursor, value); - setOffset(ordinal); - holder.cursor += 2; + assertIndexIsValid(ordinal); + Platform.putShort(holder.buffer, getElementOffset(ordinal, 2), value); } public void write(int ordinal, int value) { - Platform.putInt(holder.buffer, holder.cursor, value); - setOffset(ordinal); - holder.cursor += 4; + assertIndexIsValid(ordinal); + Platform.putInt(holder.buffer, getElementOffset(ordinal, 4), value); } public void write(int ordinal, long value) { - Platform.putLong(holder.buffer, holder.cursor, value); - setOffset(ordinal); - holder.cursor += 8; + assertIndexIsValid(ordinal); + Platform.putLong(holder.buffer, getElementOffset(ordinal, 8), value); } public void write(int ordinal, float value) { if (Float.isNaN(value)) { value = Float.NaN; } - Platform.putFloat(holder.buffer, holder.cursor, value); - setOffset(ordinal); - holder.cursor += 4; + assertIndexIsValid(ordinal); + Platform.putFloat(holder.buffer, getElementOffset(ordinal, 4), value); } public void write(int ordinal, double value) { if (Double.isNaN(value)) { value = Double.NaN; } - Platform.putDouble(holder.buffer, holder.cursor, value); - setOffset(ordinal); - holder.cursor += 8; + assertIndexIsValid(ordinal); + Platform.putDouble(holder.buffer, getElementOffset(ordinal, 8), value); } public void write(int ordinal, Decimal input, int precision, int scale) { // make sure Decimal object has the same scale as DecimalType + assertIndexIsValid(ordinal); if (input.changePrecision(precision, scale)) { if (precision <= Decimal.MAX_LONG_DIGITS()) { - Platform.putLong(holder.buffer, holder.cursor, input.toUnscaledLong()); - setOffset(ordinal); - holder.cursor += 8; + write(ordinal, input.toUnscaledLong()); } else { final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); - assert bytes.length <= 16; - holder.grow(bytes.length); + final int numBytes = bytes.length; + assert numBytes <= 16; + int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); + holder.grow(roundedSize); + + zeroOutPaddingBytes(numBytes); // Write the bytes to the variable length portion. Platform.copyMemory( - bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, bytes.length); - setOffset(ordinal); - holder.cursor += bytes.length; + bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, numBytes); + setOffsetAndSize(ordinal, holder.cursor, numBytes); + + // move the cursor forward with 8-bytes boundary + holder.cursor += roundedSize; } } else { - setNullAt(ordinal); + setNull(ordinal); } } public void write(int ordinal, UTF8String input) { final int numBytes = input.numBytes(); + final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); // grow the global buffer before writing data. - holder.grow(numBytes); + holder.grow(roundedSize); + + zeroOutPaddingBytes(numBytes); // Write the bytes to the variable length portion. input.writeToMemory(holder.buffer, holder.cursor); - setOffset(ordinal); + setOffsetAndSize(ordinal, holder.cursor, numBytes); // move the cursor forward. - holder.cursor += numBytes; + holder.cursor += roundedSize; } public void write(int ordinal, byte[] input) { + final int numBytes = input.length; + final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(input.length); + // grow the global buffer before writing data. - holder.grow(input.length); + holder.grow(roundedSize); + + zeroOutPaddingBytes(numBytes); // Write the bytes to the variable length portion. Platform.copyMemory( - input, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, input.length); + input, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, numBytes); - setOffset(ordinal); + setOffsetAndSize(ordinal, holder.cursor, numBytes); // move the cursor forward. - holder.cursor += input.length; + holder.cursor += roundedSize; } public void write(int ordinal, CalendarInterval input) { @@ -171,7 +254,7 @@ public void write(int ordinal, CalendarInterval input) { Platform.putLong(holder.buffer, holder.cursor, input.months); Platform.putLong(holder.buffer, holder.cursor + 8, input.microseconds); - setOffset(ordinal); + setOffsetAndSize(ordinal, holder.cursor, 16); // move the cursor forward. holder.cursor += 16; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 5efba4b3a6087..75bb6936b49e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -124,7 +124,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro final int $tmpCursor = $bufferHolder.cursor; ${writeArrayToBuffer(ctx, input.value, et, bufferHolder)} $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); - $rowWriter.alignToWords($bufferHolder.cursor - $tmpCursor); """ case m @ MapType(kt, vt, _) => @@ -134,7 +133,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro final int $tmpCursor = $bufferHolder.cursor; ${writeMapToBuffer(ctx, input.value, kt, vt, bufferHolder)} $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); - $rowWriter.alignToWords($bufferHolder.cursor - $tmpCursor); """ case t: DecimalType => @@ -189,29 +187,33 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val jt = ctx.javaType(et) - val fixedElementSize = et match { + val elementOrOffsetSize = et match { case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => 8 case _ if ctx.isPrimitiveType(jt) => et.defaultSize - case _ => 0 + case _ => 8 // we need 8 bytes to store offset and length } + val tmpCursor = ctx.freshName("tmpCursor") val writeElement = et match { case t: StructType => s""" - $arrayWriter.setOffset($index); + final int $tmpCursor = $bufferHolder.cursor; ${writeStructToBuffer(ctx, element, t.map(_.dataType), bufferHolder)} + $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); """ case a @ ArrayType(et, _) => s""" - $arrayWriter.setOffset($index); + final int $tmpCursor = $bufferHolder.cursor; ${writeArrayToBuffer(ctx, element, et, bufferHolder)} + $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); """ case m @ MapType(kt, vt, _) => s""" - $arrayWriter.setOffset($index); + final int $tmpCursor = $bufferHolder.cursor; ${writeMapToBuffer(ctx, element, kt, vt, bufferHolder)} + $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); """ case t: DecimalType => @@ -222,16 +224,17 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => s"$arrayWriter.write($index, $element);" } + val primitiveTypeName = if (ctx.isPrimitiveType(jt)) ctx.primitiveTypeName(et) else "" s""" if ($input instanceof UnsafeArrayData) { ${writeUnsafeData(ctx, s"((UnsafeArrayData) $input)", bufferHolder)} } else { final int $numElements = $input.numElements(); - $arrayWriter.initialize($bufferHolder, $numElements, $fixedElementSize); + $arrayWriter.initialize($bufferHolder, $numElements, $elementOrOffsetSize); for (int $index = 0; $index < $numElements; $index++) { if ($input.isNullAt($index)) { - $arrayWriter.setNullAt($index); + $arrayWriter.setNull$primitiveTypeName($index); } else { final $jt $element = ${ctx.getValue(input, et, index)}; $writeElement @@ -261,16 +264,16 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro final ArrayData $keys = $input.keyArray(); final ArrayData $values = $input.valueArray(); - // preserve 4 bytes to write the key array numBytes later. - $bufferHolder.grow(4); - $bufferHolder.cursor += 4; + // preserve 8 bytes to write the key array numBytes later. + $bufferHolder.grow(8); + $bufferHolder.cursor += 8; // Remember the current cursor so that we can write numBytes of key array later. final int $tmpCursor = $bufferHolder.cursor; ${writeArrayToBuffer(ctx, keys, keyType, bufferHolder)} - // Write the numBytes of key array into the first 4 bytes. - Platform.putInt($bufferHolder.buffer, $tmpCursor - 4, $bufferHolder.cursor - $tmpCursor); + // Write the numBytes of key array into the first 8 bytes. + Platform.putLong($bufferHolder.buffer, $tmpCursor - 8, $bufferHolder.cursor - $tmpCursor); ${writeArrayToBuffer(ctx, values, valueType, bufferHolder)} } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 1265908182b3a..90790dda753f8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -300,7 +300,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { private def testArrayInt(array: UnsafeArrayData, values: Seq[Int]): Unit = { assert(array.numElements == values.length) - assert(array.getSizeInBytes == 4 + (4 + 4) * values.length) + assert(array.getSizeInBytes == + 8 + scala.math.ceil(values.length / 64.toDouble) * 8 + roundedSize(4 * values.length)) values.zipWithIndex.foreach { case (value, index) => assert(array.getInt(index) == value) } @@ -313,7 +314,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { testArrayInt(map.keyArray, keys) testArrayInt(map.valueArray, values) - assert(map.getSizeInBytes == 4 + map.keyArray.getSizeInBytes + map.valueArray.getSizeInBytes) + assert(map.getSizeInBytes == 8 + map.keyArray.getSizeInBytes + map.valueArray.getSizeInBytes) } test("basic conversion with array type") { @@ -339,7 +340,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val nestedArray = unsafeArray2.getArray(0) testArrayInt(nestedArray, Seq(3, 4)) - assert(unsafeArray2.getSizeInBytes == 4 + 4 + nestedArray.getSizeInBytes) + assert(unsafeArray2.getSizeInBytes == 8 + 8 + 8 + nestedArray.getSizeInBytes) val array1Size = roundedSize(unsafeArray1.getSizeInBytes) val array2Size = roundedSize(unsafeArray2.getSizeInBytes) @@ -382,10 +383,10 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val nestedMap = valueArray.getMap(0) testMapInt(nestedMap, Seq(5, 6), Seq(7, 8)) - assert(valueArray.getSizeInBytes == 4 + 4 + nestedMap.getSizeInBytes) + assert(valueArray.getSizeInBytes == 8 + 8 + 8 + roundedSize(nestedMap.getSizeInBytes)) } - assert(unsafeMap2.getSizeInBytes == 4 + keyArray.getSizeInBytes + valueArray.getSizeInBytes) + assert(unsafeMap2.getSizeInBytes == 8 + keyArray.getSizeInBytes + valueArray.getSizeInBytes) val map1Size = roundedSize(unsafeMap1.getSizeInBytes) val map2Size = roundedSize(unsafeMap2.getSizeInBytes) @@ -425,7 +426,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(innerStruct.getLong(0) == 2L) } - assert(field2.getSizeInBytes == 4 + 4 + innerStruct.getSizeInBytes) + assert(field2.getSizeInBytes == 8 + 8 + 8 + innerStruct.getSizeInBytes) assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes)) @@ -468,10 +469,10 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(innerStruct.getSizeInBytes == 8 + 8) assert(innerStruct.getLong(0) == 4L) - assert(valueArray.getSizeInBytes == 4 + 4 + innerStruct.getSizeInBytes) + assert(valueArray.getSizeInBytes == 8 + 8 + 8 + innerStruct.getSizeInBytes) } - assert(field2.getSizeInBytes == 4 + keyArray.getSizeInBytes + valueArray.getSizeInBytes) + assert(field2.getSizeInBytes == 8 + keyArray.getSizeInBytes + valueArray.getSizeInBytes) assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes)) @@ -497,7 +498,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val innerMap = field1.getMap(0) testMapInt(innerMap, Seq(1), Seq(2)) - assert(field1.getSizeInBytes == 4 + 4 + innerMap.getSizeInBytes) + assert(field1.getSizeInBytes == 8 + 8 + 8 + roundedSize(innerMap.getSizeInBytes)) val field2 = unsafeRow.getMap(1) assert(field2.numElements == 1) @@ -513,10 +514,10 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val innerArray = valueArray.getArray(0) testArrayInt(innerArray, Seq(4)) - assert(valueArray.getSizeInBytes == 4 + (4 + innerArray.getSizeInBytes)) + assert(valueArray.getSizeInBytes == 8 + 8 + 8 + innerArray.getSizeInBytes) } - assert(field2.getSizeInBytes == 4 + keyArray.getSizeInBytes + valueArray.getSizeInBytes) + assert(field2.getSizeInBytes == 8 + keyArray.getSizeInBytes + valueArray.getSizeInBytes) assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + roundedSize(field1.getSizeInBytes) + roundedSize(field2.getSizeInBytes)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala index 1685276ff1201..f0e247bf46c44 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala @@ -18,27 +18,190 @@ package org.apache.spark.sql.catalyst.util import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData +import org.apache.spark.sql.Row +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} class UnsafeArraySuite extends SparkFunSuite { - test("from primitive int array") { - val array = Array(1, 10, 100) - val unsafe = UnsafeArrayData.fromPrimitiveArray(array) - assert(unsafe.numElements == 3) - assert(unsafe.getSizeInBytes == 4 + 4 * 3 + 4 * 3) - assert(unsafe.getInt(0) == 1) - assert(unsafe.getInt(1) == 10) - assert(unsafe.getInt(2) == 100) + val booleanArray = Array(false, true) + val shortArray = Array(1.toShort, 10.toShort, 100.toShort) + val intArray = Array(1, 10, 100) + val longArray = Array(1.toLong, 10.toLong, 100.toLong) + val floatArray = Array(1.1.toFloat, 2.2.toFloat, 3.3.toFloat) + val doubleArray = Array(1.1, 2.2, 3.3) + val stringArray = Array("1", "10", "100") + val dateArray = Array( + DateTimeUtils.stringToDate(UTF8String.fromString("1970-1-1")).get, + DateTimeUtils.stringToDate(UTF8String.fromString("2016-7-26")).get) + val timestampArray = Array( + DateTimeUtils.stringToTimestamp(UTF8String.fromString("1970-1-1 00:00:00")).get, + DateTimeUtils.stringToTimestamp(UTF8String.fromString("2016-7-26 00:00:00")).get) + val decimalArray4_1 = Array( + BigDecimal("123.4").setScale(1, BigDecimal.RoundingMode.FLOOR), + BigDecimal("567.8").setScale(1, BigDecimal.RoundingMode.FLOOR)) + val decimalArray20_20 = Array( + BigDecimal("1.2345678901234567890123456").setScale(21, BigDecimal.RoundingMode.FLOOR), + BigDecimal("2.3456789012345678901234567").setScale(21, BigDecimal.RoundingMode.FLOOR)) + + val calenderintervalArray = Array(new CalendarInterval(3, 321), new CalendarInterval(1, 123)) + + val intMultiDimArray = Array(Array(1), Array(2, 20), Array(3, 30, 300)) + val doubleMultiDimArray = Array( + Array(1.1, 11.1), Array(2.2, 22.2, 222.2), Array(3.3, 33.3, 333.3, 3333.3)) + + test("read array") { + val unsafeBoolean = ExpressionEncoder[Array[Boolean]].resolveAndBind(). + toRow(booleanArray).getArray(0) + assert(unsafeBoolean.isInstanceOf[UnsafeArrayData]) + assert(unsafeBoolean.numElements == booleanArray.length) + booleanArray.zipWithIndex.map { case (e, i) => + assert(unsafeBoolean.getBoolean(i) == e) + } + + val unsafeShort = ExpressionEncoder[Array[Short]].resolveAndBind(). + toRow(shortArray).getArray(0) + assert(unsafeShort.isInstanceOf[UnsafeArrayData]) + assert(unsafeShort.numElements == shortArray.length) + shortArray.zipWithIndex.map { case (e, i) => + assert(unsafeShort.getShort(i) == e) + } + + val unsafeInt = ExpressionEncoder[Array[Int]].resolveAndBind(). + toRow(intArray).getArray(0) + assert(unsafeInt.isInstanceOf[UnsafeArrayData]) + assert(unsafeInt.numElements == intArray.length) + intArray.zipWithIndex.map { case (e, i) => + assert(unsafeInt.getInt(i) == e) + } + + val unsafeLong = ExpressionEncoder[Array[Long]].resolveAndBind(). + toRow(longArray).getArray(0) + assert(unsafeLong.isInstanceOf[UnsafeArrayData]) + assert(unsafeLong.numElements == longArray.length) + longArray.zipWithIndex.map { case (e, i) => + assert(unsafeLong.getLong(i) == e) + } + + val unsafeFloat = ExpressionEncoder[Array[Float]].resolveAndBind(). + toRow(floatArray).getArray(0) + assert(unsafeFloat.isInstanceOf[UnsafeArrayData]) + assert(unsafeFloat.numElements == floatArray.length) + floatArray.zipWithIndex.map { case (e, i) => + assert(unsafeFloat.getFloat(i) == e) + } + + val unsafeDouble = ExpressionEncoder[Array[Double]].resolveAndBind(). + toRow(doubleArray).getArray(0) + assert(unsafeDouble.isInstanceOf[UnsafeArrayData]) + assert(unsafeDouble.numElements == doubleArray.length) + doubleArray.zipWithIndex.map { case (e, i) => + assert(unsafeDouble.getDouble(i) == e) + } + + val unsafeString = ExpressionEncoder[Array[String]].resolveAndBind(). + toRow(stringArray).getArray(0) + assert(unsafeString.isInstanceOf[UnsafeArrayData]) + assert(unsafeString.numElements == stringArray.length) + stringArray.zipWithIndex.map { case (e, i) => + assert(unsafeString.getUTF8String(i).toString().equals(e)) + } + + val unsafeDate = ExpressionEncoder[Array[Int]].resolveAndBind(). + toRow(dateArray).getArray(0) + assert(unsafeDate.isInstanceOf[UnsafeArrayData]) + assert(unsafeDate.numElements == dateArray.length) + dateArray.zipWithIndex.map { case (e, i) => + assert(unsafeDate.get(i, DateType) == e) + } + + val unsafeTimestamp = ExpressionEncoder[Array[Long]].resolveAndBind(). + toRow(timestampArray).getArray(0) + assert(unsafeTimestamp.isInstanceOf[UnsafeArrayData]) + assert(unsafeTimestamp.numElements == timestampArray.length) + timestampArray.zipWithIndex.map { case (e, i) => + assert(unsafeTimestamp.get(i, TimestampType) == e) + } + + Seq(decimalArray4_1, decimalArray20_20).map { decimalArray => + val decimal = decimalArray(0) + val schema = new StructType().add( + "array", ArrayType(DecimalType(decimal.precision, decimal.scale))) + val encoder = RowEncoder(schema).resolveAndBind() + val externalRow = Row(decimalArray) + val ir = encoder.toRow(externalRow) + + val unsafeDecimal = ir.getArray(0) + assert(unsafeDecimal.isInstanceOf[UnsafeArrayData]) + assert(unsafeDecimal.numElements == decimalArray.length) + decimalArray.zipWithIndex.map { case (e, i) => + assert(unsafeDecimal.getDecimal(i, e.precision, e.scale).toBigDecimal == e) + } + } + + val schema = new StructType().add("array", ArrayType(CalendarIntervalType)) + val encoder = RowEncoder(schema).resolveAndBind() + val externalRow = Row(calenderintervalArray) + val ir = encoder.toRow(externalRow) + val unsafeCalendar = ir.getArray(0) + assert(unsafeCalendar.isInstanceOf[UnsafeArrayData]) + assert(unsafeCalendar.numElements == calenderintervalArray.length) + calenderintervalArray.zipWithIndex.map { case (e, i) => + assert(unsafeCalendar.getInterval(i) == e) + } + + val unsafeMultiDimInt = ExpressionEncoder[Array[Array[Int]]].resolveAndBind(). + toRow(intMultiDimArray).getArray(0) + assert(unsafeMultiDimInt.isInstanceOf[UnsafeArrayData]) + assert(unsafeMultiDimInt.numElements == intMultiDimArray.length) + intMultiDimArray.zipWithIndex.map { case (a, j) => + val u = unsafeMultiDimInt.getArray(j) + assert(u.isInstanceOf[UnsafeArrayData]) + assert(u.numElements == a.length) + a.zipWithIndex.map { case (e, i) => + assert(u.getInt(i) == e) + } + } + + val unsafeMultiDimDouble = ExpressionEncoder[Array[Array[Double]]].resolveAndBind(). + toRow(doubleMultiDimArray).getArray(0) + assert(unsafeDouble.isInstanceOf[UnsafeArrayData]) + assert(unsafeMultiDimDouble.numElements == doubleMultiDimArray.length) + doubleMultiDimArray.zipWithIndex.map { case (a, j) => + val u = unsafeMultiDimDouble.getArray(j) + assert(u.isInstanceOf[UnsafeArrayData]) + assert(u.numElements == a.length) + a.zipWithIndex.map { case (e, i) => + assert(u.getDouble(i) == e) + } + } } - test("from primitive double array") { - val array = Array(1.1, 2.2, 3.3) - val unsafe = UnsafeArrayData.fromPrimitiveArray(array) - assert(unsafe.numElements == 3) - assert(unsafe.getSizeInBytes == 4 + 4 * 3 + 8 * 3) - assert(unsafe.getDouble(0) == 1.1) - assert(unsafe.getDouble(1) == 2.2) - assert(unsafe.getDouble(2) == 3.3) + test("from primitive array") { + val unsafeInt = UnsafeArrayData.fromPrimitiveArray(intArray) + assert(unsafeInt.numElements == 3) + assert(unsafeInt.getSizeInBytes == + ((8 + scala.math.ceil(3/64.toDouble) * 8 + 4 * 3 + 7).toInt / 8) * 8) + intArray.zipWithIndex.map { case (e, i) => + assert(unsafeInt.getInt(i) == e) + } + + val unsafeDouble = UnsafeArrayData.fromPrimitiveArray(doubleArray) + assert(unsafeDouble.numElements == 3) + assert(unsafeDouble.getSizeInBytes == + ((8 + scala.math.ceil(3/64.toDouble) * 8 + 8 * 3 + 7).toInt / 8) * 8) + doubleArray.zipWithIndex.map { case (e, i) => + assert(unsafeDouble.getDouble(i) == e) + } + } + + test("to primitive array") { + val intEncoder = ExpressionEncoder[Array[Int]].resolveAndBind() + assert(intEncoder.toRow(intArray).getArray(0).toIntArray.sameElements(intArray)) + + val doubleEncoder = ExpressionEncoder[Array[Double]].resolveAndBind() + assert(doubleEncoder.toRow(doubleArray).getArray(0).toDoubleArray.sameElements(doubleArray)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala index f9d606e37ea89..fa9619eb07fec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala @@ -601,7 +601,7 @@ private[columnar] case class ARRAY(dataType: ArrayType) override def actualSize(row: InternalRow, ordinal: Int): Int = { val unsafeArray = getField(row, ordinal) - 4 + unsafeArray.getSizeInBytes + 8 + unsafeArray.getSizeInBytes } override def append(value: UnsafeArrayData, buffer: ByteBuffer): Unit = { @@ -640,7 +640,7 @@ private[columnar] case class MAP(dataType: MapType) override def actualSize(row: InternalRow, ordinal: Int): Int = { val unsafeMap = getField(row, ordinal) - 4 + unsafeMap.getSizeInBytes + 8 + unsafeMap.getSizeInBytes } override def append(value: UnsafeMapData, buffer: ByteBuffer): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnsafeArrayDataBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnsafeArrayDataBenchmark.scala new file mode 100644 index 0000000000000..6c7779b5790d0 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnsafeArrayDataBenchmark.scala @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import scala.util.Random + +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.{UnsafeArrayData, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeArrayWriter} +import org.apache.spark.util.Benchmark + +/** + * Benchmark [[UnsafeArrayDataBenchmark]] for UnsafeArrayData + * To run this: + * 1. replace ignore(...) with test(...) + * 2. build/sbt "sql/test-only *benchmark.UnsafeArrayDataBenchmark" + * + * Benchmarks in this file are skipped in normal builds. + */ +class UnsafeArrayDataBenchmark extends BenchmarkBase { + + def calculateHeaderPortionInBytes(count: Int) : Int = { + /* 4 + 4 * count // Use this expression for SPARK-15962 */ + UnsafeArrayData.calculateHeaderPortionInBytes(count) + } + + def readUnsafeArray(iters: Int): Unit = { + val count = 1024 * 1024 * 16 + val rand = new Random(42) + + val intPrimitiveArray = Array.fill[Int](count) { rand.nextInt } + val intEncoder = ExpressionEncoder[Array[Int]].resolveAndBind() + val intUnsafeArray = intEncoder.toRow(intPrimitiveArray).getArray(0) + val readIntArray = { i: Int => + var n = 0 + while (n < iters) { + val len = intUnsafeArray.numElements + var sum = 0 + var i = 0 + while (i < len) { + sum += intUnsafeArray.getInt(i) + i += 1 + } + n += 1 + } + } + + val doublePrimitiveArray = Array.fill[Double](count) { rand.nextDouble } + val doubleEncoder = ExpressionEncoder[Array[Double]].resolveAndBind() + val doubleUnsafeArray = doubleEncoder.toRow(doublePrimitiveArray).getArray(0) + val readDoubleArray = { i: Int => + var n = 0 + while (n < iters) { + val len = doubleUnsafeArray.numElements + var sum = 0.0 + var i = 0 + while (i < len) { + sum += doubleUnsafeArray.getDouble(i) + i += 1 + } + n += 1 + } + } + + val benchmark = new Benchmark("Read UnsafeArrayData", count * iters) + benchmark.addCase("Int")(readIntArray) + benchmark.addCase("Double")(readDoubleArray) + benchmark.run + /* + OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64 + Intel Xeon E3-12xx v2 (Ivy Bridge) + Read UnsafeArrayData: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Int 252 / 260 666.1 1.5 1.0X + Double 281 / 292 597.7 1.7 0.9X + */ + } + + def writeUnsafeArray(iters: Int): Unit = { + val count = 1024 * 1024 * 2 + val rand = new Random(42) + + var intTotalLength: Int = 0 + val intPrimitiveArray = Array.fill[Int](count) { rand.nextInt } + val intEncoder = ExpressionEncoder[Array[Int]].resolveAndBind() + val writeIntArray = { i: Int => + var len = 0 + var n = 0 + while (n < iters) { + len += intEncoder.toRow(intPrimitiveArray).getArray(0).numElements() + n += 1 + } + intTotalLength = len + } + + var doubleTotalLength: Int = 0 + val doublePrimitiveArray = Array.fill[Double](count) { rand.nextDouble } + val doubleEncoder = ExpressionEncoder[Array[Double]].resolveAndBind() + val writeDoubleArray = { i: Int => + var len = 0 + var n = 0 + while (n < iters) { + len += doubleEncoder.toRow(doublePrimitiveArray).getArray(0).numElements() + n += 1 + } + doubleTotalLength = len + } + + val benchmark = new Benchmark("Write UnsafeArrayData", count * iters) + benchmark.addCase("Int")(writeIntArray) + benchmark.addCase("Double")(writeDoubleArray) + benchmark.run + /* + OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64 + Intel Xeon E3-12xx v2 (Ivy Bridge) + Write UnsafeArrayData: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Int 196 / 249 107.0 9.3 1.0X + Double 227 / 367 92.3 10.8 0.9X + */ + } + + def getPrimitiveArray(iters: Int): Unit = { + val count = 1024 * 1024 * 12 + val rand = new Random(42) + + var intTotalLength: Int = 0 + val intPrimitiveArray = Array.fill[Int](count) { rand.nextInt } + val intEncoder = ExpressionEncoder[Array[Int]].resolveAndBind() + val intUnsafeArray = intEncoder.toRow(intPrimitiveArray).getArray(0) + val readIntArray = { i: Int => + var len = 0 + var n = 0 + while (n < iters) { + len += intUnsafeArray.toIntArray.length + n += 1 + } + intTotalLength = len + } + + var doubleTotalLength: Int = 0 + val doublePrimitiveArray = Array.fill[Double](count) { rand.nextDouble } + val doubleEncoder = ExpressionEncoder[Array[Double]].resolveAndBind() + val doubleUnsafeArray = doubleEncoder.toRow(doublePrimitiveArray).getArray(0) + val readDoubleArray = { i: Int => + var len = 0 + var n = 0 + while (n < iters) { + len += doubleUnsafeArray.toDoubleArray.length + n += 1 + } + doubleTotalLength = len + } + + val benchmark = new Benchmark("Get primitive array from UnsafeArrayData", count * iters) + benchmark.addCase("Int")(readIntArray) + benchmark.addCase("Double")(readDoubleArray) + benchmark.run + /* + OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64 + Intel Xeon E3-12xx v2 (Ivy Bridge) + Get primitive array from UnsafeArrayData: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Int 151 / 198 415.8 2.4 1.0X + Double 214 / 394 293.6 3.4 0.7X + */ + } + + def putPrimitiveArray(iters: Int): Unit = { + val count = 1024 * 1024 * 12 + val rand = new Random(42) + + var intTotalLen: Int = 0 + val intPrimitiveArray = Array.fill[Int](count) { rand.nextInt } + val createIntArray = { i: Int => + var len = 0 + var n = 0 + while (n < iters) { + len += UnsafeArrayData.fromPrimitiveArray(intPrimitiveArray).numElements + n += 1 + } + intTotalLen = len + } + + var doubleTotalLen: Int = 0 + val doublePrimitiveArray = Array.fill[Double](count) { rand.nextDouble } + val createDoubleArray = { i: Int => + var len = 0 + var n = 0 + while (n < iters) { + len += UnsafeArrayData.fromPrimitiveArray(doublePrimitiveArray).numElements + n += 1 + } + doubleTotalLen = len + } + + val benchmark = new Benchmark("Create UnsafeArrayData from primitive array", count * iters) + benchmark.addCase("Int")(createIntArray) + benchmark.addCase("Double")(createDoubleArray) + benchmark.run + /* + OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64 + Intel Xeon E3-12xx v2 (Ivy Bridge) + Create UnsafeArrayData from primitive array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Int 206 / 211 306.0 3.3 1.0X + Double 232 / 406 271.6 3.7 0.9X + */ + } + + ignore("Benchmark UnsafeArrayData") { + readUnsafeArray(10) + writeUnsafeArray(10) + getPrimitiveArray(5) + putPrimitiveArray(5) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala index 052f4cbaebc8e..0b93c633b2d93 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala @@ -73,8 +73,8 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { checkActualSize(BINARY, Array.fill[Byte](4)(0.toByte), 4 + 4) checkActualSize(COMPACT_DECIMAL(15, 10), Decimal(0, 15, 10), 8) checkActualSize(LARGE_DECIMAL(20, 10), Decimal(0, 20, 10), 5) - checkActualSize(ARRAY_TYPE, Array[Any](1), 16) - checkActualSize(MAP_TYPE, Map(1 -> "a"), 29) + checkActualSize(ARRAY_TYPE, Array[Any](1), 8 + 8 + 8 + 8) + checkActualSize(MAP_TYPE, Map(1 -> "a"), 8 + (8 + 8 + 8 + 8) + (8 + 8 + 8 + 8)) checkActualSize(STRUCT_TYPE, Row("hello"), 28) } From 7f16affa262b059580ed2775a7b05a767aa72315 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Tue, 27 Sep 2016 00:00:21 -0700 Subject: [PATCH 19/96] [SPARK-17138][ML][MLIB] Add Python API for multinomial logistic regression ## What changes were proposed in this pull request? Add Python API for multinomial logistic regression. - add `family` param in python api. - expose `coefficientMatrix` and `interceptVector` for `LogisticRegressionModel` - add python-side testcase for multinomial logistic regression - update python doc. ## How was this patch tested? existing and added doc tests. Author: WeichenXu Closes #14852 from WeichenXu123/add_MLOR_python. --- python/pyspark/ml/classification.py | 90 ++++++++++++++++++++++------- 1 file changed, 70 insertions(+), 20 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index b4c01fd5c4ffb..505e7bffd1763 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -67,21 +67,34 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti HasWeightCol, HasAggregationDepth, JavaMLWritable, JavaMLReadable): """ Logistic regression. - Currently, this class only supports binary classification. + This class supports multinomial logistic (softmax) and binomial logistic regression. >>> from pyspark.sql import Row >>> from pyspark.ml.linalg import Vectors - >>> df = sc.parallelize([ + >>> bdf = sc.parallelize([ ... Row(label=1.0, weight=2.0, features=Vectors.dense(1.0)), ... Row(label=0.0, weight=2.0, features=Vectors.sparse(1, [], []))]).toDF() - >>> lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight") - >>> model = lr.fit(df) - >>> model.coefficients + >>> blor = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight") + >>> blorModel = blor.fit(bdf) + >>> blorModel.coefficients DenseVector([5.5...]) - >>> model.intercept + >>> blorModel.intercept -2.68... + >>> mdf = sc.parallelize([ + ... Row(label=1.0, weight=2.0, features=Vectors.dense(1.0)), + ... Row(label=0.0, weight=2.0, features=Vectors.sparse(1, [], [])), + ... Row(label=2.0, weight=2.0, features=Vectors.dense(3.0))]).toDF() + >>> mlor = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", + ... family="multinomial") + >>> mlorModel = mlor.fit(mdf) + >>> print(mlorModel.coefficientMatrix) + DenseMatrix([[-2.3...], + [ 0.2...], + [ 2.1... ]]) + >>> mlorModel.interceptVector + DenseVector([2.0..., 0.8..., -2.8...]) >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF() - >>> result = model.transform(test0).head() + >>> result = blorModel.transform(test0).head() >>> result.prediction 0.0 >>> result.probability @@ -89,23 +102,23 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti >>> result.rawPrediction DenseVector([8.22..., -8.22...]) >>> test1 = sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]).toDF() - >>> model.transform(test1).head().prediction + >>> blorModel.transform(test1).head().prediction 1.0 - >>> lr.setParams("vector") + >>> blor.setParams("vector") Traceback (most recent call last): ... TypeError: Method setParams forces keyword arguments. >>> lr_path = temp_path + "/lr" - >>> lr.save(lr_path) + >>> blor.save(lr_path) >>> lr2 = LogisticRegression.load(lr_path) >>> lr2.getMaxIter() 5 >>> model_path = temp_path + "/lr_model" - >>> model.save(model_path) + >>> blorModel.save(model_path) >>> model2 = LogisticRegressionModel.load(model_path) - >>> model.coefficients[0] == model2.coefficients[0] + >>> blorModel.coefficients[0] == model2.coefficients[0] True - >>> model.intercept == model2.intercept + >>> blorModel.intercept == model2.intercept True .. versionadded:: 1.3.0 @@ -117,24 +130,29 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti "e.g. if threshold is p, then thresholds must be equal to [1-p, p].", typeConverter=TypeConverters.toFloat) + family = Param(Params._dummy(), "family", + "The name of family which is a description of the label distribution to " + + "be used in the model. Supported options: auto, binomial, multinomial", + typeConverter=TypeConverters.toString) + @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, threshold=0.5, thresholds=None, probabilityCol="probability", rawPredictionCol="rawPrediction", standardization=True, weightCol=None, - aggregationDepth=2): + aggregationDepth=2, family="auto"): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ threshold=0.5, thresholds=None, probabilityCol="probability", \ rawPredictionCol="rawPrediction", standardization=True, weightCol=None, \ - aggregationDepth=2) + aggregationDepth=2, family="auto") If the threshold and thresholds Params are both set, they must be equivalent. """ super(LogisticRegression, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.LogisticRegression", self.uid) - self._setDefault(maxIter=100, regParam=0.0, tol=1E-6, threshold=0.5) + self._setDefault(maxIter=100, regParam=0.0, tol=1E-6, threshold=0.5, family="auto") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) self._checkThresholdConsistency() @@ -145,13 +163,13 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, threshold=0.5, thresholds=None, probabilityCol="probability", rawPredictionCol="rawPrediction", standardization=True, weightCol=None, - aggregationDepth=2): + aggregationDepth=2, family="auto"): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ threshold=0.5, thresholds=None, probabilityCol="probability", \ rawPredictionCol="rawPrediction", standardization=True, weightCol=None, \ - aggregationDepth=2) + aggregationDepth=2, family="auto") Sets params for logistic regression. If the threshold and thresholds Params are both set, they must be equivalent. """ @@ -232,6 +250,20 @@ def _checkThresholdConsistency(self): raise ValueError("Logistic Regression getThreshold found inconsistent values for" + " threshold (%g) and thresholds (equivalent to %g)" % (t2, t)) + @since("2.1.0") + def setFamily(self, value): + """ + Sets the value of :py:attr:`family`. + """ + return self._set(family=value) + + @since("2.1.0") + def getFamily(self): + """ + Gets the value of :py:attr:`family` or its default value. + """ + return self.getOrDefault(self.family) + class LogisticRegressionModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable): """ @@ -244,7 +276,8 @@ class LogisticRegressionModel(JavaModel, JavaClassificationModel, JavaMLWritable @since("2.0.0") def coefficients(self): """ - Model coefficients. + Model coefficients of binomial logistic regression. + An exception is thrown in the case of multinomial logistic regression. """ return self._call_java("coefficients") @@ -252,10 +285,27 @@ def coefficients(self): @since("1.4.0") def intercept(self): """ - Model intercept. + Model intercept of binomial logistic regression. + An exception is thrown in the case of multinomial logistic regression. """ return self._call_java("intercept") + @property + @since("2.1.0") + def coefficientMatrix(self): + """ + Model coefficients. + """ + return self._call_java("coefficientMatrix") + + @property + @since("2.1.0") + def interceptVector(self): + """ + Model intercept. + """ + return self._call_java("interceptVector") + @property @since("2.0.0") def summary(self): From 6a68c5d7b4eb07e4ed6b702dd1536cd08d9bba7d Mon Sep 17 00:00:00 2001 From: Weiqing Yang Date: Tue, 27 Sep 2016 08:10:38 -0500 Subject: [PATCH 20/96] [SPARK-16757] Set up Spark caller context to HDFS and YARN ## What changes were proposed in this pull request? 1. Pass `jobId` to Task. 2. Invoke Hadoop APIs. * A new function `setCallerContext` is added in `Utils`. `setCallerContext` function invokes APIs of `org.apache.hadoop.ipc.CallerContext` to set up spark caller contexts, which will be written into `hdfs-audit.log` and Yarn RM audit log. * For HDFS: Spark sets up its caller context by invoking`org.apache.hadoop.ipc.CallerContext` in `Task` and Yarn `Client` and `ApplicationMaster`. * For Yarn: Spark sets up its caller context by invoking `org.apache.hadoop.ipc.CallerContext` in Yarn `Client`. ## How was this patch tested? Manual Tests against some Spark applications in Yarn client mode and Yarn cluster mode. Need to check if spark caller contexts are written into HDFS hdfs-audit.log and Yarn RM audit log successfully. For example, run SparkKmeans in Yarn client mode: ``` ./bin/spark-submit --verbose --executor-cores 3 --num-executors 1 --master yarn --deploy-mode client --class org.apache.spark.examples.SparkKMeans examples/target/original-spark-examples_2.11-2.1.0-SNAPSHOT.jar hdfs://localhost:9000/lr_big.txt 2 5 ``` **Before**: There will be no Spark caller context in records of `hdfs-audit.log` and Yarn RM audit log. **After**: Spark caller contexts will be written in records of `hdfs-audit.log` and Yarn RM audit log. These are records in `hdfs-audit.log`: ``` 2016-09-20 11:54:24,116 INFO FSNamesystem.audit: allowed=true ugi=wyang (auth:SIMPLE) ip=/127.0.0.1 cmd=open src=/lr_big.txt dst=null perm=null proto=rpc callerContext=SPARK_CLIENT_AppId_application_1474394339641_0005 2016-09-20 11:54:28,164 INFO FSNamesystem.audit: allowed=true ugi=wyang (auth:SIMPLE) ip=/127.0.0.1 cmd=open src=/lr_big.txt dst=null perm=null proto=rpc callerContext=SPARK_TASK_AppId_application_1474394339641_0005_JobId_0_StageId_0_AttemptId_0_TaskId_2_AttemptNum_0 2016-09-20 11:54:28,164 INFO FSNamesystem.audit: allowed=true ugi=wyang (auth:SIMPLE) ip=/127.0.0.1 cmd=open src=/lr_big.txt dst=null perm=null proto=rpc callerContext=SPARK_TASK_AppId_application_1474394339641_0005_JobId_0_StageId_0_AttemptId_0_TaskId_1_AttemptNum_0 2016-09-20 11:54:28,164 INFO FSNamesystem.audit: allowed=true ugi=wyang (auth:SIMPLE) ip=/127.0.0.1 cmd=open src=/lr_big.txt dst=null perm=null proto=rpc callerContext=SPARK_TASK_AppId_application_1474394339641_0005_JobId_0_StageId_0_AttemptId_0_TaskId_0_AttemptNum_0 ``` ``` 2016-09-20 11:59:33,868 INFO FSNamesystem.audit: allowed=true ugi=wyang (auth:SIMPLE) ip=/127.0.0.1 cmd=mkdirs src=/private/tmp/hadoop-wyang/nm-local-dir/usercache/wyang/appcache/application_1474394339641_0006/container_1474394339641_0006_01_000001/spark-warehouse dst=null perm=wyang:supergroup:rwxr-xr-x proto=rpc callerContext=SPARK_APPLICATION_MASTER_AppId_application_1474394339641_0006_AttemptId_1 2016-09-20 11:59:37,214 INFO FSNamesystem.audit: allowed=true ugi=wyang (auth:SIMPLE) ip=/127.0.0.1 cmd=open src=/lr_big.txt dst=null perm=null proto=rpc callerContext=SPARK_TASK_AppId_application_1474394339641_0006_AttemptId_1_JobId_0_StageId_0_AttemptId_0_TaskId_1_AttemptNum_0 2016-09-20 11:59:37,215 INFO FSNamesystem.audit: allowed=true ugi=wyang (auth:SIMPLE) ip=/127.0.0.1 cmd=open src=/lr_big.txt dst=null perm=null proto=rpc callerContext=SPARK_TASK_AppId_application_1474394339641_0006_AttemptId_1_JobId_0_StageId_0_AttemptId_0_TaskId_2_AttemptNum_0 2016-09-20 11:59:37,215 INFO FSNamesystem.audit: allowed=true ugi=wyang (auth:SIMPLE) ip=/127.0.0.1 cmd=open src=/lr_big.txt dst=null perm=null proto=rpc callerContext=SPARK_TASK_AppId_application_1474394339641_0006_AttemptId_1_JobId_0_StageId_0_AttemptId_0_TaskId_0_AttemptNum_0 2016-09-20 11:59:42,391 INFO FSNamesystem.audit: allowed=true ugi=wyang (auth:SIMPLE) ip=/127.0.0.1 cmd=open src=/lr_big.txt dst=null perm=null proto=rpc callerContext=SPARK_TASK_AppId_application_1474394339641_0006_AttemptId_1_JobId_0_StageId_0_AttemptId_0_TaskId_3_AttemptNum_0 ``` This is a record in Yarn RM log: ``` 2016-09-20 11:59:24,050 INFO org.apache.hadoop.yarn.server.resourcemanager.RMAuditLogger: USER=wyang IP=127.0.0.1 OPERATION=Submit Application Request TARGET=ClientRMService RESULT=SUCCESS APPID=application_1474394339641_0006 CALLERCONTEXT=SPARK_CLIENT_AppId_application_1474394339641_0006 ``` Author: Weiqing Yang Closes #14659 from Sherry302/callercontextSubmit. --- .../apache/spark/scheduler/DAGScheduler.scala | 6 +- .../apache/spark/scheduler/ResultTask.scala | 15 ++++- .../spark/scheduler/ShuffleMapTask.scala | 13 +++- .../org/apache/spark/scheduler/Task.scala | 17 ++++- .../scala/org/apache/spark/util/Utils.scala | 62 +++++++++++++++++++ .../org/apache/spark/util/UtilsSuite.scala | 12 ++++ .../spark/deploy/yarn/ApplicationMaster.scala | 7 +++ .../org/apache/spark/deploy/yarn/Client.scala | 4 +- 8 files changed, 126 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index dd47c1dbbec06..5ea0b48f6e4c4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1015,7 +1015,8 @@ class DAGScheduler( val locs = taskIdToLocations(id) val part = stage.rdd.partitions(id) new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, - taskBinary, part, locs, stage.latestInfo.taskMetrics, properties) + taskBinary, part, locs, stage.latestInfo.taskMetrics, properties, Option(jobId), + Option(sc.applicationId), sc.applicationAttemptId) } case stage: ResultStage => @@ -1024,7 +1025,8 @@ class DAGScheduler( val part = stage.rdd.partitions(p) val locs = taskIdToLocations(id) new ResultTask(stage.id, stage.latestInfo.attemptId, - taskBinary, part, locs, id, properties, stage.latestInfo.taskMetrics) + taskBinary, part, locs, id, properties, stage.latestInfo.taskMetrics, + Option(jobId), Option(sc.applicationId), sc.applicationAttemptId) } } } catch { diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 609f10aee940d..1e7c63af2e797 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -43,7 +43,12 @@ import org.apache.spark.rdd.RDD * input RDD's partitions). * @param localProperties copy of thread-local properties set by the user on the driver side. * @param metrics a [[TaskMetrics]] that is created at driver side and sent to executor side. - */ + * + * The parameters below are optional: + * @param jobId id of the job this task belongs to + * @param appId id of the app this task belongs to + * @param appAttemptId attempt id of the app this task belongs to + */ private[spark] class ResultTask[T, U]( stageId: Int, stageAttemptId: Int, @@ -52,8 +57,12 @@ private[spark] class ResultTask[T, U]( locs: Seq[TaskLocation], val outputId: Int, localProperties: Properties, - metrics: TaskMetrics) - extends Task[U](stageId, stageAttemptId, partition.index, metrics, localProperties) + metrics: TaskMetrics, + jobId: Option[Int] = None, + appId: Option[String] = None, + appAttemptId: Option[String] = None) + extends Task[U](stageId, stageAttemptId, partition.index, metrics, localProperties, jobId, + appId, appAttemptId) with Serializable { @transient private[this] val preferredLocs: Seq[TaskLocation] = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 448fe02084e0d..66d6790e168f2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -44,6 +44,11 @@ import org.apache.spark.shuffle.ShuffleWriter * @param locs preferred task execution locations for locality scheduling * @param metrics a [[TaskMetrics]] that is created at driver side and sent to executor side. * @param localProperties copy of thread-local properties set by the user on the driver side. + * + * The parameters below are optional: + * @param jobId id of the job this task belongs to + * @param appId id of the app this task belongs to + * @param appAttemptId attempt id of the app this task belongs to */ private[spark] class ShuffleMapTask( stageId: Int, @@ -52,8 +57,12 @@ private[spark] class ShuffleMapTask( partition: Partition, @transient private var locs: Seq[TaskLocation], metrics: TaskMetrics, - localProperties: Properties) - extends Task[MapStatus](stageId, stageAttemptId, partition.index, metrics, localProperties) + localProperties: Properties, + jobId: Option[Int] = None, + appId: Option[String] = None, + appAttemptId: Option[String] = None) + extends Task[MapStatus](stageId, stageAttemptId, partition.index, metrics, localProperties, jobId, + appId, appAttemptId) with Logging { /** A constructor used only in test suites. This does not require passing in an RDD. */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 48daa344f3c88..9385e3c31e1e4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -29,7 +29,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.{MemoryMode, TaskMemoryManager} import org.apache.spark.metrics.MetricsSystem import org.apache.spark.serializer.SerializerInstance -import org.apache.spark.util.{AccumulatorV2, ByteBufferInputStream, ByteBufferOutputStream, Utils} +import org.apache.spark.util._ /** * A unit of execution. We have two kinds of Task's in Spark: @@ -47,6 +47,11 @@ import org.apache.spark.util.{AccumulatorV2, ByteBufferInputStream, ByteBufferOu * @param partitionId index of the number in the RDD * @param metrics a [[TaskMetrics]] that is created at driver side and sent to executor side. * @param localProperties copy of thread-local properties set by the user on the driver side. + * + * The parameters below are optional: + * @param jobId id of the job this task belongs to + * @param appId id of the app this task belongs to + * @param appAttemptId attempt id of the app this task belongs to */ private[spark] abstract class Task[T]( val stageId: Int, @@ -54,7 +59,10 @@ private[spark] abstract class Task[T]( val partitionId: Int, // The default value is only used in tests. val metrics: TaskMetrics = TaskMetrics.registered, - @transient var localProperties: Properties = new Properties) extends Serializable { + @transient var localProperties: Properties = new Properties, + val jobId: Option[Int] = None, + val appId: Option[String] = None, + val appAttemptId: Option[String] = None) extends Serializable { /** * Called by [[org.apache.spark.executor.Executor]] to run this task. @@ -79,9 +87,14 @@ private[spark] abstract class Task[T]( metrics) TaskContext.setTaskContext(context) taskThread = Thread.currentThread() + if (_killed) { kill(interruptThread = false) } + + new CallerContext("TASK", appId, appAttemptId, jobId, Option(stageId), Option(stageAttemptId), + Option(taskAttemptId), Option(attemptNumber)).setCurrentContext() + try { runTask(context) } catch { diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index e09666c6103c6..caa768cfbdc6c 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2440,6 +2440,68 @@ private[spark] object Utils extends Logging { } } +/** + * An utility class used to set up Spark caller contexts to HDFS and Yarn. The `context` will be + * constructed by parameters passed in. + * When Spark applications run on Yarn and HDFS, its caller contexts will be written into Yarn RM + * audit log and hdfs-audit.log. That can help users to better diagnose and understand how + * specific applications impacting parts of the Hadoop system and potential problems they may be + * creating (e.g. overloading NN). As HDFS mentioned in HDFS-9184, for a given HDFS operation, it's + * very helpful to track which upper level job issues it. + * + * @param from who sets up the caller context (TASK, CLIENT, APPMASTER) + * + * The parameters below are optional: + * @param appId id of the app this task belongs to + * @param appAttemptId attempt id of the app this task belongs to + * @param jobId id of the job this task belongs to + * @param stageId id of the stage this task belongs to + * @param stageAttemptId attempt id of the stage this task belongs to + * @param taskId task id + * @param taskAttemptNumber task attempt id + */ +private[spark] class CallerContext( + from: String, + appId: Option[String] = None, + appAttemptId: Option[String] = None, + jobId: Option[Int] = None, + stageId: Option[Int] = None, + stageAttemptId: Option[Int] = None, + taskId: Option[Long] = None, + taskAttemptNumber: Option[Int] = None) extends Logging { + + val appIdStr = if (appId.isDefined) s"_${appId.get}" else "" + val appAttemptIdStr = if (appAttemptId.isDefined) s"_${appAttemptId.get}" else "" + val jobIdStr = if (jobId.isDefined) s"_JId_${jobId.get}" else "" + val stageIdStr = if (stageId.isDefined) s"_SId_${stageId.get}" else "" + val stageAttemptIdStr = if (stageAttemptId.isDefined) s"_${stageAttemptId.get}" else "" + val taskIdStr = if (taskId.isDefined) s"_TId_${taskId.get}" else "" + val taskAttemptNumberStr = + if (taskAttemptNumber.isDefined) s"_${taskAttemptNumber.get}" else "" + + val context = "SPARK_" + from + appIdStr + appAttemptIdStr + + jobIdStr + stageIdStr + stageAttemptIdStr + taskIdStr + taskAttemptNumberStr + + /** + * Set up the caller context [[context]] by invoking Hadoop CallerContext API of + * [[org.apache.hadoop.ipc.CallerContext]], which was added in hadoop 2.8. + */ + def setCurrentContext(): Boolean = { + var succeed = false + try { + val callerContext = Utils.classForName("org.apache.hadoop.ipc.CallerContext") + val Builder = Utils.classForName("org.apache.hadoop.ipc.CallerContext$Builder") + val builderInst = Builder.getConstructor(classOf[String]).newInstance(context) + val hdfsContext = Builder.getMethod("build").invoke(builderInst) + callerContext.getMethod("setCurrent", callerContext).invoke(null, hdfsContext) + succeed = true + } catch { + case NonFatal(e) => logInfo("Fail to set Spark caller context", e) + } + succeed + } +} + /** * A utility class to redirect the child process's stdout or stderr. */ diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 4715fd29375d6..bc28b2d9cb831 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -788,6 +788,18 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { .set("spark.executor.instances", "1")) === 3) } + test("Set Spark CallerContext") { + val context = "test" + try { + val callerContext = Utils.classForName("org.apache.hadoop.ipc.CallerContext") + assert(new CallerContext(context).setCurrentContext()) + assert(s"SPARK_$context" === + callerContext.getMethod("getCurrent").invoke(null).toString) + } catch { + case e: ClassNotFoundException => + assert(!new CallerContext(context).setCurrentContext()) + } + } test("encodeFileNameToURIRawPath") { assert(Utils.encodeFileNameToURIRawPath("abc") === "abc") diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index ad50ea789a913..aabae140af8b1 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -184,6 +184,8 @@ private[spark] class ApplicationMaster( try { val appAttemptId = client.getAttemptId() + var attemptID: Option[String] = None + if (isClusterMode) { // Set the web ui port to be ephemeral for yarn so we don't conflict with // other spark processes running on the same box @@ -196,8 +198,13 @@ private[spark] class ApplicationMaster( // Set this internal configuration if it is running on cluster mode, this // configuration will be checked in SparkContext to avoid misuse of yarn cluster mode. System.setProperty("spark.yarn.app.id", appAttemptId.getApplicationId().toString()) + + attemptID = Option(appAttemptId.getAttemptId.toString) } + new CallerContext("APPMASTER", + Option(appAttemptId.getApplicationId.toString), attemptID).setCurrentContext() + logInfo("ApplicationAttemptId: " + appAttemptId) val fs = FileSystem.get(yarnConf) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 2398f0aea316a..ea4e1160b7672 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -54,7 +54,7 @@ import org.apache.spark.deploy.yarn.security.ConfigurableCredentialManager import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle, YarnCommandBuilderUtils} -import org.apache.spark.util.Utils +import org.apache.spark.util.{CallerContext, Utils} private[spark] class Client( val args: ClientArguments, @@ -161,6 +161,8 @@ private[spark] class Client( reportLauncherState(SparkAppHandle.State.SUBMITTED) launcherBackend.setAppId(appId.toString) + new CallerContext("CLIENT", Option(appId.toString)).setCurrentContext() + // Verify whether the cluster has enough resources for our AM verifyClusterResources(newAppResponse) From 5de1737b02710e36f6804d2ae243d1aeb30a0b32 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 28 Sep 2016 00:39:47 +0800 Subject: [PATCH 21/96] [SPARK-16777][SQL] Do not use deprecated listType API in ParquetSchemaConverter ## What changes were proposed in this pull request? This PR removes build waning as below. ```scala [WARNING] .../spark/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala:448: method listType in object ConversionPatterns is deprecated: see corresponding Javadoc for more information. [WARNING] ConversionPatterns.listType( [WARNING] ^ [WARNING] .../spark/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala:464: method listType in object ConversionPatterns is deprecated: see corresponding Javadoc for more information. [WARNING] ConversionPatterns.listType( [WARNING] ^ ``` This should not use `listOfElements` (recommended to be replaced from `listType`) instead because the new method checks if the name of elements in Parquet's `LIST` is `element` in Parquet schema and throws an exception if not. However, It seems Spark prior to 1.4.x writes `ArrayType` with Parquet's `LIST` but with `array` as its element name. Therefore, this PR avoids to use both `listOfElements` and `listType` but just use the existing schema builder to construct the same `GroupType`. ## How was this patch tested? Existing tests should cover this. Author: hyukjinkwon Closes #14399 from HyukjinKwon/SPARK-16777. --- .../parquet/ParquetSchemaConverter.scala | 26 ++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index c81a65f4973e3..b4f36ce3752c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -445,14 +445,20 @@ private[parquet] class ParquetSchemaConverter( // repeated array; // } // } - ConversionPatterns.listType( - repetition, - field.name, - Types + + // This should not use `listOfElements` here because this new method checks if the + // element name is `element` in the `GroupType` and throws an exception if not. + // As mentioned above, Spark prior to 1.4.x writes `ArrayType` as `LIST` but with + // `array` as its element name as below. Therefore, we build manually + // the correct group type here via the builder. (See SPARK-16777) + Types + .buildGroup(repetition).as(LIST) + .addField(Types .buildGroup(REPEATED) - // "array_element" is the name chosen by parquet-hive (1.7.0 and prior version) + // "array" is the name chosen by parquet-hive (1.7.0 and prior version) .addField(convertField(StructField("array", elementType, nullable))) .named("bag")) + .named(field.name) // Spark 1.4.x and prior versions convert ArrayType with non-nullable elements into a 2-level // LIST structure. This behavior mimics parquet-avro (1.6.0rc3). Note that this case is @@ -461,11 +467,13 @@ private[parquet] class ParquetSchemaConverter( // group (LIST) { // repeated element; // } - ConversionPatterns.listType( - repetition, - field.name, + + // Here too, we should not use `listOfElements`. (See SPARK-16777) + Types + .buildGroup(repetition).as(LIST) // "array" is the name chosen by parquet-avro (1.7.0 and prior version) - convertField(StructField("array", elementType, nullable), REPEATED)) + .addField(convertField(StructField("array", elementType, nullable), REPEATED)) + .named(field.name) // Spark 1.4.x and prior versions convert MapType into a 3-level group annotated by // MAP_KEY_VALUE. This is covered by `convertGroupField(field: GroupType): DataType`. From 2cac3b2d4a4a4f3d0d45af4defc23bb0ba53484b Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 28 Sep 2016 00:50:12 +0800 Subject: [PATCH 22/96] [SPARK-16516][SQL] Support for pushing down filters for decimal and timestamp types in ORC ## What changes were proposed in this pull request? It seems ORC supports all the types in ([`PredicateLeaf.Type`](https://github.com/apache/hive/blob/e085b7e9bd059d91aaf013df0db4d71dca90ec6f/storage-api/src/java/org/apache/hadoop/hive/ql/io/sarg/PredicateLeaf.java#L50-L56)) which includes timestamp type and decimal type. In more details, the types listed in [`SearchArgumentImpl.boxLiteral()`](https://github.com/apache/hive/blob/branch-1.2/ql/src/java/org/apache/hadoop/hive/ql/io/sarg/SearchArgumentImpl.java#L1068-L1093) can be used as a filter value. FYI, inital `case` caluse for supported types was introduced in https://github.com/apache/spark/commit/65d71bd9fbfe6fe1b741c80fed72d6ae3d22b028 and this was not changed overtime. At that time, Hive version was, 0.13 which supports only some types for filter-push down (See [SearchArgumentImpl.java#L945-L965](https://github.com/apache/hive/blob/branch-0.13/ql/src/java/org/apache/hadoop/hive/ql/io/sarg/SearchArgumentImpl.java#L945-L965) at 0.13). However, the version was upgraded into 1.2.x and now it supports more types (See [SearchArgumentImpl.java#L1068-L1093](https://github.com/apache/hive/blob/branch-1.2/ql/src/java/org/apache/hadoop/hive/ql/io/sarg/SearchArgumentImpl.java#L1068-L1093) at 1.2.0) ## How was this patch tested? Unit tests in `OrcFilterSuite` and `OrcQuerySuite` Author: hyukjinkwon Closes #14172 from HyukjinKwon/SPARK-16516. --- .../spark/sql/hive/orc/OrcFilters.scala | 1 + .../spark/sql/hive/orc/OrcFilterSuite.scala | 62 ++++++++++++++++--- .../spark/sql/hive/orc/OrcQuerySuite.scala | 35 +++++++++++ 3 files changed, 89 insertions(+), 9 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala index 6ab824455929d..d9efd0cb457cd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala @@ -84,6 +84,7 @@ private[orc] object OrcFilters extends Logging { // the `SearchArgumentImpl.BuilderImpl.boxLiteral()` method. case ByteType | ShortType | FloatType | DoubleType => true case IntegerType | LongType | StringType | BooleanType => true + case TimestampType | _: DecimalType => true case _ => false } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala index 471192a369f4a..222c24927a763 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala @@ -229,6 +229,59 @@ class OrcFilterSuite extends QueryTest with OrcTest { } } + test("filter pushdown - decimal") { + withOrcDataFrame((1 to 4).map(i => Tuple1.apply(BigDecimal.valueOf(i)))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === BigDecimal.valueOf(1), PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> BigDecimal.valueOf(1), PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < BigDecimal.valueOf(2), PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > BigDecimal.valueOf(3), PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= BigDecimal.valueOf(1), PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= BigDecimal.valueOf(4), PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate( + Literal(BigDecimal.valueOf(1)) === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate( + Literal(BigDecimal.valueOf(1)) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate( + Literal(BigDecimal.valueOf(2)) > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate( + Literal(BigDecimal.valueOf(3)) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate( + Literal(BigDecimal.valueOf(1)) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate( + Literal(BigDecimal.valueOf(4)) <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + + test("filter pushdown - timestamp") { + val timeString = "2015-08-20 14:57:00" + val timestamps = (1 to 4).map { i => + val milliseconds = Timestamp.valueOf(timeString).getTime + i * 3600 + new Timestamp(milliseconds) + } + withOrcDataFrame(timestamps.map(Tuple1(_))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === timestamps(0), PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> timestamps(0), PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < timestamps(1), PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > timestamps(2), PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= timestamps(0), PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= timestamps(3), PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(timestamps(0)) === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(timestamps(0)) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(timestamps(1)) > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(timestamps(2)) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(timestamps(0)) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(timestamps(3)) <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + test("filter pushdown - combinations with logical operators") { withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df => // Because `ExpressionTree` is not accessible at Hive 1.2.x, this should be checked @@ -277,19 +330,10 @@ class OrcFilterSuite extends QueryTest with OrcTest { withOrcDataFrame((1 to 4).map(i => Tuple1(Array(i)))) { implicit df => checkNoFilterPredicate('_1.isNull) } - // DecimalType - withOrcDataFrame((1 to 4).map(i => Tuple1(BigDecimal.valueOf(i)))) { implicit df => - checkNoFilterPredicate('_1 <= BigDecimal.valueOf(4)) - } // BinaryType withOrcDataFrame((1 to 4).map(i => Tuple1(i.b))) { implicit df => checkNoFilterPredicate('_1 <=> 1.b) } - // TimestampType - val stringTimestamp = "2015-08-20 15:57:00" - withOrcDataFrame(Seq(Tuple1(Timestamp.valueOf(stringTimestamp)))) { implicit df => - checkNoFilterPredicate('_1 <=> Timestamp.valueOf(stringTimestamp)) - } // DateType val stringDate = "2015-01-01" withOrcDataFrame(Seq(Tuple1(Date.valueOf(stringDate)))) { implicit df => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index b13878d578603..b2ee49c441ef2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.orc import java.nio.charset.StandardCharsets +import java.sql.Timestamp import org.scalatest.BeforeAndAfterAll @@ -500,6 +501,40 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } } + test("Support for pushing down filters for decimal types") { + withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") { + val data = (0 until 10).map(i => Tuple1(BigDecimal.valueOf(i))) + withTempPath { file => + // It needs to repartition data so that we can have several ORC files + // in order to skip stripes in ORC. + createDataFrame(data).toDF("a").repartition(10).write.orc(file.getCanonicalPath) + val df = spark.read.orc(file.getCanonicalPath).where("a == 2") + val actual = stripSparkFilter(df).count() + + assert(actual < 10) + } + } + } + + test("Support for pushing down filters for timestamp types") { + withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") { + val timeString = "2015-08-20 14:57:00" + val data = (0 until 10).map { i => + val milliseconds = Timestamp.valueOf(timeString).getTime + i * 3600 + Tuple1(new Timestamp(milliseconds)) + } + withTempPath { file => + // It needs to repartition data so that we can have several ORC files + // in order to skip stripes in ORC. + createDataFrame(data).toDF("a").repartition(10).write.orc(file.getCanonicalPath) + val df = spark.read.orc(file.getCanonicalPath).where(s"a == '$timeString'") + val actual = stripSparkFilter(df).count() + + assert(actual < 10) + } + } + } + test("column nullability and comment - write and then read") { val schema = (new StructType) .add("cl1", IntegerType, nullable = false, comment = "test") From 120723f934dc386a46a043d2833bfcee60d14e74 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 27 Sep 2016 10:20:30 -0700 Subject: [PATCH 23/96] [SPARK-17682][SQL] Mark children as final for unary, binary, leaf expressions and plan nodes ## What changes were proposed in this pull request? This patch marks the children method as final in unary, binary, and leaf expressions and plan nodes (both logical plan and physical plan), as brought up in http://apache-spark-developers-list.1001551.n3.nabble.com/Should-LeafExpression-have-children-final-override-like-Nondeterministic-td19104.html ## How was this patch tested? This is a simple modifier change and has no impact on test coverage. Author: Reynold Xin Closes #15256 from rxin/SPARK-17682. --- .../apache/spark/sql/catalyst/expressions/Expression.scala | 6 +++--- .../apache/spark/sql/catalyst/expressions/generators.scala | 4 ---- .../apache/spark/sql/catalyst/plans/logical/Command.scala | 1 - .../spark/sql/catalyst/plans/logical/LogicalPlan.scala | 6 +++--- .../scala/org/apache/spark/sql/execution/SparkPlan.scala | 6 +++--- 5 files changed, 9 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 7abbbe257d830..fa1a2ad56ccb3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -295,7 +295,7 @@ trait Nondeterministic extends Expression { */ abstract class LeafExpression extends Expression { - def children: Seq[Expression] = Nil + override final def children: Seq[Expression] = Nil } @@ -307,7 +307,7 @@ abstract class UnaryExpression extends Expression { def child: Expression - override def children: Seq[Expression] = child :: Nil + override final def children: Seq[Expression] = child :: Nil override def foldable: Boolean = child.foldable override def nullable: Boolean = child.nullable @@ -394,7 +394,7 @@ abstract class BinaryExpression extends Expression { def left: Expression def right: Expression - override def children: Seq[Expression] = Seq(left, right) + override final def children: Seq[Expression] = Seq(left, right) override def foldable: Boolean = left.foldable && right.foldable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 9d5c856a23e2a..f74208ff66db7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -152,8 +152,6 @@ case class Stack(children: Seq[Expression]) abstract class ExplodeBase(child: Expression, position: Boolean) extends UnaryExpression with Generator with CodegenFallback with Serializable { - override def children: Seq[Expression] = child :: Nil - override def checkInputDataTypes(): TypeCheckResult = { if (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) { TypeCheckResult.TypeCheckSuccess @@ -257,8 +255,6 @@ case class PosExplode(child: Expression) extends ExplodeBase(child, position = t extended = "> SELECT _FUNC_(array(struct(1, 'a'), struct(2, 'b')));\n [1,a]\n [2,b]") case class Inline(child: Expression) extends UnaryExpression with Generator with CodegenFallback { - override def children: Seq[Expression] = child :: Nil - override def checkInputDataTypes(): TypeCheckResult = child.dataType match { case ArrayType(et, _) if et.isInstanceOf[StructType] => TypeCheckResult.TypeCheckSuccess diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala index 64f57835c8898..38f47081b6f55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala @@ -25,6 +25,5 @@ import org.apache.spark.sql.catalyst.expressions.Attribute * eagerly executed. */ trait Command extends LeafNode { - final override def children: Seq[LogicalPlan] = Seq.empty override def output: Seq[Attribute] = Seq.empty } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 6d7799151d93b..09725473a384d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -276,7 +276,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { * A logical plan node with no children. */ abstract class LeafNode extends LogicalPlan { - override def children: Seq[LogicalPlan] = Nil + override final def children: Seq[LogicalPlan] = Nil override def producedAttributes: AttributeSet = outputSet } @@ -286,7 +286,7 @@ abstract class LeafNode extends LogicalPlan { abstract class UnaryNode extends LogicalPlan { def child: LogicalPlan - override def children: Seq[LogicalPlan] = child :: Nil + override final def children: Seq[LogicalPlan] = child :: Nil /** * Generates an additional set of aliased constraints by replacing the original constraint @@ -330,5 +330,5 @@ abstract class BinaryNode extends LogicalPlan { def left: LogicalPlan def right: LogicalPlan - override def children: Seq[LogicalPlan] = Seq(left, right) + override final def children: Seq[LogicalPlan] = Seq(left, right) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 6aeefa6eddafe..48d6ef6dcd44a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -380,7 +380,7 @@ object SparkPlan { } trait LeafExecNode extends SparkPlan { - override def children: Seq[SparkPlan] = Nil + override final def children: Seq[SparkPlan] = Nil override def producedAttributes: AttributeSet = outputSet } @@ -394,7 +394,7 @@ object UnaryExecNode { trait UnaryExecNode extends SparkPlan { def child: SparkPlan - override def children: Seq[SparkPlan] = child :: Nil + override final def children: Seq[SparkPlan] = child :: Nil override def outputPartitioning: Partitioning = child.outputPartitioning } @@ -403,5 +403,5 @@ trait BinaryExecNode extends SparkPlan { def left: SparkPlan def right: SparkPlan - override def children: Seq[SparkPlan] = Seq(left, right) + override final def children: Seq[SparkPlan] = Seq(left, right) } From 2ab24a7bf6687ec238306772c4c7ddef6ac93e99 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 27 Sep 2016 10:52:26 -0700 Subject: [PATCH 24/96] [SPARK-17660][SQL] DESC FORMATTED for VIEW Lacks View Definition ### What changes were proposed in this pull request? Before this PR, `DESC FORMATTED` does not have a section for the view definition. We should add it for permanent views, like what Hive does. ``` +----------------------------+-------------------------------------------------------------------------------------------------------------------------------------+-------+ |col_name |data_type |comment| +----------------------------+-------------------------------------------------------------------------------------------------------------------------------------+-------+ |a |int |null | | | | | |# Detailed Table Information| | | |Database: |default | | |Owner: |xiaoli | | |Create Time: |Sat Sep 24 21:46:19 PDT 2016 | | |Last Access Time: |Wed Dec 31 16:00:00 PST 1969 | | |Location: | | | |Table Type: |VIEW | | |Table Parameters: | | | | transient_lastDdlTime |1474778779 | | | | | | |# Storage Information | | | |SerDe Library: |org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe | | |InputFormat: |org.apache.hadoop.mapred.SequenceFileInputFormat | | |OutputFormat: |org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat | | |Compressed: |No | | |Storage Desc Parameters: | | | | serialization.format |1 | | | | | | |# View Information | | | |View Original Text: |SELECT * FROM tbl | | |View Expanded Text: |SELECT `gen_attr_0` AS `a` FROM (SELECT `gen_attr_0` FROM (SELECT `a` AS `gen_attr_0` FROM `default`.`tbl`) AS gen_subquery_0) AS tbl| | +----------------------------+-------------------------------------------------------------------------------------------------------------------------------------+-------+ ``` ### How was this patch tested? Added a test case Author: gatorsmile Closes #15234 from gatorsmile/descFormattedView. --- .../spark/sql/execution/command/tables.scala | 9 +++++++++ .../sql/hive/execution/HiveDDLSuite.scala | 19 +++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 0f61629317c81..6a91c997bac63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -462,6 +462,8 @@ case class DescribeTableCommand(table: TableIdentifier, isExtended: Boolean, isF } describeStorageInfo(table, buffer) + + if (table.tableType == CatalogTableType.VIEW) describeViewInfo(table, buffer) } private def describeStorageInfo(metadata: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { @@ -479,6 +481,13 @@ case class DescribeTableCommand(table: TableIdentifier, isExtended: Boolean, isF } } + private def describeViewInfo(metadata: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { + append(buffer, "", "", "") + append(buffer, "# View Information", "", "") + append(buffer, "View Original Text:", metadata.viewOriginalText.getOrElse(""), "") + append(buffer, "View Expanded Text:", metadata.viewText.getOrElse(""), "") + } + private def describeBucketingInfo(metadata: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { metadata.bucketSpec match { case Some(BucketSpec(numBuckets, bucketColumnNames, sortColumnNames)) => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index c927e5d802c90..751e976c7b908 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -506,6 +506,25 @@ class HiveDDLSuite } } + test("desc formatted table for permanent view") { + withTable("tbl") { + withView("view1") { + sql("CREATE TABLE tbl(a int)") + sql("CREATE VIEW view1 AS SELECT * FROM tbl") + assert(sql("DESC FORMATTED view1").collect().containsSlice( + Seq( + Row("# View Information", "", ""), + Row("View Original Text:", "SELECT * FROM tbl", ""), + Row("View Expanded Text:", + "SELECT `gen_attr_0` AS `a` FROM (SELECT `gen_attr_0` FROM " + + "(SELECT `a` AS `gen_attr_0` FROM `default`.`tbl`) AS gen_subquery_0) AS tbl", + "") + ) + )) + } + } + } + test("desc table for data source table using Hive Metastore") { assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "hive") val tabName = "tab1" From 67c73052b877a8709ae6fa22b844a45f114b1f7e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 27 Sep 2016 12:37:19 -0700 Subject: [PATCH 25/96] [SPARK-17677][SQL] Break WindowExec.scala into multiple files ## What changes were proposed in this pull request? As of Spark 2.0, all the window function execution code are in WindowExec.scala. This file is pretty large (over 1k loc) and has a lot of different abstractions in them. This patch creates a new package sql.execution.window, moves WindowExec.scala in it, and breaks WindowExec.scala into multiple, more maintainable pieces: - AggregateProcessor.scala - BoundOrdering.scala - RowBuffer.scala - WindowExec - WindowFunctionFrame.scala ## How was this patch tested? This patch mostly moves code around, and should not change any existing test coverage. Author: Reynold Xin Closes #15252 from rxin/SPARK-17677. --- .../spark/sql/execution/SparkStrategies.scala | 3 +- .../spark/sql/execution/WindowExec.scala | 1013 ----------------- .../execution/window/AggregateProcessor.scala | 159 +++ .../sql/execution/window/BoundOrdering.scala | 58 + .../sql/execution/window/RowBuffer.scala | 115 ++ .../sql/execution/window/WindowExec.scala | 412 +++++++ .../window/WindowFunctionFrame.scala | 367 ++++++ 7 files changed, 1112 insertions(+), 1015 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/window/BoundOrdering.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/window/RowBuffer.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 3441ccf53b45b..7cfae5ce283bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.{execution, SaveMode, Strategy} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.catalog.CatalogTableType import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ @@ -387,7 +386,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case e @ logical.Expand(_, _, child) => execution.ExpandExec(e.projections, e.output, planLater(child)) :: Nil case logical.Window(windowExprs, partitionSpec, orderSpec, child) => - execution.WindowExec(windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil + execution.window.WindowExec(windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil case logical.Sample(lb, ub, withReplacement, seed, child) => execution.SampleExec(lb, ub, withReplacement, seed, planLater(child)) :: Nil case logical.LocalRelation(output, data) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala deleted file mode 100644 index 9d006d21d9440..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala +++ /dev/null @@ -1,1013 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import java.util - -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.types.IntegerType -import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, UnsafeSorterIterator} - -/** - * This class calculates and outputs (windowed) aggregates over the rows in a single (sorted) - * partition. The aggregates are calculated for each row in the group. Special processing - * instructions, frames, are used to calculate these aggregates. Frames are processed in the order - * specified in the window specification (the ORDER BY ... clause). There are four different frame - * types: - * - Entire partition: The frame is the entire partition, i.e. - * UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING. For this case, window function will take all - * rows as inputs and be evaluated once. - * - Growing frame: We only add new rows into the frame, i.e. UNBOUNDED PRECEDING AND .... - * Every time we move to a new row to process, we add some rows to the frame. We do not remove - * rows from this frame. - * - Shrinking frame: We only remove rows from the frame, i.e. ... AND UNBOUNDED FOLLOWING. - * Every time we move to a new row to process, we remove some rows from the frame. We do not add - * rows to this frame. - * - Moving frame: Every time we move to a new row to process, we remove some rows from the frame - * and we add some rows to the frame. Examples are: - * 1 PRECEDING AND CURRENT ROW and 1 FOLLOWING AND 2 FOLLOWING. - * - Offset frame: The frame consist of one row, which is an offset number of rows away from the - * current row. Only [[OffsetWindowFunction]]s can be processed in an offset frame. - * - * Different frame boundaries can be used in Growing, Shrinking and Moving frames. A frame - * boundary can be either Row or Range based: - * - Row Based: A row based boundary is based on the position of the row within the partition. - * An offset indicates the number of rows above or below the current row, the frame for the - * current row starts or ends. For instance, given a row based sliding frame with a lower bound - * offset of -1 and a upper bound offset of +2. The frame for row with index 5 would range from - * index 4 to index 6. - * - Range based: A range based boundary is based on the actual value of the ORDER BY - * expression(s). An offset is used to alter the value of the ORDER BY expression, for - * instance if the current order by expression has a value of 10 and the lower bound offset - * is -3, the resulting lower bound for the current row will be 10 - 3 = 7. This however puts a - * number of constraints on the ORDER BY expressions: there can be only one expression and this - * expression must have a numerical data type. An exception can be made when the offset is 0, - * because no value modification is needed, in this case multiple and non-numeric ORDER BY - * expression are allowed. - * - * This is quite an expensive operator because every row for a single group must be in the same - * partition and partitions must be sorted according to the grouping and sort order. The operator - * requires the planner to take care of the partitioning and sorting. - * - * The operator is semi-blocking. The window functions and aggregates are calculated one group at - * a time, the result will only be made available after the processing for the entire group has - * finished. The operator is able to process different frame configurations at the same time. This - * is done by delegating the actual frame processing (i.e. calculation of the window functions) to - * specialized classes, see [[WindowFunctionFrame]], which take care of their own frame type: - * Entire Partition, Sliding, Growing & Shrinking. Boundary evaluation is also delegated to a pair - * of specialized classes: [[RowBoundOrdering]] & [[RangeBoundOrdering]]. - */ -case class WindowExec( - windowExpression: Seq[NamedExpression], - partitionSpec: Seq[Expression], - orderSpec: Seq[SortOrder], - child: SparkPlan) - extends UnaryExecNode { - - override def output: Seq[Attribute] = - child.output ++ windowExpression.map(_.toAttribute) - - override def requiredChildDistribution: Seq[Distribution] = { - if (partitionSpec.isEmpty) { - // Only show warning when the number of bytes is larger than 100 MB? - logWarning("No Partition Defined for Window operation! Moving all data to a single " - + "partition, this can cause serious performance degradation.") - AllTuples :: Nil - } else ClusteredDistribution(partitionSpec) :: Nil - } - - override def requiredChildOrdering: Seq[Seq[SortOrder]] = - Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec) - - override def outputOrdering: Seq[SortOrder] = child.outputOrdering - - /** - * Create a bound ordering object for a given frame type and offset. A bound ordering object is - * used to determine which input row lies within the frame boundaries of an output row. - * - * This method uses Code Generation. It can only be used on the executor side. - * - * @param frameType to evaluate. This can either be Row or Range based. - * @param offset with respect to the row. - * @return a bound ordering object. - */ - private[this] def createBoundOrdering(frameType: FrameType, offset: Int): BoundOrdering = { - frameType match { - case RangeFrame => - val (exprs, current, bound) = if (offset == 0) { - // Use the entire order expression when the offset is 0. - val exprs = orderSpec.map(_.child) - val buildProjection = () => newMutableProjection(exprs, child.output) - (orderSpec, buildProjection(), buildProjection()) - } else if (orderSpec.size == 1) { - // Use only the first order expression when the offset is non-null. - val sortExpr = orderSpec.head - val expr = sortExpr.child - // Create the projection which returns the current 'value'. - val current = newMutableProjection(expr :: Nil, child.output) - // Flip the sign of the offset when processing the order is descending - val boundOffset = sortExpr.direction match { - case Descending => -offset - case Ascending => offset - } - // Create the projection which returns the current 'value' modified by adding the offset. - val boundExpr = Add(expr, Cast(Literal.create(boundOffset, IntegerType), expr.dataType)) - val bound = newMutableProjection(boundExpr :: Nil, child.output) - (sortExpr :: Nil, current, bound) - } else { - sys.error("Non-Zero range offsets are not supported for windows " + - "with multiple order expressions.") - } - // Construct the ordering. This is used to compare the result of current value projection - // to the result of bound value projection. This is done manually because we want to use - // Code Generation (if it is enabled). - val sortExprs = exprs.zipWithIndex.map { case (e, i) => - SortOrder(BoundReference(i, e.dataType, e.nullable), e.direction) - } - val ordering = newOrdering(sortExprs, Nil) - RangeBoundOrdering(ordering, current, bound) - case RowFrame => RowBoundOrdering(offset) - } - } - - /** - * Collection containing an entry for each window frame to process. Each entry contains a frames' - * WindowExpressions and factory function for the WindowFrameFunction. - */ - private[this] lazy val windowFrameExpressionFactoryPairs = { - type FrameKey = (String, FrameType, Option[Int], Option[Int]) - type ExpressionBuffer = mutable.Buffer[Expression] - val framedFunctions = mutable.Map.empty[FrameKey, (ExpressionBuffer, ExpressionBuffer)] - - // Add a function and its function to the map for a given frame. - def collect(tpe: String, fr: SpecifiedWindowFrame, e: Expression, fn: Expression): Unit = { - val key = (tpe, fr.frameType, FrameBoundary(fr.frameStart), FrameBoundary(fr.frameEnd)) - val (es, fns) = framedFunctions.getOrElseUpdate( - key, (ArrayBuffer.empty[Expression], ArrayBuffer.empty[Expression])) - es += e - fns += fn - } - - // Collect all valid window functions and group them by their frame. - windowExpression.foreach { x => - x.foreach { - case e @ WindowExpression(function, spec) => - val frame = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] - function match { - case AggregateExpression(f, _, _, _) => collect("AGGREGATE", frame, e, f) - case f: AggregateWindowFunction => collect("AGGREGATE", frame, e, f) - case f: OffsetWindowFunction => collect("OFFSET", frame, e, f) - case f => sys.error(s"Unsupported window function: $f") - } - case _ => - } - } - - // Map the groups to a (unbound) expression and frame factory pair. - var numExpressions = 0 - framedFunctions.toSeq.map { - case (key, (expressions, functionSeq)) => - val ordinal = numExpressions - val functions = functionSeq.toArray - - // Construct an aggregate processor if we need one. - def processor = AggregateProcessor( - functions, - ordinal, - child.output, - (expressions, schema) => - newMutableProjection(expressions, schema, subexpressionEliminationEnabled)) - - // Create the factory - val factory = key match { - // Offset Frame - case ("OFFSET", RowFrame, Some(offset), Some(h)) if offset == h => - target: MutableRow => - new OffsetWindowFunctionFrame( - target, - ordinal, - // OFFSET frame functions are guaranteed be OffsetWindowFunctions. - functions.map(_.asInstanceOf[OffsetWindowFunction]), - child.output, - (expressions, schema) => - newMutableProjection(expressions, schema, subexpressionEliminationEnabled), - offset) - - // Growing Frame. - case ("AGGREGATE", frameType, None, Some(high)) => - target: MutableRow => { - new UnboundedPrecedingWindowFunctionFrame( - target, - processor, - createBoundOrdering(frameType, high)) - } - - // Shrinking Frame. - case ("AGGREGATE", frameType, Some(low), None) => - target: MutableRow => { - new UnboundedFollowingWindowFunctionFrame( - target, - processor, - createBoundOrdering(frameType, low)) - } - - // Moving Frame. - case ("AGGREGATE", frameType, Some(low), Some(high)) => - target: MutableRow => { - new SlidingWindowFunctionFrame( - target, - processor, - createBoundOrdering(frameType, low), - createBoundOrdering(frameType, high)) - } - - // Entire Partition Frame. - case ("AGGREGATE", frameType, None, None) => - target: MutableRow => { - new UnboundedWindowFunctionFrame(target, processor) - } - } - - // Keep track of the number of expressions. This is a side-effect in a map... - numExpressions += expressions.size - - // Create the Frame Expression - Factory pair. - (expressions, factory) - } - } - - /** - * Create the resulting projection. - * - * This method uses Code Generation. It can only be used on the executor side. - * - * @param expressions unbound ordered function expressions. - * @return the final resulting projection. - */ - private[this] def createResultProjection( - expressions: Seq[Expression]): UnsafeProjection = { - val references = expressions.zipWithIndex.map{ case (e, i) => - // Results of window expressions will be on the right side of child's output - BoundReference(child.output.size + i, e.dataType, e.nullable) - } - val unboundToRefMap = expressions.zip(references).toMap - val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap)) - UnsafeProjection.create( - child.output ++ patchedWindowExpression, - child.output) - } - - protected override def doExecute(): RDD[InternalRow] = { - // Unwrap the expressions and factories from the map. - val expressions = windowFrameExpressionFactoryPairs.flatMap(_._1) - val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray - - // Start processing. - child.execute().mapPartitions { stream => - new Iterator[InternalRow] { - - // Get all relevant projections. - val result = createResultProjection(expressions) - val grouping = UnsafeProjection.create(partitionSpec, child.output) - - // Manage the stream and the grouping. - var nextRow: UnsafeRow = null - var nextGroup: UnsafeRow = null - var nextRowAvailable: Boolean = false - private[this] def fetchNextRow() { - nextRowAvailable = stream.hasNext - if (nextRowAvailable) { - nextRow = stream.next().asInstanceOf[UnsafeRow] - nextGroup = grouping(nextRow) - } else { - nextRow = null - nextGroup = null - } - } - fetchNextRow() - - // Manage the current partition. - val rows = ArrayBuffer.empty[UnsafeRow] - val inputFields = child.output.length - var sorter: UnsafeExternalSorter = null - var rowBuffer: RowBuffer = null - val windowFunctionResult = new SpecificMutableRow(expressions.map(_.dataType)) - val frames = factories.map(_(windowFunctionResult)) - val numFrames = frames.length - private[this] def fetchNextPartition() { - // Collect all the rows in the current partition. - // Before we start to fetch new input rows, make a copy of nextGroup. - val currentGroup = nextGroup.copy() - - // clear last partition - if (sorter != null) { - // the last sorter of this task will be cleaned up via task completion listener - sorter.cleanupResources() - sorter = null - } else { - rows.clear() - } - - while (nextRowAvailable && nextGroup == currentGroup) { - if (sorter == null) { - rows += nextRow.copy() - - if (rows.length >= 4096) { - // We will not sort the rows, so prefixComparator and recordComparator are null. - sorter = UnsafeExternalSorter.create( - TaskContext.get().taskMemoryManager(), - SparkEnv.get.blockManager, - SparkEnv.get.serializerManager, - TaskContext.get(), - null, - null, - 1024, - SparkEnv.get.memoryManager.pageSizeBytes, - SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", - UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD), - false) - rows.foreach { r => - sorter.insertRecord(r.getBaseObject, r.getBaseOffset, r.getSizeInBytes, 0, false) - } - rows.clear() - } - } else { - sorter.insertRecord(nextRow.getBaseObject, nextRow.getBaseOffset, - nextRow.getSizeInBytes, 0, false) - } - fetchNextRow() - } - if (sorter != null) { - rowBuffer = new ExternalRowBuffer(sorter, inputFields) - } else { - rowBuffer = new ArrayRowBuffer(rows) - } - - // Setup the frames. - var i = 0 - while (i < numFrames) { - frames(i).prepare(rowBuffer.copy()) - i += 1 - } - - // Setup iteration - rowIndex = 0 - rowsSize = rowBuffer.size() - } - - // Iteration - var rowIndex = 0 - var rowsSize = 0L - - override final def hasNext: Boolean = rowIndex < rowsSize || nextRowAvailable - - val join = new JoinedRow - override final def next(): InternalRow = { - // Load the next partition if we need to. - if (rowIndex >= rowsSize && nextRowAvailable) { - fetchNextPartition() - } - - if (rowIndex < rowsSize) { - // Get the results for the window frames. - var i = 0 - val current = rowBuffer.next() - while (i < numFrames) { - frames(i).write(rowIndex, current) - i += 1 - } - - // 'Merge' the input row with the window function result - join(current, windowFunctionResult) - rowIndex += 1 - - // Return the projection. - result(join) - } else throw new NoSuchElementException - } - } - } - } -} - -/** - * Function for comparing boundary values. - */ -private[execution] abstract class BoundOrdering { - def compare(inputRow: InternalRow, inputIndex: Int, outputRow: InternalRow, outputIndex: Int): Int -} - -/** - * Compare the input index to the bound of the output index. - */ -private[execution] final case class RowBoundOrdering(offset: Int) extends BoundOrdering { - override def compare( - inputRow: InternalRow, - inputIndex: Int, - outputRow: InternalRow, - outputIndex: Int): Int = - inputIndex - (outputIndex + offset) -} - -/** - * Compare the value of the input index to the value bound of the output index. - */ -private[execution] final case class RangeBoundOrdering( - ordering: Ordering[InternalRow], - current: Projection, - bound: Projection) extends BoundOrdering { - override def compare( - inputRow: InternalRow, - inputIndex: Int, - outputRow: InternalRow, - outputIndex: Int): Int = - ordering.compare(current(inputRow), bound(outputRow)) -} - -/** - * The interface of row buffer for a partition - */ -private[execution] abstract class RowBuffer { - - /** Number of rows. */ - def size(): Int - - /** Return next row in the buffer, null if no more left. */ - def next(): InternalRow - - /** Skip the next `n` rows. */ - def skip(n: Int): Unit - - /** Return a new RowBuffer that has the same rows. */ - def copy(): RowBuffer -} - -/** - * A row buffer based on ArrayBuffer (the number of rows is limited) - */ -private[execution] class ArrayRowBuffer(buffer: ArrayBuffer[UnsafeRow]) extends RowBuffer { - - private[this] var cursor: Int = -1 - - /** Number of rows. */ - def size(): Int = buffer.length - - /** Return next row in the buffer, null if no more left. */ - def next(): InternalRow = { - cursor += 1 - if (cursor < buffer.length) { - buffer(cursor) - } else { - null - } - } - - /** Skip the next `n` rows. */ - def skip(n: Int): Unit = { - cursor += n - } - - /** Return a new RowBuffer that has the same rows. */ - def copy(): RowBuffer = { - new ArrayRowBuffer(buffer) - } -} - -/** - * An external buffer of rows based on UnsafeExternalSorter - */ -private[execution] class ExternalRowBuffer(sorter: UnsafeExternalSorter, numFields: Int) - extends RowBuffer { - - private[this] val iter: UnsafeSorterIterator = sorter.getIterator - - private[this] val currentRow = new UnsafeRow(numFields) - - /** Number of rows. */ - def size(): Int = iter.getNumRecords() - - /** Return next row in the buffer, null if no more left. */ - def next(): InternalRow = { - if (iter.hasNext) { - iter.loadNext() - currentRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength) - currentRow - } else { - null - } - } - - /** Skip the next `n` rows. */ - def skip(n: Int): Unit = { - var i = 0 - while (i < n && iter.hasNext) { - iter.loadNext() - i += 1 - } - } - - /** Return a new RowBuffer that has the same rows. */ - def copy(): RowBuffer = { - new ExternalRowBuffer(sorter, numFields) - } -} - -/** - * A window function calculates the results of a number of window functions for a window frame. - * Before use a frame must be prepared by passing it all the rows in the current partition. After - * preparation the update method can be called to fill the output rows. - */ -private[execution] abstract class WindowFunctionFrame { - /** - * Prepare the frame for calculating the results for a partition. - * - * @param rows to calculate the frame results for. - */ - def prepare(rows: RowBuffer): Unit - - /** - * Write the current results to the target row. - */ - def write(index: Int, current: InternalRow): Unit -} - -/** - * The offset window frame calculates frames containing LEAD/LAG statements. - * - * @param target to write results to. - * @param ordinal the ordinal is the starting offset at which the results of the window frame get - * written into the (shared) target row. The result of the frame expression with - * index 'i' will be written to the 'ordinal' + 'i' position in the target row. - * @param expressions to shift a number of rows. - * @param inputSchema required for creating a projection. - * @param newMutableProjection function used to create the projection. - * @param offset by which rows get moved within a partition. - */ -private[execution] final class OffsetWindowFunctionFrame( - target: MutableRow, - ordinal: Int, - expressions: Array[OffsetWindowFunction], - inputSchema: Seq[Attribute], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection, - offset: Int) extends WindowFunctionFrame { - - /** Rows of the partition currently being processed. */ - private[this] var input: RowBuffer = null - - /** Index of the input row currently used for output. */ - private[this] var inputIndex = 0 - - /** - * Create the projection used when the offset row exists. - * Please note that this project always respect null input values (like PostgreSQL). - */ - private[this] val projection = { - // Collect the expressions and bind them. - val inputAttrs = inputSchema.map(_.withNullability(true)) - val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map { e => - BindReferences.bindReference(e.input, inputAttrs) - } - - // Create the projection. - newMutableProjection(boundExpressions, Nil).target(target) - } - - /** Create the projection used when the offset row DOES NOT exists. */ - private[this] val fillDefaultValue = { - // Collect the expressions and bind them. - val inputAttrs = inputSchema.map(_.withNullability(true)) - val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map { e => - if (e.default == null || e.default.foldable && e.default.eval() == null) { - // The default value is null. - Literal.create(null, e.dataType) - } else { - // The default value is an expression. - BindReferences.bindReference(e.default, inputAttrs) - } - } - - // Create the projection. - newMutableProjection(boundExpressions, Nil).target(target) - } - - override def prepare(rows: RowBuffer): Unit = { - input = rows - // drain the first few rows if offset is larger than zero - inputIndex = 0 - while (inputIndex < offset) { - input.next() - inputIndex += 1 - } - inputIndex = offset - } - - override def write(index: Int, current: InternalRow): Unit = { - if (inputIndex >= 0 && inputIndex < input.size) { - val r = input.next() - projection(r) - } else { - // Use default values since the offset row does not exist. - fillDefaultValue(current) - } - inputIndex += 1 - } -} - -/** - * The sliding window frame calculates frames with the following SQL form: - * ... BETWEEN 1 PRECEDING AND 1 FOLLOWING - * - * @param target to write results to. - * @param processor to calculate the row values with. - * @param lbound comparator used to identify the lower bound of an output row. - * @param ubound comparator used to identify the upper bound of an output row. - */ -private[execution] final class SlidingWindowFunctionFrame( - target: MutableRow, - processor: AggregateProcessor, - lbound: BoundOrdering, - ubound: BoundOrdering) extends WindowFunctionFrame { - - /** Rows of the partition currently being processed. */ - private[this] var input: RowBuffer = null - - /** The next row from `input`. */ - private[this] var nextRow: InternalRow = null - - /** The rows within current sliding window. */ - private[this] val buffer = new util.ArrayDeque[InternalRow]() - - /** - * Index of the first input row with a value greater than the upper bound of the current - * output row. - */ - private[this] var inputHighIndex = 0 - - /** - * Index of the first input row with a value equal to or greater than the lower bound of the - * current output row. - */ - private[this] var inputLowIndex = 0 - - /** Prepare the frame for calculating a new partition. Reset all variables. */ - override def prepare(rows: RowBuffer): Unit = { - input = rows - nextRow = rows.next() - inputHighIndex = 0 - inputLowIndex = 0 - buffer.clear() - } - - /** Write the frame columns for the current row to the given target row. */ - override def write(index: Int, current: InternalRow): Unit = { - var bufferUpdated = index == 0 - - // Add all rows to the buffer for which the input row value is equal to or less than - // the output row upper bound. - while (nextRow != null && ubound.compare(nextRow, inputHighIndex, current, index) <= 0) { - buffer.add(nextRow.copy()) - nextRow = input.next() - inputHighIndex += 1 - bufferUpdated = true - } - - // Drop all rows from the buffer for which the input row value is smaller than - // the output row lower bound. - while (!buffer.isEmpty && lbound.compare(buffer.peek(), inputLowIndex, current, index) < 0) { - buffer.remove() - inputLowIndex += 1 - bufferUpdated = true - } - - // Only recalculate and update when the buffer changes. - if (bufferUpdated) { - processor.initialize(input.size) - val iter = buffer.iterator() - while (iter.hasNext) { - processor.update(iter.next()) - } - processor.evaluate(target) - } - } -} - -/** - * The unbounded window frame calculates frames with the following SQL forms: - * ... (No Frame Definition) - * ... BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING - * - * Its results are the same for each and every row in the partition. This class can be seen as a - * special case of a sliding window, but is optimized for the unbound case. - * - * @param target to write results to. - * @param processor to calculate the row values with. - */ -private[execution] final class UnboundedWindowFunctionFrame( - target: MutableRow, - processor: AggregateProcessor) extends WindowFunctionFrame { - - /** Prepare the frame for calculating a new partition. Process all rows eagerly. */ - override def prepare(rows: RowBuffer): Unit = { - val size = rows.size() - processor.initialize(size) - var i = 0 - while (i < size) { - processor.update(rows.next()) - i += 1 - } - } - - /** Write the frame columns for the current row to the given target row. */ - override def write(index: Int, current: InternalRow): Unit = { - // Unfortunately we cannot assume that evaluation is deterministic. So we need to re-evaluate - // for each row. - processor.evaluate(target) - } -} - -/** - * The UnboundPreceding window frame calculates frames with the following SQL form: - * ... BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW - * - * There is only an upper bound. Very common use cases are for instance running sums or counts - * (row_number). Technically this is a special case of a sliding window. However a sliding window - * has to maintain a buffer, and it must do a full evaluation everytime the buffer changes. This - * is not the case when there is no lower bound, given the additive nature of most aggregates - * streaming updates and partial evaluation suffice and no buffering is needed. - * - * @param target to write results to. - * @param processor to calculate the row values with. - * @param ubound comparator used to identify the upper bound of an output row. - */ -private[execution] final class UnboundedPrecedingWindowFunctionFrame( - target: MutableRow, - processor: AggregateProcessor, - ubound: BoundOrdering) extends WindowFunctionFrame { - - /** Rows of the partition currently being processed. */ - private[this] var input: RowBuffer = null - - /** The next row from `input`. */ - private[this] var nextRow: InternalRow = null - - /** - * Index of the first input row with a value greater than the upper bound of the current - * output row. - */ - private[this] var inputIndex = 0 - - /** Prepare the frame for calculating a new partition. */ - override def prepare(rows: RowBuffer): Unit = { - input = rows - nextRow = rows.next() - inputIndex = 0 - processor.initialize(input.size) - } - - /** Write the frame columns for the current row to the given target row. */ - override def write(index: Int, current: InternalRow): Unit = { - var bufferUpdated = index == 0 - - // Add all rows to the aggregates for which the input row value is equal to or less than - // the output row upper bound. - while (nextRow != null && ubound.compare(nextRow, inputIndex, current, index) <= 0) { - processor.update(nextRow) - nextRow = input.next() - inputIndex += 1 - bufferUpdated = true - } - - // Only recalculate and update when the buffer changes. - if (bufferUpdated) { - processor.evaluate(target) - } - } -} - -/** - * The UnboundFollowing window frame calculates frames with the following SQL form: - * ... BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING - * - * There is only an upper bound. This is a slightly modified version of the sliding window. The - * sliding window operator has to check if both upper and the lower bound change when a new row - * gets processed, where as the unbounded following only has to check the lower bound. - * - * This is a very expensive operator to use, O(n * (n - 1) /2), because we need to maintain a - * buffer and must do full recalculation after each row. Reverse iteration would be possible, if - * the commutativity of the used window functions can be guaranteed. - * - * @param target to write results to. - * @param processor to calculate the row values with. - * @param lbound comparator used to identify the lower bound of an output row. - */ -private[execution] final class UnboundedFollowingWindowFunctionFrame( - target: MutableRow, - processor: AggregateProcessor, - lbound: BoundOrdering) extends WindowFunctionFrame { - - /** Rows of the partition currently being processed. */ - private[this] var input: RowBuffer = null - - /** - * Index of the first input row with a value equal to or greater than the lower bound of the - * current output row. - */ - private[this] var inputIndex = 0 - - /** Prepare the frame for calculating a new partition. */ - override def prepare(rows: RowBuffer): Unit = { - input = rows - inputIndex = 0 - } - - /** Write the frame columns for the current row to the given target row. */ - override def write(index: Int, current: InternalRow): Unit = { - var bufferUpdated = index == 0 - - // Duplicate the input to have a new iterator - val tmp = input.copy() - - // Drop all rows from the buffer for which the input row value is smaller than - // the output row lower bound. - tmp.skip(inputIndex) - var nextRow = tmp.next() - while (nextRow != null && lbound.compare(nextRow, inputIndex, current, index) < 0) { - nextRow = tmp.next() - inputIndex += 1 - bufferUpdated = true - } - - // Only recalculate and update when the buffer changes. - if (bufferUpdated) { - processor.initialize(input.size) - while (nextRow != null) { - processor.update(nextRow) - nextRow = tmp.next() - } - processor.evaluate(target) - } - } -} - -/** - * This class prepares and manages the processing of a number of [[AggregateFunction]]s within a - * single frame. The [[WindowFunctionFrame]] takes care of processing the frame in the correct way, - * this reduces the processing of a [[AggregateWindowFunction]] to processing the underlying - * [[AggregateFunction]]. All [[AggregateFunction]]s are processed in [[Complete]] mode. - * - * [[SizeBasedWindowFunction]]s are initialized in a slightly different way. These functions - * require the size of the partition processed, this value is exposed to them when the processor is - * constructed. - * - * Processing of distinct aggregates is currently not supported. - * - * The implementation is split into an object which takes care of construction, and a the actual - * processor class. - */ -private[execution] object AggregateProcessor { - def apply( - functions: Array[Expression], - ordinal: Int, - inputAttributes: Seq[Attribute], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection): - AggregateProcessor = { - val aggBufferAttributes = mutable.Buffer.empty[AttributeReference] - val initialValues = mutable.Buffer.empty[Expression] - val updateExpressions = mutable.Buffer.empty[Expression] - val evaluateExpressions = mutable.Buffer.fill[Expression](ordinal)(NoOp) - val imperatives = mutable.Buffer.empty[ImperativeAggregate] - - // SPARK-14244: `SizeBasedWindowFunction`s are firstly created on driver side and then - // serialized to executor side. These functions all reference a global singleton window - // partition size attribute reference, i.e., `SizeBasedWindowFunction.n`. Here we must collect - // the singleton instance created on driver side instead of using executor side - // `SizeBasedWindowFunction.n` to avoid binding failure caused by mismatching expression ID. - val partitionSize: Option[AttributeReference] = { - val aggs = functions.flatMap(_.collectFirst { case f: SizeBasedWindowFunction => f }) - aggs.headOption.map(_.n) - } - - // Check if there are any SizeBasedWindowFunctions. If there are, we add the partition size to - // the aggregation buffer. Note that the ordinal of the partition size value will always be 0. - partitionSize.foreach { n => - aggBufferAttributes += n - initialValues += NoOp - updateExpressions += NoOp - } - - // Add an AggregateFunction to the AggregateProcessor. - functions.foreach { - case agg: DeclarativeAggregate => - aggBufferAttributes ++= agg.aggBufferAttributes - initialValues ++= agg.initialValues - updateExpressions ++= agg.updateExpressions - evaluateExpressions += agg.evaluateExpression - case agg: ImperativeAggregate => - val offset = aggBufferAttributes.size - val imperative = BindReferences.bindReference(agg - .withNewInputAggBufferOffset(offset) - .withNewMutableAggBufferOffset(offset), - inputAttributes) - imperatives += imperative - aggBufferAttributes ++= imperative.aggBufferAttributes - val noOps = Seq.fill(imperative.aggBufferAttributes.size)(NoOp) - initialValues ++= noOps - updateExpressions ++= noOps - evaluateExpressions += imperative - case other => - sys.error(s"Unsupported Aggregate Function: $other") - } - - // Create the projections. - val initialProjection = newMutableProjection( - initialValues, - partitionSize.toSeq) - val updateProjection = newMutableProjection( - updateExpressions, - aggBufferAttributes ++ inputAttributes) - val evaluateProjection = newMutableProjection( - evaluateExpressions, - aggBufferAttributes) - - // Create the processor - new AggregateProcessor( - aggBufferAttributes.toArray, - initialProjection, - updateProjection, - evaluateProjection, - imperatives.toArray, - partitionSize.isDefined) - } -} - -/** - * This class manages the processing of a number of aggregate functions. See the documentation of - * the object for more information. - */ -private[execution] final class AggregateProcessor( - private[this] val bufferSchema: Array[AttributeReference], - private[this] val initialProjection: MutableProjection, - private[this] val updateProjection: MutableProjection, - private[this] val evaluateProjection: MutableProjection, - private[this] val imperatives: Array[ImperativeAggregate], - private[this] val trackPartitionSize: Boolean) { - - private[this] val join = new JoinedRow - private[this] val numImperatives = imperatives.length - private[this] val buffer = new SpecificMutableRow(bufferSchema.toSeq.map(_.dataType)) - initialProjection.target(buffer) - updateProjection.target(buffer) - - /** Create the initial state. */ - def initialize(size: Int): Unit = { - // Some initialization expressions are dependent on the partition size so we have to - // initialize the size before initializing all other fields, and we have to pass the buffer to - // the initialization projection. - if (trackPartitionSize) { - buffer.setInt(0, size) - } - initialProjection(buffer) - var i = 0 - while (i < numImperatives) { - imperatives(i).initialize(buffer) - i += 1 - } - } - - /** Update the buffer. */ - def update(input: InternalRow): Unit = { - updateProjection(join(buffer, input)) - var i = 0 - while (i < numImperatives) { - imperatives(i).update(buffer, input) - i += 1 - } - } - - /** Evaluate buffer. */ - def evaluate(target: MutableRow): Unit = - evaluateProjection.target(target)(buffer) -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala new file mode 100644 index 0000000000000..d3a46d020dbbf --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.window + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ + + +/** + * This class prepares and manages the processing of a number of [[AggregateFunction]]s within a + * single frame. The [[WindowFunctionFrame]] takes care of processing the frame in the correct way, + * this reduces the processing of a [[AggregateWindowFunction]] to processing the underlying + * [[AggregateFunction]]. All [[AggregateFunction]]s are processed in [[Complete]] mode. + * + * [[SizeBasedWindowFunction]]s are initialized in a slightly different way. These functions + * require the size of the partition processed, this value is exposed to them when the processor is + * constructed. + * + * Processing of distinct aggregates is currently not supported. + * + * The implementation is split into an object which takes care of construction, and a the actual + * processor class. + */ +private[window] object AggregateProcessor { + def apply( + functions: Array[Expression], + ordinal: Int, + inputAttributes: Seq[Attribute], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection) + : AggregateProcessor = { + val aggBufferAttributes = mutable.Buffer.empty[AttributeReference] + val initialValues = mutable.Buffer.empty[Expression] + val updateExpressions = mutable.Buffer.empty[Expression] + val evaluateExpressions = mutable.Buffer.fill[Expression](ordinal)(NoOp) + val imperatives = mutable.Buffer.empty[ImperativeAggregate] + + // SPARK-14244: `SizeBasedWindowFunction`s are firstly created on driver side and then + // serialized to executor side. These functions all reference a global singleton window + // partition size attribute reference, i.e., `SizeBasedWindowFunction.n`. Here we must collect + // the singleton instance created on driver side instead of using executor side + // `SizeBasedWindowFunction.n` to avoid binding failure caused by mismatching expression ID. + val partitionSize: Option[AttributeReference] = { + val aggs = functions.flatMap(_.collectFirst { case f: SizeBasedWindowFunction => f }) + aggs.headOption.map(_.n) + } + + // Check if there are any SizeBasedWindowFunctions. If there are, we add the partition size to + // the aggregation buffer. Note that the ordinal of the partition size value will always be 0. + partitionSize.foreach { n => + aggBufferAttributes += n + initialValues += NoOp + updateExpressions += NoOp + } + + // Add an AggregateFunction to the AggregateProcessor. + functions.foreach { + case agg: DeclarativeAggregate => + aggBufferAttributes ++= agg.aggBufferAttributes + initialValues ++= agg.initialValues + updateExpressions ++= agg.updateExpressions + evaluateExpressions += agg.evaluateExpression + case agg: ImperativeAggregate => + val offset = aggBufferAttributes.size + val imperative = BindReferences.bindReference(agg + .withNewInputAggBufferOffset(offset) + .withNewMutableAggBufferOffset(offset), + inputAttributes) + imperatives += imperative + aggBufferAttributes ++= imperative.aggBufferAttributes + val noOps = Seq.fill(imperative.aggBufferAttributes.size)(NoOp) + initialValues ++= noOps + updateExpressions ++= noOps + evaluateExpressions += imperative + case other => + sys.error(s"Unsupported Aggregate Function: $other") + } + + // Create the projections. + val initialProj = newMutableProjection(initialValues, partitionSize.toSeq) + val updateProj = newMutableProjection(updateExpressions, aggBufferAttributes ++ inputAttributes) + val evalProj = newMutableProjection(evaluateExpressions, aggBufferAttributes) + + // Create the processor + new AggregateProcessor( + aggBufferAttributes.toArray, + initialProj, + updateProj, + evalProj, + imperatives.toArray, + partitionSize.isDefined) + } +} + +/** + * This class manages the processing of a number of aggregate functions. See the documentation of + * the object for more information. + */ +private[window] final class AggregateProcessor( + private[this] val bufferSchema: Array[AttributeReference], + private[this] val initialProjection: MutableProjection, + private[this] val updateProjection: MutableProjection, + private[this] val evaluateProjection: MutableProjection, + private[this] val imperatives: Array[ImperativeAggregate], + private[this] val trackPartitionSize: Boolean) { + + private[this] val join = new JoinedRow + private[this] val numImperatives = imperatives.length + private[this] val buffer = new SpecificMutableRow(bufferSchema.toSeq.map(_.dataType)) + initialProjection.target(buffer) + updateProjection.target(buffer) + + /** Create the initial state. */ + def initialize(size: Int): Unit = { + // Some initialization expressions are dependent on the partition size so we have to + // initialize the size before initializing all other fields, and we have to pass the buffer to + // the initialization projection. + if (trackPartitionSize) { + buffer.setInt(0, size) + } + initialProjection(buffer) + var i = 0 + while (i < numImperatives) { + imperatives(i).initialize(buffer) + i += 1 + } + } + + /** Update the buffer. */ + def update(input: InternalRow): Unit = { + updateProjection(join(buffer, input)) + var i = 0 + while (i < numImperatives) { + imperatives(i).update(buffer, input) + i += 1 + } + } + + /** Evaluate buffer. */ + def evaluate(target: MutableRow): Unit = + evaluateProjection.target(target)(buffer) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/BoundOrdering.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/BoundOrdering.scala new file mode 100644 index 0000000000000..d6a801954c1ac --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/BoundOrdering.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.window + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Projection + + +/** + * Function for comparing boundary values. + */ +private[window] abstract class BoundOrdering { + def compare(inputRow: InternalRow, inputIndex: Int, outputRow: InternalRow, outputIndex: Int): Int +} + +/** + * Compare the input index to the bound of the output index. + */ +private[window] final case class RowBoundOrdering(offset: Int) extends BoundOrdering { + override def compare( + inputRow: InternalRow, + inputIndex: Int, + outputRow: InternalRow, + outputIndex: Int): Int = + inputIndex - (outputIndex + offset) +} + +/** + * Compare the value of the input index to the value bound of the output index. + */ +private[window] final case class RangeBoundOrdering( + ordering: Ordering[InternalRow], + current: Projection, + bound: Projection) + extends BoundOrdering { + + override def compare( + inputRow: InternalRow, + inputIndex: Int, + outputRow: InternalRow, + outputIndex: Int): Int = + ordering.compare(current(inputRow), bound(outputRow)) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/RowBuffer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/RowBuffer.scala new file mode 100644 index 0000000000000..ee36c84251519 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/RowBuffer.scala @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.window + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, UnsafeSorterIterator} + + +/** + * The interface of row buffer for a partition. In absence of a buffer pool (with locking), the + * row buffer is used to materialize a partition of rows since we need to repeatedly scan these + * rows in window function processing. + */ +private[window] abstract class RowBuffer { + + /** Number of rows. */ + def size: Int + + /** Return next row in the buffer, null if no more left. */ + def next(): InternalRow + + /** Skip the next `n` rows. */ + def skip(n: Int): Unit + + /** Return a new RowBuffer that has the same rows. */ + def copy(): RowBuffer +} + +/** + * A row buffer based on ArrayBuffer (the number of rows is limited). + */ +private[window] class ArrayRowBuffer(buffer: ArrayBuffer[UnsafeRow]) extends RowBuffer { + + private[this] var cursor: Int = -1 + + /** Number of rows. */ + override def size: Int = buffer.length + + /** Return next row in the buffer, null if no more left. */ + override def next(): InternalRow = { + cursor += 1 + if (cursor < buffer.length) { + buffer(cursor) + } else { + null + } + } + + /** Skip the next `n` rows. */ + override def skip(n: Int): Unit = { + cursor += n + } + + /** Return a new RowBuffer that has the same rows. */ + override def copy(): RowBuffer = { + new ArrayRowBuffer(buffer) + } +} + +/** + * An external buffer of rows based on UnsafeExternalSorter. + */ +private[window] class ExternalRowBuffer(sorter: UnsafeExternalSorter, numFields: Int) + extends RowBuffer { + + private[this] val iter: UnsafeSorterIterator = sorter.getIterator + + private[this] val currentRow = new UnsafeRow(numFields) + + /** Number of rows. */ + override def size: Int = iter.getNumRecords() + + /** Return next row in the buffer, null if no more left. */ + override def next(): InternalRow = { + if (iter.hasNext) { + iter.loadNext() + currentRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength) + currentRow + } else { + null + } + } + + /** Skip the next `n` rows. */ + override def skip(n: Int): Unit = { + var i = 0 + while (i < n && iter.hasNext) { + iter.loadNext() + i += 1 + } + } + + /** Return a new RowBuffer that has the same rows. */ + override def copy(): RowBuffer = { + new ExternalRowBuffer(sorter, numFields) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala new file mode 100644 index 0000000000000..7a6a30f120386 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala @@ -0,0 +1,412 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.window + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.types.IntegerType +import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter + +/** + * This class calculates and outputs (windowed) aggregates over the rows in a single (sorted) + * partition. The aggregates are calculated for each row in the group. Special processing + * instructions, frames, are used to calculate these aggregates. Frames are processed in the order + * specified in the window specification (the ORDER BY ... clause). There are four different frame + * types: + * - Entire partition: The frame is the entire partition, i.e. + * UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING. For this case, window function will take all + * rows as inputs and be evaluated once. + * - Growing frame: We only add new rows into the frame, i.e. UNBOUNDED PRECEDING AND .... + * Every time we move to a new row to process, we add some rows to the frame. We do not remove + * rows from this frame. + * - Shrinking frame: We only remove rows from the frame, i.e. ... AND UNBOUNDED FOLLOWING. + * Every time we move to a new row to process, we remove some rows from the frame. We do not add + * rows to this frame. + * - Moving frame: Every time we move to a new row to process, we remove some rows from the frame + * and we add some rows to the frame. Examples are: + * 1 PRECEDING AND CURRENT ROW and 1 FOLLOWING AND 2 FOLLOWING. + * - Offset frame: The frame consist of one row, which is an offset number of rows away from the + * current row. Only [[OffsetWindowFunction]]s can be processed in an offset frame. + * + * Different frame boundaries can be used in Growing, Shrinking and Moving frames. A frame + * boundary can be either Row or Range based: + * - Row Based: A row based boundary is based on the position of the row within the partition. + * An offset indicates the number of rows above or below the current row, the frame for the + * current row starts or ends. For instance, given a row based sliding frame with a lower bound + * offset of -1 and a upper bound offset of +2. The frame for row with index 5 would range from + * index 4 to index 6. + * - Range based: A range based boundary is based on the actual value of the ORDER BY + * expression(s). An offset is used to alter the value of the ORDER BY expression, for + * instance if the current order by expression has a value of 10 and the lower bound offset + * is -3, the resulting lower bound for the current row will be 10 - 3 = 7. This however puts a + * number of constraints on the ORDER BY expressions: there can be only one expression and this + * expression must have a numerical data type. An exception can be made when the offset is 0, + * because no value modification is needed, in this case multiple and non-numeric ORDER BY + * expression are allowed. + * + * This is quite an expensive operator because every row for a single group must be in the same + * partition and partitions must be sorted according to the grouping and sort order. The operator + * requires the planner to take care of the partitioning and sorting. + * + * The operator is semi-blocking. The window functions and aggregates are calculated one group at + * a time, the result will only be made available after the processing for the entire group has + * finished. The operator is able to process different frame configurations at the same time. This + * is done by delegating the actual frame processing (i.e. calculation of the window functions) to + * specialized classes, see [[WindowFunctionFrame]], which take care of their own frame type: + * Entire Partition, Sliding, Growing & Shrinking. Boundary evaluation is also delegated to a pair + * of specialized classes: [[RowBoundOrdering]] & [[RangeBoundOrdering]]. + */ +case class WindowExec( + windowExpression: Seq[NamedExpression], + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], + child: SparkPlan) + extends UnaryExecNode { + + override def output: Seq[Attribute] = + child.output ++ windowExpression.map(_.toAttribute) + + override def requiredChildDistribution: Seq[Distribution] = { + if (partitionSpec.isEmpty) { + // Only show warning when the number of bytes is larger than 100 MB? + logWarning("No Partition Defined for Window operation! Moving all data to a single " + + "partition, this can cause serious performance degradation.") + AllTuples :: Nil + } else ClusteredDistribution(partitionSpec) :: Nil + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec) + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + /** + * Create a bound ordering object for a given frame type and offset. A bound ordering object is + * used to determine which input row lies within the frame boundaries of an output row. + * + * This method uses Code Generation. It can only be used on the executor side. + * + * @param frameType to evaluate. This can either be Row or Range based. + * @param offset with respect to the row. + * @return a bound ordering object. + */ + private[this] def createBoundOrdering(frameType: FrameType, offset: Int): BoundOrdering = { + frameType match { + case RangeFrame => + val (exprs, current, bound) = if (offset == 0) { + // Use the entire order expression when the offset is 0. + val exprs = orderSpec.map(_.child) + val buildProjection = () => newMutableProjection(exprs, child.output) + (orderSpec, buildProjection(), buildProjection()) + } else if (orderSpec.size == 1) { + // Use only the first order expression when the offset is non-null. + val sortExpr = orderSpec.head + val expr = sortExpr.child + // Create the projection which returns the current 'value'. + val current = newMutableProjection(expr :: Nil, child.output) + // Flip the sign of the offset when processing the order is descending + val boundOffset = sortExpr.direction match { + case Descending => -offset + case Ascending => offset + } + // Create the projection which returns the current 'value' modified by adding the offset. + val boundExpr = Add(expr, Cast(Literal.create(boundOffset, IntegerType), expr.dataType)) + val bound = newMutableProjection(boundExpr :: Nil, child.output) + (sortExpr :: Nil, current, bound) + } else { + sys.error("Non-Zero range offsets are not supported for windows " + + "with multiple order expressions.") + } + // Construct the ordering. This is used to compare the result of current value projection + // to the result of bound value projection. This is done manually because we want to use + // Code Generation (if it is enabled). + val sortExprs = exprs.zipWithIndex.map { case (e, i) => + SortOrder(BoundReference(i, e.dataType, e.nullable), e.direction) + } + val ordering = newOrdering(sortExprs, Nil) + RangeBoundOrdering(ordering, current, bound) + case RowFrame => RowBoundOrdering(offset) + } + } + + /** + * Collection containing an entry for each window frame to process. Each entry contains a frames' + * WindowExpressions and factory function for the WindowFrameFunction. + */ + private[this] lazy val windowFrameExpressionFactoryPairs = { + type FrameKey = (String, FrameType, Option[Int], Option[Int]) + type ExpressionBuffer = mutable.Buffer[Expression] + val framedFunctions = mutable.Map.empty[FrameKey, (ExpressionBuffer, ExpressionBuffer)] + + // Add a function and its function to the map for a given frame. + def collect(tpe: String, fr: SpecifiedWindowFrame, e: Expression, fn: Expression): Unit = { + val key = (tpe, fr.frameType, FrameBoundary(fr.frameStart), FrameBoundary(fr.frameEnd)) + val (es, fns) = framedFunctions.getOrElseUpdate( + key, (ArrayBuffer.empty[Expression], ArrayBuffer.empty[Expression])) + es += e + fns += fn + } + + // Collect all valid window functions and group them by their frame. + windowExpression.foreach { x => + x.foreach { + case e @ WindowExpression(function, spec) => + val frame = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] + function match { + case AggregateExpression(f, _, _, _) => collect("AGGREGATE", frame, e, f) + case f: AggregateWindowFunction => collect("AGGREGATE", frame, e, f) + case f: OffsetWindowFunction => collect("OFFSET", frame, e, f) + case f => sys.error(s"Unsupported window function: $f") + } + case _ => + } + } + + // Map the groups to a (unbound) expression and frame factory pair. + var numExpressions = 0 + framedFunctions.toSeq.map { + case (key, (expressions, functionSeq)) => + val ordinal = numExpressions + val functions = functionSeq.toArray + + // Construct an aggregate processor if we need one. + def processor = AggregateProcessor( + functions, + ordinal, + child.output, + (expressions, schema) => + newMutableProjection(expressions, schema, subexpressionEliminationEnabled)) + + // Create the factory + val factory = key match { + // Offset Frame + case ("OFFSET", RowFrame, Some(offset), Some(h)) if offset == h => + target: MutableRow => + new OffsetWindowFunctionFrame( + target, + ordinal, + // OFFSET frame functions are guaranteed be OffsetWindowFunctions. + functions.map(_.asInstanceOf[OffsetWindowFunction]), + child.output, + (expressions, schema) => + newMutableProjection(expressions, schema, subexpressionEliminationEnabled), + offset) + + // Growing Frame. + case ("AGGREGATE", frameType, None, Some(high)) => + target: MutableRow => { + new UnboundedPrecedingWindowFunctionFrame( + target, + processor, + createBoundOrdering(frameType, high)) + } + + // Shrinking Frame. + case ("AGGREGATE", frameType, Some(low), None) => + target: MutableRow => { + new UnboundedFollowingWindowFunctionFrame( + target, + processor, + createBoundOrdering(frameType, low)) + } + + // Moving Frame. + case ("AGGREGATE", frameType, Some(low), Some(high)) => + target: MutableRow => { + new SlidingWindowFunctionFrame( + target, + processor, + createBoundOrdering(frameType, low), + createBoundOrdering(frameType, high)) + } + + // Entire Partition Frame. + case ("AGGREGATE", frameType, None, None) => + target: MutableRow => { + new UnboundedWindowFunctionFrame(target, processor) + } + } + + // Keep track of the number of expressions. This is a side-effect in a map... + numExpressions += expressions.size + + // Create the Frame Expression - Factory pair. + (expressions, factory) + } + } + + /** + * Create the resulting projection. + * + * This method uses Code Generation. It can only be used on the executor side. + * + * @param expressions unbound ordered function expressions. + * @return the final resulting projection. + */ + private[this] def createResultProjection(expressions: Seq[Expression]): UnsafeProjection = { + val references = expressions.zipWithIndex.map{ case (e, i) => + // Results of window expressions will be on the right side of child's output + BoundReference(child.output.size + i, e.dataType, e.nullable) + } + val unboundToRefMap = expressions.zip(references).toMap + val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap)) + UnsafeProjection.create( + child.output ++ patchedWindowExpression, + child.output) + } + + protected override def doExecute(): RDD[InternalRow] = { + // Unwrap the expressions and factories from the map. + val expressions = windowFrameExpressionFactoryPairs.flatMap(_._1) + val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray + + // Start processing. + child.execute().mapPartitions { stream => + new Iterator[InternalRow] { + + // Get all relevant projections. + val result = createResultProjection(expressions) + val grouping = UnsafeProjection.create(partitionSpec, child.output) + + // Manage the stream and the grouping. + var nextRow: UnsafeRow = null + var nextGroup: UnsafeRow = null + var nextRowAvailable: Boolean = false + private[this] def fetchNextRow() { + nextRowAvailable = stream.hasNext + if (nextRowAvailable) { + nextRow = stream.next().asInstanceOf[UnsafeRow] + nextGroup = grouping(nextRow) + } else { + nextRow = null + nextGroup = null + } + } + fetchNextRow() + + // Manage the current partition. + val rows = ArrayBuffer.empty[UnsafeRow] + val inputFields = child.output.length + var sorter: UnsafeExternalSorter = null + var rowBuffer: RowBuffer = null + val windowFunctionResult = new SpecificMutableRow(expressions.map(_.dataType)) + val frames = factories.map(_(windowFunctionResult)) + val numFrames = frames.length + private[this] def fetchNextPartition() { + // Collect all the rows in the current partition. + // Before we start to fetch new input rows, make a copy of nextGroup. + val currentGroup = nextGroup.copy() + + // clear last partition + if (sorter != null) { + // the last sorter of this task will be cleaned up via task completion listener + sorter.cleanupResources() + sorter = null + } else { + rows.clear() + } + + while (nextRowAvailable && nextGroup == currentGroup) { + if (sorter == null) { + rows += nextRow.copy() + + if (rows.length >= 4096) { + // We will not sort the rows, so prefixComparator and recordComparator are null. + sorter = UnsafeExternalSorter.create( + TaskContext.get().taskMemoryManager(), + SparkEnv.get.blockManager, + SparkEnv.get.serializerManager, + TaskContext.get(), + null, + null, + 1024, + SparkEnv.get.memoryManager.pageSizeBytes, + SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", + UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD), + false) + rows.foreach { r => + sorter.insertRecord(r.getBaseObject, r.getBaseOffset, r.getSizeInBytes, 0, false) + } + rows.clear() + } + } else { + sorter.insertRecord(nextRow.getBaseObject, nextRow.getBaseOffset, + nextRow.getSizeInBytes, 0, false) + } + fetchNextRow() + } + if (sorter != null) { + rowBuffer = new ExternalRowBuffer(sorter, inputFields) + } else { + rowBuffer = new ArrayRowBuffer(rows) + } + + // Setup the frames. + var i = 0 + while (i < numFrames) { + frames(i).prepare(rowBuffer.copy()) + i += 1 + } + + // Setup iteration + rowIndex = 0 + rowsSize = rowBuffer.size + } + + // Iteration + var rowIndex = 0 + var rowsSize = 0L + + override final def hasNext: Boolean = rowIndex < rowsSize || nextRowAvailable + + val join = new JoinedRow + override final def next(): InternalRow = { + // Load the next partition if we need to. + if (rowIndex >= rowsSize && nextRowAvailable) { + fetchNextPartition() + } + + if (rowIndex < rowsSize) { + // Get the results for the window frames. + var i = 0 + val current = rowBuffer.next() + while (i < numFrames) { + frames(i).write(rowIndex, current) + i += 1 + } + + // 'Merge' the input row with the window function result + join(current, windowFunctionResult) + rowIndex += 1 + + // Return the projection. + result(join) + } else throw new NoSuchElementException + } + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala new file mode 100644 index 0000000000000..2ab9faab7a59b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala @@ -0,0 +1,367 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.window + +import java.util + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp + + +/** + * A window function calculates the results of a number of window functions for a window frame. + * Before use a frame must be prepared by passing it all the rows in the current partition. After + * preparation the update method can be called to fill the output rows. + */ +private[window] abstract class WindowFunctionFrame { + /** + * Prepare the frame for calculating the results for a partition. + * + * @param rows to calculate the frame results for. + */ + def prepare(rows: RowBuffer): Unit + + /** + * Write the current results to the target row. + */ + def write(index: Int, current: InternalRow): Unit +} + +/** + * The offset window frame calculates frames containing LEAD/LAG statements. + * + * @param target to write results to. + * @param ordinal the ordinal is the starting offset at which the results of the window frame get + * written into the (shared) target row. The result of the frame expression with + * index 'i' will be written to the 'ordinal' + 'i' position in the target row. + * @param expressions to shift a number of rows. + * @param inputSchema required for creating a projection. + * @param newMutableProjection function used to create the projection. + * @param offset by which rows get moved within a partition. + */ +private[window] final class OffsetWindowFunctionFrame( + target: MutableRow, + ordinal: Int, + expressions: Array[OffsetWindowFunction], + inputSchema: Seq[Attribute], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection, + offset: Int) + extends WindowFunctionFrame { + + /** Rows of the partition currently being processed. */ + private[this] var input: RowBuffer = null + + /** Index of the input row currently used for output. */ + private[this] var inputIndex = 0 + + /** + * Create the projection used when the offset row exists. + * Please note that this project always respect null input values (like PostgreSQL). + */ + private[this] val projection = { + // Collect the expressions and bind them. + val inputAttrs = inputSchema.map(_.withNullability(true)) + val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map { e => + BindReferences.bindReference(e.input, inputAttrs) + } + + // Create the projection. + newMutableProjection(boundExpressions, Nil).target(target) + } + + /** Create the projection used when the offset row DOES NOT exists. */ + private[this] val fillDefaultValue = { + // Collect the expressions and bind them. + val inputAttrs = inputSchema.map(_.withNullability(true)) + val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map { e => + if (e.default == null || e.default.foldable && e.default.eval() == null) { + // The default value is null. + Literal.create(null, e.dataType) + } else { + // The default value is an expression. + BindReferences.bindReference(e.default, inputAttrs) + } + } + + // Create the projection. + newMutableProjection(boundExpressions, Nil).target(target) + } + + override def prepare(rows: RowBuffer): Unit = { + input = rows + // drain the first few rows if offset is larger than zero + inputIndex = 0 + while (inputIndex < offset) { + input.next() + inputIndex += 1 + } + inputIndex = offset + } + + override def write(index: Int, current: InternalRow): Unit = { + if (inputIndex >= 0 && inputIndex < input.size) { + val r = input.next() + projection(r) + } else { + // Use default values since the offset row does not exist. + fillDefaultValue(current) + } + inputIndex += 1 + } +} + +/** + * The sliding window frame calculates frames with the following SQL form: + * ... BETWEEN 1 PRECEDING AND 1 FOLLOWING + * + * @param target to write results to. + * @param processor to calculate the row values with. + * @param lbound comparator used to identify the lower bound of an output row. + * @param ubound comparator used to identify the upper bound of an output row. + */ +private[window] final class SlidingWindowFunctionFrame( + target: MutableRow, + processor: AggregateProcessor, + lbound: BoundOrdering, + ubound: BoundOrdering) + extends WindowFunctionFrame { + + /** Rows of the partition currently being processed. */ + private[this] var input: RowBuffer = null + + /** The next row from `input`. */ + private[this] var nextRow: InternalRow = null + + /** The rows within current sliding window. */ + private[this] val buffer = new util.ArrayDeque[InternalRow]() + + /** + * Index of the first input row with a value greater than the upper bound of the current + * output row. + */ + private[this] var inputHighIndex = 0 + + /** + * Index of the first input row with a value equal to or greater than the lower bound of the + * current output row. + */ + private[this] var inputLowIndex = 0 + + /** Prepare the frame for calculating a new partition. Reset all variables. */ + override def prepare(rows: RowBuffer): Unit = { + input = rows + nextRow = rows.next() + inputHighIndex = 0 + inputLowIndex = 0 + buffer.clear() + } + + /** Write the frame columns for the current row to the given target row. */ + override def write(index: Int, current: InternalRow): Unit = { + var bufferUpdated = index == 0 + + // Add all rows to the buffer for which the input row value is equal to or less than + // the output row upper bound. + while (nextRow != null && ubound.compare(nextRow, inputHighIndex, current, index) <= 0) { + buffer.add(nextRow.copy()) + nextRow = input.next() + inputHighIndex += 1 + bufferUpdated = true + } + + // Drop all rows from the buffer for which the input row value is smaller than + // the output row lower bound. + while (!buffer.isEmpty && lbound.compare(buffer.peek(), inputLowIndex, current, index) < 0) { + buffer.remove() + inputLowIndex += 1 + bufferUpdated = true + } + + // Only recalculate and update when the buffer changes. + if (bufferUpdated) { + processor.initialize(input.size) + val iter = buffer.iterator() + while (iter.hasNext) { + processor.update(iter.next()) + } + processor.evaluate(target) + } + } +} + +/** + * The unbounded window frame calculates frames with the following SQL forms: + * ... (No Frame Definition) + * ... BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + * + * Its results are the same for each and every row in the partition. This class can be seen as a + * special case of a sliding window, but is optimized for the unbound case. + * + * @param target to write results to. + * @param processor to calculate the row values with. + */ +private[window] final class UnboundedWindowFunctionFrame( + target: MutableRow, + processor: AggregateProcessor) + extends WindowFunctionFrame { + + /** Prepare the frame for calculating a new partition. Process all rows eagerly. */ + override def prepare(rows: RowBuffer): Unit = { + val size = rows.size + processor.initialize(size) + var i = 0 + while (i < size) { + processor.update(rows.next()) + i += 1 + } + } + + /** Write the frame columns for the current row to the given target row. */ + override def write(index: Int, current: InternalRow): Unit = { + // Unfortunately we cannot assume that evaluation is deterministic. So we need to re-evaluate + // for each row. + processor.evaluate(target) + } +} + +/** + * The UnboundPreceding window frame calculates frames with the following SQL form: + * ... BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + * + * There is only an upper bound. Very common use cases are for instance running sums or counts + * (row_number). Technically this is a special case of a sliding window. However a sliding window + * has to maintain a buffer, and it must do a full evaluation everytime the buffer changes. This + * is not the case when there is no lower bound, given the additive nature of most aggregates + * streaming updates and partial evaluation suffice and no buffering is needed. + * + * @param target to write results to. + * @param processor to calculate the row values with. + * @param ubound comparator used to identify the upper bound of an output row. + */ +private[window] final class UnboundedPrecedingWindowFunctionFrame( + target: MutableRow, + processor: AggregateProcessor, + ubound: BoundOrdering) + extends WindowFunctionFrame { + + /** Rows of the partition currently being processed. */ + private[this] var input: RowBuffer = null + + /** The next row from `input`. */ + private[this] var nextRow: InternalRow = null + + /** + * Index of the first input row with a value greater than the upper bound of the current + * output row. + */ + private[this] var inputIndex = 0 + + /** Prepare the frame for calculating a new partition. */ + override def prepare(rows: RowBuffer): Unit = { + input = rows + nextRow = rows.next() + inputIndex = 0 + processor.initialize(input.size) + } + + /** Write the frame columns for the current row to the given target row. */ + override def write(index: Int, current: InternalRow): Unit = { + var bufferUpdated = index == 0 + + // Add all rows to the aggregates for which the input row value is equal to or less than + // the output row upper bound. + while (nextRow != null && ubound.compare(nextRow, inputIndex, current, index) <= 0) { + processor.update(nextRow) + nextRow = input.next() + inputIndex += 1 + bufferUpdated = true + } + + // Only recalculate and update when the buffer changes. + if (bufferUpdated) { + processor.evaluate(target) + } + } +} + +/** + * The UnboundFollowing window frame calculates frames with the following SQL form: + * ... BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING + * + * There is only an upper bound. This is a slightly modified version of the sliding window. The + * sliding window operator has to check if both upper and the lower bound change when a new row + * gets processed, where as the unbounded following only has to check the lower bound. + * + * This is a very expensive operator to use, O(n * (n - 1) /2), because we need to maintain a + * buffer and must do full recalculation after each row. Reverse iteration would be possible, if + * the commutativity of the used window functions can be guaranteed. + * + * @param target to write results to. + * @param processor to calculate the row values with. + * @param lbound comparator used to identify the lower bound of an output row. + */ +private[window] final class UnboundedFollowingWindowFunctionFrame( + target: MutableRow, + processor: AggregateProcessor, + lbound: BoundOrdering) + extends WindowFunctionFrame { + + /** Rows of the partition currently being processed. */ + private[this] var input: RowBuffer = null + + /** + * Index of the first input row with a value equal to or greater than the lower bound of the + * current output row. + */ + private[this] var inputIndex = 0 + + /** Prepare the frame for calculating a new partition. */ + override def prepare(rows: RowBuffer): Unit = { + input = rows + inputIndex = 0 + } + + /** Write the frame columns for the current row to the given target row. */ + override def write(index: Int, current: InternalRow): Unit = { + var bufferUpdated = index == 0 + + // Duplicate the input to have a new iterator + val tmp = input.copy() + + // Drop all rows from the buffer for which the input row value is smaller than + // the output row lower bound. + tmp.skip(inputIndex) + var nextRow = tmp.next() + while (nextRow != null && lbound.compare(nextRow, inputIndex, current, index) < 0) { + nextRow = tmp.next() + inputIndex += 1 + bufferUpdated = true + } + + // Only recalculate and update when the buffer changes. + if (bufferUpdated) { + processor.initialize(input.size) + while (nextRow != null) { + processor.update(nextRow) + nextRow = tmp.next() + } + processor.evaluate(target) + } + } +} From 2f84a686604b298537bfd4d087b41594d2aa7ec6 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 27 Sep 2016 14:14:27 -0700 Subject: [PATCH 26/96] [SPARK-17618] Guard against invalid comparisons between UnsafeRow and other formats This patch ports changes from #15185 to Spark 2.x. In that patch, a correctness bug in Spark 1.6.x which was caused by an invalid `equals()` comparison between an `UnsafeRow` and another row of a different format. Spark 2.x is not affected by that specific correctness bug but it can still reap the error-prevention benefits of that patch's changes, which modify ``UnsafeRow.equals()` to throw an IllegalArgumentException if it is called with an object that is not an `UnsafeRow`. Author: Josh Rosen Closes #15265 from JoshRosen/SPARK-17618-master. --- .../apache/spark/sql/catalyst/expressions/UnsafeRow.java | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index dd2f39eb816f2..9027652d57f14 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -31,6 +31,7 @@ import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; +import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; @@ -577,8 +578,12 @@ public boolean equals(Object other) { return (sizeInBytes == o.sizeInBytes) && ByteArrayMethods.arrayEquals(baseObject, baseOffset, o.baseObject, o.baseOffset, sizeInBytes); + } else if (!(other instanceof InternalRow)) { + return false; + } else { + throw new IllegalArgumentException( + "Cannot compare UnsafeRow to " + other.getClass().getName()); } - return false; } /** From e7bce9e1876de6ee975ccc89351db58119674aef Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 27 Sep 2016 16:00:39 -0700 Subject: [PATCH 27/96] [SPARK-17056][CORE] Fix a wrong assert regarding unroll memory in MemoryStore ## What changes were proposed in this pull request? There is an assert in MemoryStore's putIteratorAsValues method which is used to check if unroll memory is not released too much. This assert looks wrong. ## How was this patch tested? Jenkins tests. Author: Liang-Chi Hsieh Closes #14642 from viirya/fix-unroll-memory. --- .../scala/org/apache/spark/storage/memory/MemoryStore.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index 205d469f48144..095d32407f345 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -273,7 +273,7 @@ private[spark] class MemoryStore( blockId, Utils.bytesToString(size), Utils.bytesToString(maxMemory - blocksMemoryUsed))) Right(size) } else { - assert(currentUnrollMemoryForThisTask >= currentUnrollMemoryForThisTask, + assert(currentUnrollMemoryForThisTask >= unrollMemoryUsedByThisBlock, "released too much unroll memory") Left(new PartiallyUnrolledIterator( this, From b03b4adf6d8f4c6d92575c0947540cb474bf7de1 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 27 Sep 2016 17:52:57 -0700 Subject: [PATCH 28/96] [SPARK-17666] Ensure that RecordReaders are closed by data source file scans ## What changes were proposed in this pull request? This patch addresses a potential cause of resource leaks in data source file scans. As reported in [SPARK-17666](https://issues.apache.org/jira/browse/SPARK-17666), tasks which do not fully-consume their input may cause file handles / network connections (e.g. S3 connections) to be leaked. Spark's `NewHadoopRDD` uses a TaskContext callback to [close its record readers](https://github.com/apache/spark/blame/master/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala#L208), but the new data source file scans will only close record readers once their iterators are fully-consumed. This patch modifies `RecordReaderIterator` and `HadoopFileLinesReader` to add `close()` methods and modifies all six implementations of `FileFormat.buildReader()` to register TaskContext task completion callbacks to guarantee that cleanup is eventually performed. ## How was this patch tested? Tested manually for now. Author: Josh Rosen Closes #15245 from JoshRosen/SPARK-17666-close-recordreader. --- .../ml/source/libsvm/LibSVMRelation.scala | 7 +++++-- .../datasources/HadoopFileLinesReader.scala | 6 +++++- .../datasources/RecordReaderIterator.scala | 21 +++++++++++++++++-- .../datasources/csv/CSVFileFormat.scala | 5 ++++- .../datasources/json/JsonFileFormat.scala | 5 ++++- .../parquet/ParquetFileFormat.scala | 3 ++- .../datasources/text/TextFileFormat.scala | 2 ++ .../spark/sql/hive/orc/OrcFileFormat.scala | 6 +++++- 8 files changed, 46 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 5c79c6905801c..8577803743c8e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.io.{NullWritable, Text} import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat +import org.apache.spark.TaskContext import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.mllib.util.MLUtils @@ -159,8 +160,10 @@ private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSour sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) (file: PartitionedFile) => { - val points = - new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value) + val linesReader = new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) + + val points = linesReader .map(_.toString.trim) .filterNot(line => line.isEmpty || line.startsWith("#")) .map { line => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala index 18f9b55895a64..83cf26c63a175 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources +import java.io.Closeable import java.net.URI import org.apache.hadoop.conf.Configuration @@ -30,7 +31,8 @@ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl * An adaptor from a [[PartitionedFile]] to an [[Iterator]] of [[Text]], which are all of the lines * in that file. */ -class HadoopFileLinesReader(file: PartitionedFile, conf: Configuration) extends Iterator[Text] { +class HadoopFileLinesReader( + file: PartitionedFile, conf: Configuration) extends Iterator[Text] with Closeable { private val iterator = { val fileSplit = new FileSplit( new Path(new URI(file.filePath)), @@ -48,4 +50,6 @@ class HadoopFileLinesReader(file: PartitionedFile, conf: Configuration) extends override def hasNext: Boolean = iterator.hasNext override def next(): Text = iterator.next() + + override def close(): Unit = iterator.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/RecordReaderIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/RecordReaderIterator.scala index f03ae94d55838..938af25a96844 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/RecordReaderIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/RecordReaderIterator.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources +import java.io.Closeable + import org.apache.hadoop.mapreduce.RecordReader import org.apache.spark.sql.catalyst.InternalRow @@ -27,7 +29,8 @@ import org.apache.spark.sql.catalyst.InternalRow * Note that this returns [[Object]]s instead of [[InternalRow]] because we rely on erasure to pass * column batches by pretending they are rows. */ -class RecordReaderIterator[T](rowReader: RecordReader[_, T]) extends Iterator[T] { +class RecordReaderIterator[T]( + private[this] var rowReader: RecordReader[_, T]) extends Iterator[T] with Closeable { private[this] var havePair = false private[this] var finished = false @@ -38,7 +41,7 @@ class RecordReaderIterator[T](rowReader: RecordReader[_, T]) extends Iterator[T] // Close and release the reader here; close() will also be called when the task // completes, but for tasks that read from many files, it helps to release the // resources early. - rowReader.close() + close() } havePair = !finished } @@ -52,4 +55,18 @@ class RecordReaderIterator[T](rowReader: RecordReader[_, T]) extends Iterator[T] havePair = false rowReader.getCurrentValue } + + override def close(): Unit = { + if (rowReader != null) { + try { + // Close the reader and release it. Note: it's very important that we don't close the + // reader more than once, since that exposes us to MAPREDUCE-5918 when running against + // older Hadoop 2.x releases. That bug can lead to non-deterministic corruption issues + // when reading compressed input. + rowReader.close() + } finally { + rowReader = null + } + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 9a118fe5a273d..9610746a81ef7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapreduce._ +import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow @@ -112,7 +113,9 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { (file: PartitionedFile) => { val lineIterator = { val conf = broadcastedHadoopConf.value.value - new HadoopFileLinesReader(file, conf).map { line => + val linesReader = new HadoopFileLinesReader(file, conf) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) + linesReader.map { line => new String(line.getBytes, 0, line.getLength, csvOptions.charset) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 7421314df7aa5..6882a6cdcac26 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat +import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, Row, SparkSession} @@ -104,7 +105,9 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { .getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord) (file: PartitionedFile) => { - val lines = new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value).map(_.toString) + val linesReader = new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) + val lines = linesReader.map(_.toString) val parser = new JacksonParser(requiredSchema, columnNameOfCorruptRecord, parsedOptions) lines.flatMap(parser.parse) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index e7c3545630fea..4a308ff1a32f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -37,7 +37,7 @@ import org.apache.parquet.hadoop.util.ContextUtil import org.apache.parquet.schema.MessageType import org.slf4j.bridge.SLF4JBridgeHandler -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow @@ -388,6 +388,7 @@ class ParquetFileFormat } val iter = new RecordReaderIterator(parquetReader) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) // UnsafeRowParquetRecordReader appends the columns internally to avoid another copy. if (parquetReader.isInstanceOf[VectorizedParquetRecordReader] && diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index a0c3fd53fb53b..a875b01ec2d7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -23,6 +23,7 @@ import org.apache.hadoop.io.{NullWritable, Text} import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat +import org.apache.spark.TaskContext import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow @@ -100,6 +101,7 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { (file: PartitionedFile) => { val reader = new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => reader.close())) if (requiredSchema.isEmpty) { val emptyUnsafeRow = new UnsafeRow(0) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 03b508e11aa76..15b72d8d2179f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -31,6 +31,7 @@ import org.apache.hadoop.mapred.{JobConf, OutputFormat => MapRedOutputFormat, Re import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit} +import org.apache.spark.TaskContext import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -146,12 +147,15 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable new SparkOrcNewRecordReader(orcReader, conf, fileSplit.getStart, fileSplit.getLength) } + val recordsIterator = new RecordReaderIterator[OrcStruct](orcRecordReader) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => recordsIterator.close())) + // Unwraps `OrcStruct`s to `UnsafeRow`s OrcRelation.unwrapOrcStructs( conf, requiredSchema, Some(orcRecordReader.getObjectInspector.asInstanceOf[StructObjectInspector]), - new RecordReaderIterator[OrcStruct](orcRecordReader)) + recordsIterator) } } } From 4a83395681e0bca356363a6cfb25c952f235560d Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 27 Sep 2016 21:19:59 -0700 Subject: [PATCH 29/96] [SPARK-17499][SPARKR][FOLLOWUP] Check null first for layers in spark.mlp to avoid warnings in test results ## What changes were proposed in this pull request? Some tests in `test_mllib.r` are as below: ```r expect_error(spark.mlp(df, layers = NULL), "layers must be a integer vector with length > 1.") expect_error(spark.mlp(df, layers = c()), "layers must be a integer vector with length > 1.") ``` The problem is, `is.na` is internally called via `na.omit` in `spark.mlp` which causes warnings as below: ``` Warnings ----------------------------------------------------------------------- 1. spark.mlp (test_mllib.R#400) - is.na() applied to non-(list or vector) of type 'NULL' 2. spark.mlp (test_mllib.R#401) - is.na() applied to non-(list or vector) of type 'NULL' ``` ## How was this patch tested? Manually tested. Also, Jenkins tests and AppVeyor. Author: hyukjinkwon Closes #15232 from HyukjinKwon/remove-warnnings. --- R/pkg/R/mllib.R | 3 +++ 1 file changed, 3 insertions(+) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 971c16658fe9a..b901307f8f409 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -696,6 +696,9 @@ setMethod("predict", signature(object = "KMeansModel"), setMethod("spark.mlp", signature(data = "SparkDataFrame"), function(data, layers, blockSize = 128, solver = "l-bfgs", maxIter = 100, tol = 1E-6, stepSize = 0.03, seed = NULL) { + if (is.null(layers)) { + stop ("layers must be a integer vector with length > 1.") + } layers <- as.integer(na.omit(layers)) if (length(layers) <= 1) { stop ("layers must be a integer vector with length > 1.") From b2a7eedcddf0e682ff46afd1b764d0b81ccdf395 Mon Sep 17 00:00:00 2001 From: Shuai Lin Date: Wed, 28 Sep 2016 06:12:48 -0400 Subject: [PATCH 30/96] [SPARK-17017][ML][MLLIB][ML][DOC] Updated the ml/mllib feature selection docs for ChiSqSelector ## What changes were proposed in this pull request? A follow up for #14597 to update feature selection docs about ChiSqSelector. ## How was this patch tested? Generated html docs. It can be previewed at: * ml: http://sparkdocs.lins05.pw/spark-17017/ml-features.html#chisqselector * mllib: http://sparkdocs.lins05.pw/spark-17017/mllib-feature-extraction.html#chisqselector Author: Shuai Lin Closes #15236 from lins05/spark-17017-update-docs-for-chisq-selector-fpr. --- docs/ml-features.md | 14 ++++++++++---- docs/mllib-feature-extraction.md | 14 ++++++++++---- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/docs/ml-features.md b/docs/ml-features.md index a39b31c8f7ffc..a7f710fa52e64 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1331,10 +1331,16 @@ for more details on the API. ## ChiSqSelector `ChiSqSelector` stands for Chi-Squared feature selection. It operates on labeled data with -categorical features. ChiSqSelector orders features based on a -[Chi-Squared test of independence](https://en.wikipedia.org/wiki/Chi-squared_test) -from the class, and then filters (selects) the top features which the class label depends on the -most. This is akin to yielding the features with the most predictive power. +categorical features. ChiSqSelector uses the +[Chi-Squared test of independence](https://en.wikipedia.org/wiki/Chi-squared_test) to decide which +features to choose. It supports three selection methods: `KBest`, `Percentile` and `FPR`: + +* `KBest` chooses the `k` top features according to a chi-squared test. This is akin to yielding the features with the most predictive power. +* `Percentile` is similar to `KBest` but chooses a fraction of all features instead of a fixed number. +* `FPR` chooses all features whose false positive rate meets some threshold. + +By default, the selection method is `KBest`, the default number of top features is 50. User can use +`setNumTopFeatures`, `setPercentile` and `setAlpha` to set different selection methods. **Examples** diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index 353d391249973..87e1e027e945b 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -225,10 +225,16 @@ features for use in model construction. It reduces the size of the feature space both speed and statistical learning behavior. [`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) implements -Chi-Squared feature selection. It operates on labeled data with categorical features. -`ChiSqSelector` orders features based on a Chi-Squared test of independence from the class, -and then filters (selects) the top features which the class label depends on the most. -This is akin to yielding the features with the most predictive power. +Chi-Squared feature selection. It operates on labeled data with categorical features. ChiSqSelector uses the +[Chi-Squared test of independence](https://en.wikipedia.org/wiki/Chi-squared_test) to decide which +features to choose. It supports three selection methods: `KBest`, `Percentile` and `FPR`: + +* `KBest` chooses the `k` top features according to a chi-squared test. This is akin to yielding the features with the most predictive power. +* `Percentile` is similar to `KBest` but chooses a fraction of all features instead of a fixed number. +* `FPR` chooses all features whose false positive rate meets some threshold. + +By default, the selection method is `KBest`, the default number of top features is 50. User can use +`setNumTopFeatures`, `setPercentile` and `setAlpha` to set different selection methods. The number of features to select can be tuned using a held-out validation set. From 2190037757a81d3172f75227f7891d968e1f0d90 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 28 Sep 2016 06:19:04 -0400 Subject: [PATCH 31/96] [MINOR][PYSPARK][DOCS] Fix examples in PySpark documentation ## What changes were proposed in this pull request? This PR proposes to fix wrongly indented examples in PySpark documentation ``` - >>> json_sdf = spark.readStream.format("json")\ - .schema(sdf_schema)\ - .load(tempfile.mkdtemp()) + >>> json_sdf = spark.readStream.format("json") \\ + ... .schema(sdf_schema) \\ + ... .load(tempfile.mkdtemp()) ``` ``` - people.filter(people.age > 30).join(department, people.deptId == department.id)\ + people.filter(people.age > 30).join(department, people.deptId == department.id) \\ ``` ``` - >>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, 1.23), (2, 4.56)])), \ - LabeledPoint(0.0, Vectors.dense([1.01, 2.02, 3.03]))] + >>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, 1.23), (2, 4.56)])), + ... LabeledPoint(0.0, Vectors.dense([1.01, 2.02, 3.03]))] ``` ``` - >>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, -1.23), (2, 4.56e-7)])), \ - LabeledPoint(0.0, Vectors.dense([1.01, 2.02, 3.03]))] + >>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, -1.23), (2, 4.56e-7)])), + ... LabeledPoint(0.0, Vectors.dense([1.01, 2.02, 3.03]))] ``` ``` - ... for x in iterator: - ... print(x) + ... for x in iterator: + ... print(x) ``` ## How was this patch tested? Manually tested. **Before** ![2016-09-26 8 36 02](https://cloud.githubusercontent.com/assets/6477701/18834471/05c7a478-8431-11e6-94bb-09aa37b12ddb.png) ![2016-09-26 9 22 16](https://cloud.githubusercontent.com/assets/6477701/18834472/06c8735c-8431-11e6-8775-78631eab0411.png) 2016-09-27 2 29 27 2016-09-27 2 29 58 2016-09-27 2 30 05 **After** ![2016-09-26 9 29 47](https://cloud.githubusercontent.com/assets/6477701/18834467/0367f9da-8431-11e6-86d9-a490d3297339.png) ![2016-09-26 9 30 24](https://cloud.githubusercontent.com/assets/6477701/18834463/f870fae0-8430-11e6-9482-01fc47898492.png) 2016-09-27 2 28 19 2016-09-27 3 50 59 2016-09-27 3 51 03 Author: hyukjinkwon Closes #15242 from HyukjinKwon/minor-example-pyspark. --- python/pyspark/mllib/util.py | 8 ++++---- python/pyspark/rdd.py | 4 ++-- python/pyspark/sql/dataframe.py | 2 +- python/pyspark/sql/streaming.py | 6 +++--- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 48867a08dbfad..ed6fd4bca4c54 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -140,8 +140,8 @@ def saveAsLibSVMFile(data, dir): >>> from pyspark.mllib.regression import LabeledPoint >>> from glob import glob >>> from pyspark.mllib.util import MLUtils - >>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, 1.23), (2, 4.56)])), \ - LabeledPoint(0.0, Vectors.dense([1.01, 2.02, 3.03]))] + >>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, 1.23), (2, 4.56)])), + ... LabeledPoint(0.0, Vectors.dense([1.01, 2.02, 3.03]))] >>> tempFile = NamedTemporaryFile(delete=True) >>> tempFile.close() >>> MLUtils.saveAsLibSVMFile(sc.parallelize(examples), tempFile.name) @@ -166,8 +166,8 @@ def loadLabeledPoints(sc, path, minPartitions=None): >>> from tempfile import NamedTemporaryFile >>> from pyspark.mllib.util import MLUtils >>> from pyspark.mllib.regression import LabeledPoint - >>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, -1.23), (2, 4.56e-7)])), \ - LabeledPoint(0.0, Vectors.dense([1.01, 2.02, 3.03]))] + >>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, -1.23), (2, 4.56e-7)])), + ... LabeledPoint(0.0, Vectors.dense([1.01, 2.02, 3.03]))] >>> tempFile = NamedTemporaryFile(delete=True) >>> tempFile.close() >>> sc.parallelize(examples, 1).saveAsTextFile(tempFile.name) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 0508235c1c9ee..5fb10f86f4692 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -754,8 +754,8 @@ def foreachPartition(self, f): Applies a function to each partition of this RDD. >>> def f(iterator): - ... for x in iterator: - ... print(x) + ... for x in iterator: + ... print(x) >>> sc.parallelize([1, 2, 3, 4, 5]).foreachPartition(f) """ def func(it): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 0f7d8fba3bd54..0ac481a8a8b56 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -61,7 +61,7 @@ class DataFrame(object): people = sqlContext.read.parquet("...") department = sqlContext.read.parquet("...") - people.filter(people.age > 30).join(department, people.deptId == department.id)\ + people.filter(people.age > 30).join(department, people.deptId == department.id) \\ .groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"}) .. versionadded:: 1.3 diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index cbd827950bbb4..4e438fd5bee22 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -315,9 +315,9 @@ def load(self, path=None, format=None, schema=None, **options): :param schema: optional :class:`pyspark.sql.types.StructType` for the input schema. :param options: all other string options - >>> json_sdf = spark.readStream.format("json")\ - .schema(sdf_schema)\ - .load(tempfile.mkdtemp()) + >>> json_sdf = spark.readStream.format("json") \\ + ... .schema(sdf_schema) \\ + ... .load(tempfile.mkdtemp()) >>> json_sdf.isStreaming True >>> json_sdf.schema == sdf_schema From 46d1203bf2d01b219c4efc7e0e77a844c0c664da Mon Sep 17 00:00:00 2001 From: w00228970 Date: Wed, 28 Sep 2016 12:02:59 -0700 Subject: [PATCH 32/96] [SPARK-17644][CORE] Do not add failedStages when abortStage for fetch failure ## What changes were proposed in this pull request? | Time |Thread 1 , Job1 | Thread 2 , Job2 | |:-------------:|:-------------:|:-----:| | 1 | abort stage due to FetchFailed | | | 2 | failedStages += failedStage | | | 3 | | task failed due to FetchFailed | | 4 | | can not post ResubmitFailedStages because failedStages is not empty | Then job2 of thread2 never resubmit the failed stage and hang. We should not add the failedStages when abortStage for fetch failure ## How was this patch tested? added unit test Author: w00228970 Author: wangfei Closes #15213 from scwf/dag-resubmit. --- .../apache/spark/scheduler/DAGScheduler.scala | 24 ++++---- .../spark/scheduler/DAGSchedulerSuite.scala | 58 ++++++++++++++++++- 2 files changed, 70 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 5ea0b48f6e4c4..f2517401cb76b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1263,18 +1263,20 @@ class DAGScheduler( s"has failed the maximum allowable number of " + s"times: ${Stage.MAX_CONSECUTIVE_FETCH_FAILURES}. " + s"Most recent failure reason: ${failureMessage}", None) - } else if (failedStages.isEmpty) { - // Don't schedule an event to resubmit failed stages if failed isn't empty, because - // in that case the event will already have been scheduled. - // TODO: Cancel running tasks in the stage - logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " + - s"$failedStage (${failedStage.name}) due to fetch failure") - messageScheduler.schedule(new Runnable { - override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) - }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS) + } else { + if (failedStages.isEmpty) { + // Don't schedule an event to resubmit failed stages if failed isn't empty, because + // in that case the event will already have been scheduled. + // TODO: Cancel running tasks in the stage + logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " + + s"$failedStage (${failedStage.name}) due to fetch failure") + messageScheduler.schedule(new Runnable { + override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) + }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS) + } + failedStages += failedStage + failedStages += mapStage } - failedStages += failedStage - failedStages += mapStage // Mark the map whose fetch failed as broken in the map stage if (mapId != -1) { mapStage.removeOutputLoc(mapId, bmAddress) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 6787b302614e6..bec95d13d193a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler import java.util.Properties +import java.util.concurrent.atomic.AtomicBoolean import scala.annotation.meta.param import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} @@ -31,7 +32,7 @@ import org.apache.spark._ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.SchedulingMode.SchedulingMode -import org.apache.spark.shuffle.MetadataFetchFailedException +import org.apache.spark.shuffle.{FetchFailedException, MetadataFetchFailedException} import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, CallSite, LongAccumulator, Utils} @@ -2105,6 +2106,61 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou assert(scheduler.getShuffleDependencies(rddE) === Set(shuffleDepA, shuffleDepC)) } + test("SPARK-17644: After one stage is aborted for too many failed attempts, subsequent stages" + + "still behave correctly on fetch failures") { + // Runs a job that always encounters a fetch failure, so should eventually be aborted + def runJobWithPersistentFetchFailure: Unit = { + val rdd1 = sc.makeRDD(Array(1, 2, 3, 4), 2).map(x => (x, 1)).groupByKey() + val shuffleHandle = + rdd1.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleHandle + rdd1.map { + case (x, _) if (x == 1) => + throw new FetchFailedException( + BlockManagerId("1", "1", 1), shuffleHandle.shuffleId, 0, 0, "test") + case (x, _) => x + }.count() + } + + // Runs a job that encounters a single fetch failure but succeeds on the second attempt + def runJobWithTemporaryFetchFailure: Unit = { + object FailThisAttempt { + val _fail = new AtomicBoolean(true) + } + val rdd1 = sc.makeRDD(Array(1, 2, 3, 4), 2).map(x => (x, 1)).groupByKey() + val shuffleHandle = + rdd1.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleHandle + rdd1.map { + case (x, _) if (x == 1) && FailThisAttempt._fail.getAndSet(false) => + throw new FetchFailedException( + BlockManagerId("1", "1", 1), shuffleHandle.shuffleId, 0, 0, "test") + } + } + + failAfter(10.seconds) { + val e = intercept[SparkException] { + runJobWithPersistentFetchFailure + } + assert(e.getMessage.contains("org.apache.spark.shuffle.FetchFailedException")) + } + + // Run a second job that will fail due to a fetch failure. + // This job will hang without the fix for SPARK-17644. + failAfter(10.seconds) { + val e = intercept[SparkException] { + runJobWithPersistentFetchFailure + } + assert(e.getMessage.contains("org.apache.spark.shuffle.FetchFailedException")) + } + + failAfter(10.seconds) { + try { + runJobWithTemporaryFetchFailure + } catch { + case e: Throwable => fail("A job with one fetch failure should eventually succeed") + } + } + } + /** * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations. * Note that this checks only the host and not the executor ID. From a6cfa3f38bcf6ba154d5ed2a53748fbc90c8872a Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 28 Sep 2016 13:22:45 -0700 Subject: [PATCH 33/96] [SPARK-17673][SQL] Incorrect exchange reuse with RowDataSourceScan ## What changes were proposed in this pull request? It seems the equality check for reuse of `RowDataSourceScanExec` nodes doesn't respect the output schema. This can cause self-joins or unions over the same underlying data source to return incorrect results if they select different fields. ## How was this patch tested? New unit test passes after the fix. Author: Eric Liang Closes #15273 from ericl/spark-17673. --- .../sql/execution/datasources/DataSourceStrategy.scala | 4 ++++ .../test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala | 8 ++++++++ 2 files changed, 12 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 63f01c5bb9e3c..693b4c4d0e5e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -340,6 +340,8 @@ object DataSourceStrategy extends Strategy with Logging { // `Filter`s or cannot be handled by `relation`. val filterCondition = unhandledPredicates.reduceLeftOption(expressions.And) + // These metadata values make scan plans uniquely identifiable for equality checking. + // TODO(SPARK-17701) using strings for equality checking is brittle val metadata: Map[String, String] = { val pairs = ArrayBuffer.empty[(String, String)] @@ -350,6 +352,8 @@ object DataSourceStrategy extends Strategy with Logging { } pairs += ("PushedFilters" -> markedFilters.mkString("[", ", ", "]")) } + pairs += ("ReadSchema" -> + StructType.fromAttributes(projects.map(_.toAttribute)).catalogString) pairs.toMap } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 10f15ca280689..c94cb3b69dfbe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -791,4 +791,12 @@ class JDBCSuite extends SparkFunSuite val schema = JdbcUtils.schemaString(df, "jdbc:mysql://localhost:3306/temp") assert(schema.contains("`order` TEXT")) } + + test("SPARK-17673: Exchange reuse respects differences in output schema") { + val df = sql("SELECT * FROM inttypes WHERE a IS NOT NULL") + val df1 = df.groupBy("a").agg("c" -> "min") + val df2 = df.groupBy("a").agg("d" -> "min") + val res = df1.union(df2) + assert(res.distinct().count() == 2) // would be 1 if the exchange was incorrectly reused + } } From 557d6e32272dee4eaa0f426cc3e2f82ea361c3da Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 28 Sep 2016 16:20:49 -0700 Subject: [PATCH 34/96] [SPARK-17713][SQL] Move row-datasource related tests out of JDBCSuite ## What changes were proposed in this pull request? As a followup for https://github.com/apache/spark/pull/15273 we should move non-JDBC specific tests out of that suite. ## How was this patch tested? Ran the test. Author: Eric Liang Closes #15287 from ericl/spark-17713. --- .../RowDataSourceStrategySuite.scala | 72 +++++++++++++++++++ .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 8 --- 2 files changed, 72 insertions(+), 8 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/RowDataSourceStrategySuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/RowDataSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/RowDataSourceStrategySuite.scala new file mode 100644 index 0000000000000..d9afa4635318f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/RowDataSourceStrategySuite.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.sql.DriverManager +import java.util.Properties + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils + +class RowDataSourceStrategySuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext { + import testImplicits._ + + val url = "jdbc:h2:mem:testdb0" + val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass" + var conn: java.sql.Connection = null + + before { + Utils.classForName("org.h2.Driver") + // Extra properties that will be specified for our database. We need these to test + // usage of parameters from OPTIONS clause in queries. + val properties = new Properties() + properties.setProperty("user", "testUser") + properties.setProperty("password", "testPass") + properties.setProperty("rowId", "false") + + conn = DriverManager.getConnection(url, properties) + conn.prepareStatement("create schema test").executeUpdate() + conn.prepareStatement("create table test.inttypes (a INT, b INT, c INT)").executeUpdate() + conn.prepareStatement("insert into test.inttypes values (1, 2, 3)").executeUpdate() + conn.commit() + sql( + s""" + |CREATE TEMPORARY TABLE inttypes + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', dbtable 'TEST.INTTYPES', user 'testUser', password 'testPass') + """.stripMargin.replaceAll("\n", " ")) + } + + after { + conn.close() + } + + test("SPARK-17673: Exchange reuse respects differences in output schema") { + val df = sql("SELECT * FROM inttypes") + val df1 = df.groupBy("a").agg("b" -> "min") + val df2 = df.groupBy("a").agg("c" -> "min") + val res = df1.union(df2) + assert(res.distinct().count() == 2) // would be 1 if the exchange was incorrectly reused + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index c94cb3b69dfbe..10f15ca280689 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -791,12 +791,4 @@ class JDBCSuite extends SparkFunSuite val schema = JdbcUtils.schemaString(df, "jdbc:mysql://localhost:3306/temp") assert(schema.contains("`order` TEXT")) } - - test("SPARK-17673: Exchange reuse respects differences in output schema") { - val df = sql("SELECT * FROM inttypes WHERE a IS NOT NULL") - val df1 = df.groupBy("a").agg("c" -> "min") - val df2 = df.groupBy("a").agg("d" -> "min") - val res = df1.union(df2) - assert(res.distinct().count() == 2) // would be 1 if the exchange was incorrectly reused - } } From 7d09232028967978d9db314ec041a762599f636b Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 28 Sep 2016 16:25:10 -0700 Subject: [PATCH 35/96] [SPARK-17641][SQL] Collect_list/Collect_set should not collect null values. ## What changes were proposed in this pull request? We added native versions of `collect_set` and `collect_list` in Spark 2.0. These currently also (try to) collect null values, this is different from the original Hive implementation. This PR fixes this by adding a null check to the `Collect.update` method. ## How was this patch tested? Added a regression test to `DataFrameAggregateSuite`. Author: Herman van Hovell Closes #15208 from hvanhovell/SPARK-17641. --- .../sql/catalyst/expressions/aggregate/collect.scala | 7 ++++++- .../apache/spark/sql/DataFrameAggregateSuite.scala | 12 ++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index 896ff61b23093..78a388d20630b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -65,7 +65,12 @@ abstract class Collect extends ImperativeAggregate { } override def update(b: MutableRow, input: InternalRow): Unit = { - buffer += child.eval(input) + // Do not allow null values. We follow the semantics of Hive's collect_list/collect_set here. + // See: org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMkCollectionEvaluator + val value = child.eval(input) + if (value != null) { + buffer += value + } } override def merge(buffer: MutableRow, input: InternalRow): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 0e172bee4f661..7aa4f0026f275 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -477,6 +477,18 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { assert(error.message.contains("collect_set() cannot have map type data")) } + test("SPARK-17641: collect functions should not collect null values") { + val df = Seq(("1", 2), (null, 2), ("1", 4)).toDF("a", "b") + checkAnswer( + df.select(collect_list($"a"), collect_list($"b")), + Seq(Row(Seq("1", "1"), Seq(2, 2, 4))) + ) + checkAnswer( + df.select(collect_set($"a"), collect_set($"b")), + Seq(Row(Seq("1"), Seq(2, 4))) + ) + } + test("SPARK-14664: Decimal sum/avg over window should work.") { checkAnswer( spark.sql("select sum(a) over () from values 1.0, 2.0, 3.0 T(a)"), From 7dfad4b132bc46263ef788ced4a935862f5c8756 Mon Sep 17 00:00:00 2001 From: Weiqing Yang Date: Wed, 28 Sep 2016 20:20:03 -0500 Subject: [PATCH 36/96] [SPARK-17710][HOTFIX] Fix ClassCircularityError in ReplSuite tests in Maven build: use 'Class.forName' instead of 'Utils.classForName' ## What changes were proposed in this pull request? Fix ClassCircularityError in ReplSuite tests when Spark is built by Maven build. ## How was this patch tested? (1) ``` build/mvn -DskipTests -Phadoop-2.3 -Pyarn -Phive -Phive-thriftserver -Pkinesis-asl -Pmesos clean package ``` Then test: ``` build/mvn -Dtest=none -DwildcardSuites=org.apache.spark.repl.ReplSuite test ``` ReplSuite tests passed (2) Manual Tests against some Spark applications in Yarn client mode and Yarn cluster mode. Need to check if spark caller contexts are written into HDFS hdfs-audit.log and Yarn RM audit log successfully. Author: Weiqing Yang Closes #15286 from Sherry302/SPARK-16757. --- core/src/main/scala/org/apache/spark/util/Utils.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index caa768cfbdc6c..f3493bd96b1ee 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2489,8 +2489,10 @@ private[spark] class CallerContext( def setCurrentContext(): Boolean = { var succeed = false try { - val callerContext = Utils.classForName("org.apache.hadoop.ipc.CallerContext") - val Builder = Utils.classForName("org.apache.hadoop.ipc.CallerContext$Builder") + // scalastyle:off classforname + val callerContext = Class.forName("org.apache.hadoop.ipc.CallerContext") + val Builder = Class.forName("org.apache.hadoop.ipc.CallerContext$Builder") + // scalastyle:on classforname val builderInst = Builder.getConstructor(classOf[String]).newInstance(context) val hdfsContext = Builder.getMethod("build").invoke(builderInst) callerContext.getMethod("setCurrent", callerContext).invoke(null, hdfsContext) From 37eb9184f1e9f1c07142c66936671f4711ef407d Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 28 Sep 2016 19:03:05 -0700 Subject: [PATCH 37/96] [SPARK-17712][SQL] Fix invalid pushdown of data-independent filters beneath aggregates ## What changes were proposed in this pull request? This patch fixes a minor correctness issue impacting the pushdown of filters beneath aggregates. Specifically, if a filter condition references no grouping or aggregate columns (e.g. `WHERE false`) then it would be incorrectly pushed beneath an aggregate. Intuitively, the only case where you can push a filter beneath an aggregate is when that filter is deterministic and is defined over the grouping columns / expressions, since in that case the filter is acting to exclude entire groups from the query (like a `HAVING` clause). The existing code would only push deterministic filters beneath aggregates when all of the filter's references were grouping columns, but this logic missed the case where a filter has no references. For example, `WHERE false` is deterministic but is independent of the actual data. This patch fixes this minor bug by adding a new check to ensure that we don't push filters beneath aggregates when those filters don't reference any columns. ## How was this patch tested? New regression test in FilterPushdownSuite. Author: Josh Rosen Closes #15289 from JoshRosen/SPARK-17712. --- .../sql/catalyst/optimizer/Optimizer.scala | 2 +- .../optimizer/FilterPushdownSuite.scala | 17 +++++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0df16b7a56c56..4952ba3b2b99d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -710,7 +710,7 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { val (pushDown, rest) = candidates.partition { cond => val replaced = replaceAlias(cond, aliasMap) - replaced.references.subsetOf(aggregate.child.outputSet) + cond.references.nonEmpty && replaced.references.subsetOf(aggregate.child.outputSet) } val stayUp = rest ++ containingNonDeterministic diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 55836f96f7e0e..019f132d94cb2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -687,6 +687,23 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("SPARK-17712: aggregate: don't push down filters that are data-independent") { + val originalQuery = LocalRelation.apply(testRelation.output, Seq.empty) + .select('a, 'b) + .groupBy('a)(count('a)) + .where(false) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = testRelation + .select('a, 'b) + .groupBy('a)(count('a)) + .where(false) + .analyze + + comparePlans(optimized, correctAnswer) + } + test("broadcast hint") { val originalQuery = BroadcastHint(testRelation) .where('a === 2L && 'b + Rand(10).as("rnd") === 3) From a19a1bb59411177caaf99581e89098826b7d0c7b Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 29 Sep 2016 00:54:26 -0700 Subject: [PATCH 38/96] [SPARK-16356][FOLLOW-UP][ML] Enforce ML test of exception for local/distributed Dataset. ## What changes were proposed in this pull request? #14035 added ```testImplicits``` to ML unit tests and promoted ```toDF()```, but left one minor issue at ```VectorIndexerSuite```. If we create the DataFrame by ```Seq(...).toDF()```, it will throw different error/exception compared with ```sc.parallelize(Seq(...)).toDF()``` for one of the test cases. After in-depth study, I found it was caused by different behavior of local and distributed Dataset if the UDF failed at ```assert```. If the data is local Dataset, it throws ```AssertionError``` directly; If the data is distributed Dataset, it throws ```SparkException``` which is the wrapper of ```AssertionError```. I think we should enforce this test to cover both case. ## How was this patch tested? Unit test. Author: Yanbo Liang Closes #15261 from yanboliang/spark-16356. --- .../spark/ml/feature/VectorIndexerSuite.scala | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index 4da1b133e8cd5..b28ce2ab45b45 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -88,9 +88,7 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext densePoints1 = densePoints1Seq.map(FeatureData).toDF() sparsePoints1 = sparsePoints1Seq.map(FeatureData).toDF() - // TODO: If we directly use `toDF` without parallelize, the test in - // "Throws error when given RDDs with different size vectors" is failed for an unknown reason. - densePoints2 = sc.parallelize(densePoints2Seq, 2).map(FeatureData).toDF() + densePoints2 = densePoints2Seq.map(FeatureData).toDF() sparsePoints2 = sparsePoints2Seq.map(FeatureData).toDF() badPoints = badPointsSeq.map(FeatureData).toDF() } @@ -121,10 +119,17 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext model.transform(densePoints1) // should work model.transform(sparsePoints1) // should work - intercept[SparkException] { + // If the data is local Dataset, it throws AssertionError directly. + intercept[AssertionError] { model.transform(densePoints2).collect() logInfo("Did not throw error when fit, transform were called on vectors of different lengths") } + // If the data is distributed Dataset, it throws SparkException + // which is the wrapper of AssertionError. + intercept[SparkException] { + model.transform(densePoints2.repartition(2)).collect() + logInfo("Did not throw error when fit, transform were called on vectors of different lengths") + } intercept[SparkException] { vectorIndexer.fit(badPoints) logInfo("Did not throw error when fitting vectors of different lengths in same RDD.") From f7082ac12518ae84d6d1d4b7330a9f12cf95e7c1 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 29 Sep 2016 04:30:42 -0700 Subject: [PATCH 39/96] [SPARK-17704][ML][MLLIB] ChiSqSelector performance improvement. ## What changes were proposed in this pull request? Several performance improvement for ```ChiSqSelector```: 1, Keep ```selectedFeatures``` ordered ascendent. ```ChiSqSelectorModel.transform``` need ```selectedFeatures``` ordered to make prediction. We should sort it when training model rather than making prediction, since users usually train model once and use the model to do prediction multiple times. 2, When training ```fpr``` type ```ChiSqSelectorModel```, it's not necessary to sort the ChiSq test result by statistic. ## How was this patch tested? Existing unit tests. Author: Yanbo Liang Closes #15277 from yanboliang/spark-17704. --- .../spark/mllib/feature/ChiSqSelector.scala | 45 ++++++++++++------- project/MimaExcludes.scala | 3 -- 2 files changed, 30 insertions(+), 18 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index 0f7c6e8bc04bb..706ce78f260a6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -35,12 +35,24 @@ import org.apache.spark.sql.{Row, SparkSession} /** * Chi Squared selector model. * - * @param selectedFeatures list of indices to select (filter). + * @param selectedFeatures list of indices to select (filter). Must be ordered asc */ @Since("1.3.0") class ChiSqSelectorModel @Since("1.3.0") ( @Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer with Saveable { + require(isSorted(selectedFeatures), "Array has to be sorted asc") + + protected def isSorted(array: Array[Int]): Boolean = { + var i = 1 + val len = array.length + while (i < len) { + if (array(i) < array(i-1)) return false + i += 1 + } + true + } + /** * Applies transformation on a vector. * @@ -57,22 +69,21 @@ class ChiSqSelectorModel @Since("1.3.0") ( * Preserves the order of filtered features the same as their indices are stored. * Might be moved to Vector as .slice * @param features vector - * @param filterIndices indices of features to filter + * @param filterIndices indices of features to filter, must be ordered asc */ private def compress(features: Vector, filterIndices: Array[Int]): Vector = { - val orderedIndices = filterIndices.sorted features match { case SparseVector(size, indices, values) => - val newSize = orderedIndices.length + val newSize = filterIndices.length val newValues = new ArrayBuilder.ofDouble val newIndices = new ArrayBuilder.ofInt var i = 0 var j = 0 var indicesIdx = 0 var filterIndicesIdx = 0 - while (i < indices.length && j < orderedIndices.length) { + while (i < indices.length && j < filterIndices.length) { indicesIdx = indices(i) - filterIndicesIdx = orderedIndices(j) + filterIndicesIdx = filterIndices(j) if (indicesIdx == filterIndicesIdx) { newIndices += j newValues += values(i) @@ -90,7 +101,7 @@ class ChiSqSelectorModel @Since("1.3.0") ( Vectors.sparse(newSize, newIndices.result(), newValues.result()) case DenseVector(values) => val values = features.toArray - Vectors.dense(orderedIndices.map(i => values(i))) + Vectors.dense(filterIndices.map(i => values(i))) case other => throw new UnsupportedOperationException( s"Only sparse and dense vectors are supported but got ${other.getClass}.") @@ -220,18 +231,22 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { @Since("1.3.0") def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = { val chiSqTestResult = Statistics.chiSqTest(data) - .zipWithIndex.sortBy { case (res, _) => -res.statistic } val features = selectorType match { - case ChiSqSelector.KBest => chiSqTestResult - .take(numTopFeatures) - case ChiSqSelector.Percentile => chiSqTestResult - .take((chiSqTestResult.length * percentile).toInt) - case ChiSqSelector.FPR => chiSqTestResult - .filter{ case (res, _) => res.pValue < alpha } + case ChiSqSelector.KBest => + chiSqTestResult.zipWithIndex + .sortBy { case (res, _) => -res.statistic } + .take(numTopFeatures) + case ChiSqSelector.Percentile => + chiSqTestResult.zipWithIndex + .sortBy { case (res, _) => -res.statistic } + .take((chiSqTestResult.length * percentile).toInt) + case ChiSqSelector.FPR => + chiSqTestResult.zipWithIndex + .filter{ case (res, _) => res.pValue < alpha } case errorType => throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType") } - val indices = features.map { case (_, indices) => indices } + val indices = features.map { case (_, indices) => indices }.sorted new ChiSqSelectorModel(indices) } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 8024fbd21bbfc..4db3edb733a56 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -817,9 +817,6 @@ object MimaExcludes { ) ++ Seq( // [SPARK-17163] Unify logistic regression interface. Private constructor has new signature. ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.this") - ) ++ Seq( - // [SPARK-17017] Add chiSquare selector based on False Positive Rate (FPR) test - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.feature.ChiSqSelectorModel.isSorted") ) ++ Seq( // [SPARK-17365][Core] Remove/Kill multiple executors together to reduce RPC call time ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.SparkContext") From b35b0dbbfa3dc1bdf5e2fa1e9677d06635142b22 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 29 Sep 2016 08:24:34 -0400 Subject: [PATCH 40/96] [SPARK-17614][SQL] sparkSession.read() .jdbc(***) use the sql syntax "where 1=0" that Cassandra does not support ## What changes were proposed in this pull request? Use dialect's table-exists query rather than hard-coded WHERE 1=0 query ## How was this patch tested? Existing tests. Author: Sean Owen Closes #15196 from srowen/SPARK-17614. --- .../sql/execution/datasources/jdbc/JDBCRDD.scala | 6 ++---- .../org/apache/spark/sql/jdbc/JdbcDialects.scala | 15 ++++++++++++++- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index a7da29f9252b3..f10615ebe4bcf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -58,11 +58,11 @@ object JDBCRDD extends Logging { val dialect = JdbcDialects.get(url) val conn: Connection = JdbcUtils.createConnectionFactory(url, properties)() try { - val statement = conn.prepareStatement(s"SELECT * FROM $table WHERE 1=0") + val statement = conn.prepareStatement(dialect.getSchemaQuery(table)) try { val rs = statement.executeQuery() try { - return JdbcUtils.getSchema(rs, dialect) + JdbcUtils.getSchema(rs, dialect) } finally { rs.close() } @@ -72,8 +72,6 @@ object JDBCRDD extends Logging { } finally { conn.close() } - - throw new RuntimeException("This line is unreachable.") } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 3a6d5b7f1ced6..8dd4b8f662713 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.jdbc import java.sql.Connection -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.sql.types._ /** @@ -99,6 +99,19 @@ abstract class JdbcDialect extends Serializable { s"SELECT * FROM $table WHERE 1=0" } + /** + * The SQL query that should be used to discover the schema of a table. It only needs to + * ensure that the result set has the same schema as the table, such as by calling + * "SELECT * ...". Dialects can override this method to return a query that works best in a + * particular database. + * @param table The name of the table. + * @return The SQL query to use for discovering the schema. + */ + @Since("2.1.0") + def getSchemaQuery(table: String): String = { + s"SELECT * FROM $table WHERE 1=0" + } + /** * Override connection specific properties to run before a select is made. This is in place to * allow dialects that need special treatment to optimize behavior. From b2e9731ca494c0c60d571499f68bb8306a3c9fe5 Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Thu, 29 Sep 2016 08:26:03 -0400 Subject: [PATCH 41/96] [MINOR][DOCS] Fix th doc. of spark-streaming with kinesis ## What changes were proposed in this pull request? This pr is just to fix the document of `spark-kinesis-integration`. Since `SPARK-17418` prevented all the kinesis stuffs (including kinesis example code) from publishing, `bin/run-example streaming.KinesisWordCountASL` and `bin/run-example streaming.JavaKinesisWordCountASL` does not work. Instead, it fetches the kinesis jar from the Spark Package. Author: Takeshi YAMAMURO Closes #15260 from maropu/DocFixKinesis. --- docs/streaming-kinesis-integration.md | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/docs/streaming-kinesis-integration.md b/docs/streaming-kinesis-integration.md index 96198ddf537b6..6be0b548bc62b 100644 --- a/docs/streaming-kinesis-integration.md +++ b/docs/streaming-kinesis-integration.md @@ -166,10 +166,7 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m #### Running the Example To run the example, -- Download Spark source and follow the [instructions](building-spark.html) to build Spark with profile *-Pkinesis-asl*. - - mvn -Pkinesis-asl -DskipTests clean package - +- Download a Spark binary from the [download site](http://spark.apache.org/downloads.html). - Set up Kinesis stream (see earlier section) within AWS. Note the name of the Kinesis stream and the endpoint URL corresponding to the region where the stream was created. @@ -180,12 +177,12 @@ To run the example,
- bin/run-example streaming.KinesisWordCountASL [Kinesis app name] [Kinesis stream name] [endpoint URL] + bin/run-example --packages org.apache.spark:spark-streaming-kinesis-asl_{{site.SCALA_BINARY_VERSION}}:{{site.SPARK_VERSION_SHORT}} streaming.KinesisWordCountASL [Kinesis app name] [Kinesis stream name] [endpoint URL]
- bin/run-example streaming.JavaKinesisWordCountASL [Kinesis app name] [Kinesis stream name] [endpoint URL] + bin/run-example --packages org.apache.spark:spark-streaming-kinesis-asl_{{site.SCALA_BINARY_VERSION}}:{{site.SPARK_VERSION_SHORT}} streaming.JavaKinesisWordCountASL [Kinesis app name] [Kinesis stream name] [endpoint URL]
From 958200497affb40f05e321c2b0e252d365ae02f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Hiram=20Soltren?= Date: Thu, 29 Sep 2016 10:18:56 -0700 Subject: [PATCH 42/96] [DOCS] Reorganize explanation of Accumulators and Broadcast Variables MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? The discussion of the interaction of Accumulators and Broadcast Variables should logically follow the discussion on Checkpointing. As currently written, this section discusses Checkpointing before it is formally introduced. To remedy this: - Rename this section to "Accumulators, Broadcast Variables, and Checkpoints", and - Move this section after "Checkpointing". ## How was this patch tested? Testing: ran $ SKIP_API=1 jekyll build , and verified changes in a Web browser pointed at docs/_site/index.html. Author: José Hiram Soltren Closes #15281 from jsoltren/doc-changes. --- docs/streaming-programming-guide.md | 328 ++++++++++++++-------------- 1 file changed, 164 insertions(+), 164 deletions(-) diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 43f1cf3e31871..0b0315b366501 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -1368,170 +1368,6 @@ Note that the connections in the pool should be lazily created on demand and tim *** -## Accumulators and Broadcast Variables - -[Accumulators](programming-guide.html#accumulators) and [Broadcast variables](programming-guide.html#broadcast-variables) cannot be recovered from checkpoint in Spark Streaming. If you enable checkpointing and use [Accumulators](programming-guide.html#accumulators) or [Broadcast variables](programming-guide.html#broadcast-variables) as well, you'll have to create lazily instantiated singleton instances for [Accumulators](programming-guide.html#accumulators) and [Broadcast variables](programming-guide.html#broadcast-variables) so that they can be re-instantiated after the driver restarts on failure. This is shown in the following example. - -
-
-{% highlight scala %} - -object WordBlacklist { - - @volatile private var instance: Broadcast[Seq[String]] = null - - def getInstance(sc: SparkContext): Broadcast[Seq[String]] = { - if (instance == null) { - synchronized { - if (instance == null) { - val wordBlacklist = Seq("a", "b", "c") - instance = sc.broadcast(wordBlacklist) - } - } - } - instance - } -} - -object DroppedWordsCounter { - - @volatile private var instance: LongAccumulator = null - - def getInstance(sc: SparkContext): LongAccumulator = { - if (instance == null) { - synchronized { - if (instance == null) { - instance = sc.longAccumulator("WordsInBlacklistCounter") - } - } - } - instance - } -} - -wordCounts.foreachRDD { (rdd: RDD[(String, Int)], time: Time) => - // Get or register the blacklist Broadcast - val blacklist = WordBlacklist.getInstance(rdd.sparkContext) - // Get or register the droppedWordsCounter Accumulator - val droppedWordsCounter = DroppedWordsCounter.getInstance(rdd.sparkContext) - // Use blacklist to drop words and use droppedWordsCounter to count them - val counts = rdd.filter { case (word, count) => - if (blacklist.value.contains(word)) { - droppedWordsCounter.add(count) - false - } else { - true - } - }.collect().mkString("[", ", ", "]") - val output = "Counts at time " + time + " " + counts -}) - -{% endhighlight %} - -See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala). -
-
-{% highlight java %} - -class JavaWordBlacklist { - - private static volatile Broadcast> instance = null; - - public static Broadcast> getInstance(JavaSparkContext jsc) { - if (instance == null) { - synchronized (JavaWordBlacklist.class) { - if (instance == null) { - List wordBlacklist = Arrays.asList("a", "b", "c"); - instance = jsc.broadcast(wordBlacklist); - } - } - } - return instance; - } -} - -class JavaDroppedWordsCounter { - - private static volatile LongAccumulator instance = null; - - public static LongAccumulator getInstance(JavaSparkContext jsc) { - if (instance == null) { - synchronized (JavaDroppedWordsCounter.class) { - if (instance == null) { - instance = jsc.sc().longAccumulator("WordsInBlacklistCounter"); - } - } - } - return instance; - } -} - -wordCounts.foreachRDD(new Function2, Time, Void>() { - @Override - public Void call(JavaPairRDD rdd, Time time) throws IOException { - // Get or register the blacklist Broadcast - final Broadcast> blacklist = JavaWordBlacklist.getInstance(new JavaSparkContext(rdd.context())); - // Get or register the droppedWordsCounter Accumulator - final LongAccumulator droppedWordsCounter = JavaDroppedWordsCounter.getInstance(new JavaSparkContext(rdd.context())); - // Use blacklist to drop words and use droppedWordsCounter to count them - String counts = rdd.filter(new Function, Boolean>() { - @Override - public Boolean call(Tuple2 wordCount) throws Exception { - if (blacklist.value().contains(wordCount._1())) { - droppedWordsCounter.add(wordCount._2()); - return false; - } else { - return true; - } - } - }).collect().toString(); - String output = "Counts at time " + time + " " + counts; - } -} - -{% endhighlight %} - -See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java). -
-
-{% highlight python %} -def getWordBlacklist(sparkContext): - if ("wordBlacklist" not in globals()): - globals()["wordBlacklist"] = sparkContext.broadcast(["a", "b", "c"]) - return globals()["wordBlacklist"] - -def getDroppedWordsCounter(sparkContext): - if ("droppedWordsCounter" not in globals()): - globals()["droppedWordsCounter"] = sparkContext.accumulator(0) - return globals()["droppedWordsCounter"] - -def echo(time, rdd): - # Get or register the blacklist Broadcast - blacklist = getWordBlacklist(rdd.context) - # Get or register the droppedWordsCounter Accumulator - droppedWordsCounter = getDroppedWordsCounter(rdd.context) - - # Use blacklist to drop words and use droppedWordsCounter to count them - def filterFunc(wordCount): - if wordCount[0] in blacklist.value: - droppedWordsCounter.add(wordCount[1]) - False - else: - True - - counts = "Counts at time %s %s" % (time, rdd.filter(filterFunc).collect()) - -wordCounts.foreachRDD(echo) - -{% endhighlight %} - -See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/streaming/recoverable_network_wordcount.py). - -
-
- -*** - ## DataFrame and SQL Operations You can easily use [DataFrames and SQL](sql-programming-guide.html) operations on streaming data. You have to create a SparkSession using the SparkContext that the StreamingContext is using. Furthermore this has to done such that it can be restarted on driver failures. This is done by creating a lazily instantiated singleton instance of SparkSession. This is shown in the following example. It modifies the earlier [word count example](#a-quick-example) to generate word counts using DataFrames and SQL. Each RDD is converted to a DataFrame, registered as a temporary table and then queried using SQL. @@ -1877,6 +1713,170 @@ batch interval that is at least 10 seconds. It can be set by using *** +## Accumulators, Broadcast Variables, and Checkpoints + +[Accumulators](programming-guide.html#accumulators) and [Broadcast variables](programming-guide.html#broadcast-variables) cannot be recovered from checkpoint in Spark Streaming. If you enable checkpointing and use [Accumulators](programming-guide.html#accumulators) or [Broadcast variables](programming-guide.html#broadcast-variables) as well, you'll have to create lazily instantiated singleton instances for [Accumulators](programming-guide.html#accumulators) and [Broadcast variables](programming-guide.html#broadcast-variables) so that they can be re-instantiated after the driver restarts on failure. This is shown in the following example. + +
+
+{% highlight scala %} + +object WordBlacklist { + + @volatile private var instance: Broadcast[Seq[String]] = null + + def getInstance(sc: SparkContext): Broadcast[Seq[String]] = { + if (instance == null) { + synchronized { + if (instance == null) { + val wordBlacklist = Seq("a", "b", "c") + instance = sc.broadcast(wordBlacklist) + } + } + } + instance + } +} + +object DroppedWordsCounter { + + @volatile private var instance: LongAccumulator = null + + def getInstance(sc: SparkContext): LongAccumulator = { + if (instance == null) { + synchronized { + if (instance == null) { + instance = sc.longAccumulator("WordsInBlacklistCounter") + } + } + } + instance + } +} + +wordCounts.foreachRDD { (rdd: RDD[(String, Int)], time: Time) => + // Get or register the blacklist Broadcast + val blacklist = WordBlacklist.getInstance(rdd.sparkContext) + // Get or register the droppedWordsCounter Accumulator + val droppedWordsCounter = DroppedWordsCounter.getInstance(rdd.sparkContext) + // Use blacklist to drop words and use droppedWordsCounter to count them + val counts = rdd.filter { case (word, count) => + if (blacklist.value.contains(word)) { + droppedWordsCounter.add(count) + false + } else { + true + } + }.collect().mkString("[", ", ", "]") + val output = "Counts at time " + time + " " + counts +}) + +{% endhighlight %} + +See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala). +
+
+{% highlight java %} + +class JavaWordBlacklist { + + private static volatile Broadcast> instance = null; + + public static Broadcast> getInstance(JavaSparkContext jsc) { + if (instance == null) { + synchronized (JavaWordBlacklist.class) { + if (instance == null) { + List wordBlacklist = Arrays.asList("a", "b", "c"); + instance = jsc.broadcast(wordBlacklist); + } + } + } + return instance; + } +} + +class JavaDroppedWordsCounter { + + private static volatile LongAccumulator instance = null; + + public static LongAccumulator getInstance(JavaSparkContext jsc) { + if (instance == null) { + synchronized (JavaDroppedWordsCounter.class) { + if (instance == null) { + instance = jsc.sc().longAccumulator("WordsInBlacklistCounter"); + } + } + } + return instance; + } +} + +wordCounts.foreachRDD(new Function2, Time, Void>() { + @Override + public Void call(JavaPairRDD rdd, Time time) throws IOException { + // Get or register the blacklist Broadcast + final Broadcast> blacklist = JavaWordBlacklist.getInstance(new JavaSparkContext(rdd.context())); + // Get or register the droppedWordsCounter Accumulator + final LongAccumulator droppedWordsCounter = JavaDroppedWordsCounter.getInstance(new JavaSparkContext(rdd.context())); + // Use blacklist to drop words and use droppedWordsCounter to count them + String counts = rdd.filter(new Function, Boolean>() { + @Override + public Boolean call(Tuple2 wordCount) throws Exception { + if (blacklist.value().contains(wordCount._1())) { + droppedWordsCounter.add(wordCount._2()); + return false; + } else { + return true; + } + } + }).collect().toString(); + String output = "Counts at time " + time + " " + counts; + } +} + +{% endhighlight %} + +See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java). +
+
+{% highlight python %} +def getWordBlacklist(sparkContext): + if ("wordBlacklist" not in globals()): + globals()["wordBlacklist"] = sparkContext.broadcast(["a", "b", "c"]) + return globals()["wordBlacklist"] + +def getDroppedWordsCounter(sparkContext): + if ("droppedWordsCounter" not in globals()): + globals()["droppedWordsCounter"] = sparkContext.accumulator(0) + return globals()["droppedWordsCounter"] + +def echo(time, rdd): + # Get or register the blacklist Broadcast + blacklist = getWordBlacklist(rdd.context) + # Get or register the droppedWordsCounter Accumulator + droppedWordsCounter = getDroppedWordsCounter(rdd.context) + + # Use blacklist to drop words and use droppedWordsCounter to count them + def filterFunc(wordCount): + if wordCount[0] in blacklist.value: + droppedWordsCounter.add(wordCount[1]) + False + else: + True + + counts = "Counts at time %s %s" % (time, rdd.filter(filterFunc).collect()) + +wordCounts.foreachRDD(echo) + +{% endhighlight %} + +See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/streaming/recoverable_network_wordcount.py). + +
+
+ +*** + ## Deploying Applications This section discusses the steps to deploy a Spark Streaming application. From 7f779e7439127efa0e3611f7745e1c8423845198 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Thu, 29 Sep 2016 15:36:40 -0400 Subject: [PATCH 43/96] [SPARK-17648][CORE] TaskScheduler really needs offers to be an IndexedSeq ## What changes were proposed in this pull request? The Seq[WorkerOffer] is accessed by index, so it really should be an IndexedSeq, otherwise an O(n) operation becomes O(n^2). In practice this hasn't been an issue b/c where these offers are generated, the call to `.toSeq` just happens to create an IndexedSeq anyway.I got bitten by this in performance tests I was doing, and its better for the types to be more precise so eg. a change in Scala doesn't destroy performance. ## How was this patch tested? Unit tests via jenkins. Author: Imran Rashid Closes #15221 from squito/SPARK-17648. --- .../spark/scheduler/TaskSchedulerImpl.scala | 4 +-- .../CoarseGrainedSchedulerBackend.scala | 4 +-- .../local/LocalSchedulerBackend.scala | 2 +- .../scheduler/SchedulerIntegrationSuite.scala | 7 ++-- .../scheduler/TaskSchedulerImplSuite.scala | 32 +++++++++---------- .../MesosFineGrainedSchedulerBackend.scala | 2 +- ...esosFineGrainedSchedulerBackendSuite.scala | 2 +- 7 files changed, 26 insertions(+), 27 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 52a7186cbf45c..0ad4730fe20a6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -252,7 +252,7 @@ private[spark] class TaskSchedulerImpl( maxLocality: TaskLocality, shuffledOffers: Seq[WorkerOffer], availableCpus: Array[Int], - tasks: Seq[ArrayBuffer[TaskDescription]]) : Boolean = { + tasks: IndexedSeq[ArrayBuffer[TaskDescription]]) : Boolean = { var launchedTask = false for (i <- 0 until shuffledOffers.size) { val execId = shuffledOffers(i).executorId @@ -286,7 +286,7 @@ private[spark] class TaskSchedulerImpl( * sets for tasks in order of priority. We fill each node with tasks in a round-robin manner so * that tasks are balanced across the cluster. */ - def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = synchronized { + def resourceOffers(offers: IndexedSeq[WorkerOffer]): Seq[Seq[TaskDescription]] = synchronized { // Mark each slave as alive and remember its hostname // Also track if new executor is added var newExecAvail = false diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index edc3c199376ef..2d0986316601f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -216,7 +216,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp val activeExecutors = executorDataMap.filterKeys(executorIsAlive) val workOffers = activeExecutors.map { case (id, executorData) => new WorkerOffer(id, executorData.executorHost, executorData.freeCores) - }.toSeq + }.toIndexedSeq launchTasks(scheduler.resourceOffers(workOffers)) } @@ -233,7 +233,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Filter out executors under killing if (executorIsAlive(executorId)) { val executorData = executorDataMap(executorId) - val workOffers = Seq( + val workOffers = IndexedSeq( new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores)) launchTasks(scheduler.resourceOffers(workOffers)) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala index e386052814039..7a73e8ed8a38f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala @@ -81,7 +81,7 @@ private[spark] class LocalEndpoint( } def reviveOffers() { - val offers = Seq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores)) + val offers = IndexedSeq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores)) for (task <- scheduler.resourceOffers(offers).flatten) { freeCores -= scheduler.CPUS_PER_TASK executor.launchTask(executorBackend, taskId = task.taskId, attemptNumber = task.attemptNumber, diff --git a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala index 14f52a6be9d1f..5cd548bbc72d9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala @@ -366,13 +366,13 @@ private[spark] abstract class MockBackend( */ def executorIdToExecutor: Map[String, ExecutorTaskStatus] - private def generateOffers(): Seq[WorkerOffer] = { + private def generateOffers(): IndexedSeq[WorkerOffer] = { executorIdToExecutor.values.filter { exec => exec.freeCores > 0 }.map { exec => WorkerOffer(executorId = exec.executorId, host = exec.host, cores = exec.freeCores) - }.toSeq + }.toIndexedSeq } /** @@ -381,8 +381,7 @@ private[spark] abstract class MockBackend( * scheduling. */ override def reviveOffers(): Unit = { - val offers: Seq[WorkerOffer] = generateOffers() - val newTaskDescriptions = taskScheduler.resourceOffers(offers).flatten + val newTaskDescriptions = taskScheduler.resourceOffers(generateOffers()).flatten // get the task now, since that requires a lock on TaskSchedulerImpl, to prevent individual // tests from introducing a race if they need it val newTasks = taskScheduler.synchronized { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 100b15740ca92..61787b54f824f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -87,7 +87,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B test("Scheduler does not always schedule tasks on the same workers") { val taskScheduler = setupScheduler() val numFreeCores = 1 - val workerOffers = Seq(new WorkerOffer("executor0", "host0", numFreeCores), + val workerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", numFreeCores), new WorkerOffer("executor1", "host1", numFreeCores)) // Repeatedly try to schedule a 1-task job, and make sure that it doesn't always // get scheduled on the same executor. While there is a chance this test will fail @@ -112,7 +112,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B val taskCpus = 2 val taskScheduler = setupScheduler("spark.task.cpus" -> taskCpus.toString) // Give zero core offers. Should not generate any tasks - val zeroCoreWorkerOffers = Seq(new WorkerOffer("executor0", "host0", 0), + val zeroCoreWorkerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", 0), new WorkerOffer("executor1", "host1", 0)) val taskSet = FakeTask.createTaskSet(1) taskScheduler.submitTasks(taskSet) @@ -121,7 +121,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B // No tasks should run as we only have 1 core free. val numFreeCores = 1 - val singleCoreWorkerOffers = Seq(new WorkerOffer("executor0", "host0", numFreeCores), + val singleCoreWorkerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", numFreeCores), new WorkerOffer("executor1", "host1", numFreeCores)) taskScheduler.submitTasks(taskSet) taskDescriptions = taskScheduler.resourceOffers(singleCoreWorkerOffers).flatten @@ -129,7 +129,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B // Now change the offers to have 2 cores in one executor and verify if it // is chosen. - val multiCoreWorkerOffers = Seq(new WorkerOffer("executor0", "host0", taskCpus), + val multiCoreWorkerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", taskCpus), new WorkerOffer("executor1", "host1", numFreeCores)) taskScheduler.submitTasks(taskSet) taskDescriptions = taskScheduler.resourceOffers(multiCoreWorkerOffers).flatten @@ -144,7 +144,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B val numFreeCores = 1 val taskSet = new TaskSet( Array(new NotSerializableFakeTask(1, 0), new NotSerializableFakeTask(0, 1)), 0, 0, 0, null) - val multiCoreWorkerOffers = Seq(new WorkerOffer("executor0", "host0", taskCpus), + val multiCoreWorkerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", taskCpus), new WorkerOffer("executor1", "host1", numFreeCores)) taskScheduler.submitTasks(taskSet) var taskDescriptions = taskScheduler.resourceOffers(multiCoreWorkerOffers).flatten @@ -184,7 +184,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B val taskScheduler = setupScheduler() val numFreeCores = 1 - val workerOffers = Seq(new WorkerOffer("executor0", "host0", numFreeCores)) + val workerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", numFreeCores)) val attempt1 = FakeTask.createTaskSet(10) // submit attempt 1, offer some resources, some tasks get scheduled @@ -216,7 +216,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B val taskScheduler = setupScheduler() val numFreeCores = 10 - val workerOffers = Seq(new WorkerOffer("executor0", "host0", numFreeCores)) + val workerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", numFreeCores)) val attempt1 = FakeTask.createTaskSet(10) // submit attempt 1, offer some resources, some tasks get scheduled @@ -254,8 +254,8 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B test("tasks are not re-scheduled while executor loss reason is pending") { val taskScheduler = setupScheduler() - val e0Offers = Seq(new WorkerOffer("executor0", "host0", 1)) - val e1Offers = Seq(new WorkerOffer("executor1", "host0", 1)) + val e0Offers = IndexedSeq(new WorkerOffer("executor0", "host0", 1)) + val e1Offers = IndexedSeq(new WorkerOffer("executor1", "host0", 1)) val attempt1 = FakeTask.createTaskSet(1) // submit attempt 1, offer resources, task gets scheduled @@ -296,7 +296,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B taskScheduler.submitTasks(taskSet) val tsm = taskScheduler.taskSetManagerForAttempt(taskSet.stageId, taskSet.stageAttemptId).get - val firstTaskAttempts = taskScheduler.resourceOffers(Seq( + val firstTaskAttempts = taskScheduler.resourceOffers(IndexedSeq( new WorkerOffer("executor0", "host0", 1), new WorkerOffer("executor1", "host1", 1) )).flatten @@ -313,7 +313,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B // on that executor, and make sure that the other task (not the failed one) is assigned there taskScheduler.executorLost("executor1", SlaveLost("oops")) val nextTaskAttempts = - taskScheduler.resourceOffers(Seq(new WorkerOffer("executor0", "host0", 1))).flatten + taskScheduler.resourceOffers(IndexedSeq(new WorkerOffer("executor0", "host0", 1))).flatten // Note: Its OK if some future change makes this already realize the taskset has become // unschedulable at this point (though in the current implementation, we're sure it will not) assert(nextTaskAttempts.size === 1) @@ -323,7 +323,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B // now we should definitely realize that our task set is unschedulable, because the only // task left can't be scheduled on any executors due to the blacklist - taskScheduler.resourceOffers(Seq(new WorkerOffer("executor0", "host0", 1))) + taskScheduler.resourceOffers(IndexedSeq(new WorkerOffer("executor0", "host0", 1))) sc.listenerBus.waitUntilEmpty(100000) assert(tsm.isZombie) assert(failedTaskSet) @@ -348,7 +348,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B taskScheduler.submitTasks(taskSet) val tsm = taskScheduler.taskSetManagerForAttempt(taskSet.stageId, taskSet.stageAttemptId).get - val offers = Seq( + val offers = IndexedSeq( // each offer has more than enough free cores for the entire task set, so when combined // with the locality preferences, we schedule all tasks on one executor new WorkerOffer("executor0", "host0", 4), @@ -380,7 +380,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B (0 until 2).map { _ => Seq(TaskLocation("host0", "executor2"))}: _* )) - val taskDescs = taskScheduler.resourceOffers(Seq( + val taskDescs = taskScheduler.resourceOffers(IndexedSeq( new WorkerOffer("executor0", "host0", 1), new WorkerOffer("executor1", "host1", 1) )).flatten @@ -396,7 +396,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B // when executor2 is added, we should realize that we can run process-local tasks. // And we should know its alive on the host. val secondTaskDescs = taskScheduler.resourceOffers( - Seq(new WorkerOffer("executor2", "host0", 1))).flatten + IndexedSeq(new WorkerOffer("executor2", "host0", 1))).flatten assert(secondTaskDescs.size === 1) assert(mgr.myLocalityLevels.toSet === Set(TaskLocality.PROCESS_LOCAL, TaskLocality.NODE_LOCAL, TaskLocality.ANY)) @@ -406,7 +406,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B // And even if we don't have anything left to schedule, another resource offer on yet another // executor should also update the set of live executors val thirdTaskDescs = taskScheduler.resourceOffers( - Seq(new WorkerOffer("executor3", "host1", 1))).flatten + IndexedSeq(new WorkerOffer("executor3", "host1", 1))).flatten assert(thirdTaskDescs.size === 0) assert(taskScheduler.getExecutorsAliveOnHost("host1") === Some(Set("executor1", "executor3"))) } diff --git a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala index eb3b235949501..09a252f3c74ac 100644 --- a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala +++ b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala @@ -286,7 +286,7 @@ private[spark] class MesosFineGrainedSchedulerBackend( o.getSlaveId.getValue, o.getHostname, cpus) - } + }.toIndexedSeq val slaveIdToOffer = usableOffers.map(o => o.getSlaveId.getValue -> o).toMap val slaveIdToWorkerOffer = workerOffers.map(o => o.executorId -> o).toMap diff --git a/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala b/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala index 7a706ab256f82..1d7a86f4b0904 100644 --- a/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala +++ b/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala @@ -283,7 +283,7 @@ class MesosFineGrainedSchedulerBackendSuite mesosOffers2.add(createOffer(1, minMem, minCpu)) reset(taskScheduler) reset(driver) - when(taskScheduler.resourceOffers(any(classOf[Seq[WorkerOffer]]))).thenReturn(Seq(Seq())) + when(taskScheduler.resourceOffers(any(classOf[IndexedSeq[WorkerOffer]]))).thenReturn(Seq(Seq())) when(taskScheduler.CPUS_PER_TASK).thenReturn(2) when(driver.declineOffer(mesosOffers2.get(0).getId)).thenReturn(Status.valueOf(1)) From cb87b3ced9453b5717fa8e8637b97a2f3f25fdd7 Mon Sep 17 00:00:00 2001 From: Gang Wu Date: Thu, 29 Sep 2016 15:51:05 -0400 Subject: [PATCH 44/96] [SPARK-17672] Spark 2.0 history server web Ui takes too long for a single application Added a new API getApplicationInfo(appId: String) in class ApplicationHistoryProvider and class SparkUI to get app info. In this change, FsHistoryProvider can directly fetch one app info in O(1) time complexity compared to O(n) before the change which used an Iterator.find() interface. Both ApplicationCache and OneApplicationResource classes adopt this new api. manual tests Author: Gang Wu Closes #15247 from wgtmac/SPARK-17671. --- .../spark/deploy/history/ApplicationHistoryProvider.scala | 5 +++++ .../org/apache/spark/deploy/history/FsHistoryProvider.scala | 4 ++++ .../org/apache/spark/deploy/history/HistoryServer.scala | 4 ++++ .../org/apache/spark/status/api/v1/ApiRootResource.scala | 1 + .../apache/spark/status/api/v1/OneApplicationResource.scala | 2 +- core/src/main/scala/org/apache/spark/ui/SparkUI.scala | 4 ++++ 6 files changed, 19 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala index 44661edfff90b..ba42b4862aa90 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala @@ -109,4 +109,9 @@ private[history] abstract class ApplicationHistoryProvider { @throws(classOf[SparkException]) def writeEventLogs(appId: String, attemptId: Option[String], zipStream: ZipOutputStream): Unit + /** + * @return the [[ApplicationHistoryInfo]] for the appId if it exists. + */ + def getApplicationInfo(appId: String): Option[ApplicationHistoryInfo] + } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 6874aa5f938ac..d494ff0659bd2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -224,6 +224,10 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) override def getListing(): Iterable[FsApplicationHistoryInfo] = applications.values + override def getApplicationInfo(appId: String): Option[FsApplicationHistoryInfo] = { + applications.get(appId) + } + override def getAppUI(appId: String, attemptId: Option[String]): Option[LoadedAppUI] = { try { applications.get(appId).flatMap { appInfo => diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index c178917d8da3b..735aa43cfc994 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -182,6 +182,10 @@ class HistoryServer( getApplicationList().iterator.map(ApplicationsListResource.appHistoryInfoToPublicAppInfo) } + def getApplicationInfo(appId: String): Option[ApplicationInfo] = { + provider.getApplicationInfo(appId).map(ApplicationsListResource.appHistoryInfoToPublicAppInfo) + } + override def writeEventLogs( appId: String, attemptId: Option[String], diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index de927117e1f63..17bc04303fa8b 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -222,6 +222,7 @@ private[spark] object ApiRootResource { private[spark] trait UIRoot { def getSparkUI(appKey: String): Option[SparkUI] def getApplicationInfoList: Iterator[ApplicationInfo] + def getApplicationInfo(appId: String): Option[ApplicationInfo] /** * Write the event logs for the given app to the [[ZipOutputStream]] instance. If attemptId is diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala index d7e6a8b589953..18c3e2f407360 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala @@ -24,7 +24,7 @@ private[v1] class OneApplicationResource(uiRoot: UIRoot) { @GET def getApp(@PathParam("appId") appId: String): ApplicationInfo = { - val apps = uiRoot.getApplicationInfoList.find { _.id == appId } + val apps = uiRoot.getApplicationInfo(appId) apps.getOrElse(throw new NotFoundException("unknown app: " + appId)) } diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 39155ff2649ec..ef71db89798f1 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -126,6 +126,10 @@ private[spark] class SparkUI private ( )) )) } + + def getApplicationInfo(appId: String): Option[ApplicationInfo] = { + getApplicationInfoList.find(_.id == appId) + } } private[spark] abstract class SparkUITab(parent: SparkUI, prefix: String) From 027dea8f294504bc5cd8bfedde546d171cb78657 Mon Sep 17 00:00:00 2001 From: Brian Cho Date: Thu, 29 Sep 2016 15:59:17 -0400 Subject: [PATCH 45/96] [SPARK-17715][SCHEDULER] Make task launch logs DEBUG ## What changes were proposed in this pull request? Ramp down the task launch logs from INFO to DEBUG. Task launches can happen orders of magnitude more than executor registration so it makes the logs easier to handle if they are different log levels. For larger jobs, there can be 100,000s of task launches which makes the driver log huge. ## How was this patch tested? No tests, as this is a trivial change. Author: Brian Cho Closes #15290 from dafrista/ramp-down-task-logging. --- .../spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 2d0986316601f..0dae0e614e17d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -265,7 +265,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp val executorData = executorDataMap(task.executorId) executorData.freeCores -= scheduler.CPUS_PER_TASK - logInfo(s"Launching task ${task.taskId} on executor id: ${task.executorId} hostname: " + + logDebug(s"Launching task ${task.taskId} on executor id: ${task.executorId} hostname: " + s"${executorData.executorHost}.") executorData.executorEndpoint.send(LaunchTask(new SerializableBuffer(serializedTask))) From fe33121a53384811a8e094ab6c05dc85b7c7ca87 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 29 Sep 2016 13:01:10 -0700 Subject: [PATCH 46/96] [SPARK-17699] Support for parsing JSON string columns Spark SQL has great support for reading text files that contain JSON data. However, in many cases the JSON data is just one column amongst others. This is particularly true when reading from sources such as Kafka. This PR adds a new functions `from_json` that converts a string column into a nested `StructType` with a user specified schema. Example usage: ```scala val df = Seq("""{"a": 1}""").toDS() val schema = new StructType().add("a", IntegerType) df.select(from_json($"value", schema) as 'json) // => [json: ] ``` This PR adds support for java, scala and python. I leveraged our existing JSON parsing support by moving it into catalyst (so that we could define expressions using it). I left SQL out for now, because I'm not sure how users would specify a schema. Author: Michael Armbrust Closes #15274 from marmbrus/jsonParser. --- python/pyspark/sql/functions.py | 23 ++++++++ .../expressions/jsonExpressions.scala | 31 +++++++++- .../sql/catalyst}/json/JSONOptions.scala | 6 +- .../sql/catalyst}/json/JacksonParser.scala | 13 +++-- .../sql/catalyst}/json/JacksonUtils.scala | 4 +- .../catalyst/util}/CompressionCodecs.scala | 6 +- .../spark/sql/catalyst/util}/ParseModes.scala | 4 +- .../expressions/JsonExpressionsSuite.scala | 26 +++++++++ .../apache/spark/sql/DataFrameReader.scala | 5 +- .../datasources/csv/CSVFileFormat.scala | 1 + .../datasources/csv/CSVOptions.scala | 2 +- .../datasources/json/InferSchema.scala | 3 +- .../datasources/json/JacksonGenerator.scala | 3 +- .../datasources/json/JsonFileFormat.scala | 2 + .../datasources/text/TextFileFormat.scala | 1 + .../org/apache/spark/sql/functions.scala | 58 +++++++++++++++++++ .../apache/spark/sql/JsonFunctionsSuite.scala | 29 ++++++++++ .../json/JsonParsingOptionsSuite.scala | 1 + .../datasources/json/JsonSuite.scala | 3 +- 19 files changed, 198 insertions(+), 23 deletions(-) rename sql/{core/src/main/scala/org/apache/spark/sql/execution/datasources => catalyst/src/main/scala/org/apache/spark/sql/catalyst}/json/JSONOptions.scala (95%) rename sql/{core/src/main/scala/org/apache/spark/sql/execution/datasources => catalyst/src/main/scala/org/apache/spark/sql/catalyst}/json/JacksonParser.scala (97%) rename sql/{core/src/main/scala/org/apache/spark/sql/execution/datasources => catalyst/src/main/scala/org/apache/spark/sql/catalyst}/json/JacksonUtils.scala (92%) rename sql/{core/src/main/scala/org/apache/spark/sql/execution/datasources => catalyst/src/main/scala/org/apache/spark/sql/catalyst/util}/CompressionCodecs.scala (93%) rename sql/{core/src/main/scala/org/apache/spark/sql/execution/datasources => catalyst/src/main/scala/org/apache/spark/sql/catalyst/util}/ParseModes.scala (94%) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 89b3c07c0740f..45d6bf944b702 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1706,6 +1706,29 @@ def json_tuple(col, *fields): return Column(jc) +@since(2.1) +def from_json(col, schema, options={}): + """ + Parses a column containing a JSON string into a [[StructType]] with the + specified schema. Returns `null`, in the case of an unparseable string. + + :param col: string column in json format + :param schema: a StructType to use when parsing the json column + :param options: options to control parsing. accepts the same options as the json datasource + + >>> from pyspark.sql.types import * + >>> data = [(1, '''{"a": 1}''')] + >>> schema = StructType([StructField("a", IntegerType())]) + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(from_json(df.value, schema).alias("json")).collect() + [Row(json=Row(a=1))] + """ + + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.from_json(_to_java_column(col), schema.json(), options) + return Column(jc) + + @since(1.5) def size(col): """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index c14a2fb122618..65dbd6a4e3f1d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -23,10 +23,12 @@ import scala.util.parsing.combinator.RegexParsers import com.fasterxml.jackson.core._ -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.types.{DataType, StringType, StructField, StructType} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.json.{JacksonParser, JSONOptions, SparkSQLJsonProcessingException} +import org.apache.spark.sql.catalyst.util.ParseModes +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -467,3 +469,28 @@ case class JsonTuple(children: Seq[Expression]) } } +/** + * Converts an json input string to a [[StructType]] with the specified schema. + */ +case class JsonToStruct(schema: StructType, options: Map[String, String], child: Expression) + extends Expression with CodegenFallback with ExpectsInputTypes { + override def nullable: Boolean = true + + @transient + lazy val parser = + new JacksonParser( + schema, + "invalid", // Not used since we force fail fast. Invalid rows will be set to `null`. + new JSONOptions(options ++ Map("mode" -> ParseModes.FAIL_FAST_MODE))) + + override def dataType: DataType = schema + override def children: Seq[Expression] = child :: Nil + + override def eval(input: InternalRow): Any = { + try parser.parse(child.eval(input).toString).head catch { + case _: SparkSQLJsonProcessingException => null + } + } + + override def inputTypes: Seq[AbstractDataType] = StringType :: Nil +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala similarity index 95% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 02d211d04265e..aec18922ea6c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -15,16 +15,16 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.json +package org.apache.spark.sql.catalyst.json import com.fasterxml.jackson.core.{JsonFactory, JsonParser} import org.apache.commons.lang3.time.FastDateFormat import org.apache.spark.internal.Logging -import org.apache.spark.sql.execution.datasources.{CompressionCodecs, ParseModes} +import org.apache.spark.sql.catalyst.util.{CompressionCodecs, ParseModes} /** - * Options for the JSON data source. + * Options for parsing JSON data into Spark SQL rows. * * Most of these map directly to Jackson's internal options, specified in [[JsonParser.Feature]]. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala similarity index 97% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 5ce1bf7432159..f80e6373d2f89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.json +package org.apache.spark.sql.catalyst.json import java.io.ByteArrayOutputStream @@ -28,19 +28,22 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution.datasources.ParseModes.{DROP_MALFORMED_MODE, PERMISSIVE_MODE} -import org.apache.spark.sql.execution.datasources.json.JacksonUtils.nextUntil import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils -private[json] class SparkSQLJsonProcessingException(msg: String) extends RuntimeException(msg) +private[sql] class SparkSQLJsonProcessingException(msg: String) extends RuntimeException(msg) +/** + * Constructs a parser for a given schema that translates a json string to an [[InternalRow]]. + */ class JacksonParser( schema: StructType, columnNameOfCorruptRecord: String, options: JSONOptions) extends Logging { + import JacksonUtils._ + import ParseModes._ import com.fasterxml.jackson.core.JsonToken._ // A `ValueConverter` is responsible for converting a value from `JsonParser` @@ -65,7 +68,7 @@ class JacksonParser( private def failedRecord(record: String): Seq[InternalRow] = { // create a row even if no corrupt record column is present if (options.failFast) { - throw new RuntimeException(s"Malformed line in FAILFAST mode: $record") + throw new SparkSQLJsonProcessingException(s"Malformed line in FAILFAST mode: $record") } if (options.dropMalformed) { if (!isWarningPrintedForMalformedRecord) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala similarity index 92% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonUtils.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala index 005546f37dda0..c4d9abb2c07e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala @@ -15,11 +15,11 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.json +package org.apache.spark.sql.catalyst.json import com.fasterxml.jackson.core.{JsonParser, JsonToken} -private object JacksonUtils { +object JacksonUtils { /** * Advance the parser until a null or a specific token is found */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CompressionCodecs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CompressionCodecs.scala similarity index 93% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CompressionCodecs.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CompressionCodecs.scala index 41cff07472d1e..435fba9d8851c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CompressionCodecs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CompressionCodecs.scala @@ -15,15 +15,15 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources +package org.apache.spark.sql.catalyst.util import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.SequenceFile.CompressionType -import org.apache.hadoop.io.compress.{BZip2Codec, DeflateCodec, GzipCodec, Lz4Codec, SnappyCodec} +import org.apache.hadoop.io.compress._ import org.apache.spark.util.Utils -private[datasources] object CompressionCodecs { +object CompressionCodecs { private val shortCompressionCodecNames = Map( "none" -> null, "uncompressed" -> null, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ParseModes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ParseModes.scala similarity index 94% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ParseModes.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ParseModes.scala index 468228053c964..0e466962b4678 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ParseModes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ParseModes.scala @@ -15,9 +15,9 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources +package org.apache.spark.sql.catalyst.util -private[datasources] object ParseModes { +object ParseModes { val PERMISSIVE_MODE = "PERMISSIVE" val DROP_MALFORMED_MODE = "DROPMALFORMED" val FAIL_FAST_MODE = "FAILFAST" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 7b754091f4714..84623934d95d2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.ParseModes +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.unsafe.types.UTF8String class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -317,4 +319,28 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { JsonTuple(Literal("{\"a\":\"b\nc\"}") :: Literal("a") :: Nil), InternalRow.fromSeq(Seq(UTF8String.fromString("b\nc")))) } + + test("from_json") { + val jsonData = """{"a": 1}""" + val schema = StructType(StructField("a", IntegerType) :: Nil) + checkEvaluation( + JsonToStruct(schema, Map.empty, Literal(jsonData)), + InternalRow.fromSeq(1 :: Nil) + ) + } + + test("from_json - invalid data") { + val jsonData = """{"a" 1}""" + val schema = StructType(StructField("a", IntegerType) :: Nil) + checkEvaluation( + JsonToStruct(schema, Map.empty, Literal(jsonData)), + null + ) + + // Other modes should still return `null`. + checkEvaluation( + JsonToStruct(schema, Map("mode" -> ParseModes.PERMISSIVE_MODE), Literal(jsonData)), + null + ) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index b10d2c86ac5ef..b84fb2fb95914 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -21,14 +21,15 @@ import java.util.Properties import scala.collection.JavaConverters._ -import org.apache.spark.Partition import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging +import org.apache.spark.Partition import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.json.{JacksonParser, JSONOptions} import org.apache.spark.sql.execution.LogicalRDD import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} -import org.apache.spark.sql.execution.datasources.json.{InferSchema, JacksonParser, JSONOptions} +import org.apache.spark.sql.execution.datasources.json.InferSchema import org.apache.spark.sql.types.StructType /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 9610746a81ef7..4e662a52a7bb7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -29,6 +29,7 @@ import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index e7dcc22272192..014614eb997a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -22,7 +22,7 @@ import java.nio.charset.StandardCharsets import org.apache.commons.lang3.time.FastDateFormat import org.apache.spark.internal.Logging -import org.apache.spark.sql.execution.datasources.{CompressionCodecs, ParseModes} +import org.apache.spark.sql.catalyst.util.{CompressionCodecs, ParseModes} private[csv] class CSVOptions(@transient private val parameters: Map[String, String]) extends Logging with Serializable { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala index 91c58d059d287..dc8bd817f2906 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala @@ -23,7 +23,8 @@ import com.fasterxml.jackson.core._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion -import org.apache.spark.sql.execution.datasources.json.JacksonUtils.nextUntil +import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil +import org.apache.spark.sql.catalyst.json.JSONOptions import org.apache.spark.sql.types._ import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala index 270e7fbd3c137..5b55b701862b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala @@ -21,8 +21,9 @@ import java.io.Writer import com.fasterxml.jackson.core._ -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.json.JSONOptions import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, MapData} import org.apache.spark.sql.types._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 6882a6cdcac26..9fe38ccc9fdc6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -32,6 +32,8 @@ import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.json.{JacksonParser, JSONOptions} +import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index a875b01ec2d7a..9f96667311015 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} +import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{StringType, StructType} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 47bf41a2da813..3bc1c5b90031d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.runtime.universe.{typeTag, TypeTag} import scala.util.Try @@ -2818,6 +2819,63 @@ object functions { JsonTuple(json.expr +: fields.map(Literal.apply)) } + /** + * (Scala-specific) Parses a column containing a JSON string into a [[StructType]] with the + * specified schema. Returns `null`, in the case of an unparseable string. + * + * @param schema the schema to use when parsing the json string + * @param options options to control how the json is parsed. accepts the same options and the + * json data source. + * @param e a string column containing JSON data. + * + * @group collection_funcs + * @since 2.1.0 + */ + def from_json(e: Column, schema: StructType, options: Map[String, String]): Column = withExpr { + JsonToStruct(schema, options, e.expr) + } + + /** + * (Java-specific) Parses a column containing a JSON string into a [[StructType]] with the + * specified schema. Returns `null`, in the case of an unparseable string. + * + * @param e a string column containing JSON data. + * @param schema the schema to use when parsing the json string + * @param options options to control how the json is parsed. accepts the same options and the + * json data source. + * + * @group collection_funcs + * @since 2.1.0 + */ + def from_json(e: Column, schema: StructType, options: java.util.Map[String, String]): Column = + from_json(e, schema, options.asScala.toMap) + + /** + * Parses a column containing a JSON string into a [[StructType]] with the specified schema. + * Returns `null`, in the case of an unparseable string. + * + * @param e a string column containing JSON data. + * @param schema the schema to use when parsing the json string + * + * @group collection_funcs + * @since 2.1.0 + */ + def from_json(e: Column, schema: StructType): Column = + from_json(e, schema, Map.empty[String, String]) + + /** + * Parses a column containing a JSON string into a [[StructType]] with the specified schema. + * Returns `null`, in the case of an unparseable string. + * + * @param e a string column containing JSON data. + * @param schema the schema to use when parsing the json string as a json string + * + * @group collection_funcs + * @since 2.1.0 + */ + def from_json(e: Column, schema: String, options: java.util.Map[String, String]): Column = + from_json(e, DataType.fromJson(schema).asInstanceOf[StructType], options) + /** * Returns length of array or map. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 1391c9d57ff7c..518d6e92b2ff7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql +import org.apache.spark.sql.functions.from_json import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{IntegerType, StructType} class JsonFunctionsSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -94,4 +96,31 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(expr, expected) } + + test("json_parser") { + val df = Seq("""{"a": 1}""").toDS() + val schema = new StructType().add("a", IntegerType) + + checkAnswer( + df.select(from_json($"value", schema)), + Row(Row(1)) :: Nil) + } + + test("json_parser missing columns") { + val df = Seq("""{"a": 1}""").toDS() + val schema = new StructType().add("b", IntegerType) + + checkAnswer( + df.select(from_json($"value", schema)), + Row(Row(null)) :: Nil) + } + + test("json_parser invalid json") { + val df = Seq("""{"a" 1}""").toDS() + val schema = new StructType().add("a", IntegerType) + + checkAnswer( + df.select(from_json($"value", schema)), + Row(null) :: Nil) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala index c31dffedbdf67..0b72da5f3759c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.json import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.json.JSONOptions import org.apache.spark.sql.test.SharedSQLContext /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 3d533c14e18e7..456052f79afcc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -26,9 +26,10 @@ import org.apache.hadoop.fs.{Path, PathFilter} import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.GzipCodec -import org.apache.spark.SparkException import org.apache.spark.rdd.RDD +import org.apache.spark.SparkException import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.json.{JacksonParser, JSONOptions} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType From 566d7f28275f90f7b9bed6a75e90989ad0c59931 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 29 Sep 2016 14:30:23 -0700 Subject: [PATCH 47/96] [SPARK-17653][SQL] Remove unnecessary distincts in multiple unions ## What changes were proposed in this pull request? Currently for `Union [Distinct]`, a `Distinct` operator is necessary to be on the top of `Union`. Once there are adjacent `Union [Distinct]`, there will be multiple `Distinct` in the query plan. E.g., For a query like: select 1 a union select 2 b union select 3 c Before this patch, its physical plan looks like: *HashAggregate(keys=[a#13], functions=[]) +- Exchange hashpartitioning(a#13, 200) +- *HashAggregate(keys=[a#13], functions=[]) +- Union :- *HashAggregate(keys=[a#13], functions=[]) : +- Exchange hashpartitioning(a#13, 200) : +- *HashAggregate(keys=[a#13], functions=[]) : +- Union : :- *Project [1 AS a#13] : : +- Scan OneRowRelation[] : +- *Project [2 AS b#14] : +- Scan OneRowRelation[] +- *Project [3 AS c#15] +- Scan OneRowRelation[] Only the top distinct should be necessary. After this patch, the physical plan looks like: *HashAggregate(keys=[a#221], functions=[], output=[a#221]) +- Exchange hashpartitioning(a#221, 5) +- *HashAggregate(keys=[a#221], functions=[], output=[a#221]) +- Union :- *Project [1 AS a#221] : +- Scan OneRowRelation[] :- *Project [2 AS b#222] : +- Scan OneRowRelation[] +- *Project [3 AS c#223] +- Scan OneRowRelation[] ## How was this patch tested? Jenkins tests. Author: Liang-Chi Hsieh Closes #15238 from viirya/remove-extra-distinct-union. --- .../sql/catalyst/optimizer/Optimizer.scala | 24 ++++++- .../sql/catalyst/planning/patterns.scala | 27 -------- .../optimizer/SetOperationSuite.scala | 68 +++++++++++++++++++ 3 files changed, 89 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 4952ba3b2b99d..9df8ce1fa3b28 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.annotation.tailrec import scala.collection.immutable.HashSet +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.api.java.function.FilterFunction @@ -29,7 +30,7 @@ import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} -import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, Unions} +import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -579,8 +580,25 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe * Combines all adjacent [[Union]] operators into a single [[Union]]. */ object CombineUnions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case Unions(children) => Union(children) + def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + case u: Union => flattenUnion(u, false) + case Distinct(u: Union) => Distinct(flattenUnion(u, true)) + } + + private def flattenUnion(union: Union, flattenDistinct: Boolean): Union = { + val stack = mutable.Stack[LogicalPlan](union) + val flattened = mutable.ArrayBuffer.empty[LogicalPlan] + while (stack.nonEmpty) { + stack.pop() match { + case Distinct(Union(children)) if flattenDistinct => + stack.pushAll(children.reverse) + case Union(children) => + stack.pushAll(children.reverse) + case child => + flattened += child + } + } + Union(flattened) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 41cabb8cb3390..bdae56881bf46 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -188,33 +188,6 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper { } } - -/** - * A pattern that collects all adjacent unions and returns their children as a Seq. - */ -object Unions { - def unapply(plan: LogicalPlan): Option[Seq[LogicalPlan]] = plan match { - case u: Union => Some(collectUnionChildren(mutable.Stack(u), Seq.empty[LogicalPlan])) - case _ => None - } - - // Doing a depth-first tree traversal to combine all the union children. - @tailrec - private def collectUnionChildren( - plans: mutable.Stack[LogicalPlan], - children: Seq[LogicalPlan]): Seq[LogicalPlan] = { - if (plans.isEmpty) children - else { - plans.pop match { - case Union(grandchildren) => - grandchildren.reverseMap(plans.push(_)) - collectUnionChildren(plans, children) - case other => collectUnionChildren(plans, children :+ other) - } - } - } -} - /** * An extractor used when planning the physical execution of an aggregation. Compared with a logical * aggregation, the following transformations are performed: diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala index 7227706ab2b36..21b7f49e14bd5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -76,4 +77,71 @@ class SetOperationSuite extends PlanTest { testRelation3.select('g) :: Nil).analyze comparePlans(unionOptimized, unionCorrectAnswer) } + + test("Remove unnecessary distincts in multiple unions") { + val query1 = OneRowRelation + .select(Literal(1).as('a)) + val query2 = OneRowRelation + .select(Literal(2).as('b)) + val query3 = OneRowRelation + .select(Literal(3).as('c)) + + // D - U - D - U - query1 + // | | + // query3 query2 + val unionQuery1 = Distinct(Union(Distinct(Union(query1, query2)), query3)).analyze + val optimized1 = Optimize.execute(unionQuery1) + val distinctUnionCorrectAnswer1 = + Distinct(Union(query1 :: query2 :: query3 :: Nil)).analyze + comparePlans(distinctUnionCorrectAnswer1, optimized1) + + // query1 + // | + // D - U - U - query2 + // | + // D - U - query2 + // | + // query3 + val unionQuery2 = Distinct(Union(Union(query1, query2), + Distinct(Union(query2, query3)))).analyze + val optimized2 = Optimize.execute(unionQuery2) + val distinctUnionCorrectAnswer2 = + Distinct(Union(query1 :: query2 :: query2 :: query3 :: Nil)).analyze + comparePlans(distinctUnionCorrectAnswer2, optimized2) + } + + test("Keep necessary distincts in multiple unions") { + val query1 = OneRowRelation + .select(Literal(1).as('a)) + val query2 = OneRowRelation + .select(Literal(2).as('b)) + val query3 = OneRowRelation + .select(Literal(3).as('c)) + val query4 = OneRowRelation + .select(Literal(4).as('d)) + + // U - D - U - query1 + // | | + // query3 query2 + val unionQuery1 = Union(Distinct(Union(query1, query2)), query3).analyze + val optimized1 = Optimize.execute(unionQuery1) + val distinctUnionCorrectAnswer1 = + Union(Distinct(Union(query1 :: query2 :: Nil)) :: query3 :: Nil).analyze + comparePlans(distinctUnionCorrectAnswer1, optimized1) + + // query1 + // | + // U - D - U - query2 + // | + // D - U - query3 + // | + // query4 + val unionQuery2 = + Union(Distinct(Union(query1, query2)), Distinct(Union(query3, query4))).analyze + val optimized2 = Optimize.execute(unionQuery2) + val distinctUnionCorrectAnswer2 = + Union(Distinct(Union(query1 :: query2 :: Nil)), + Distinct(Union(query3 :: query4 :: Nil))).analyze + comparePlans(distinctUnionCorrectAnswer2, optimized2) + } } From 4ecc648ad713f9d618adf0406b5d39981779059d Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 29 Sep 2016 15:30:18 -0700 Subject: [PATCH 48/96] [SPARK-17612][SQL] Support `DESCRIBE table PARTITION` SQL syntax ## What changes were proposed in this pull request? This PR implements `DESCRIBE table PARTITION` SQL Syntax again. It was supported until Spark 1.6.2, but was dropped since 2.0.0. **Spark 1.6.2** ```scala scala> sql("CREATE TABLE partitioned_table (a STRING, b INT) PARTITIONED BY (c STRING, d STRING)") res1: org.apache.spark.sql.DataFrame = [result: string] scala> sql("ALTER TABLE partitioned_table ADD PARTITION (c='Us', d=1)") res2: org.apache.spark.sql.DataFrame = [result: string] scala> sql("DESC partitioned_table PARTITION (c='Us', d=1)").show(false) +----------------------------------------------------------------+ |result | +----------------------------------------------------------------+ |a string | |b int | |c string | |d string | | | |# Partition Information | |# col_name data_type comment | | | |c string | |d string | +----------------------------------------------------------------+ ``` **Spark 2.0** - **Before** ```scala scala> sql("CREATE TABLE partitioned_table (a STRING, b INT) PARTITIONED BY (c STRING, d STRING)") res0: org.apache.spark.sql.DataFrame = [] scala> sql("ALTER TABLE partitioned_table ADD PARTITION (c='Us', d=1)") res1: org.apache.spark.sql.DataFrame = [] scala> sql("DESC partitioned_table PARTITION (c='Us', d=1)").show(false) org.apache.spark.sql.catalyst.parser.ParseException: Unsupported SQL statement ``` - **After** ```scala scala> sql("CREATE TABLE partitioned_table (a STRING, b INT) PARTITIONED BY (c STRING, d STRING)") res0: org.apache.spark.sql.DataFrame = [] scala> sql("ALTER TABLE partitioned_table ADD PARTITION (c='Us', d=1)") res1: org.apache.spark.sql.DataFrame = [] scala> sql("DESC partitioned_table PARTITION (c='Us', d=1)").show(false) +-----------------------+---------+-------+ |col_name |data_type|comment| +-----------------------+---------+-------+ |a |string |null | |b |int |null | |c |string |null | |d |string |null | |# Partition Information| | | |# col_name |data_type|comment| |c |string |null | |d |string |null | +-----------------------+---------+-------+ scala> sql("DESC EXTENDED partitioned_table PARTITION (c='Us', d=1)").show(100,false) +-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+---------+-------+ |col_name |data_type|comment| +-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+---------+-------+ |a |string |null | |b |int |null | |c |string |null | |d |string |null | |# Partition Information | | | |# col_name |data_type|comment| |c |string |null | |d |string |null | | | | | |Detailed Partition Information CatalogPartition( Partition Values: [Us, 1] Storage(Location: file:/Users/dhyun/SPARK-17612-DESC-PARTITION/spark-warehouse/partitioned_table/c=Us/d=1, InputFormat: org.apache.hadoop.mapred.TextInputFormat, OutputFormat: org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat, Serde: org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, Properties: [serialization.format=1]) Partition Parameters:{transient_lastDdlTime=1475001066})| | | +-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+---------+-------+ scala> sql("DESC FORMATTED partitioned_table PARTITION (c='Us', d=1)").show(100,false) +--------------------------------+---------------------------------------------------------------------------------------+-------+ |col_name |data_type |comment| +--------------------------------+---------------------------------------------------------------------------------------+-------+ |a |string |null | |b |int |null | |c |string |null | |d |string |null | |# Partition Information | | | |# col_name |data_type |comment| |c |string |null | |d |string |null | | | | | |# Detailed Partition Information| | | |Partition Value: |[Us, 1] | | |Database: |default | | |Table: |partitioned_table | | |Location: |file:/Users/dhyun/SPARK-17612-DESC-PARTITION/spark-warehouse/partitioned_table/c=Us/d=1| | |Partition Parameters: | | | | transient_lastDdlTime |1475001066 | | | | | | |# Storage Information | | | |SerDe Library: |org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe | | |InputFormat: |org.apache.hadoop.mapred.TextInputFormat | | |OutputFormat: |org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat | | |Compressed: |No | | |Storage Desc Parameters: | | | | serialization.format |1 | | +--------------------------------+---------------------------------------------------------------------------------------+-------+ ``` ## How was this patch tested? Pass the Jenkins tests with a new testcase. Author: Dongjoon Hyun Closes #15168 from dongjoon-hyun/SPARK-17612. --- .../sql/catalyst/catalog/interface.scala | 13 ++- .../spark/sql/execution/SparkSqlParser.scala | 15 +++- .../spark/sql/execution/command/tables.scala | 83 ++++++++++++++--- .../resources/sql-tests/inputs/describe.sql | 27 ++++++ .../sql-tests/results/describe.sql.out | 90 +++++++++++++++++++ .../sql/hive/execution/SQLQuerySuite.scala | 77 +++++++++++++++- 6 files changed, 287 insertions(+), 18 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/describe.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/describe.sql.out diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index e52251f960ff4..51326ca25e9cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -86,7 +86,18 @@ object CatalogStorageFormat { case class CatalogTablePartition( spec: CatalogTypes.TablePartitionSpec, storage: CatalogStorageFormat, - parameters: Map[String, String] = Map.empty) + parameters: Map[String, String] = Map.empty) { + + override def toString: String = { + val output = + Seq( + s"Partition Values: [${spec.values.mkString(", ")}]", + s"$storage", + s"Partition Parameters:{${parameters.map(p => p._1 + "=" + p._2).mkString(", ")}}") + + output.filter(_.nonEmpty).mkString("CatalogPartition(\n\t", "\n\t", ")") + } +} /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 5359cedc80974..3f34d0f25393d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -276,13 +276,24 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * Create a [[DescribeTableCommand]] logical plan. */ override def visitDescribeTable(ctx: DescribeTableContext): LogicalPlan = withOrigin(ctx) { - // Describe partition and column are not supported yet. Return null and let the parser decide + // Describe column are not supported yet. Return null and let the parser decide // what to do with this (create an exception or pass it on to a different system). - if (ctx.describeColName != null || ctx.partitionSpec != null) { + if (ctx.describeColName != null) { null } else { + val partitionSpec = if (ctx.partitionSpec != null) { + // According to the syntax, visitPartitionSpec returns `Map[String, Option[String]]`. + visitPartitionSpec(ctx.partitionSpec).map { + case (key, Some(value)) => key -> value + case (key, _) => + throw new ParseException(s"PARTITION specification is incomplete: `$key`", ctx) + } + } else { + Map.empty[String, String] + } DescribeTableCommand( visitTableIdentifier(ctx.tableIdentifier), + partitionSpec, ctx.EXTENDED != null, ctx.FORMATTED != null) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 6a91c997bac63..08de6cd4242c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -29,7 +29,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTableType._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} @@ -390,10 +390,14 @@ case class TruncateTableCommand( /** * Command that looks like * {{{ - * DESCRIBE [EXTENDED|FORMATTED] table_name; + * DESCRIBE [EXTENDED|FORMATTED] table_name partitionSpec?; * }}} */ -case class DescribeTableCommand(table: TableIdentifier, isExtended: Boolean, isFormatted: Boolean) +case class DescribeTableCommand( + table: TableIdentifier, + partitionSpec: TablePartitionSpec, + isExtended: Boolean, + isFormatted: Boolean) extends RunnableCommand { override val output: Seq[Attribute] = Seq( @@ -411,17 +415,25 @@ case class DescribeTableCommand(table: TableIdentifier, isExtended: Boolean, isF val catalog = sparkSession.sessionState.catalog if (catalog.isTemporaryTable(table)) { + if (partitionSpec.nonEmpty) { + throw new AnalysisException( + s"DESC PARTITION is not allowed on a temporary view: ${table.identifier}") + } describeSchema(catalog.lookupRelation(table).schema, result) } else { val metadata = catalog.getTableMetadata(table) describeSchema(metadata.schema, result) - if (isExtended) { - describeExtended(metadata, result) - } else if (isFormatted) { - describeFormatted(metadata, result) + describePartitionInfo(metadata, result) + + if (partitionSpec.isEmpty) { + if (isExtended) { + describeExtendedTableInfo(metadata, result) + } else if (isFormatted) { + describeFormattedTableInfo(metadata, result) + } } else { - describePartitionInfo(metadata, result) + describeDetailedPartitionInfo(catalog, metadata, result) } } @@ -436,16 +448,12 @@ case class DescribeTableCommand(table: TableIdentifier, isExtended: Boolean, isF } } - private def describeExtended(table: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { - describePartitionInfo(table, buffer) - + private def describeExtendedTableInfo(table: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { append(buffer, "", "", "") append(buffer, "# Detailed Table Information", table.toString, "") } - private def describeFormatted(table: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { - describePartitionInfo(table, buffer) - + private def describeFormattedTableInfo(table: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { append(buffer, "", "", "") append(buffer, "# Detailed Table Information", "", "") append(buffer, "Database:", table.database, "") @@ -499,6 +507,53 @@ case class DescribeTableCommand(table: TableIdentifier, isExtended: Boolean, isF } } + private def describeDetailedPartitionInfo( + catalog: SessionCatalog, + metadata: CatalogTable, + result: ArrayBuffer[Row]): Unit = { + if (metadata.tableType == CatalogTableType.VIEW) { + throw new AnalysisException( + s"DESC PARTITION is not allowed on a view: ${table.identifier}") + } + if (DDLUtils.isDatasourceTable(metadata)) { + throw new AnalysisException( + s"DESC PARTITION is not allowed on a datasource table: ${table.identifier}") + } + val partition = catalog.getPartition(table, partitionSpec) + if (isExtended) { + describeExtendedDetailedPartitionInfo(table, metadata, partition, result) + } else if (isFormatted) { + describeFormattedDetailedPartitionInfo(table, metadata, partition, result) + describeStorageInfo(metadata, result) + } + } + + private def describeExtendedDetailedPartitionInfo( + tableIdentifier: TableIdentifier, + table: CatalogTable, + partition: CatalogTablePartition, + buffer: ArrayBuffer[Row]): Unit = { + append(buffer, "", "", "") + append(buffer, "Detailed Partition Information " + partition.toString, "", "") + } + + private def describeFormattedDetailedPartitionInfo( + tableIdentifier: TableIdentifier, + table: CatalogTable, + partition: CatalogTablePartition, + buffer: ArrayBuffer[Row]): Unit = { + append(buffer, "", "", "") + append(buffer, "# Detailed Partition Information", "", "") + append(buffer, "Partition Value:", s"[${partition.spec.values.mkString(", ")}]", "") + append(buffer, "Database:", table.database, "") + append(buffer, "Table:", tableIdentifier.table, "") + append(buffer, "Location:", partition.storage.locationUri.getOrElse(""), "") + append(buffer, "Partition Parameters:", "", "") + partition.parameters.foreach { case (key, value) => + append(buffer, s" $key", value, "") + } + } + private def describeSchema(schema: StructType, buffer: ArrayBuffer[Row]): Unit = { schema.foreach { column => append(buffer, column.name, column.dataType.simpleString, column.getComment().orNull) diff --git a/sql/core/src/test/resources/sql-tests/inputs/describe.sql b/sql/core/src/test/resources/sql-tests/inputs/describe.sql new file mode 100644 index 0000000000000..3f0ae902e0529 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/describe.sql @@ -0,0 +1,27 @@ +CREATE TABLE t (a STRING, b INT) PARTITIONED BY (c STRING, d STRING); + +ALTER TABLE t ADD PARTITION (c='Us', d=1); + +DESC t; + +-- Ignore these because there exist timestamp results, e.g., `Create Table`. +-- DESC EXTENDED t; +-- DESC FORMATTED t; + +DESC t PARTITION (c='Us', d=1); + +-- Ignore these because there exist timestamp results, e.g., transient_lastDdlTime. +-- DESC EXTENDED t PARTITION (c='Us', d=1); +-- DESC FORMATTED t PARTITION (c='Us', d=1); + +-- NoSuchPartitionException: Partition not found in table +DESC t PARTITION (c='Us', d=2); + +-- AnalysisException: Partition spec is invalid +DESC t PARTITION (c='Us'); + +-- ParseException: PARTITION specification is incomplete +DESC t PARTITION (c='Us', d); + +-- DROP TEST TABLE +DROP TABLE t; diff --git a/sql/core/src/test/resources/sql-tests/results/describe.sql.out b/sql/core/src/test/resources/sql-tests/results/describe.sql.out new file mode 100644 index 0000000000000..37bf303f1bfe4 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/describe.sql.out @@ -0,0 +1,90 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 8 + + +-- !query 0 +CREATE TABLE t (a STRING, b INT) PARTITIONED BY (c STRING, d STRING) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +ALTER TABLE t ADD PARTITION (c='Us', d=1) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +DESC t +-- !query 2 schema +struct +-- !query 2 output +# Partition Information +# col_name data_type comment +a string +b int +c string +c string +d string +d string + + +-- !query 3 +DESC t PARTITION (c='Us', d=1) +-- !query 3 schema +struct +-- !query 3 output +# Partition Information +# col_name data_type comment +a string +b int +c string +c string +d string +d string + + +-- !query 4 +DESC t PARTITION (c='Us', d=2) +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.catalyst.analysis.NoSuchPartitionException +Partition not found in table 't' database 'default': +c -> Us +d -> 2; + + +-- !query 5 +DESC t PARTITION (c='Us') +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +Partition spec is invalid. The spec (c) must match the partition spec (c, d) defined in table '`default`.`t`'; + + +-- !query 6 +DESC t PARTITION (c='Us', d) +-- !query 6 schema +struct<> +-- !query 6 output +org.apache.spark.sql.catalyst.parser.ParseException + +PARTITION specification is incomplete: `d`(line 1, pos 0) + +== SQL == +DESC t PARTITION (c='Us', d) +^^^ + + +-- !query 7 +DROP TABLE t +-- !query 7 schema +struct<> +-- !query 7 output + diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index dc4d099f0f666..6c77a0deb52a4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, FunctionRegistry} +import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, FunctionRegistry, NoSuchPartitionException} import org.apache.spark.sql.catalyst.catalog.CatalogTableType import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} @@ -341,6 +341,81 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } + test("describe partition") { + withTable("partitioned_table") { + sql("CREATE TABLE partitioned_table (a STRING, b INT) PARTITIONED BY (c STRING, d STRING)") + sql("ALTER TABLE partitioned_table ADD PARTITION (c='Us', d=1)") + + checkKeywordsExist(sql("DESC partitioned_table PARTITION (c='Us', d=1)"), + "# Partition Information", + "# col_name") + + checkKeywordsExist(sql("DESC EXTENDED partitioned_table PARTITION (c='Us', d=1)"), + "# Partition Information", + "# col_name", + "Detailed Partition Information CatalogPartition(", + "Partition Values: [Us, 1]", + "Storage(Location:", + "Partition Parameters") + + checkKeywordsExist(sql("DESC FORMATTED partitioned_table PARTITION (c='Us', d=1)"), + "# Partition Information", + "# col_name", + "# Detailed Partition Information", + "Partition Value:", + "Database:", + "Table:", + "Location:", + "Partition Parameters:", + "# Storage Information") + } + } + + test("describe partition - error handling") { + withTable("partitioned_table", "datasource_table") { + sql("CREATE TABLE partitioned_table (a STRING, b INT) PARTITIONED BY (c STRING, d STRING)") + sql("ALTER TABLE partitioned_table ADD PARTITION (c='Us', d=1)") + + val m = intercept[NoSuchPartitionException] { + sql("DESC partitioned_table PARTITION (c='Us', d=2)") + }.getMessage() + assert(m.contains("Partition not found in table")) + + val m2 = intercept[AnalysisException] { + sql("DESC partitioned_table PARTITION (c='Us')") + }.getMessage() + assert(m2.contains("Partition spec is invalid")) + + val m3 = intercept[ParseException] { + sql("DESC partitioned_table PARTITION (c='Us', d)") + }.getMessage() + assert(m3.contains("PARTITION specification is incomplete: `d`")) + + spark + .range(1).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd).write + .partitionBy("d") + .saveAsTable("datasource_table") + val m4 = intercept[AnalysisException] { + sql("DESC datasource_table PARTITION (d=2)") + }.getMessage() + assert(m4.contains("DESC PARTITION is not allowed on a datasource table")) + + val m5 = intercept[AnalysisException] { + spark.range(10).select('id as 'a, 'id as 'b).createTempView("view1") + sql("DESC view1 PARTITION (c='Us', d=1)") + }.getMessage() + assert(m5.contains("DESC PARTITION is not allowed on a temporary view")) + + withView("permanent_view") { + val m = intercept[AnalysisException] { + sql("CREATE VIEW permanent_view AS SELECT * FROM partitioned_table") + sql("DESC permanent_view PARTITION (c='Us', d=1)") + }.getMessage() + assert(m.contains("DESC PARTITION is not allowed on a view")) + } + } + } + test("SPARK-5371: union with null and sum") { val df = Seq((1, 1)).toDF("c1", "c2") df.createOrReplaceTempView("table1") From 29396e7d1483d027960b9a1bed47008775c4253e Mon Sep 17 00:00:00 2001 From: Bjarne Fruergaard Date: Thu, 29 Sep 2016 15:39:57 -0700 Subject: [PATCH 49/96] [SPARK-17721][MLLIB][ML] Fix for multiplying transposed SparseMatrix with SparseVector ## What changes were proposed in this pull request? * changes the implementation of gemv with transposed SparseMatrix and SparseVector both in mllib-local and mllib (identical) * adds a test that was failing before this change, but succeeds with these changes. The problem in the previous implementation was that it only increments `i`, that is enumerating the columns of a row in the SparseMatrix, when the row-index of the vector matches the column-index of the SparseMatrix. In cases where a particular row of the SparseMatrix has non-zero values at column-indices lower than corresponding non-zero row-indices of the SparseVector, the non-zero values of the SparseVector are enumerated without ever matching the column-index at index `i` and the remaining column-indices i+1,...,indEnd-1 are never attempted. The test cases in this PR illustrate this issue. ## How was this patch tested? I have run the specific `gemv` tests in both mllib-local and mllib. I am currently still running `./dev/run-tests`. ## ___ As per instructions, I hereby state that this is my original work and that I license the work to the project (Apache Spark) under the project's open source license. Mentioning dbtsai, viirya and brkyvz whom I can see have worked/authored on these parts before. Author: Bjarne Fruergaard Closes #15296 from bwahlgreen/bugfix-spark-17721. --- .../scala/org/apache/spark/ml/linalg/BLAS.scala | 8 ++++++-- .../org/apache/spark/ml/linalg/BLASSuite.scala | 17 +++++++++++++++++ .../org/apache/spark/mllib/linalg/BLAS.scala | 8 ++++++-- .../apache/spark/mllib/linalg/BLASSuite.scala | 17 +++++++++++++++++ 4 files changed, 46 insertions(+), 4 deletions(-) diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala index 41b0c6c89a647..4ca19f3387f07 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala @@ -638,12 +638,16 @@ private[spark] object BLAS extends Serializable { val indEnd = Arows(rowCounter + 1) var sum = 0.0 var k = 0 - while (k < xNnz && i < indEnd) { + while (i < indEnd && k < xNnz) { if (xIndices(k) == Acols(i)) { sum += Avals(i) * xValues(k) + k += 1 + i += 1 + } else if (xIndices(k) < Acols(i)) { + k += 1 + } else { i += 1 } - k += 1 } yValues(rowCounter) = sum * alpha + beta * yValues(rowCounter) rowCounter += 1 diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala index 8a9f49792c1cd..6e72a5fff0a91 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala @@ -392,6 +392,23 @@ class BLASSuite extends SparkMLFunSuite { } } + val y17 = new DenseVector(Array(0.0, 0.0)) + val y18 = y17.copy + + val sA3 = new SparseMatrix(3, 2, Array(0, 2, 4), Array(1, 2, 0, 1), Array(2.0, 1.0, 1.0, 2.0)) + .transpose + val sA4 = + new SparseMatrix(2, 3, Array(0, 1, 3, 4), Array(1, 0, 1, 0), Array(1.0, 2.0, 2.0, 1.0)) + val sx3 = new SparseVector(3, Array(1, 2), Array(2.0, 1.0)) + + val expected4 = new DenseVector(Array(5.0, 4.0)) + + gemv(1.0, sA3, sx3, 0.0, y17) + gemv(1.0, sA4, sx3, 0.0, y18) + + assert(y17 ~== expected4 absTol 1e-15) + assert(y18 ~== expected4 absTol 1e-15) + val dAT = new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0)) val sAT = diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index 6a85608706974..0cd68a633c0b5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -637,12 +637,16 @@ private[spark] object BLAS extends Serializable with Logging { val indEnd = Arows(rowCounter + 1) var sum = 0.0 var k = 0 - while (k < xNnz && i < indEnd) { + while (i < indEnd && k < xNnz) { if (xIndices(k) == Acols(i)) { sum += Avals(i) * xValues(k) + k += 1 + i += 1 + } else if (xIndices(k) < Acols(i)) { + k += 1 + } else { i += 1 } - k += 1 } yValues(rowCounter) = sum * alpha + beta * yValues(rowCounter) rowCounter += 1 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala index 80da03cc2efeb..6e68c1c9d36c8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala @@ -392,6 +392,23 @@ class BLASSuite extends SparkFunSuite { } } + val y17 = new DenseVector(Array(0.0, 0.0)) + val y18 = y17.copy + + val sA3 = new SparseMatrix(3, 2, Array(0, 2, 4), Array(1, 2, 0, 1), Array(2.0, 1.0, 1.0, 2.0)) + .transpose + val sA4 = + new SparseMatrix(2, 3, Array(0, 1, 3, 4), Array(1, 0, 1, 0), Array(1.0, 2.0, 2.0, 1.0)) + val sx3 = new SparseVector(3, Array(1, 2), Array(2.0, 1.0)) + + val expected4 = new DenseVector(Array(5.0, 4.0)) + + gemv(1.0, sA3, sx3, 0.0, y17) + gemv(1.0, sA4, sx3, 0.0, y18) + + assert(y17 ~== expected4 absTol 1e-15) + assert(y18 ~== expected4 absTol 1e-15) + val dAT = new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0)) val sAT = From 3993ebca23afa4b8770695051635933a6c9d2c11 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Thu, 29 Sep 2016 15:40:35 -0700 Subject: [PATCH 50/96] [SPARK-17676][CORE] FsHistoryProvider should ignore hidden files ## What changes were proposed in this pull request? FsHistoryProvider was writing a hidden file (to check the fs's clock). Even though it deleted the file immediately, sometimes another thread would try to scan the files on the fs in-between, and then there would be an error msg logged which was very misleading for the end-user. (The logged error was harmless, though.) ## How was this patch tested? I added one unit test, but to be clear, that test was passing before. The actual change in behavior in that test is just logging (after the change, there is no more logged error), which I just manually verified. Author: Imran Rashid Closes #15250 from squito/SPARK-17676. --- .../deploy/history/FsHistoryProvider.scala | 7 +++- .../history/FsHistoryProviderSuite.scala | 36 +++++++++++++++++-- 2 files changed, 40 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index d494ff0659bd2..c5740e4737094 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -294,7 +294,12 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) .filter { entry => try { val prevFileSize = fileToAppInfo.get(entry.getPath()).map{_.fileSize}.getOrElse(0L) - !entry.isDirectory() && prevFileSize < entry.getLen() + !entry.isDirectory() && + // FsHistoryProvider generates a hidden file which can't be read. Accidentally + // reading a garbage file is safe, but we would log an error which can be scary to + // the end-user. + !entry.getPath().getName().startsWith(".") && + prevFileSize < entry.getLen() } catch { case e: AccessControlException => // Do not use "logInfo" since these messages can get pretty noisy if printed on diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 39c5857b13451..01bef0a11c124 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.deploy.history -import java.io.{BufferedOutputStream, ByteArrayInputStream, ByteArrayOutputStream, File, - FileOutputStream, OutputStreamWriter} +import java.io._ import java.net.URI import java.nio.charset.StandardCharsets import java.util.concurrent.TimeUnit @@ -394,6 +393,39 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc } } + test("ignore hidden files") { + + // FsHistoryProvider should ignore hidden files. (It even writes out a hidden file itself + // that should be ignored). + + // write out one totally bogus hidden file + val hiddenGarbageFile = new File(testDir, ".garbage") + val out = new PrintWriter(hiddenGarbageFile) + // scalastyle:off println + out.println("GARBAGE") + // scalastyle:on println + out.close() + + // also write out one real event log file, but since its a hidden file, we shouldn't read it + val tmpNewAppFile = newLogFile("hidden", None, inProgress = false) + val hiddenNewAppFile = new File(tmpNewAppFile.getParentFile, "." + tmpNewAppFile.getName) + tmpNewAppFile.renameTo(hiddenNewAppFile) + + // and write one real file, which should still get picked up just fine + val newAppComplete = newLogFile("real-app", None, inProgress = false) + writeFile(newAppComplete, true, None, + SparkListenerApplicationStart(newAppComplete.getName(), Some("new-app-complete"), 1L, "test", + None), + SparkListenerApplicationEnd(5L) + ) + + val provider = new FsHistoryProvider(createTestConf()) + updateAndCheck(provider) { list => + list.size should be (1) + list(0).name should be ("real-app") + } + } + /** * Asks the provider to check for logs and calls a function to perform checks on the updated * app list. Example: From 39eb3bb1ec29aa993de13a6eba3ab27db6fc5371 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 29 Sep 2016 16:01:45 -0700 Subject: [PATCH 51/96] [SPARK-17412][DOC] All test should not be run by `root` or any admin user ## What changes were proposed in this pull request? `FsHistoryProviderSuite` fails if `root` user runs it. The test case **SPARK-3697: ignore directories that cannot be read** depends on `setReadable(false, false)` to make test data files and expects the number of accessible files is 1. But, `root` can access all files, so it returns 2. This PR adds the assumption explicitly on doc. `building-spark.md`. ## How was this patch tested? This is a documentation change. Author: Dongjoon Hyun Closes #15291 from dongjoon-hyun/SPARK-17412. --- docs/building-spark.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/building-spark.md b/docs/building-spark.md index 75c304a3ccecd..da7eeb8348378 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -215,6 +215,7 @@ For help in setting up IntelliJ IDEA or Eclipse for Spark development, and troub # Running Tests Tests are run by default via the [ScalaTest Maven plugin](http://www.scalatest.org/user_guide/using_the_scalatest_maven_plugin). +Note that tests should not be run as root or an admin user. Some of the tests require Spark to be packaged first, so always run `mvn package` with `-DskipTests` the first time. The following is an example of a correct (build, test) sequence: From 2f739567080d804a942cfcca0e22f91ab7cbea36 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 29 Sep 2016 16:31:30 -0700 Subject: [PATCH 52/96] [SPARK-17697][ML] Fixed bug in summary calculations that pattern match against label without casting ## What changes were proposed in this pull request? In calling LogisticRegression.evaluate and GeneralizedLinearRegression.evaluate using a Dataset where the Label is not of a double type, calculations pattern match against a double and throw a MatchError. This fix casts the Label column to a DoubleType to ensure there is no MatchError. ## How was this patch tested? Added unit tests to call evaluate with a dataset that has Label as other numeric types. Author: Bryan Cutler Closes #15288 from BryanCutler/binaryLOR-numericCheck-SPARK-17697. --- .../classification/LogisticRegression.scala | 2 +- .../GeneralizedLinearRegression.scala | 11 ++++---- .../LogisticRegressionSuite.scala | 18 ++++++++++++- .../GeneralizedLinearRegressionSuite.scala | 25 +++++++++++++++++++ 4 files changed, 49 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 5ab63d1de95d3..329961a25d984 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -1169,7 +1169,7 @@ class BinaryLogisticRegressionSummary private[classification] ( // TODO: Allow the user to vary the number of bins using a setBins method in // BinaryClassificationMetrics. For now the default is set to 100. @transient private val binaryMetrics = new BinaryClassificationMetrics( - predictions.select(probabilityCol, labelCol).rdd.map { + predictions.select(col(probabilityCol), col(labelCol).cast(DoubleType)).rdd.map { case Row(score: Vector, label: Double) => (score(1), label) }, 100 ) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 02b27fb650979..bb9e150c49772 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -992,7 +992,7 @@ class GeneralizedLinearRegressionSummary private[regression] ( } else { link.unlink(0.0) } - predictions.select(col(model.getLabelCol), w).rdd.map { + predictions.select(col(model.getLabelCol).cast(DoubleType), w).rdd.map { case Row(y: Double, weight: Double) => family.deviance(y, wtdmu, weight) }.sum() @@ -1004,7 +1004,7 @@ class GeneralizedLinearRegressionSummary private[regression] ( @Since("2.0.0") lazy val deviance: Double = { val w = weightCol - predictions.select(col(model.getLabelCol), col(predictionCol), w).rdd.map { + predictions.select(col(model.getLabelCol).cast(DoubleType), col(predictionCol), w).rdd.map { case Row(label: Double, pred: Double, weight: Double) => family.deviance(label, pred, weight) }.sum() @@ -1030,9 +1030,10 @@ class GeneralizedLinearRegressionSummary private[regression] ( lazy val aic: Double = { val w = weightCol val weightSum = predictions.select(w).agg(sum(w)).first().getDouble(0) - val t = predictions.select(col(model.getLabelCol), col(predictionCol), w).rdd.map { - case Row(label: Double, pred: Double, weight: Double) => - (label, pred, weight) + val t = predictions.select( + col(model.getLabelCol).cast(DoubleType), col(predictionCol), w).rdd.map { + case Row(label: Double, pred: Double, weight: Double) => + (label, pred, weight) } family.aic(t, deviance, numInstances, weightSum) + 2 * rank } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 8451e60144981..42b56754e0835 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -32,7 +32,8 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Dataset, Row} -import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.functions.{col, lit} +import org.apache.spark.sql.types.LongType class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -1776,6 +1777,21 @@ class LogisticRegressionSuite summary.precisionByThreshold.collect() === sameSummary.precisionByThreshold.collect()) } + test("evaluate with labels that are not doubles") { + // Evaluate a test set with Label that is a numeric type other than Double + val lr = new LogisticRegression() + .setMaxIter(1) + .setRegParam(1.0) + val model = lr.fit(smallBinaryDataset) + val summary = model.evaluate(smallBinaryDataset).asInstanceOf[BinaryLogisticRegressionSummary] + + val longLabelData = smallBinaryDataset.select(col(model.getLabelCol).cast(LongType), + col(model.getFeaturesCol)) + val longSummary = model.evaluate(longLabelData).asInstanceOf[BinaryLogisticRegressionSummary] + + assert(summary.areaUnderROC === longSummary.areaUnderROC) + } + test("statistics on training data") { // Test that loss is monotonically decreasing. val lr = new LogisticRegression() diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index 937aa7d3c2045..ac1ef5feb95ba 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.mllib.random._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.FloatType class GeneralizedLinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -1067,6 +1068,30 @@ class GeneralizedLinearRegressionSuite idx += 1 } } + + test("evaluate with labels that are not doubles") { + // Evaulate with a dataset that contains Labels not as doubles to verify correct casting + val dataset = Seq( + Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(19.0, 1.0, Vectors.dense(1.0, 7.0)), + Instance(23.0, 1.0, Vectors.dense(2.0, 11.0)), + Instance(29.0, 1.0, Vectors.dense(3.0, 13.0)) + ).toDF() + + val trainer = new GeneralizedLinearRegression() + .setMaxIter(1) + val model = trainer.fit(dataset) + assert(model.hasSummary) + val summary = model.summary + + val longLabelDataset = dataset.select(col(model.getLabelCol).cast(FloatType), + col(model.getFeaturesCol)) + val evalSummary = model.evaluate(longLabelDataset) + // The calculations below involve pattern matching with Label as a double + assert(evalSummary.nullDeviance === summary.nullDeviance) + assert(evalSummary.deviance === summary.deviance) + assert(evalSummary.aic === summary.aic) + } } object GeneralizedLinearRegressionSuite { From 74ac1c43817c0b8da70342e540ec7638dd7d01bd Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Thu, 29 Sep 2016 17:56:32 -0700 Subject: [PATCH 53/96] [SPARK-17717][SQL] Add exist/find methods to Catalog. ## What changes were proposed in this pull request? The current user facing catalog does not implement methods for checking object existence or finding objects. You could theoretically do this using the `list*` commands, but this is rather cumbersome and can actually be costly when there are many objects. This PR adds `exists*` and `find*` methods for Databases, Table and Functions. ## How was this patch tested? Added tests to `org.apache.spark.sql.internal.CatalogSuite` Author: Herman van Hovell Closes #15301 from hvanhovell/SPARK-17717. --- project/MimaExcludes.scala | 11 +- .../apache/spark/sql/catalog/Catalog.scala | 83 ++++++++++ .../spark/sql/internal/CatalogImpl.scala | 152 +++++++++++++++--- .../spark/sql/internal/CatalogSuite.scala | 118 ++++++++++++++ 4 files changed, 339 insertions(+), 25 deletions(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 4db3edb733a56..2ffe0ac9bc982 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -46,7 +46,16 @@ object MimaExcludes { // [SPARK-16967] Move Mesos to Module ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkMasterRegex.MESOS_REGEX"), // [SPARK-16240] ML persistence backward compatibility for LDA - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.clustering.LDA$") + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.clustering.LDA$"), + // [SPARK-17717] Add Find and Exists method to Catalog. + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.findDatabase"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.findTable"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.findFunction"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.findColumn"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.databaseExists"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.tableExists"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.functionExists"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.columnExists") ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala index 1aed245fdd332..b439022d227cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala @@ -101,6 +101,89 @@ abstract class Catalog { @throws[AnalysisException]("database or table does not exist") def listColumns(dbName: String, tableName: String): Dataset[Column] + /** + * Find the database with the specified name. This throws an AnalysisException when the database + * cannot be found. + * + * @since 2.1.0 + */ + @throws[AnalysisException]("database does not exist") + def findDatabase(dbName: String): Database + + /** + * Find the table with the specified name. This table can be a temporary table or a table in the + * current database. This throws an AnalysisException when the table cannot be found. + * + * @since 2.1.0 + */ + @throws[AnalysisException]("table does not exist") + def findTable(tableName: String): Table + + /** + * Find the table with the specified name in the specified database. This throws an + * AnalysisException when the table cannot be found. + * + * @since 2.1.0 + */ + @throws[AnalysisException]("database or table does not exist") + def findTable(dbName: String, tableName: String): Table + + /** + * Find the function with the specified name. This function can be a temporary function or a + * function in the current database. This throws an AnalysisException when the function cannot + * be found. + * + * @since 2.1.0 + */ + @throws[AnalysisException]("function does not exist") + def findFunction(functionName: String): Function + + /** + * Find the function with the specified name. This throws an AnalysisException when the function + * cannot be found. + * + * @since 2.1.0 + */ + @throws[AnalysisException]("database or function does not exist") + def findFunction(dbName: String, functionName: String): Function + + /** + * Check if the database with the specified name exists. + * + * @since 2.1.0 + */ + def databaseExists(dbName: String): Boolean + + /** + * Check if the table with the specified name exists. This can either be a temporary table or a + * table in the current database. + * + * @since 2.1.0 + */ + def tableExists(tableName: String): Boolean + + /** + * Check if the table with the specified name exists in the specified database. + * + * @since 2.1.0 + */ + def tableExists(dbName: String, tableName: String): Boolean + + /** + * Check if the function with the specified name exists. This can either be a temporary function + * or a function in the current database. + * + * @since 2.1.0 + */ + def functionExists(functionName: String): Boolean + + /** + * Check if the function with the specified name exists in the specified database. + * + * @since 2.1.0 + */ + def functionExists(dbName: String, functionName: String): Boolean + /** * :: Experimental :: * Creates an external table from the given path and returns the corresponding DataFrame. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index f252535765899..a1087edd03fdf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -23,10 +23,10 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ import org.apache.spark.sql.catalog.{Catalog, Column, Database, Function, Table} -import org.apache.spark.sql.catalyst.{DefinedByConstructorParams, TableIdentifier} -import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType, SessionCatalog} +import org.apache.spark.sql.catalyst.{DefinedByConstructorParams, FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, SubqueryAlias} +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.types.StructType @@ -69,15 +69,18 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ override def listDatabases(): Dataset[Database] = { val databases = sessionCatalog.listDatabases().map { dbName => - val metadata = sessionCatalog.getDatabaseMetadata(dbName) - new Database( - name = metadata.name, - description = metadata.description, - locationUri = metadata.locationUri) + makeDatabase(sessionCatalog.getDatabaseMetadata(dbName)) } CatalogImpl.makeDataset(databases, sparkSession) } + private def makeDatabase(metadata: CatalogDatabase): Database = { + new Database( + name = metadata.name, + description = metadata.description, + locationUri = metadata.locationUri) + } + /** * Returns a list of tables in the current database. * This includes all temporary tables. @@ -94,18 +97,21 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { override def listTables(dbName: String): Dataset[Table] = { requireDatabaseExists(dbName) val tables = sessionCatalog.listTables(dbName).map { tableIdent => - val isTemp = tableIdent.database.isEmpty - val metadata = if (isTemp) None else Some(sessionCatalog.getTableMetadata(tableIdent)) - new Table( - name = tableIdent.identifier, - database = metadata.flatMap(_.identifier.database).orNull, - description = metadata.flatMap(_.comment).orNull, - tableType = metadata.map(_.tableType.name).getOrElse("TEMPORARY"), - isTemporary = isTemp) + makeTable(tableIdent, tableIdent.database.isEmpty) } CatalogImpl.makeDataset(tables, sparkSession) } + private def makeTable(tableIdent: TableIdentifier, isTemp: Boolean): Table = { + val metadata = if (isTemp) None else Some(sessionCatalog.getTableMetadata(tableIdent)) + new Table( + name = tableIdent.identifier, + database = metadata.flatMap(_.identifier.database).orNull, + description = metadata.flatMap(_.comment).orNull, + tableType = metadata.map(_.tableType.name).getOrElse("TEMPORARY"), + isTemporary = isTemp) + } + /** * Returns a list of functions registered in the current database. * This includes all temporary functions @@ -121,18 +127,22 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { @throws[AnalysisException]("database does not exist") override def listFunctions(dbName: String): Dataset[Function] = { requireDatabaseExists(dbName) - val functions = sessionCatalog.listFunctions(dbName).map { case (funcIdent, _) => - val metadata = sessionCatalog.lookupFunctionInfo(funcIdent) - new Function( - name = funcIdent.identifier, - database = funcIdent.database.orNull, - description = null, // for now, this is always undefined - className = metadata.getClassName, - isTemporary = funcIdent.database.isEmpty) + val functions = sessionCatalog.listFunctions(dbName).map { case (functIdent, _) => + makeFunction(functIdent) } CatalogImpl.makeDataset(functions, sparkSession) } + private def makeFunction(funcIdent: FunctionIdentifier): Function = { + val metadata = sessionCatalog.lookupFunctionInfo(funcIdent) + new Function( + name = funcIdent.identifier, + database = funcIdent.database.orNull, + description = null, // for now, this is always undefined + className = metadata.getClassName, + isTemporary = funcIdent.database.isEmpty) + } + /** * Returns a list of columns for the given table in the current database. */ @@ -167,6 +177,100 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { CatalogImpl.makeDataset(columns, sparkSession) } + /** + * Find the database with the specified name. This throws an [[AnalysisException]] when no + * [[Database]] can be found. + */ + override def findDatabase(dbName: String): Database = { + if (sessionCatalog.databaseExists(dbName)) { + makeDatabase(sessionCatalog.getDatabaseMetadata(dbName)) + } else { + throw new AnalysisException(s"The specified database $dbName does not exist.") + } + } + + /** + * Find the table with the specified name. This table can be a temporary table or a table in the + * current database. This throws an [[AnalysisException]] when no [[Table]] can be found. + */ + override def findTable(tableName: String): Table = { + findTable(null, tableName) + } + + /** + * Find the table with the specified name in the specified database. This throws an + * [[AnalysisException]] when no [[Table]] can be found. + */ + override def findTable(dbName: String, tableName: String): Table = { + val tableIdent = TableIdentifier(tableName, Option(dbName)) + val isTemporary = sessionCatalog.isTemporaryTable(tableIdent) + if (isTemporary || sessionCatalog.tableExists(tableIdent)) { + makeTable(tableIdent, isTemporary) + } else { + throw new AnalysisException(s"The specified table $tableIdent does not exist.") + } + } + + /** + * Find the function with the specified name. This function can be a temporary function or a + * function in the current database. This throws an [[AnalysisException]] when no [[Function]] + * can be found. + */ + override def findFunction(functionName: String): Function = { + findFunction(null, functionName) + } + + /** + * Find the function with the specified name. This returns [[None]] when no [[Function]] can be + * found. + */ + override def findFunction(dbName: String, functionName: String): Function = { + val functionIdent = FunctionIdentifier(functionName, Option(dbName)) + if (sessionCatalog.functionExists(functionIdent)) { + makeFunction(functionIdent) + } else { + throw new AnalysisException(s"The specified function $functionIdent does not exist.") + } + } + + /** + * Check if the database with the specified name exists. + */ + override def databaseExists(dbName: String): Boolean = { + sessionCatalog.databaseExists(dbName) + } + + /** + * Check if the table with the specified name exists. This can either be a temporary table or a + * table in the current database. + */ + override def tableExists(tableName: String): Boolean = { + tableExists(null, tableName) + } + + /** + * Check if the table with the specified name exists in the specified database. + */ + override def tableExists(dbName: String, tableName: String): Boolean = { + val tableIdent = TableIdentifier(tableName, Option(dbName)) + sessionCatalog.isTemporaryTable(tableIdent) || sessionCatalog.tableExists(tableIdent) + } + + /** + * Check if the function with the specified name exists. This can either be a temporary function + * or a function in the current database. + */ + override def functionExists(functionName: String): Boolean = { + functionExists(null, functionName) + } + + /** + * Check if the function with the specified name exists in the specified database. + */ + override def functionExists(dbName: String, functionName: String): Boolean = { + sessionCatalog.functionExists(FunctionIdentifier(functionName, Option(dbName))) + } + /** * :: Experimental :: * Creates an external table from the given path and returns the corresponding DataFrame. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala index 3dc67ffafb048..783bf77f86b46 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala @@ -340,6 +340,124 @@ class CatalogSuite } } + test("find database") { + intercept[AnalysisException](spark.catalog.findDatabase("db10")) + withTempDatabase { db => + assert(spark.catalog.findDatabase(db).name === db) + } + } + + test("find table") { + withTempDatabase { db => + withTable(s"tbl_x", s"$db.tbl_y") { + // Try to find non existing tables. + intercept[AnalysisException](spark.catalog.findTable("tbl_x")) + intercept[AnalysisException](spark.catalog.findTable("tbl_y")) + intercept[AnalysisException](spark.catalog.findTable(db, "tbl_y")) + + // Create objects. + createTempTable("tbl_x") + createTable("tbl_y", Some(db)) + + // Find a temporary table + assert(spark.catalog.findTable("tbl_x").name === "tbl_x") + + // Find a qualified table + assert(spark.catalog.findTable(db, "tbl_y").name === "tbl_y") + + // Find an unqualified table using the current database + intercept[AnalysisException](spark.catalog.findTable("tbl_y")) + spark.catalog.setCurrentDatabase(db) + assert(spark.catalog.findTable("tbl_y").name === "tbl_y") + } + } + } + + test("find function") { + withTempDatabase { db => + withUserDefinedFunction("fn1" -> true, s"$db.fn2" -> false) { + // Try to find non existing functions. + intercept[AnalysisException](spark.catalog.findFunction("fn1")) + intercept[AnalysisException](spark.catalog.findFunction("fn2")) + intercept[AnalysisException](spark.catalog.findFunction(db, "fn2")) + + // Create objects. + createTempFunction("fn1") + createFunction("fn2", Some(db)) + + // Find a temporary function + assert(spark.catalog.findFunction("fn1").name === "fn1") + + // Find a qualified function + assert(spark.catalog.findFunction(db, "fn2").name === "fn2") + + // Find an unqualified function using the current database + intercept[AnalysisException](spark.catalog.findFunction("fn2")) + spark.catalog.setCurrentDatabase(db) + assert(spark.catalog.findFunction("fn2").name === "fn2") + } + } + } + + test("database exists") { + assert(!spark.catalog.databaseExists("db10")) + createDatabase("db10") + assert(spark.catalog.databaseExists("db10")) + dropDatabase("db10") + } + + test("table exists") { + withTempDatabase { db => + withTable(s"tbl_x", s"$db.tbl_y") { + // Try to find non existing tables. + assert(!spark.catalog.tableExists("tbl_x")) + assert(!spark.catalog.tableExists("tbl_y")) + assert(!spark.catalog.tableExists(db, "tbl_y")) + + // Create objects. + createTempTable("tbl_x") + createTable("tbl_y", Some(db)) + + // Find a temporary table + assert(spark.catalog.tableExists("tbl_x")) + + // Find a qualified table + assert(spark.catalog.tableExists(db, "tbl_y")) + + // Find an unqualified table using the current database + assert(!spark.catalog.tableExists("tbl_y")) + spark.catalog.setCurrentDatabase(db) + assert(spark.catalog.tableExists("tbl_y")) + } + } + } + + test("function exists") { + withTempDatabase { db => + withUserDefinedFunction("fn1" -> true, s"$db.fn2" -> false) { + // Try to find non existing functions. + assert(!spark.catalog.functionExists("fn1")) + assert(!spark.catalog.functionExists("fn2")) + assert(!spark.catalog.functionExists(db, "fn2")) + + // Create objects. + createTempFunction("fn1") + createFunction("fn2", Some(db)) + + // Find a temporary function + assert(spark.catalog.functionExists("fn1")) + + // Find a qualified function + assert(spark.catalog.functionExists(db, "fn2")) + + // Find an unqualified function using the current database + assert(!spark.catalog.functionExists("fn2")) + spark.catalog.setCurrentDatabase(db) + assert(spark.catalog.functionExists("fn2")) + } + } + } + // TODO: add tests for the rest of them } From 1fad5596885aab8b32d2307c0edecbae50d5bd7a Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Thu, 29 Sep 2016 23:55:42 -0700 Subject: [PATCH 54/96] [SPARK-14077][ML] Refactor NaiveBayes to support weighted instances ## What changes were proposed in this pull request? 1,support weighted data 2,use dataset/dataframe instead of rdd 3,make mllib as a wrapper to call ml ## How was this patch tested? local manual tests in spark-shell unit tests Author: Zheng RuiFeng Closes #12819 from zhengruifeng/weighted_nb. --- .../spark/ml/classification/NaiveBayes.scala | 154 +++++++++++++----- .../mllib/classification/NaiveBayes.scala | 99 +++-------- .../ml/classification/NaiveBayesSuite.scala | 50 +++++- 3 files changed, 191 insertions(+), 112 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index f939a1c6808e6..0d652aa4c65a1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -19,23 +19,20 @@ package org.apache.spark.ml.classification import org.apache.hadoop.fs.Path -import org.apache.spark.SparkException import org.apache.spark.annotation.Since import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators} +import org.apache.spark.ml.param.shared.HasWeightCol import org.apache.spark.ml.util._ -import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes} -import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesModel} -import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} -import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Dataset, Row} +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.functions.{col, lit} +import org.apache.spark.sql.types.DoubleType /** * Params for Naive Bayes Classifiers. */ -private[ml] trait NaiveBayesParams extends PredictorParams { +private[ml] trait NaiveBayesParams extends PredictorParams with HasWeightCol { /** * The smoothing parameter. @@ -56,7 +53,7 @@ private[ml] trait NaiveBayesParams extends PredictorParams { */ final val modelType: Param[String] = new Param[String](this, "modelType", "The model type " + "which is a string (case-sensitive). Supported options: multinomial (default) and bernoulli.", - ParamValidators.inArray[String](OldNaiveBayes.supportedModelTypes.toArray)) + ParamValidators.inArray[String](NaiveBayes.supportedModelTypes.toArray)) /** @group getParam */ final def getModelType: String = $(modelType) @@ -64,7 +61,7 @@ private[ml] trait NaiveBayesParams extends PredictorParams { /** * Naive Bayes Classifiers. - * It supports both Multinomial NB + * It supports Multinomial NB * ([[http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html]]) * which can handle finitely supported discrete data. For example, by converting documents into * TF-IDF vectors, it can be used for document classification. By making every vector a @@ -78,6 +75,8 @@ class NaiveBayes @Since("1.5.0") ( extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel] with NaiveBayesParams with DefaultParamsWritable { + import NaiveBayes.{Bernoulli, Multinomial} + @Since("1.5.0") def this() = this(Identifiable.randomUID("nb")) @@ -98,7 +97,17 @@ class NaiveBayes @Since("1.5.0") ( */ @Since("1.5.0") def setModelType(value: String): this.type = set(modelType, value) - setDefault(modelType -> OldNaiveBayes.Multinomial) + setDefault(modelType -> NaiveBayes.Multinomial) + + /** + * Sets the value of param [[weightCol]]. + * If this is not set or empty, we treat all instance weights as 1.0. + * Default is not set, so all instances have weight one. + * + * @group setParam + */ + @Since("2.1.0") + def setWeightCol(value: String): this.type = set(weightCol, value) override protected def train(dataset: Dataset[_]): NaiveBayesModel = { val numClasses = getNumClasses(dataset) @@ -109,10 +118,89 @@ class NaiveBayes @Since("1.5.0") ( s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") } - val oldDataset: RDD[OldLabeledPoint] = - extractLabeledPoints(dataset).map(OldLabeledPoint.fromML) - val oldModel = OldNaiveBayes.train(oldDataset, $(smoothing), $(modelType)) - NaiveBayesModel.fromOld(oldModel, this) + val numFeatures = dataset.select(col($(featuresCol))).head().getAs[Vector](0).size + + val requireNonnegativeValues: Vector => Unit = (v: Vector) => { + val values = v match { + case sv: SparseVector => sv.values + case dv: DenseVector => dv.values + } + + require(values.forall(_ >= 0.0), + s"Naive Bayes requires nonnegative feature values but found $v.") + } + + val requireZeroOneBernoulliValues: Vector => Unit = (v: Vector) => { + val values = v match { + case sv: SparseVector => sv.values + case dv: DenseVector => dv.values + } + + require(values.forall(v => v == 0.0 || v == 1.0), + s"Bernoulli naive Bayes requires 0 or 1 feature values but found $v.") + } + + val requireValues: Vector => Unit = { + $(modelType) match { + case Multinomial => + requireNonnegativeValues + case Bernoulli => + requireZeroOneBernoulliValues + case _ => + // This should never happen. + throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") + } + } + + val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) + + // Aggregates term frequencies per label. + // TODO: Calling aggregateByKey and collect creates two stages, we can implement something + // TODO: similar to reduceByKeyLocally to save one stage. + val aggregated = dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd + .map { row => (row.getDouble(0), (row.getDouble(1), row.getAs[Vector](2))) + }.aggregateByKey[(Double, DenseVector)]((0.0, Vectors.zeros(numFeatures).toDense))( + seqOp = { + case ((weightSum: Double, featureSum: DenseVector), (weight, features)) => + requireValues(features) + BLAS.axpy(weight, features, featureSum) + (weightSum + weight, featureSum) + }, + combOp = { + case ((weightSum1, featureSum1), (weightSum2, featureSum2)) => + BLAS.axpy(1.0, featureSum2, featureSum1) + (weightSum1 + weightSum2, featureSum1) + }).collect().sortBy(_._1) + + val numLabels = aggregated.length + val numDocuments = aggregated.map(_._2._1).sum + + val piArray = Array.fill[Double](numLabels)(0.0) + val thetaArrays = Array.fill[Double](numLabels, numFeatures)(0.0) + + val lambda = $(smoothing) + val piLogDenom = math.log(numDocuments + numLabels * lambda) + var i = 0 + aggregated.foreach { case (label, (n, sumTermFreqs)) => + piArray(i) = math.log(n + lambda) - piLogDenom + val thetaLogDenom = $(modelType) match { + case Multinomial => math.log(sumTermFreqs.values.sum + numFeatures * lambda) + case Bernoulli => math.log(n + 2.0 * lambda) + case _ => + // This should never happen. + throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") + } + var j = 0 + while (j < numFeatures) { + thetaArrays(i)(j) = math.log(sumTermFreqs(j) + lambda) - thetaLogDenom + j += 1 + } + i += 1 + } + + val pi = Vectors.dense(piArray) + val theta = new DenseMatrix(numLabels, thetaArrays(0).length, thetaArrays.flatten, true) + new NaiveBayesModel(uid, pi, theta) } @Since("1.5.0") @@ -121,6 +209,14 @@ class NaiveBayes @Since("1.5.0") ( @Since("1.6.0") object NaiveBayes extends DefaultParamsReadable[NaiveBayes] { + /** String name for multinomial model type. */ + private[spark] val Multinomial: String = "multinomial" + + /** String name for Bernoulli model type. */ + private[spark] val Bernoulli: String = "bernoulli" + + /* Set of modelTypes that NaiveBayes supports */ + private[spark] val supportedModelTypes = Set(Multinomial, Bernoulli) @Since("1.6.0") override def load(path: String): NaiveBayes = super.load(path) @@ -140,7 +236,7 @@ class NaiveBayesModel private[ml] ( extends ProbabilisticClassificationModel[Vector, NaiveBayesModel] with NaiveBayesParams with MLWritable { - import OldNaiveBayes.{Bernoulli, Multinomial} + import NaiveBayes.{Bernoulli, Multinomial} /** * Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0. @@ -175,10 +271,8 @@ class NaiveBayesModel private[ml] ( private def bernoulliCalculation(features: Vector) = { features.foreachActive((_, value) => - if (value != 0.0 && value != 1.0) { - throw new SparkException( - s"Bernoulli naive Bayes requires 0 or 1 feature values but found $features.") - } + require(value == 0.0 || value == 1.0, + s"Bernoulli naive Bayes requires 0 or 1 feature values but found $features.") ) val prob = thetaMinusNegTheta.get.multiply(features) BLAS.axpy(1.0, pi, prob) @@ -238,18 +332,6 @@ class NaiveBayesModel private[ml] ( @Since("1.6.0") object NaiveBayesModel extends MLReadable[NaiveBayesModel] { - /** Convert a model from the old API */ - private[ml] def fromOld( - oldModel: OldNaiveBayesModel, - parent: NaiveBayes): NaiveBayesModel = { - val uid = if (parent != null) parent.uid else Identifiable.randomUID("nb") - val labels = Vectors.dense(oldModel.labels) - val pi = Vectors.dense(oldModel.pi) - val theta = new DenseMatrix(oldModel.labels.length, oldModel.theta(0).length, - oldModel.theta.flatten, true) - new NaiveBayesModel(uid, pi, theta) - } - @Since("1.6.0") override def read: MLReader[NaiveBayesModel] = new NaiveBayesModelReader @@ -280,11 +362,9 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sparkSession.read.parquet(dataPath) - val vecConverted = MLUtils.convertVectorColumnsToML(data, "pi") - val Row(pi: Vector, theta: Matrix) = MLUtils.convertMatrixColumnsToML(vecConverted, "theta") - .select("pi", "theta") - .head() + val data = sparkSession.read.parquet(dataPath).select("pi", "theta").head() + val pi = data.getAs[Vector](0) + val theta = data.getAs[Matrix](1) val model = new NaiveBayesModel(metadata.uid, pi, theta) DefaultParamsReader.getAndSetParams(model, metadata) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 593a86f69ad51..32d6968a4e85f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -27,7 +27,8 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{SparkContext, SparkException} import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging -import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, DenseVector, SparseVector, Vector} +import org.apache.spark.ml.classification.{NaiveBayes => NewNaiveBayes} +import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, DenseVector, Vector} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD @@ -311,8 +312,6 @@ class NaiveBayes private ( private var lambda: Double, private var modelType: String) extends Serializable with Logging { - import NaiveBayes.{Bernoulli, Multinomial} - @Since("1.4.0") def this(lambda: Double) = this(lambda, NaiveBayes.Multinomial) @@ -355,79 +354,33 @@ class NaiveBayes private ( */ @Since("0.9.0") def run(data: RDD[LabeledPoint]): NaiveBayesModel = { - val requireNonnegativeValues: Vector => Unit = (v: Vector) => { - val values = v match { - case sv: SparseVector => sv.values - case dv: DenseVector => dv.values - } - if (!values.forall(_ >= 0.0)) { - throw new SparkException(s"Naive Bayes requires nonnegative feature values but found $v.") - } - } + val spark = SparkSession + .builder() + .sparkContext(data.context) + .getOrCreate() - val requireZeroOneBernoulliValues: Vector => Unit = (v: Vector) => { - val values = v match { - case sv: SparseVector => sv.values - case dv: DenseVector => dv.values - } - if (!values.forall(v => v == 0.0 || v == 1.0)) { - throw new SparkException( - s"Bernoulli naive Bayes requires 0 or 1 feature values but found $v.") - } - } + import spark.implicits._ - // Aggregates term frequencies per label. - // TODO: Calling combineByKey and collect creates two stages, we can implement something - // TODO: similar to reduceByKeyLocally to save one stage. - val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, DenseVector)]( - createCombiner = (v: Vector) => { - if (modelType == Bernoulli) { - requireZeroOneBernoulliValues(v) - } else { - requireNonnegativeValues(v) - } - (1L, v.copy.toDense) - }, - mergeValue = (c: (Long, DenseVector), v: Vector) => { - requireNonnegativeValues(v) - BLAS.axpy(1.0, v, c._2) - (c._1 + 1L, c._2) - }, - mergeCombiners = (c1: (Long, DenseVector), c2: (Long, DenseVector)) => { - BLAS.axpy(1.0, c2._2, c1._2) - (c1._1 + c2._1, c1._2) - } - ).collect().sortBy(_._1) + val nb = new NewNaiveBayes() + .setModelType(modelType) + .setSmoothing(lambda) - val numLabels = aggregated.length - var numDocuments = 0L - aggregated.foreach { case (_, (n, _)) => - numDocuments += n - } - val numFeatures = aggregated.head match { case (_, (_, v)) => v.size } - - val labels = new Array[Double](numLabels) - val pi = new Array[Double](numLabels) - val theta = Array.fill(numLabels)(new Array[Double](numFeatures)) - - val piLogDenom = math.log(numDocuments + numLabels * lambda) - var i = 0 - aggregated.foreach { case (label, (n, sumTermFreqs)) => - labels(i) = label - pi(i) = math.log(n + lambda) - piLogDenom - val thetaLogDenom = modelType match { - case Multinomial => math.log(sumTermFreqs.values.sum + numFeatures * lambda) - case Bernoulli => math.log(n + 2.0 * lambda) - case _ => - // This should never happen. - throw new UnknownError(s"Invalid modelType: $modelType.") - } - var j = 0 - while (j < numFeatures) { - theta(i)(j) = math.log(sumTermFreqs(j) + lambda) - thetaLogDenom - j += 1 - } - i += 1 + val labels = data.map(_.label).distinct().collect().sorted + + // Input labels for [[org.apache.spark.ml.classification.NaiveBayes]] must be + // in range [0, numClasses). + val dataset = data.map { + case LabeledPoint(label, features) => + (labels.indexOf(label).toDouble, features.asML) + }.toDF("label", "features") + + val newModel = nb.fit(dataset) + + val pi = newModel.pi.toArray + val theta = Array.fill[Double](newModel.numClasses, newModel.numFeatures)(0.0) + newModel.theta.foreachActive { + case (i, j, v) => + theta(i)(j) = v } new NaiveBayesModel(labels, pi, theta, modelType) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 99099324284dc..597428d036c7a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -23,13 +23,13 @@ import breeze.linalg.{DenseVector => BDV, Vector => BV} import breeze.stats.distributions.{Multinomial => BrzMultinomial} import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.classification.NaiveBayes.{Bernoulli, Multinomial} import org.apache.spark.ml.classification.NaiveBayesSuite._ -import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.classification.NaiveBayes.{Bernoulli, Multinomial} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, Row} @@ -152,6 +152,52 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa validateProbabilities(featureAndProbabilities, model, "multinomial") } + test("Naive Bayes Multinomial with weighted samples") { + val nPoints = 1000 + val piArray = Array(0.5, 0.1, 0.4).map(math.log) + val thetaArray = Array( + Array(0.70, 0.10, 0.10, 0.10), // label 0 + Array(0.10, 0.70, 0.10, 0.10), // label 1 + Array(0.10, 0.10, 0.70, 0.10) // label 2 + ).map(_.map(math.log)) + + val testData = generateNaiveBayesInput(piArray, thetaArray, nPoints, 42, "multinomial").toDF() + val (overSampledData, weightedData) = + MLTestingUtils.genEquivalentOversampledAndWeightedInstances(testData, + "label", "features", 42L) + val nb = new NaiveBayes().setModelType("multinomial") + val unweightedModel = nb.fit(weightedData) + val overSampledModel = nb.fit(overSampledData) + val weightedModel = nb.setWeightCol("weight").fit(weightedData) + assert(weightedModel.theta ~== overSampledModel.theta relTol 0.001) + assert(weightedModel.pi ~== overSampledModel.pi relTol 0.001) + assert(unweightedModel.theta !~= overSampledModel.theta relTol 0.001) + assert(unweightedModel.pi !~= overSampledModel.pi relTol 0.001) + } + + test("Naive Bayes Bernoulli with weighted samples") { + val nPoints = 10000 + val piArray = Array(0.5, 0.3, 0.2).map(math.log) + val thetaArray = Array( + Array(0.50, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.40), // label 0 + Array(0.02, 0.70, 0.10, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02), // label 1 + Array(0.02, 0.02, 0.60, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.30) // label 2 + ).map(_.map(math.log)) + + val testData = generateNaiveBayesInput(piArray, thetaArray, nPoints, 42, "bernoulli").toDF() + val (overSampledData, weightedData) = + MLTestingUtils.genEquivalentOversampledAndWeightedInstances(testData, + "label", "features", 42L) + val nb = new NaiveBayes().setModelType("bernoulli") + val unweightedModel = nb.fit(weightedData) + val overSampledModel = nb.fit(overSampledData) + val weightedModel = nb.setWeightCol("weight").fit(weightedData) + assert(weightedModel.theta ~== overSampledModel.theta relTol 0.001) + assert(weightedModel.pi ~== overSampledModel.pi relTol 0.001) + assert(unweightedModel.theta !~= overSampledModel.theta relTol 0.001) + assert(unweightedModel.pi !~= overSampledModel.pi relTol 0.001) + } + test("Naive Bayes Bernoulli") { val nPoints = 10000 val piArray = Array(0.5, 0.3, 0.2).map(math.log) From 8e491af52930886cbe0c54e7d67add3796ddb15f Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Fri, 30 Sep 2016 08:18:48 -0700 Subject: [PATCH 55/96] [SPARK-14077][ML][FOLLOW-UP] Revert change for NB Model's Load to maintain compatibility with the model stored before 2.0 ## What changes were proposed in this pull request? Revert change for NB Model's Load to maintain compatibility with the model stored before 2.0 ## How was this patch tested? local build Author: Zheng RuiFeng Closes #15313 from zhengruifeng/revert_save_load. --- .../apache/spark/ml/classification/NaiveBayes.scala | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 0d652aa4c65a1..6775745167b08 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -25,7 +25,8 @@ import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.HasWeightCol import org.apache.spark.ml.util._ -import org.apache.spark.sql.Dataset +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.functions.{col, lit} import org.apache.spark.sql.types.DoubleType @@ -362,9 +363,11 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sparkSession.read.parquet(dataPath).select("pi", "theta").head() - val pi = data.getAs[Vector](0) - val theta = data.getAs[Matrix](1) + val data = sparkSession.read.parquet(dataPath) + val vecConverted = MLUtils.convertVectorColumnsToML(data, "pi") + val Row(pi: Vector, theta: Matrix) = MLUtils.convertMatrixColumnsToML(vecConverted, "theta") + .select("pi", "theta") + .head() val model = new NaiveBayesModel(metadata.uid, pi, theta) DefaultParamsReader.getAndSetParams(model, metadata) From f327e16863371076dbd2a7f22c8895ae07f8274b Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 30 Sep 2016 09:59:12 -0700 Subject: [PATCH 56/96] [SPARK-17738] [SQL] fix ARRAY/MAP in columnar cache ## What changes were proposed in this pull request? The actualSize() of array and map is different from the actual size, the header is Int, rather than Long. ## How was this patch tested? The flaky test should be fixed. Author: Davies Liu Closes #15305 from davies/fix_MAP. --- .../apache/spark/sql/execution/columnar/ColumnType.scala | 8 ++++---- .../spark/sql/execution/columnar/ColumnTypeSuite.scala | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala index fa9619eb07fec..d27d8c362dd9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala @@ -589,7 +589,7 @@ private[columnar] case class STRUCT(dataType: StructType) private[columnar] case class ARRAY(dataType: ArrayType) extends ColumnType[UnsafeArrayData] with DirectCopyColumnType[UnsafeArrayData] { - override def defaultSize: Int = 16 + override def defaultSize: Int = 28 override def setField(row: MutableRow, ordinal: Int, value: UnsafeArrayData): Unit = { row.update(ordinal, value) @@ -601,7 +601,7 @@ private[columnar] case class ARRAY(dataType: ArrayType) override def actualSize(row: InternalRow, ordinal: Int): Int = { val unsafeArray = getField(row, ordinal) - 8 + unsafeArray.getSizeInBytes + 4 + unsafeArray.getSizeInBytes } override def append(value: UnsafeArrayData, buffer: ByteBuffer): Unit = { @@ -628,7 +628,7 @@ private[columnar] case class ARRAY(dataType: ArrayType) private[columnar] case class MAP(dataType: MapType) extends ColumnType[UnsafeMapData] with DirectCopyColumnType[UnsafeMapData] { - override def defaultSize: Int = 32 + override def defaultSize: Int = 68 override def setField(row: MutableRow, ordinal: Int, value: UnsafeMapData): Unit = { row.update(ordinal, value) @@ -640,7 +640,7 @@ private[columnar] case class MAP(dataType: MapType) override def actualSize(row: InternalRow, ordinal: Int): Int = { val unsafeMap = getField(row, ordinal) - 8 + unsafeMap.getSizeInBytes + 4 + unsafeMap.getSizeInBytes } override def append(value: UnsafeMapData, buffer: ByteBuffer): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala index 0b93c633b2d93..805b5667287ea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala @@ -38,7 +38,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { val checks = Map( NULL -> 0, BOOLEAN -> 1, BYTE -> 1, SHORT -> 2, INT -> 4, LONG -> 8, FLOAT -> 4, DOUBLE -> 8, COMPACT_DECIMAL(15, 10) -> 8, LARGE_DECIMAL(20, 10) -> 12, - STRING -> 8, BINARY -> 16, STRUCT_TYPE -> 20, ARRAY_TYPE -> 16, MAP_TYPE -> 32) + STRING -> 8, BINARY -> 16, STRUCT_TYPE -> 20, ARRAY_TYPE -> 28, MAP_TYPE -> 68) checks.foreach { case (columnType, expectedSize) => assertResult(expectedSize, s"Wrong defaultSize for $columnType") { @@ -73,8 +73,8 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { checkActualSize(BINARY, Array.fill[Byte](4)(0.toByte), 4 + 4) checkActualSize(COMPACT_DECIMAL(15, 10), Decimal(0, 15, 10), 8) checkActualSize(LARGE_DECIMAL(20, 10), Decimal(0, 20, 10), 5) - checkActualSize(ARRAY_TYPE, Array[Any](1), 8 + 8 + 8 + 8) - checkActualSize(MAP_TYPE, Map(1 -> "a"), 8 + (8 + 8 + 8 + 8) + (8 + 8 + 8 + 8)) + checkActualSize(ARRAY_TYPE, Array[Any](1), 4 + 8 + 8 + 8) + checkActualSize(MAP_TYPE, Map(1 -> "a"), 4 + (8 + 8 + 8 + 8) + (8 + 8 + 8 + 8)) checkActualSize(STRUCT_TYPE, Row("hello"), 28) } From 81455a9cd963098613bad10182e3fafc83a6e352 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 30 Sep 2016 17:31:59 -0700 Subject: [PATCH 57/96] [SPARK-17703][SQL] Add unnamed version of addReferenceObj for minor objects. ## What changes were proposed in this pull request? There are many minor objects in references, which are extracted to the generated class field, e.g. `errMsg` in `GetExternalRowField` or `ValidateExternalType`, but number of fields in class is limited so we should reduce the number. This pr adds unnamed version of `addReferenceObj` for these minor objects not to store the object into field but refer it from the `references` field at the time of use. ## How was this patch tested? Existing tests. Author: Takuya UESHIN Closes #15276 from ueshin/issues/SPARK-17703. --- .../expressions/codegen/CodeGenerator.scala | 15 +++++++++++++++ .../spark/sql/catalyst/expressions/misc.scala | 5 ++++- .../catalyst/expressions/objects/objects.scala | 12 +++++++++--- 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 33b9b804fc601..cb808e375a35f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -84,6 +84,21 @@ class CodegenContext { */ val references: mutable.ArrayBuffer[Any] = new mutable.ArrayBuffer[Any]() + /** + * Add an object to `references`. + * + * Returns the code to access it. + * + * This is for minor objects not to store the object into field but refer it from the references + * field at the time of use because number of fields in class is limited so we should reduce it. + */ + def addReferenceObj(obj: Any): String = { + val idx = references.length + references += obj + val clsName = obj.getClass.getName + s"(($clsName) references[$idx])" + } + /** * Add an object to `references`, create a class member to access it. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 92f8fb85fc0e2..dbb52a4bb18de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -517,7 +517,10 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) - val errMsgField = ctx.addReferenceObj("errMsg", errMsg) + + // Use unnamed reference that doesn't create a local field here to reduce the number of fields + // because errMsgField is used only when the value is null or false. + val errMsgField = ctx.addReferenceObj(errMsg) ExprCode(code = s"""${eval.code} |if (${eval.isNull} || !${eval.value}) { | throw new RuntimeException($errMsgField); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index faf8fecd79f4d..50e2ac3c36d93 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -906,7 +906,9 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val childGen = child.genCode(ctx) - val errMsgField = ctx.addReferenceObj("errMsg", errMsg) + // Use unnamed reference that doesn't create a local field here to reduce the number of fields + // because errMsgField is used only when the value is null. + val errMsgField = ctx.addReferenceObj(errMsg) val code = s""" ${childGen.code} @@ -941,7 +943,9 @@ case class GetExternalRowField( private val errMsg = s"The ${index}th field '$fieldName' of input row cannot be null." override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val errMsgField = ctx.addReferenceObj("errMsg", errMsg) + // Use unnamed reference that doesn't create a local field here to reduce the number of fields + // because errMsgField is used only when the field is null. + val errMsgField = ctx.addReferenceObj(errMsg) val row = child.genCode(ctx) val code = s""" ${row.code} @@ -979,7 +983,9 @@ case class ValidateExternalType(child: Expression, expected: DataType) private val errMsg = s" is not a valid external type for schema of ${expected.simpleString}" override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val errMsgField = ctx.addReferenceObj("errMsg", errMsg) + // Use unnamed reference that doesn't create a local field here to reduce the number of fields + // because errMsgField is used only when the type doesn't match. + val errMsgField = ctx.addReferenceObj(errMsg) val input = child.genCode(ctx) val obj = input.value From a26afd52198523dbd51dc94053424494638c7de5 Mon Sep 17 00:00:00 2001 From: Shubham Chopra Date: Fri, 30 Sep 2016 18:24:39 -0700 Subject: [PATCH 58/96] [SPARK-15353][CORE] Making peer selection for block replication pluggable ## What changes were proposed in this pull request? This PR makes block replication strategies pluggable. It provides two trait that can be implemented, one that maps a host to its topology and is used in the master, and the second that helps prioritize a list of peers for block replication and would run in the executors. This patch contains default implementations of these traits that make sure current Spark behavior is unchanged. ## How was this patch tested? This patch should not change Spark behavior in any way, and was tested with unit tests for storage. Author: Shubham Chopra Closes #13152 from shubhamchopra/RackAwareBlockReplication. --- .../apache/spark/storage/BlockManager.scala | 167 +++++++++--------- .../apache/spark/storage/BlockManagerId.scala | 34 +++- .../spark/storage/BlockManagerMaster.scala | 16 +- .../storage/BlockManagerMasterEndpoint.scala | 32 +++- .../storage/BlockReplicationPolicy.scala | 112 ++++++++++++ .../apache/spark/storage/TopologyMapper.scala | 86 +++++++++ .../BlockManagerReplicationSuite.scala | 2 + .../storage/BlockReplicationPolicySuite.scala | 74 ++++++++ .../spark/storage/TopologyMapperSuite.scala | 68 +++++++ 9 files changed, 492 insertions(+), 99 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/TopologyMapper.scala create mode 100644 core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala create mode 100644 core/src/test/scala/org/apache/spark/storage/TopologyMapperSuite.scala diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index aa29acfd70461..982b83324e0fc 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -20,7 +20,8 @@ package org.apache.spark.storage import java.io._ import java.nio.ByteBuffer -import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.collection.mutable +import scala.collection.mutable.HashMap import scala.concurrent.{Await, ExecutionContext, Future} import scala.concurrent.duration._ import scala.reflect.ClassTag @@ -44,6 +45,7 @@ import org.apache.spark.unsafe.Platform import org.apache.spark.util._ import org.apache.spark.util.io.ChunkedByteBuffer + /* Class for returning a fetched block and associated metrics. */ private[spark] class BlockResult( val data: Iterator[Any], @@ -147,6 +149,8 @@ private[spark] class BlockManager( private val peerFetchLock = new Object private var lastPeerFetchTime = 0L + private var blockReplicationPolicy: BlockReplicationPolicy = _ + /** * Initializes the BlockManager with the given appId. This is not performed in the constructor as * the appId may not be known at BlockManager instantiation time (in particular for the driver, @@ -160,8 +164,24 @@ private[spark] class BlockManager( blockTransferService.init(this) shuffleClient.init(appId) - blockManagerId = BlockManagerId( - executorId, blockTransferService.hostName, blockTransferService.port) + blockReplicationPolicy = { + val priorityClass = conf.get( + "spark.storage.replication.policy", classOf[RandomBlockReplicationPolicy].getName) + val clazz = Utils.classForName(priorityClass) + val ret = clazz.newInstance.asInstanceOf[BlockReplicationPolicy] + logInfo(s"Using $priorityClass for block replication policy") + ret + } + + val id = + BlockManagerId(executorId, blockTransferService.hostName, blockTransferService.port, None) + + val idFromMaster = master.registerBlockManager( + id, + maxMemory, + slaveEndpoint) + + blockManagerId = if (idFromMaster != null) idFromMaster else id shuffleServerId = if (externalShuffleServiceEnabled) { logInfo(s"external shuffle service port = $externalShuffleServicePort") @@ -170,12 +190,12 @@ private[spark] class BlockManager( blockManagerId } - master.registerBlockManager(blockManagerId, maxMemory, slaveEndpoint) - // Register Executors' configuration with the local shuffle service, if one should exist. if (externalShuffleServiceEnabled && !blockManagerId.isDriver) { registerWithExternalShuffleServer() } + + logInfo(s"Initialized BlockManager: $blockManagerId") } private def registerWithExternalShuffleServer() { @@ -1111,7 +1131,7 @@ private[spark] class BlockManager( } /** - * Replicate block to another node. Not that this is a blocking call that returns after + * Replicate block to another node. Note that this is a blocking call that returns after * the block has been replicated. */ private def replicate( @@ -1119,101 +1139,78 @@ private[spark] class BlockManager( data: ChunkedByteBuffer, level: StorageLevel, classTag: ClassTag[_]): Unit = { + val maxReplicationFailures = conf.getInt("spark.storage.maxReplicationFailures", 1) - val numPeersToReplicateTo = level.replication - 1 - val peersForReplication = new ArrayBuffer[BlockManagerId] - val peersReplicatedTo = new ArrayBuffer[BlockManagerId] - val peersFailedToReplicateTo = new ArrayBuffer[BlockManagerId] val tLevel = StorageLevel( useDisk = level.useDisk, useMemory = level.useMemory, useOffHeap = level.useOffHeap, deserialized = level.deserialized, replication = 1) - val startTime = System.currentTimeMillis - val random = new Random(blockId.hashCode) - - var replicationFailed = false - var failures = 0 - var done = false - - // Get cached list of peers - peersForReplication ++= getPeers(forceFetch = false) - - // Get a random peer. Note that this selection of a peer is deterministic on the block id. - // So assuming the list of peers does not change and no replication failures, - // if there are multiple attempts in the same node to replicate the same block, - // the same set of peers will be selected. - def getRandomPeer(): Option[BlockManagerId] = { - // If replication had failed, then force update the cached list of peers and remove the peers - // that have been already used - if (replicationFailed) { - peersForReplication.clear() - peersForReplication ++= getPeers(forceFetch = true) - peersForReplication --= peersReplicatedTo - peersForReplication --= peersFailedToReplicateTo - } - if (!peersForReplication.isEmpty) { - Some(peersForReplication(random.nextInt(peersForReplication.size))) - } else { - None - } - } - // One by one choose a random peer and try uploading the block to it - // If replication fails (e.g., target peer is down), force the list of cached peers - // to be re-fetched from driver and then pick another random peer for replication. Also - // temporarily black list the peer for which replication failed. - // - // This selection of a peer and replication is continued in a loop until one of the - // following 3 conditions is fulfilled: - // (i) specified number of peers have been replicated to - // (ii) too many failures in replicating to peers - // (iii) no peer left to replicate to - // - while (!done) { - getRandomPeer() match { - case Some(peer) => - try { - val onePeerStartTime = System.currentTimeMillis - logTrace(s"Trying to replicate $blockId of ${data.size} bytes to $peer") - blockTransferService.uploadBlockSync( - peer.host, - peer.port, - peer.executorId, - blockId, - new NettyManagedBuffer(data.toNetty), - tLevel, - classTag) - logTrace(s"Replicated $blockId of ${data.size} bytes to $peer in %s ms" - .format(System.currentTimeMillis - onePeerStartTime)) - peersReplicatedTo += peer - peersForReplication -= peer - replicationFailed = false - if (peersReplicatedTo.size == numPeersToReplicateTo) { - done = true // specified number of peers have been replicated to - } - } catch { - case NonFatal(e) => - logWarning(s"Failed to replicate $blockId to $peer, failure #$failures", e) - failures += 1 - replicationFailed = true - peersFailedToReplicateTo += peer - if (failures > maxReplicationFailures) { // too many failures in replicating to peers - done = true - } + val numPeersToReplicateTo = level.replication - 1 + + val startTime = System.nanoTime + + var peersReplicatedTo = mutable.HashSet.empty[BlockManagerId] + var peersFailedToReplicateTo = mutable.HashSet.empty[BlockManagerId] + var numFailures = 0 + + var peersForReplication = blockReplicationPolicy.prioritize( + blockManagerId, + getPeers(false), + mutable.HashSet.empty, + blockId, + numPeersToReplicateTo) + + while(numFailures <= maxReplicationFailures && + !peersForReplication.isEmpty && + peersReplicatedTo.size != numPeersToReplicateTo) { + val peer = peersForReplication.head + try { + val onePeerStartTime = System.nanoTime + logTrace(s"Trying to replicate $blockId of ${data.size} bytes to $peer") + blockTransferService.uploadBlockSync( + peer.host, + peer.port, + peer.executorId, + blockId, + new NettyManagedBuffer(data.toNetty), + tLevel, + classTag) + logTrace(s"Replicated $blockId of ${data.size} bytes to $peer" + + s" in ${(System.nanoTime - onePeerStartTime).toDouble / 1e6} ms") + peersForReplication = peersForReplication.tail + peersReplicatedTo += peer + } catch { + case NonFatal(e) => + logWarning(s"Failed to replicate $blockId to $peer, failure #$numFailures", e) + peersFailedToReplicateTo += peer + // we have a failed replication, so we get the list of peers again + // we don't want peers we have already replicated to and the ones that + // have failed previously + val filteredPeers = getPeers(true).filter { p => + !peersFailedToReplicateTo.contains(p) && !peersReplicatedTo.contains(p) } - case None => // no peer left to replicate to - done = true + + numFailures += 1 + peersForReplication = blockReplicationPolicy.prioritize( + blockManagerId, + filteredPeers, + peersReplicatedTo, + blockId, + numPeersToReplicateTo - peersReplicatedTo.size) } } - val timeTakeMs = (System.currentTimeMillis - startTime) + logDebug(s"Replicating $blockId of ${data.size} bytes to " + - s"${peersReplicatedTo.size} peer(s) took $timeTakeMs ms") + s"${peersReplicatedTo.size} peer(s) took ${(System.nanoTime - startTime) / 1e6} ms") if (peersReplicatedTo.size < numPeersToReplicateTo) { logWarning(s"Block $blockId replicated to only " + s"${peersReplicatedTo.size} peer(s) instead of $numPeersToReplicateTo peers") } + + logDebug(s"block $blockId replicated to ${peersReplicatedTo.mkString(", ")}") } /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index f255f5be63fcf..c37a3604d28fa 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -37,10 +37,11 @@ import org.apache.spark.util.Utils class BlockManagerId private ( private var executorId_ : String, private var host_ : String, - private var port_ : Int) + private var port_ : Int, + private var topologyInfo_ : Option[String]) extends Externalizable { - private def this() = this(null, null, 0) // For deserialization only + private def this() = this(null, null, 0, None) // For deserialization only def executorId: String = executorId_ @@ -60,6 +61,8 @@ class BlockManagerId private ( def port: Int = port_ + def topologyInfo: Option[String] = topologyInfo_ + def isDriver: Boolean = { executorId == SparkContext.DRIVER_IDENTIFIER || executorId == SparkContext.LEGACY_DRIVER_IDENTIFIER @@ -69,24 +72,33 @@ class BlockManagerId private ( out.writeUTF(executorId_) out.writeUTF(host_) out.writeInt(port_) + out.writeBoolean(topologyInfo_.isDefined) + // we only write topologyInfo if we have it + topologyInfo.foreach(out.writeUTF(_)) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { executorId_ = in.readUTF() host_ = in.readUTF() port_ = in.readInt() + val isTopologyInfoAvailable = in.readBoolean() + topologyInfo_ = if (isTopologyInfoAvailable) Option(in.readUTF()) else None } @throws(classOf[IOException]) private def readResolve(): Object = BlockManagerId.getCachedBlockManagerId(this) - override def toString: String = s"BlockManagerId($executorId, $host, $port)" + override def toString: String = s"BlockManagerId($executorId, $host, $port, $topologyInfo)" - override def hashCode: Int = (executorId.hashCode * 41 + host.hashCode) * 41 + port + override def hashCode: Int = + ((executorId.hashCode * 41 + host.hashCode) * 41 + port) * 41 + topologyInfo.hashCode override def equals(that: Any): Boolean = that match { case id: BlockManagerId => - executorId == id.executorId && port == id.port && host == id.host + executorId == id.executorId && + port == id.port && + host == id.host && + topologyInfo == id.topologyInfo case _ => false } @@ -101,10 +113,18 @@ private[spark] object BlockManagerId { * @param execId ID of the executor. * @param host Host name of the block manager. * @param port Port of the block manager. + * @param topologyInfo topology information for the blockmanager, if available + * This can be network topology information for use while choosing peers + * while replicating data blocks. More information available here: + * [[org.apache.spark.storage.TopologyMapper]] * @return A new [[org.apache.spark.storage.BlockManagerId]]. */ - def apply(execId: String, host: String, port: Int): BlockManagerId = - getCachedBlockManagerId(new BlockManagerId(execId, host, port)) + def apply( + execId: String, + host: String, + port: Int, + topologyInfo: Option[String] = None): BlockManagerId = + getCachedBlockManagerId(new BlockManagerId(execId, host, port, topologyInfo)) def apply(in: ObjectInput): BlockManagerId = { val obj = new BlockManagerId() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 8655cf10fc28f..7a600068912b1 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -50,12 +50,20 @@ class BlockManagerMaster( logInfo("Removal of executor " + execId + " requested") } - /** Register the BlockManager's id with the driver. */ + /** + * Register the BlockManager's id with the driver. The input BlockManagerId does not contain + * topology information. This information is obtained from the master and we respond with an + * updated BlockManagerId fleshed out with this information. + */ def registerBlockManager( - blockManagerId: BlockManagerId, maxMemSize: Long, slaveEndpoint: RpcEndpointRef): Unit = { + blockManagerId: BlockManagerId, + maxMemSize: Long, + slaveEndpoint: RpcEndpointRef): BlockManagerId = { logInfo(s"Registering BlockManager $blockManagerId") - tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint)) - logInfo(s"Registered BlockManager $blockManagerId") + val updatedId = driverEndpoint.askWithRetry[BlockManagerId]( + RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint)) + logInfo(s"Registered BlockManager $updatedId") + updatedId } def updateBlockInfo( diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 8fa12150114db..145c434a4f0cf 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -55,10 +55,21 @@ class BlockManagerMasterEndpoint( private val askThreadPool = ThreadUtils.newDaemonCachedThreadPool("block-manager-ask-thread-pool") private implicit val askExecutionContext = ExecutionContext.fromExecutorService(askThreadPool) + private val topologyMapper = { + val topologyMapperClassName = conf.get( + "spark.storage.replication.topologyMapper", classOf[DefaultTopologyMapper].getName) + val clazz = Utils.classForName(topologyMapperClassName) + val mapper = + clazz.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[TopologyMapper] + logInfo(s"Using $topologyMapperClassName for getting topology information") + mapper + } + + logInfo("BlockManagerMasterEndpoint up") + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint) => - register(blockManagerId, maxMemSize, slaveEndpoint) - context.reply(true) + context.reply(register(blockManagerId, maxMemSize, slaveEndpoint)) case _updateBlockInfo @ UpdateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) => @@ -298,7 +309,21 @@ class BlockManagerMasterEndpoint( ).map(_.flatten.toSeq) } - private def register(id: BlockManagerId, maxMemSize: Long, slaveEndpoint: RpcEndpointRef) { + /** + * Returns the BlockManagerId with topology information populated, if available. + */ + private def register( + idWithoutTopologyInfo: BlockManagerId, + maxMemSize: Long, + slaveEndpoint: RpcEndpointRef): BlockManagerId = { + // the dummy id is not expected to contain the topology information. + // we get that info here and respond back with a more fleshed out block manager id + val id = BlockManagerId( + idWithoutTopologyInfo.executorId, + idWithoutTopologyInfo.host, + idWithoutTopologyInfo.port, + topologyMapper.getTopologyForHost(idWithoutTopologyInfo.host)) + val time = System.currentTimeMillis() if (!blockManagerInfo.contains(id)) { blockManagerIdByExecutor.get(id.executorId) match { @@ -318,6 +343,7 @@ class BlockManagerMasterEndpoint( id, System.currentTimeMillis(), maxMemSize, slaveEndpoint) } listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxMemSize)) + id } private def updateBlockInfo( diff --git a/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala b/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala new file mode 100644 index 0000000000000..bf087af16a5b1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import scala.collection.mutable +import scala.util.Random + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.internal.Logging + +/** + * ::DeveloperApi:: + * BlockReplicationPrioritization provides logic for prioritizing a sequence of peers for + * replicating blocks. BlockManager will replicate to each peer returned in order until the + * desired replication order is reached. If a replication fails, prioritize() will be called + * again to get a fresh prioritization. + */ +@DeveloperApi +trait BlockReplicationPolicy { + + /** + * Method to prioritize a bunch of candidate peers of a block + * + * @param blockManagerId Id of the current BlockManager for self identification + * @param peers A list of peers of a BlockManager + * @param peersReplicatedTo Set of peers already replicated to + * @param blockId BlockId of the block being replicated. This can be used as a source of + * randomness if needed. + * @param numReplicas Number of peers we need to replicate to + * @return A prioritized list of peers. Lower the index of a peer, higher its priority. + * This returns a list of size at most `numPeersToReplicateTo`. + */ + def prioritize( + blockManagerId: BlockManagerId, + peers: Seq[BlockManagerId], + peersReplicatedTo: mutable.HashSet[BlockManagerId], + blockId: BlockId, + numReplicas: Int): List[BlockManagerId] +} + +@DeveloperApi +class RandomBlockReplicationPolicy + extends BlockReplicationPolicy + with Logging { + + /** + * Method to prioritize a bunch of candidate peers of a block. This is a basic implementation, + * that just makes sure we put blocks on different hosts, if possible + * + * @param blockManagerId Id of the current BlockManager for self identification + * @param peers A list of peers of a BlockManager + * @param peersReplicatedTo Set of peers already replicated to + * @param blockId BlockId of the block being replicated. This can be used as a source of + * randomness if needed. + * @return A prioritized list of peers. Lower the index of a peer, higher its priority + */ + override def prioritize( + blockManagerId: BlockManagerId, + peers: Seq[BlockManagerId], + peersReplicatedTo: mutable.HashSet[BlockManagerId], + blockId: BlockId, + numReplicas: Int): List[BlockManagerId] = { + val random = new Random(blockId.hashCode) + logDebug(s"Input peers : ${peers.mkString(", ")}") + val prioritizedPeers = if (peers.size > numReplicas) { + getSampleIds(peers.size, numReplicas, random).map(peers(_)) + } else { + if (peers.size < numReplicas) { + logWarning(s"Expecting ${numReplicas} replicas with only ${peers.size} peer/s.") + } + random.shuffle(peers).toList + } + logDebug(s"Prioritized peers : ${prioritizedPeers.mkString(", ")}") + prioritizedPeers + } + + /** + * Uses sampling algorithm by Robert Floyd. Finds a random sample in O(n) while + * minimizing space usage + * [[http://math.stackexchange.com/questions/178690/ + * whats-the-proof-of-correctness-for-robert-floyds-algorithm-for-selecting-a-sin]] + * + * @param n total number of indices + * @param m number of samples needed + * @param r random number generator + * @return list of m random unique indices + */ + private def getSampleIds(n: Int, m: Int, r: Random): List[Int] = { + val indices = (n - m + 1 to n).foldLeft(Set.empty[Int]) {case (set, i) => + val t = r.nextInt(i) + 1 + if (set.contains(t)) set + i else set + t + } + // we shuffle the result to ensure a random arrangement within the sample + // to avoid any bias from set implementations + r.shuffle(indices.map(_ - 1).toList) + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/TopologyMapper.scala b/core/src/main/scala/org/apache/spark/storage/TopologyMapper.scala new file mode 100644 index 0000000000000..a0f0fdef8e948 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/TopologyMapper.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import org.apache.spark.SparkConf +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils + +/** + * ::DeveloperApi:: + * TopologyMapper provides topology information for a given host + * @param conf SparkConf to get required properties, if needed + */ +@DeveloperApi +abstract class TopologyMapper(conf: SparkConf) { + /** + * Gets the topology information given the host name + * + * @param hostname Hostname + * @return topology information for the given hostname. One can use a 'topology delimiter' + * to make this topology information nested. + * For example : ‘/myrack/myhost’, where ‘/’ is the topology delimiter, + * ‘myrack’ is the topology identifier, and ‘myhost’ is the individual host. + * This function only returns the topology information without the hostname. + * This information can be used when choosing executors for block replication + * to discern executors from a different rack than a candidate executor, for example. + * + * An implementation can choose to use empty strings or None in case topology info + * is not available. This would imply that all such executors belong to the same rack. + */ + def getTopologyForHost(hostname: String): Option[String] +} + +/** + * A TopologyMapper that assumes all nodes are in the same rack + */ +@DeveloperApi +class DefaultTopologyMapper(conf: SparkConf) extends TopologyMapper(conf) with Logging { + override def getTopologyForHost(hostname: String): Option[String] = { + logDebug(s"Got a request for $hostname") + None + } +} + +/** + * A simple file based topology mapper. This expects topology information provided as a + * [[java.util.Properties]] file. The name of the file is obtained from SparkConf property + * `spark.storage.replication.topologyFile`. To use this topology mapper, set the + * `spark.storage.replication.topologyMapper` property to + * [[org.apache.spark.storage.FileBasedTopologyMapper]] + * @param conf SparkConf object + */ +@DeveloperApi +class FileBasedTopologyMapper(conf: SparkConf) extends TopologyMapper(conf) with Logging { + val topologyFile = conf.getOption("spark.storage.replication.topologyFile") + require(topologyFile.isDefined, "Please specify topology file via " + + "spark.storage.replication.topologyFile for FileBasedTopologyMapper.") + val topologyMap = Utils.getPropertiesFromFile(topologyFile.get) + + override def getTopologyForHost(hostname: String): Option[String] = { + val topology = topologyMap.get(hostname) + if (topology.isDefined) { + logDebug(s"$hostname -> ${topology.get}") + } else { + logWarning(s"$hostname does not have any topology information") + } + topology + } +} + diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index e1c1787cbd15e..f4bfdc2fd69a9 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -346,6 +346,8 @@ class BlockManagerReplicationSuite extends SparkFunSuite } } + + /** * Test replication of blocks with different storage levels (various combinations of * memory, disk & serialization). For each storage level, this function tests every store diff --git a/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala new file mode 100644 index 0000000000000..800c3899f1a72 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import scala.collection.mutable + +import org.scalatest.{BeforeAndAfter, Matchers} + +import org.apache.spark.{LocalSparkContext, SparkFunSuite} + +class BlockReplicationPolicySuite extends SparkFunSuite + with Matchers + with BeforeAndAfter + with LocalSparkContext { + + // Implicitly convert strings to BlockIds for test clarity. + private implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value) + + /** + * Test if we get the required number of peers when using random sampling from + * RandomBlockReplicationPolicy + */ + test(s"block replication - random block replication policy") { + val numBlockManagers = 10 + val storeSize = 1000 + val blockManagers = (1 to numBlockManagers).map { i => + BlockManagerId(s"store-$i", "localhost", 1000 + i, None) + } + val candidateBlockManager = BlockManagerId("test-store", "localhost", 1000, None) + val replicationPolicy = new RandomBlockReplicationPolicy + val blockId = "test-block" + + (1 to 10).foreach {numReplicas => + logDebug(s"Num replicas : $numReplicas") + val randomPeers = replicationPolicy.prioritize( + candidateBlockManager, + blockManagers, + mutable.HashSet.empty[BlockManagerId], + blockId, + numReplicas + ) + logDebug(s"Random peers : ${randomPeers.mkString(", ")}") + assert(randomPeers.toSet.size === numReplicas) + + // choosing n peers out of n + val secondPass = replicationPolicy.prioritize( + candidateBlockManager, + randomPeers, + mutable.HashSet.empty[BlockManagerId], + blockId, + numReplicas + ) + logDebug(s"Random peers : ${secondPass.mkString(", ")}") + assert(secondPass.toSet.size === numReplicas) + } + + } + +} diff --git a/core/src/test/scala/org/apache/spark/storage/TopologyMapperSuite.scala b/core/src/test/scala/org/apache/spark/storage/TopologyMapperSuite.scala new file mode 100644 index 0000000000000..bbd252d7be7ea --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/TopologyMapperSuite.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.io.{File, FileOutputStream} + +import org.scalatest.{BeforeAndAfter, Matchers} + +import org.apache.spark._ +import org.apache.spark.util.Utils + +class TopologyMapperSuite extends SparkFunSuite + with Matchers + with BeforeAndAfter + with LocalSparkContext { + + test("File based Topology Mapper") { + val numHosts = 100 + val numRacks = 4 + val props = (1 to numHosts).map{i => s"host-$i" -> s"rack-${i % numRacks}"}.toMap + val propsFile = createPropertiesFile(props) + + val sparkConf = (new SparkConf(false)) + sparkConf.set("spark.storage.replication.topologyFile", propsFile.getAbsolutePath) + val topologyMapper = new FileBasedTopologyMapper(sparkConf) + + props.foreach {case (host, topology) => + val obtainedTopology = topologyMapper.getTopologyForHost(host) + assert(obtainedTopology.isDefined) + assert(obtainedTopology.get === topology) + } + + // we get None for hosts not in the file + assert(topologyMapper.getTopologyForHost("host").isEmpty) + + cleanup(propsFile) + } + + def createPropertiesFile(props: Map[String, String]): File = { + val testFile = new File(Utils.createTempDir(), "TopologyMapperSuite-test").getAbsoluteFile + val fileOS = new FileOutputStream(testFile) + props.foreach{case (k, v) => fileOS.write(s"$k=$v\n".getBytes)} + fileOS.close + testFile + } + + def cleanup(testFile: File): Unit = { + testFile.getParentFile.listFiles.filter { file => + file.getName.startsWith(testFile.getName) + }.foreach { _.delete() } + } + +} From aef506e39a41cfe7198162c324a11ef2f01136c3 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 30 Sep 2016 21:05:06 -0700 Subject: [PATCH 59/96] [SPARK-17739][SQL] Collapse adjacent similar Window operators ## What changes were proposed in this pull request? Currently, Spark does not collapse adjacent windows with the same partitioning and sorting. This PR implements `CollapseWindow` optimizer to do the followings. 1. If the partition specs and order specs are the same, collapse into the parent. 2. If the partition specs are the same and one order spec is a prefix of the other, collapse to the more specific one. For example: ```scala val df = spark.range(1000).select($"id" % 100 as "grp", $"id", rand() as "col1", rand() as "col2") // Add summary statistics for all columns import org.apache.spark.sql.expressions.Window val cols = Seq("id", "col1", "col2") val window = Window.partitionBy($"grp").orderBy($"id") val result = cols.foldLeft(df) { (base, name) => base.withColumn(s"${name}_avg", avg(col(name)).over(window)) .withColumn(s"${name}_stddev", stddev(col(name)).over(window)) .withColumn(s"${name}_min", min(col(name)).over(window)) .withColumn(s"${name}_max", max(col(name)).over(window)) } ``` **Before** ```scala scala> result.explain == Physical Plan == Window [max(col2#19) windowspecdefinition(grp#17L, id#14L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS col2_max#234], [grp#17L], [id#14L ASC NULLS FIRST] +- Window [min(col2#19) windowspecdefinition(grp#17L, id#14L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS col2_min#216], [grp#17L], [id#14L ASC NULLS FIRST] +- Window [stddev_samp(col2#19) windowspecdefinition(grp#17L, id#14L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS col2_stddev#191], [grp#17L], [id#14L ASC NULLS FIRST] +- Window [avg(col2#19) windowspecdefinition(grp#17L, id#14L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS col2_avg#167], [grp#17L], [id#14L ASC NULLS FIRST] +- Window [max(col1#18) windowspecdefinition(grp#17L, id#14L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS col1_max#152], [grp#17L], [id#14L ASC NULLS FIRST] +- Window [min(col1#18) windowspecdefinition(grp#17L, id#14L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS col1_min#138], [grp#17L], [id#14L ASC NULLS FIRST] +- Window [stddev_samp(col1#18) windowspecdefinition(grp#17L, id#14L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS col1_stddev#117], [grp#17L], [id#14L ASC NULLS FIRST] +- Window [avg(col1#18) windowspecdefinition(grp#17L, id#14L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS col1_avg#97], [grp#17L], [id#14L ASC NULLS FIRST] +- Window [max(id#14L) windowspecdefinition(grp#17L, id#14L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS id_max#86L], [grp#17L], [id#14L ASC NULLS FIRST] +- Window [min(id#14L) windowspecdefinition(grp#17L, id#14L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS id_min#76L], [grp#17L], [id#14L ASC NULLS FIRST] +- *Project [grp#17L, id#14L, col1#18, col2#19, id_avg#26, id_stddev#42] +- Window [stddev_samp(_w0#59) windowspecdefinition(grp#17L, id#14L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS id_stddev#42], [grp#17L], [id#14L ASC NULLS FIRST] +- *Project [grp#17L, id#14L, col1#18, col2#19, id_avg#26, cast(id#14L as double) AS _w0#59] +- Window [avg(id#14L) windowspecdefinition(grp#17L, id#14L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS id_avg#26], [grp#17L], [id#14L ASC NULLS FIRST] +- *Sort [grp#17L ASC NULLS FIRST, id#14L ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(grp#17L, 200) +- *Project [(id#14L % 100) AS grp#17L, id#14L, rand(-6329949029880411066) AS col1#18, rand(-7251358484380073081) AS col2#19] +- *Range (0, 1000, step=1, splits=Some(8)) ``` **After** ```scala scala> result.explain == Physical Plan == Window [max(col2#5) windowspecdefinition(grp#3L, id#0L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS col2_max#220, min(col2#5) windowspecdefinition(grp#3L, id#0L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS col2_min#202, stddev_samp(col2#5) windowspecdefinition(grp#3L, id#0L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS col2_stddev#177, avg(col2#5) windowspecdefinition(grp#3L, id#0L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS col2_avg#153, max(col1#4) windowspecdefinition(grp#3L, id#0L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS col1_max#138, min(col1#4) windowspecdefinition(grp#3L, id#0L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS col1_min#124, stddev_samp(col1#4) windowspecdefinition(grp#3L, id#0L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS col1_stddev#103, avg(col1#4) windowspecdefinition(grp#3L, id#0L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS col1_avg#83, max(id#0L) windowspecdefinition(grp#3L, id#0L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS id_max#72L, min(id#0L) windowspecdefinition(grp#3L, id#0L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS id_min#62L], [grp#3L], [id#0L ASC NULLS FIRST] +- *Project [grp#3L, id#0L, col1#4, col2#5, id_avg#12, id_stddev#28] +- Window [stddev_samp(_w0#45) windowspecdefinition(grp#3L, id#0L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS id_stddev#28], [grp#3L], [id#0L ASC NULLS FIRST] +- *Project [grp#3L, id#0L, col1#4, col2#5, id_avg#12, cast(id#0L as double) AS _w0#45] +- Window [avg(id#0L) windowspecdefinition(grp#3L, id#0L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS id_avg#12], [grp#3L], [id#0L ASC NULLS FIRST] +- *Sort [grp#3L ASC NULLS FIRST, id#0L ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(grp#3L, 200) +- *Project [(id#0L % 100) AS grp#3L, id#0L, rand(6537478539664068821) AS col1#4, rand(-8961093871295252795) AS col2#5] +- *Range (0, 1000, step=1, splits=Some(8)) ``` ## How was this patch tested? Pass the Jenkins tests with a newly added testsuite. Author: Dongjoon Hyun Closes #15317 from dongjoon-hyun/SPARK-17739. --- .../sql/catalyst/optimizer/Optimizer.scala | 12 +++ .../optimizer/CollapseWindowSuite.scala | 78 +++++++++++++++++++ 2 files changed, 90 insertions(+) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseWindowSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 9df8ce1fa3b28..e5e2cd7d27d15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -88,6 +88,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) // Operator combine CollapseRepartition, CollapseProject, + CollapseWindow, CombineFilters, CombineLimits, CombineUnions, @@ -537,6 +538,17 @@ object CollapseRepartition extends Rule[LogicalPlan] { } } +/** + * Collapse Adjacent Window Expression. + * - If the partition specs and order specs are the same, collapse into the parent. + */ +object CollapseWindow extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case w @ Window(we1, ps1, os1, Window(we2, ps2, os2, grandChild)) if ps1 == ps2 && os1 == os2 => + w.copy(windowExpressions = we1 ++ we2, child = grandChild) + } +} + /** * Generate a list of additional filters from an operator's existing constraint but remove those * that are either already part of the operator's condition or are part of the operator's child diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseWindowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseWindowSuite.scala new file mode 100644 index 0000000000000..797076e55cfcc --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseWindowSuite.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +class CollapseWindowSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("CollapseWindow", FixedPoint(10), + CollapseWindow) :: Nil + } + + val testRelation = LocalRelation('a.double, 'b.double, 'c.string) + val a = testRelation.output(0) + val b = testRelation.output(1) + val c = testRelation.output(2) + val partitionSpec1 = Seq(c) + val partitionSpec2 = Seq(c + 1) + val orderSpec1 = Seq(c.asc) + val orderSpec2 = Seq(c.desc) + + test("collapse two adjacent windows with the same partition/order") { + val query = testRelation + .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1) + .window(Seq(max(a).as('max_a)), partitionSpec1, orderSpec1) + .window(Seq(sum(b).as('sum_b)), partitionSpec1, orderSpec1) + .window(Seq(avg(b).as('avg_b)), partitionSpec1, orderSpec1) + + val optimized = Optimize.execute(query.analyze) + val correctAnswer = testRelation.window(Seq( + avg(b).as('avg_b), + sum(b).as('sum_b), + max(a).as('max_a), + min(a).as('min_a)), partitionSpec1, orderSpec1) + + comparePlans(optimized, correctAnswer) + } + + test("Don't collapse adjacent windows with different partitions or orders") { + val query1 = testRelation + .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1) + .window(Seq(max(a).as('max_a)), partitionSpec1, orderSpec2) + + val optimized1 = Optimize.execute(query1.analyze) + val correctAnswer1 = query1.analyze + + comparePlans(optimized1, correctAnswer1) + + val query2 = testRelation + .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1) + .window(Seq(max(a).as('max_a)), partitionSpec2, orderSpec1) + + val optimized2 = Optimize.execute(query2.analyze) + val correctAnswer2 = query2.analyze + + comparePlans(optimized2, correctAnswer2) + } +} From 15e9bbb49e00b3982c428d39776725d0dea2cdfa Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 30 Sep 2016 22:05:59 -0700 Subject: [PATCH 60/96] [MINOR][DOC] Add an up-to-date description for default serialization during shuffling ## What changes were proposed in this pull request? This PR aims to make the doc up-to-date. The documentation is generally correct, but after https://issues.apache.org/jira/browse/SPARK-13926, Spark starts to choose Kyro as a default serialization library during shuffling of simple types, arrays of simple types, or string type. ## How was this patch tested? This is a documentation update. Author: Dongjoon Hyun Closes #15315 from dongjoon-hyun/SPARK-DOC-SERIALIZER. --- docs/tuning.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/tuning.md b/docs/tuning.md index cbf37213aa724..9c43b315bbb9e 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -45,6 +45,7 @@ and calling `conf.set("spark.serializer", "org.apache.spark.serializer.KryoSeria This setting configures the serializer used for not only shuffling data between worker nodes but also when serializing RDDs to disk. The only reason Kryo is not the default is because of the custom registration requirement, but we recommend trying it in any network-intensive application. +Since Spark 2.0.0, we internally use Kryo serializer when shuffling RDDs with simple types, arrays of simple types, or string type. Spark automatically includes Kryo serializers for the many commonly-used core Scala classes covered in the AllScalaRegistrar from the [Twitter chill](https://github.com/twitter/chill) library. From 4bcd9b728b8df74756d16b27725c2db7c523d4b2 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Fri, 30 Sep 2016 23:51:36 -0700 Subject: [PATCH 61/96] [SPARK-17740] Spark tests should mock / interpose HDFS to ensure that streams are closed ## What changes were proposed in this pull request? As a followup to SPARK-17666, ensure filesystem connections are not leaked at least in unit tests. This is done here by intercepting filesystem calls as suggested by JoshRosen . At the end of each test, we assert no filesystem streams are left open. This applies to all tests using SharedSQLContext or SharedSparkContext. ## How was this patch tested? I verified that tests in sql and core are indeed using the filesystem backend, and fixed the detected leaks. I also checked that reverting https://github.com/apache/spark/pull/15245 causes many actual test failures due to connection leaks. Author: Eric Liang Author: Eric Liang Closes #15306 from ericl/sc-4672. --- .../org/apache/spark/DebugFilesystem.scala | 114 ++++++++++++++++++ .../org/apache/spark/SharedSparkContext.scala | 17 ++- .../parquet/ParquetEncodingSuite.scala | 1 + .../streaming/HDFSMetadataLogSuite.scala | 3 +- .../spark/sql/test/SharedSQLContext.scala | 19 ++- 5 files changed, 147 insertions(+), 7 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/DebugFilesystem.scala diff --git a/core/src/test/scala/org/apache/spark/DebugFilesystem.scala b/core/src/test/scala/org/apache/spark/DebugFilesystem.scala new file mode 100644 index 0000000000000..fb8d701ebda8a --- /dev/null +++ b/core/src/test/scala/org/apache/spark/DebugFilesystem.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.io.{FileDescriptor, InputStream} +import java.lang +import java.nio.ByteBuffer +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.hadoop.fs._ + +import org.apache.spark.internal.Logging + +object DebugFilesystem extends Logging { + // Stores the set of active streams and their creation sites. + private val openStreams = new ConcurrentHashMap[FSDataInputStream, Throwable]() + + def clearOpenStreams(): Unit = { + openStreams.clear() + } + + def assertNoOpenStreams(): Unit = { + val numOpen = openStreams.size() + if (numOpen > 0) { + for (exc <- openStreams.values().asScala) { + logWarning("Leaked filesystem connection created at:") + exc.printStackTrace() + } + throw new RuntimeException(s"There are $numOpen possibly leaked file streams.") + } + } +} + +/** + * DebugFilesystem wraps file open calls to track all open connections. This can be used in tests + * to check that connections are not leaked. + */ +// TODO(ekl) we should consider always interposing this to expose num open conns as a metric +class DebugFilesystem extends LocalFileSystem { + import DebugFilesystem._ + + override def open(f: Path, bufferSize: Int): FSDataInputStream = { + val wrapped: FSDataInputStream = super.open(f, bufferSize) + openStreams.put(wrapped, new Throwable()) + + new FSDataInputStream(wrapped.getWrappedStream) { + override def setDropBehind(dropBehind: lang.Boolean): Unit = wrapped.setDropBehind(dropBehind) + + override def getWrappedStream: InputStream = wrapped.getWrappedStream + + override def getFileDescriptor: FileDescriptor = wrapped.getFileDescriptor + + override def getPos: Long = wrapped.getPos + + override def seekToNewSource(targetPos: Long): Boolean = wrapped.seekToNewSource(targetPos) + + override def seek(desired: Long): Unit = wrapped.seek(desired) + + override def setReadahead(readahead: lang.Long): Unit = wrapped.setReadahead(readahead) + + override def read(position: Long, buffer: Array[Byte], offset: Int, length: Int): Int = + wrapped.read(position, buffer, offset, length) + + override def read(buf: ByteBuffer): Int = wrapped.read(buf) + + override def readFully(position: Long, buffer: Array[Byte], offset: Int, length: Int): Unit = + wrapped.readFully(position, buffer, offset, length) + + override def readFully(position: Long, buffer: Array[Byte]): Unit = + wrapped.readFully(position, buffer) + + override def available(): Int = wrapped.available() + + override def mark(readlimit: Int): Unit = wrapped.mark(readlimit) + + override def skip(n: Long): Long = wrapped.skip(n) + + override def markSupported(): Boolean = wrapped.markSupported() + + override def close(): Unit = { + wrapped.close() + openStreams.remove(wrapped) + } + + override def read(): Int = wrapped.read() + + override def reset(): Unit = wrapped.reset() + + override def toString: String = wrapped.toString + + override def equals(obj: scala.Any): Boolean = wrapped.equals(obj) + + override def hashCode(): Int = wrapped.hashCode() + } + } +} diff --git a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala index 858bc742e07cf..6aedcb1271ff6 100644 --- a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala @@ -17,11 +17,11 @@ package org.apache.spark -import org.scalatest.BeforeAndAfterAll +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} import org.scalatest.Suite /** Shares a local `SparkContext` between all tests in a suite and closes it at the end */ -trait SharedSparkContext extends BeforeAndAfterAll { self: Suite => +trait SharedSparkContext extends BeforeAndAfterAll with BeforeAndAfterEach { self: Suite => @transient private var _sc: SparkContext = _ @@ -31,7 +31,8 @@ trait SharedSparkContext extends BeforeAndAfterAll { self: Suite => override def beforeAll() { super.beforeAll() - _sc = new SparkContext("local[4]", "test", conf) + _sc = new SparkContext( + "local[4]", "test", conf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName)) } override def afterAll() { @@ -42,4 +43,14 @@ trait SharedSparkContext extends BeforeAndAfterAll { self: Suite => super.afterAll() } } + + protected override def beforeEach(): Unit = { + super.beforeEach() + DebugFilesystem.clearOpenStreams() + } + + protected override def afterEach(): Unit = { + super.afterEach() + DebugFilesystem.assertNoOpenStreams() + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala index c7541889f202e..00799301ca8d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala @@ -104,6 +104,7 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContex assert(column.getUTF8String(3 * i + 1).toString == i.toString) assert(column.getUTF8String(3 * i + 2).toString == i.toString) } + reader.close() } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala index 4259384f0bc61..9c1d26dcb2241 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala @@ -203,13 +203,14 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { } // Open and delete - fm.open(path) + val f1 = fm.open(path) fm.delete(path) assert(!fm.exists(path)) intercept[IOException] { fm.open(path) } fm.delete(path) // should not throw exception + f1.close() // Rename val path1 = new Path(s"$dir/file1") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index 79c37faa4e9ba..db24ee8b46dd5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -17,14 +17,16 @@ package org.apache.spark.sql.test -import org.apache.spark.SparkConf +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.{DebugFilesystem, SparkConf} import org.apache.spark.sql.{SparkSession, SQLContext} /** * Helper trait for SQL test suites where all tests share a single [[TestSparkSession]]. */ -trait SharedSQLContext extends SQLTestUtils { +trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach { protected val sparkConf = new SparkConf() @@ -52,7 +54,8 @@ trait SharedSQLContext extends SQLTestUtils { protected override def beforeAll(): Unit = { SparkSession.sqlListener.set(null) if (_spark == null) { - _spark = new TestSparkSession(sparkConf) + _spark = new TestSparkSession( + sparkConf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName)) } // Ensure we have initialized the context before calling parent code super.beforeAll() @@ -71,4 +74,14 @@ trait SharedSQLContext extends SQLTestUtils { super.afterAll() } } + + protected override def beforeEach(): Unit = { + super.beforeEach() + DebugFilesystem.clearOpenStreams() + } + + protected override def afterEach(): Unit = { + super.afterEach() + DebugFilesystem.assertNoOpenStreams() + } } From af6ece33d39cf305bd4a211d08a2f8e910c69bc1 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Sat, 1 Oct 2016 00:50:16 -0700 Subject: [PATCH 62/96] [SPARK-17717][SQL] Add Exist/find methods to Catalog [FOLLOW-UP] ## What changes were proposed in this pull request? We added find and exists methods for Databases, Tables and Functions to the user facing Catalog in PR https://github.com/apache/spark/pull/15301. However, it was brought up that the semantics of the `find` methods are more in line a `get` method (get an object or else fail). So we rename these in this PR. ## How was this patch tested? Existing tests. Author: Herman van Hovell Closes #15308 from hvanhovell/SPARK-17717-2. --- project/MimaExcludes.scala | 10 +-- .../apache/spark/sql/catalog/Catalog.scala | 31 +++---- .../spark/sql/internal/CatalogImpl.scala | 80 ++++++++----------- .../spark/sql/internal/CatalogSuite.scala | 38 ++++----- 4 files changed, 71 insertions(+), 88 deletions(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 2ffe0ac9bc982..7362041428b1f 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -48,14 +48,12 @@ object MimaExcludes { // [SPARK-16240] ML persistence backward compatibility for LDA ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.clustering.LDA$"), // [SPARK-17717] Add Find and Exists method to Catalog. - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.findDatabase"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.findTable"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.findFunction"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.findColumn"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.getDatabase"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.getTable"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.getFunction"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.databaseExists"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.tableExists"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.functionExists"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.columnExists") + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.functionExists") ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala index b439022d227cc..7f2762c7dac92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala @@ -102,50 +102,51 @@ abstract class Catalog { def listColumns(dbName: String, tableName: String): Dataset[Column] /** - * Find the database with the specified name. This throws an AnalysisException when the database + * Get the database with the specified name. This throws an AnalysisException when the database * cannot be found. * * @since 2.1.0 */ @throws[AnalysisException]("database does not exist") - def findDatabase(dbName: String): Database + def getDatabase(dbName: String): Database /** - * Find the table with the specified name. This table can be a temporary table or a table in the - * current database. This throws an AnalysisException when the table cannot be found. + * Get the table or view with the specified name. This table can be a temporary view or a + * table/view in the current database. This throws an AnalysisException when no Table + * can be found. * * @since 2.1.0 */ @throws[AnalysisException]("table does not exist") - def findTable(tableName: String): Table + def getTable(tableName: String): Table /** - * Find the table with the specified name in the specified database. This throws an - * AnalysisException when the table cannot be found. + * Get the table or view with the specified name in the specified database. This throws an + * AnalysisException when no Table can be found. * * @since 2.1.0 */ @throws[AnalysisException]("database or table does not exist") - def findTable(dbName: String, tableName: String): Table + def getTable(dbName: String, tableName: String): Table /** - * Find the function with the specified name. This function can be a temporary function or a + * Get the function with the specified name. This function can be a temporary function or a * function in the current database. This throws an AnalysisException when the function cannot * be found. * * @since 2.1.0 */ @throws[AnalysisException]("function does not exist") - def findFunction(functionName: String): Function + def getFunction(functionName: String): Function /** - * Find the function with the specified name. This throws an AnalysisException when the function + * Get the function with the specified name. This throws an AnalysisException when the function * cannot be found. * * @since 2.1.0 */ @throws[AnalysisException]("database or function does not exist") - def findFunction(dbName: String, functionName: String): Function + def getFunction(dbName: String, functionName: String): Function /** * Check if the database with the specified name exists. @@ -155,15 +156,15 @@ abstract class Catalog { def databaseExists(dbName: String): Boolean /** - * Check if the table with the specified name exists. This can either be a temporary table or a - * table in the current database. + * Check if the table or view with the specified name exists. This can either be a temporary + * view or a table/view in the current database. * * @since 2.1.0 */ def tableExists(tableName: String): Boolean /** - * Check if the table with the specified name exists in the specified database. + * Check if the table or view with the specified name exists in the specified database. * * @since 2.1.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index a1087edd03fdf..e412e1b4b302a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -68,13 +68,12 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * Returns a list of databases available across all sessions. */ override def listDatabases(): Dataset[Database] = { - val databases = sessionCatalog.listDatabases().map { dbName => - makeDatabase(sessionCatalog.getDatabaseMetadata(dbName)) - } + val databases = sessionCatalog.listDatabases().map(makeDatabase) CatalogImpl.makeDataset(databases, sparkSession) } - private def makeDatabase(metadata: CatalogDatabase): Database = { + private def makeDatabase(dbName: String): Database = { + val metadata = sessionCatalog.getDatabaseMetadata(dbName) new Database( name = metadata.name, description = metadata.description, @@ -96,20 +95,19 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { @throws[AnalysisException]("database does not exist") override def listTables(dbName: String): Dataset[Table] = { requireDatabaseExists(dbName) - val tables = sessionCatalog.listTables(dbName).map { tableIdent => - makeTable(tableIdent, tableIdent.database.isEmpty) - } + val tables = sessionCatalog.listTables(dbName).map(makeTable) CatalogImpl.makeDataset(tables, sparkSession) } - private def makeTable(tableIdent: TableIdentifier, isTemp: Boolean): Table = { - val metadata = if (isTemp) None else Some(sessionCatalog.getTableMetadata(tableIdent)) + private def makeTable(tableIdent: TableIdentifier): Table = { + val metadata = sessionCatalog.getTempViewOrPermanentTableMetadata(tableIdent) + val database = metadata.identifier.database new Table( - name = tableIdent.identifier, - database = metadata.flatMap(_.identifier.database).orNull, - description = metadata.flatMap(_.comment).orNull, - tableType = metadata.map(_.tableType.name).getOrElse("TEMPORARY"), - isTemporary = isTemp) + name = tableIdent.table, + database = database.orNull, + description = metadata.comment.orNull, + tableType = if (database.isEmpty) "TEMPORARY" else metadata.tableType.name, + isTemporary = database.isEmpty) } /** @@ -178,59 +176,45 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Find the database with the specified name. This throws an [[AnalysisException]] when no + * Get the database with the specified name. This throws an [[AnalysisException]] when no * [[Database]] can be found. */ - override def findDatabase(dbName: String): Database = { - if (sessionCatalog.databaseExists(dbName)) { - makeDatabase(sessionCatalog.getDatabaseMetadata(dbName)) - } else { - throw new AnalysisException(s"The specified database $dbName does not exist.") - } + override def getDatabase(dbName: String): Database = { + makeDatabase(dbName) } /** - * Find the table with the specified name. This table can be a temporary table or a table in the - * current database. This throws an [[AnalysisException]] when no [[Table]] can be found. + * Get the table or view with the specified name. This table can be a temporary view or a + * table/view in the current database. This throws an [[AnalysisException]] when no [[Table]] + * can be found. */ - override def findTable(tableName: String): Table = { - findTable(null, tableName) + override def getTable(tableName: String): Table = { + getTable(null, tableName) } /** - * Find the table with the specified name in the specified database. This throws an + * Get the table or view with the specified name in the specified database. This throws an * [[AnalysisException]] when no [[Table]] can be found. */ - override def findTable(dbName: String, tableName: String): Table = { - val tableIdent = TableIdentifier(tableName, Option(dbName)) - val isTemporary = sessionCatalog.isTemporaryTable(tableIdent) - if (isTemporary || sessionCatalog.tableExists(tableIdent)) { - makeTable(tableIdent, isTemporary) - } else { - throw new AnalysisException(s"The specified table $tableIdent does not exist.") - } + override def getTable(dbName: String, tableName: String): Table = { + makeTable(TableIdentifier(tableName, Option(dbName))) } /** - * Find the function with the specified name. This function can be a temporary function or a + * Get the function with the specified name. This function can be a temporary function or a * function in the current database. This throws an [[AnalysisException]] when no [[Function]] * can be found. */ - override def findFunction(functionName: String): Function = { - findFunction(null, functionName) + override def getFunction(functionName: String): Function = { + getFunction(null, functionName) } /** - * Find the function with the specified name. This returns [[None]] when no [[Function]] can be + * Get the function with the specified name. This returns [[None]] when no [[Function]] can be * found. */ - override def findFunction(dbName: String, functionName: String): Function = { - val functionIdent = FunctionIdentifier(functionName, Option(dbName)) - if (sessionCatalog.functionExists(functionIdent)) { - makeFunction(functionIdent) - } else { - throw new AnalysisException(s"The specified function $functionIdent does not exist.") - } + override def getFunction(dbName: String, functionName: String): Function = { + makeFunction(FunctionIdentifier(functionName, Option(dbName))) } /** @@ -241,15 +225,15 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Check if the table with the specified name exists. This can either be a temporary table or a - * table in the current database. + * Check if the table or view with the specified name exists. This can either be a temporary + * view or a table/view in the current database. */ override def tableExists(tableName: String): Boolean = { tableExists(null, tableName) } /** - * Check if the table with the specified name exists in the specified database. + * Check if the table or view with the specified name exists in the specified database. */ override def tableExists(dbName: String, tableName: String): Boolean = { val tableIdent = TableIdentifier(tableName, Option(dbName)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala index 783bf77f86b46..214bc736bd4de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala @@ -340,61 +340,61 @@ class CatalogSuite } } - test("find database") { - intercept[AnalysisException](spark.catalog.findDatabase("db10")) + test("get database") { + intercept[AnalysisException](spark.catalog.getDatabase("db10")) withTempDatabase { db => - assert(spark.catalog.findDatabase(db).name === db) + assert(spark.catalog.getDatabase(db).name === db) } } - test("find table") { + test("get table") { withTempDatabase { db => withTable(s"tbl_x", s"$db.tbl_y") { // Try to find non existing tables. - intercept[AnalysisException](spark.catalog.findTable("tbl_x")) - intercept[AnalysisException](spark.catalog.findTable("tbl_y")) - intercept[AnalysisException](spark.catalog.findTable(db, "tbl_y")) + intercept[AnalysisException](spark.catalog.getTable("tbl_x")) + intercept[AnalysisException](spark.catalog.getTable("tbl_y")) + intercept[AnalysisException](spark.catalog.getTable(db, "tbl_y")) // Create objects. createTempTable("tbl_x") createTable("tbl_y", Some(db)) // Find a temporary table - assert(spark.catalog.findTable("tbl_x").name === "tbl_x") + assert(spark.catalog.getTable("tbl_x").name === "tbl_x") // Find a qualified table - assert(spark.catalog.findTable(db, "tbl_y").name === "tbl_y") + assert(spark.catalog.getTable(db, "tbl_y").name === "tbl_y") // Find an unqualified table using the current database - intercept[AnalysisException](spark.catalog.findTable("tbl_y")) + intercept[AnalysisException](spark.catalog.getTable("tbl_y")) spark.catalog.setCurrentDatabase(db) - assert(spark.catalog.findTable("tbl_y").name === "tbl_y") + assert(spark.catalog.getTable("tbl_y").name === "tbl_y") } } } - test("find function") { + test("get function") { withTempDatabase { db => withUserDefinedFunction("fn1" -> true, s"$db.fn2" -> false) { // Try to find non existing functions. - intercept[AnalysisException](spark.catalog.findFunction("fn1")) - intercept[AnalysisException](spark.catalog.findFunction("fn2")) - intercept[AnalysisException](spark.catalog.findFunction(db, "fn2")) + intercept[AnalysisException](spark.catalog.getFunction("fn1")) + intercept[AnalysisException](spark.catalog.getFunction("fn2")) + intercept[AnalysisException](spark.catalog.getFunction(db, "fn2")) // Create objects. createTempFunction("fn1") createFunction("fn2", Some(db)) // Find a temporary function - assert(spark.catalog.findFunction("fn1").name === "fn1") + assert(spark.catalog.getFunction("fn1").name === "fn1") // Find a qualified function - assert(spark.catalog.findFunction(db, "fn2").name === "fn2") + assert(spark.catalog.getFunction(db, "fn2").name === "fn2") // Find an unqualified function using the current database - intercept[AnalysisException](spark.catalog.findFunction("fn2")) + intercept[AnalysisException](spark.catalog.getFunction("fn2")) spark.catalog.setCurrentDatabase(db) - assert(spark.catalog.findFunction("fn2").name === "fn2") + assert(spark.catalog.getFunction("fn2").name === "fn2") } } } From b88cb63da39786c07cb4bfa70afed32ec5eb3286 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 1 Oct 2016 16:10:39 -0400 Subject: [PATCH 63/96] [SPARK-17704][ML][MLLIB] ChiSqSelector performance improvement. ## What changes were proposed in this pull request? Partial revert of #15277 to instead sort and store input to model rather than require sorted input ## How was this patch tested? Existing tests. Author: Sean Owen Closes #15299 from srowen/SPARK-17704.2. --- .../spark/ml/feature/ChiSqSelector.scala | 2 +- .../spark/mllib/feature/ChiSqSelector.scala | 22 +++++++++---------- python/pyspark/ml/feature.py | 2 +- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index 9c131a41850cc..d0385e220e1e2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -193,7 +193,7 @@ final class ChiSqSelectorModel private[ml] ( import ChiSqSelectorModel._ - /** list of indices to select (filter). Must be ordered asc */ + /** list of indices to select (filter). */ @Since("1.6.0") val selectedFeatures: Array[Int] = chiSqSelector.selectedFeatures diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index 706ce78f260a6..c305b36278e87 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -35,14 +35,15 @@ import org.apache.spark.sql.{Row, SparkSession} /** * Chi Squared selector model. * - * @param selectedFeatures list of indices to select (filter). Must be ordered asc + * @param selectedFeatures list of indices to select (filter). */ @Since("1.3.0") class ChiSqSelectorModel @Since("1.3.0") ( @Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer with Saveable { - require(isSorted(selectedFeatures), "Array has to be sorted asc") + private val filterIndices = selectedFeatures.sorted + @deprecated("not intended for subclasses to use", "2.1.0") protected def isSorted(array: Array[Int]): Boolean = { var i = 1 val len = array.length @@ -61,7 +62,7 @@ class ChiSqSelectorModel @Since("1.3.0") ( */ @Since("1.3.0") override def transform(vector: Vector): Vector = { - compress(vector, selectedFeatures) + compress(vector) } /** @@ -69,9 +70,8 @@ class ChiSqSelectorModel @Since("1.3.0") ( * Preserves the order of filtered features the same as their indices are stored. * Might be moved to Vector as .slice * @param features vector - * @param filterIndices indices of features to filter, must be ordered asc */ - private def compress(features: Vector, filterIndices: Array[Int]): Vector = { + private def compress(features: Vector): Vector = { features match { case SparseVector(size, indices, values) => val newSize = filterIndices.length @@ -230,23 +230,23 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { */ @Since("1.3.0") def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = { - val chiSqTestResult = Statistics.chiSqTest(data) + val chiSqTestResult = Statistics.chiSqTest(data).zipWithIndex val features = selectorType match { case ChiSqSelector.KBest => - chiSqTestResult.zipWithIndex + chiSqTestResult .sortBy { case (res, _) => -res.statistic } .take(numTopFeatures) case ChiSqSelector.Percentile => - chiSqTestResult.zipWithIndex + chiSqTestResult .sortBy { case (res, _) => -res.statistic } .take((chiSqTestResult.length * percentile).toInt) case ChiSqSelector.FPR => - chiSqTestResult.zipWithIndex - .filter{ case (res, _) => res.pValue < alpha } + chiSqTestResult + .filter { case (res, _) => res.pValue < alpha } case errorType => throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType") } - val indices = features.map { case (_, indices) => indices }.sorted + val indices = features.map { case (_, index) => index } new ChiSqSelectorModel(indices) } } diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 12a13849dc9bc..64b21caa616ec 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2705,7 +2705,7 @@ class ChiSqSelectorModel(JavaModel, JavaMLReadable, JavaMLWritable): @since("2.0.0") def selectedFeatures(self): """ - List of indices to select (filter). Must be ordered asc. + List of indices to select (filter). """ return self._call_java("selectedFeatures") From f8d7fade4b9a78ae87b6012e3d6f71eef3032b22 Mon Sep 17 00:00:00 2001 From: Sital Kedia Date: Sun, 2 Oct 2016 15:47:36 -0700 Subject: [PATCH 64/96] =?UTF-8?q?[SPARK-17509][SQL]=20When=20wrapping=20ca?= =?UTF-8?q?talyst=20datatype=20to=20Hive=20data=20type=20avoid=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? When wrapping catalyst datatypes to Hive data type, wrap function was doing an expensive pattern matching which was consuming around 11% of cpu time. Avoid the pattern matching by returning the wrapper only once and reuse it. ## How was this patch tested? Tested by running the job on cluster and saw around 8% cpu improvements. Author: Sital Kedia Closes #15064 from sitalkedia/skedia/hive_wrapper. --- .../spark/sql/hive/HiveInspectors.scala | 307 ++++++++---------- .../org/apache/spark/sql/hive/hiveUDFs.scala | 15 +- 2 files changed, 145 insertions(+), 177 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index e4b963efeaf18..c3c4351cf58a9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -238,102 +238,161 @@ private[hive] trait HiveInspectors { case c => throw new AnalysisException(s"Unsupported java type $c") } + private def withNullSafe(f: Any => Any): Any => Any = { + input => if (input == null) null else f(input) + } + /** * Wraps with Hive types based on object inspector. - * TODO: Consolidate all hive OI/data interface code. */ protected def wrapperFor(oi: ObjectInspector, dataType: DataType): Any => Any = oi match { - case _: JavaHiveVarcharObjectInspector => - (o: Any) => - if (o != null) { - val s = o.asInstanceOf[UTF8String].toString - new HiveVarchar(s, s.length) - } else { - null - } - - case _: JavaHiveCharObjectInspector => - (o: Any) => - if (o != null) { - val s = o.asInstanceOf[UTF8String].toString - new HiveChar(s, s.length) - } else { - null - } - - case _: JavaHiveDecimalObjectInspector => - (o: Any) => - if (o != null) { - HiveDecimal.create(o.asInstanceOf[Decimal].toJavaBigDecimal) - } else { - null - } - - case _: JavaDateObjectInspector => - (o: Any) => - if (o != null) { - DateTimeUtils.toJavaDate(o.asInstanceOf[Int]) - } else { - null - } - - case _: JavaTimestampObjectInspector => + case x: ConstantObjectInspector => (o: Any) => - if (o != null) { - DateTimeUtils.toJavaTimestamp(o.asInstanceOf[Long]) - } else { - null + x.getWritableConstantValue + case x: PrimitiveObjectInspector => x match { + // TODO we don't support the HiveVarcharObjectInspector yet. + case _: StringObjectInspector if x.preferWritable() => + withNullSafe(o => getStringWritable(o)) + case _: StringObjectInspector => + withNullSafe(o => o.asInstanceOf[UTF8String].toString()) + case _: IntObjectInspector if x.preferWritable() => + withNullSafe(o => getIntWritable(o)) + case _: IntObjectInspector => + withNullSafe(o => o.asInstanceOf[java.lang.Integer]) + case _: BooleanObjectInspector if x.preferWritable() => + withNullSafe(o => getBooleanWritable(o)) + case _: BooleanObjectInspector => + withNullSafe(o => o.asInstanceOf[java.lang.Boolean]) + case _: FloatObjectInspector if x.preferWritable() => + withNullSafe(o => getFloatWritable(o)) + case _: FloatObjectInspector => + withNullSafe(o => o.asInstanceOf[java.lang.Float]) + case _: DoubleObjectInspector if x.preferWritable() => + withNullSafe(o => getDoubleWritable(o)) + case _: DoubleObjectInspector => + withNullSafe(o => o.asInstanceOf[java.lang.Double]) + case _: LongObjectInspector if x.preferWritable() => + withNullSafe(o => getLongWritable(o)) + case _: LongObjectInspector => + withNullSafe(o => o.asInstanceOf[java.lang.Long]) + case _: ShortObjectInspector if x.preferWritable() => + withNullSafe(o => getShortWritable(o)) + case _: ShortObjectInspector => + withNullSafe(o => o.asInstanceOf[java.lang.Short]) + case _: ByteObjectInspector if x.preferWritable() => + withNullSafe(o => getByteWritable(o)) + case _: ByteObjectInspector => + withNullSafe(o => o.asInstanceOf[java.lang.Byte]) + case _: JavaHiveVarcharObjectInspector => + withNullSafe { o => + val s = o.asInstanceOf[UTF8String].toString + new HiveVarchar(s, s.length) } + case _: JavaHiveCharObjectInspector => + withNullSafe { o => + val s = o.asInstanceOf[UTF8String].toString + new HiveChar(s, s.length) + } + case _: JavaHiveDecimalObjectInspector => + withNullSafe(o => + HiveDecimal.create(o.asInstanceOf[Decimal].toJavaBigDecimal)) + case _: JavaDateObjectInspector => + withNullSafe(o => + DateTimeUtils.toJavaDate(o.asInstanceOf[Int])) + case _: JavaTimestampObjectInspector => + withNullSafe(o => + DateTimeUtils.toJavaTimestamp(o.asInstanceOf[Long])) + case _: HiveDecimalObjectInspector if x.preferWritable() => + withNullSafe(o => getDecimalWritable(o.asInstanceOf[Decimal])) + case _: HiveDecimalObjectInspector => + withNullSafe(o => + HiveDecimal.create(o.asInstanceOf[Decimal].toJavaBigDecimal)) + case _: BinaryObjectInspector if x.preferWritable() => + withNullSafe(o => getBinaryWritable(o)) + case _: BinaryObjectInspector => + withNullSafe(o => o.asInstanceOf[Array[Byte]]) + case _: DateObjectInspector if x.preferWritable() => + withNullSafe(o => getDateWritable(o)) + case _: DateObjectInspector => + withNullSafe(o => DateTimeUtils.toJavaDate(o.asInstanceOf[Int])) + case _: TimestampObjectInspector if x.preferWritable() => + withNullSafe(o => getTimestampWritable(o)) + case _: TimestampObjectInspector => + withNullSafe(o => DateTimeUtils.toJavaTimestamp(o.asInstanceOf[Long])) + } case soi: StandardStructObjectInspector => val schema = dataType.asInstanceOf[StructType] val wrappers = soi.getAllStructFieldRefs.asScala.zip(schema.fields).map { case (ref, field) => wrapperFor(ref.getFieldObjectInspector, field.dataType) } - (o: Any) => { - if (o != null) { - val struct = soi.create() - val row = o.asInstanceOf[InternalRow] - soi.getAllStructFieldRefs.asScala.zip(wrappers).zipWithIndex.foreach { - case ((field, wrapper), i) => - soi.setStructFieldData(struct, field, wrapper(row.get(i, schema(i).dataType))) - } - struct - } else { - null + withNullSafe { o => + val struct = soi.create() + val row = o.asInstanceOf[InternalRow] + soi.getAllStructFieldRefs.asScala.zip(wrappers).zipWithIndex.foreach { + case ((field, wrapper), i) => + soi.setStructFieldData(struct, field, wrapper(row.get(i, schema(i).dataType))) + } + struct + } + + case ssoi: SettableStructObjectInspector => + val structType = dataType.asInstanceOf[StructType] + val wrappers = ssoi.getAllStructFieldRefs.asScala.zip(structType).map { + case (ref, tpe) => wrapperFor(ref.getFieldObjectInspector, tpe.dataType) + } + withNullSafe { o => + val row = o.asInstanceOf[InternalRow] + // 1. create the pojo (most likely) object + val result = ssoi.create() + ssoi.getAllStructFieldRefs.asScala.zip(wrappers).zipWithIndex.foreach { + case ((field, wrapper), i) => + val tpe = structType(i).dataType + ssoi.setStructFieldData( + result, + field, + wrapper(row.get(i, tpe)).asInstanceOf[AnyRef]) } + result + } + + case soi: StructObjectInspector => + val structType = dataType.asInstanceOf[StructType] + val wrappers = soi.getAllStructFieldRefs.asScala.zip(structType).map { + case (ref, tpe) => wrapperFor(ref.getFieldObjectInspector, tpe.dataType) + } + withNullSafe { o => + val row = o.asInstanceOf[InternalRow] + val result = new java.util.ArrayList[AnyRef](wrappers.size) + soi.getAllStructFieldRefs.asScala.zip(wrappers).zipWithIndex.foreach { + case ((field, wrapper), i) => + val tpe = structType(i).dataType + result.add(wrapper(row.get(i, tpe)).asInstanceOf[AnyRef]) + } + result } case loi: ListObjectInspector => val elementType = dataType.asInstanceOf[ArrayType].elementType val wrapper = wrapperFor(loi.getListElementObjectInspector, elementType) - (o: Any) => { - if (o != null) { - val array = o.asInstanceOf[ArrayData] - val values = new java.util.ArrayList[Any](array.numElements()) - array.foreach(elementType, (_, e) => values.add(wrapper(e))) - values - } else { - null - } + withNullSafe { o => + val array = o.asInstanceOf[ArrayData] + val values = new java.util.ArrayList[Any](array.numElements()) + array.foreach(elementType, (_, e) => values.add(wrapper(e))) + values } case moi: MapObjectInspector => val mt = dataType.asInstanceOf[MapType] val keyWrapper = wrapperFor(moi.getMapKeyObjectInspector, mt.keyType) val valueWrapper = wrapperFor(moi.getMapValueObjectInspector, mt.valueType) - - (o: Any) => { - if (o != null) { + withNullSafe { o => val map = o.asInstanceOf[MapData] val jmap = new java.util.HashMap[Any, Any](map.numElements()) map.foreach(mt.keyType, mt.valueType, (k, v) => jmap.put(keyWrapper(k), valueWrapper(v))) jmap - } else { - null } - } case _ => identity[Any] @@ -648,119 +707,19 @@ private[hive] trait HiveInspectors { (value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = unwrapper(value) } - /** - * Converts native catalyst types to the types expected by Hive - * @param a the value to be wrapped - * @param oi This ObjectInspector associated with the value returned by this function, and - * the ObjectInspector should also be consistent with those returned from - * toInspector: DataType => ObjectInspector and - * toInspector: Expression => ObjectInspector - * - * Strictly follows the following order in wrapping (constant OI has the higher priority): - * Constant object inspector => return the bundled value of Constant object inspector - * Check whether the `a` is null => return null if true - * If object inspector prefers writable object => return a Writable for the given data `a` - * Map the catalyst data to the boxed java primitive - * - * NOTICE: the complex data type requires recursive wrapping. - */ - def wrap(a: Any, oi: ObjectInspector, dataType: DataType): AnyRef = oi match { - case x: ConstantObjectInspector => x.getWritableConstantValue - case _ if a == null => null - case x: PrimitiveObjectInspector => x match { - // TODO we don't support the HiveVarcharObjectInspector yet. - case _: StringObjectInspector if x.preferWritable() => getStringWritable(a) - case _: StringObjectInspector => a.asInstanceOf[UTF8String].toString() - case _: IntObjectInspector if x.preferWritable() => getIntWritable(a) - case _: IntObjectInspector => a.asInstanceOf[java.lang.Integer] - case _: BooleanObjectInspector if x.preferWritable() => getBooleanWritable(a) - case _: BooleanObjectInspector => a.asInstanceOf[java.lang.Boolean] - case _: FloatObjectInspector if x.preferWritable() => getFloatWritable(a) - case _: FloatObjectInspector => a.asInstanceOf[java.lang.Float] - case _: DoubleObjectInspector if x.preferWritable() => getDoubleWritable(a) - case _: DoubleObjectInspector => a.asInstanceOf[java.lang.Double] - case _: LongObjectInspector if x.preferWritable() => getLongWritable(a) - case _: LongObjectInspector => a.asInstanceOf[java.lang.Long] - case _: ShortObjectInspector if x.preferWritable() => getShortWritable(a) - case _: ShortObjectInspector => a.asInstanceOf[java.lang.Short] - case _: ByteObjectInspector if x.preferWritable() => getByteWritable(a) - case _: ByteObjectInspector => a.asInstanceOf[java.lang.Byte] - case _: HiveDecimalObjectInspector if x.preferWritable() => - getDecimalWritable(a.asInstanceOf[Decimal]) - case _: HiveDecimalObjectInspector => - HiveDecimal.create(a.asInstanceOf[Decimal].toJavaBigDecimal) - case _: BinaryObjectInspector if x.preferWritable() => getBinaryWritable(a) - case _: BinaryObjectInspector => a.asInstanceOf[Array[Byte]] - case _: DateObjectInspector if x.preferWritable() => getDateWritable(a) - case _: DateObjectInspector => DateTimeUtils.toJavaDate(a.asInstanceOf[Int]) - case _: TimestampObjectInspector if x.preferWritable() => getTimestampWritable(a) - case _: TimestampObjectInspector => DateTimeUtils.toJavaTimestamp(a.asInstanceOf[Long]) - } - case x: SettableStructObjectInspector => - val fieldRefs = x.getAllStructFieldRefs - val structType = dataType.asInstanceOf[StructType] - val row = a.asInstanceOf[InternalRow] - // 1. create the pojo (most likely) object - val result = x.create() - var i = 0 - val size = fieldRefs.size - while (i < size) { - // 2. set the property for the pojo - val tpe = structType(i).dataType - x.setStructFieldData( - result, - fieldRefs.get(i), - wrap(row.get(i, tpe), fieldRefs.get(i).getFieldObjectInspector, tpe)) - i += 1 - } - - result - case x: StructObjectInspector => - val fieldRefs = x.getAllStructFieldRefs - val structType = dataType.asInstanceOf[StructType] - val row = a.asInstanceOf[InternalRow] - val result = new java.util.ArrayList[AnyRef](fieldRefs.size) - var i = 0 - val size = fieldRefs.size - while (i < size) { - val tpe = structType(i).dataType - result.add(wrap(row.get(i, tpe), fieldRefs.get(i).getFieldObjectInspector, tpe)) - i += 1 - } - - result - case x: ListObjectInspector => - val list = new java.util.ArrayList[Object] - val tpe = dataType.asInstanceOf[ArrayType].elementType - a.asInstanceOf[ArrayData].foreach(tpe, (_, e) => - list.add(wrap(e, x.getListElementObjectInspector, tpe)) - ) - list - case x: MapObjectInspector => - val keyType = dataType.asInstanceOf[MapType].keyType - val valueType = dataType.asInstanceOf[MapType].valueType - val map = a.asInstanceOf[MapData] - - // Some UDFs seem to assume we pass in a HashMap. - val hashMap = new java.util.HashMap[Any, Any](map.numElements()) - - map.foreach(keyType, valueType, (k, v) => - hashMap.put(wrap(k, x.getMapKeyObjectInspector, keyType), - wrap(v, x.getMapValueObjectInspector, valueType)) - ) - - hashMap + def wrap(a: Any, oi: ObjectInspector, dataType: DataType): AnyRef = { + wrapperFor(oi, dataType)(a).asInstanceOf[AnyRef] } def wrap( row: InternalRow, - inspectors: Seq[ObjectInspector], + wrappers: Array[(Any) => Any], cache: Array[AnyRef], dataTypes: Array[DataType]): Array[AnyRef] = { var i = 0 - val length = inspectors.length + val length = wrappers.length while (i < length) { - cache(i) = wrap(row.get(i, dataTypes(i)), inspectors(i), dataTypes(i)) + cache(i) = wrappers(i)(row.get(i, dataTypes(i))).asInstanceOf[AnyRef] i += 1 } cache @@ -768,13 +727,13 @@ private[hive] trait HiveInspectors { def wrap( row: Seq[Any], - inspectors: Seq[ObjectInspector], + wrappers: Array[(Any) => Any], cache: Array[AnyRef], dataTypes: Array[DataType]): Array[AnyRef] = { var i = 0 - val length = inspectors.length + val length = wrappers.length while (i < length) { - cache(i) = wrap(row(i), inspectors(i), dataTypes(i)) + cache(i) = wrappers(i)(row(i)).asInstanceOf[AnyRef] i += 1 } cache diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 962dd5a52ebc0..d54913518bb33 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -70,6 +70,9 @@ private[hive] case class HiveSimpleUDF( override lazy val dataType = javaClassToDataType(method.getReturnType) + @transient + private lazy val wrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray + @transient lazy val unwrapper = unwrapperFor(ObjectInspectorFactory.getReflectionObjectInspector( method.getGenericReturnType(), ObjectInspectorOptions.JAVA)) @@ -82,7 +85,7 @@ private[hive] case class HiveSimpleUDF( // TODO: Finish input output types. override def eval(input: InternalRow): Any = { - val inputs = wrap(children.map(_.eval(input)), arguments, cached, inputDataTypes) + val inputs = wrap(children.map(_.eval(input)), wrappers, cached, inputDataTypes) val ret = FunctionRegistry.invoke( method, function, @@ -214,6 +217,9 @@ private[hive] case class HiveGenericUDTF( @transient private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray + @transient + private lazy val wrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray + @transient private lazy val unwrapper = unwrapperFor(outputInspector) @@ -222,7 +228,7 @@ private[hive] case class HiveGenericUDTF( val inputProjection = new InterpretedProjection(children) - function.process(wrap(inputProjection(input), inputInspectors, udtInput, inputDataTypes)) + function.process(wrap(inputProjection(input), wrappers, udtInput, inputDataTypes)) collector.collectRows() } @@ -296,6 +302,9 @@ private[hive] case class HiveUDAFFunction( @transient private lazy val function = functionAndInspector._1 + @transient + private lazy val wrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray + @transient private lazy val returnInspector = functionAndInspector._2 @@ -322,7 +331,7 @@ private[hive] case class HiveUDAFFunction( override def update(_buffer: MutableRow, input: InternalRow): Unit = { val inputs = inputProjection(input) - function.iterate(buffer, wrap(inputs, inspectors, cached, inputDataTypes)) + function.iterate(buffer, wrap(inputs, wrappers, cached, inputDataTypes)) } override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { From 76dc2d9073e5e5c45c8b806a474beacb8415d506 Mon Sep 17 00:00:00 2001 From: Tao LI Date: Sun, 2 Oct 2016 16:01:02 -0700 Subject: [PATCH 65/96] [SPARK-14914][CORE][SQL] Skip/fix some test cases on Windows due to limitation of Windows ## What changes were proposed in this pull request? This PR proposes to fix/skip some tests failed on Windows. This PR takes over https://github.com/apache/spark/pull/12696. **Before** - **SparkSubmitSuite** ``` [info] - launch simple application with spark-submit *** FAILED *** (202 milliseconds) [info] java.io.IOException: Cannot run program "./bin/spark-submit" (in directory "C:\projects\spark"): CreateProcess error=2, The system cannot find the file specifie [info] - includes jars passed in through --jars *** FAILED *** (1 second, 625 milliseconds) [info] java.io.IOException: Cannot run program "./bin/spark-submit" (in directory "C:\projects\spark"): CreateProcess error=2, The system cannot find the file specified ``` - **DiskStoreSuite** ``` [info] - reads of memory-mapped and non memory-mapped files are equivalent *** FAILED *** (1 second, 78 milliseconds) [info] diskStoreMapped.remove(blockId) was false (DiskStoreSuite.scala:41) ``` **After** - **SparkSubmitSuite** ``` [info] - launch simple application with spark-submit (578 milliseconds) [info] - includes jars passed in through --jars (1 second, 875 milliseconds) ``` - **DiskStoreSuite** ``` [info] DiskStoreSuite: [info] - reads of memory-mapped and non memory-mapped files are equivalent !!! CANCELED !!! (766 milliseconds ``` For `CreateTableAsSelectSuite` and `FsHistoryProviderSuite`, I could not reproduce as the Java version seems higher than the one that has the bugs about `setReadable(..)` and `setWritable(...)` but as they are bugs reported clearly, it'd be sensible to skip those. We should revert the changes for both back as soon as we drop the support of Java 7. ## How was this patch tested? Manually tested via AppVeyor. Closes #12696 Author: Tao LI Author: U-FAREAST\tl Author: hyukjinkwon Closes #15320 from HyukjinKwon/SPARK-14914. --- .../src/main/scala/org/apache/spark/util/Utils.scala | 12 ++---------- .../org/apache/spark/deploy/SparkSubmitSuite.scala | 7 ++++++- .../deploy/history/FsHistoryProviderSuite.scala | 2 ++ .../org/apache/spark/storage/DiskStoreSuite.scala | 4 ++++ .../spark/sql/sources/CreateTableAsSelectSuite.scala | 3 ++- 5 files changed, 16 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index f3493bd96b1ee..ef832756ce3b7 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -23,7 +23,7 @@ import java.net._ import java.nio.ByteBuffer import java.nio.channels.Channels import java.nio.charset.StandardCharsets -import java.nio.file.Files +import java.nio.file.{Files, Paths} import java.util.{Locale, Properties, Random, UUID} import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean @@ -1014,15 +1014,7 @@ private[spark] object Utils extends Logging { * Check to see if file is a symbolic link. */ def isSymlink(file: File): Boolean = { - if (file == null) throw new NullPointerException("File must not be null") - if (isWindows) return false - val fileInCanonicalDir = if (file.getParent() == null) { - file - } else { - new File(file.getParentFile().getCanonicalFile(), file.getName()) - } - - !fileInCanonicalDir.getCanonicalFile().equals(fileInCanonicalDir.getAbsoluteFile()) + return Files.isSymbolicLink(Paths.get(file.toURI)) } /** diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 31c8fb26460df..732cbfaaeea46 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -649,8 +649,13 @@ class SparkSubmitSuite // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. private def runSparkSubmit(args: Seq[String]): Unit = { val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) + val sparkSubmitFile = if (Utils.isWindows) { + new File("..\\bin\\spark-submit.cmd") + } else { + new File("../bin/spark-submit") + } val process = Utils.executeCommand( - Seq("./bin/spark-submit") ++ args, + Seq(sparkSubmitFile.getCanonicalPath) ++ args, new File(sparkHome), Map("SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome)) diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 01bef0a11c124..a5eda7b5a5a75 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -126,6 +126,8 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc } test("SPARK-3697: ignore directories that cannot be read.") { + // setReadable(...) does not work on Windows. Please refer JDK-6728842. + assume(!Utils.isWindows) val logFile1 = newLogFile("new1", None, inProgress = false) writeFile(logFile1, true, None, SparkListenerApplicationStart("app1-1", Some("app1-1"), 1L, "test", None), diff --git a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala index 9ed5016510d56..9e6b02b9eac4d 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala @@ -22,10 +22,14 @@ import java.util.Arrays import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.util.io.ChunkedByteBuffer +import org.apache.spark.util.Utils class DiskStoreSuite extends SparkFunSuite { test("reads of memory-mapped and non memory-mapped files are equivalent") { + // It will cause error when we tried to re-open the filestore and the + // memory-mapped byte buffer tot he file has not been GC on Windows. + assume(!Utils.isWindows) val confKey = "spark.storage.memoryMapThreshold" // Create a non-trivial (not all zeros) byte array diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index 344d4aa6cfea4..c39005f6a1063 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.parser.ParseException -import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils @@ -83,6 +82,8 @@ class CreateTableAsSelectSuite } test("CREATE TABLE USING AS SELECT based on the file without write permission") { + // setWritable(...) does not work on Windows. Please refer JDK-6728842. + assume(!Utils.isWindows) val childPath = new File(path.toString, "child") path.mkdir() path.setWritable(false) From de3f71ed7a301387e870a38c14dad9508efc9743 Mon Sep 17 00:00:00 2001 From: Alex Bozarth Date: Mon, 3 Oct 2016 10:24:30 +0100 Subject: [PATCH 66/96] [SPARK-17598][SQL][WEB UI] User-friendly name for Spark Thrift Server in web UI ## What changes were proposed in this pull request? The name of Spark Thrift JDBC/ODBC Server in web UI reflects the name of the class, i.e. org.apache.spark.sql.hive.thrift.HiveThriftServer2. I changed it to Thrift JDBC/ODBC Server (like Spark shell for spark-shell) as recommended by jaceklaskowski. Note the user can still change the name adding `--name "App Name"` parameter to the start script as before ## How was this patch tested? By running the script with various parameters and checking the web ui ![screen shot 2016-09-27 at 12 19 12 pm](https://cloud.githubusercontent.com/assets/13952758/18888329/aebca47c-84ac-11e6-93d0-6e98684977c5.png) Author: Alex Bozarth Closes #15268 from ajbozarth/spark17598. --- sbin/start-thriftserver.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sbin/start-thriftserver.sh b/sbin/start-thriftserver.sh index ad7e7c5277eb1..f02f31793e346 100755 --- a/sbin/start-thriftserver.sh +++ b/sbin/start-thriftserver.sh @@ -53,4 +53,4 @@ fi export SUBMIT_USAGE_FUNCTION=usage -exec "${SPARK_HOME}"/sbin/spark-daemon.sh submit $CLASS 1 "$@" +exec "${SPARK_HOME}"/sbin/spark-daemon.sh submit $CLASS 1 --name "Thrift JDBC/ODBC Server" "$@" From a27033c0bbaae8f31db9b91693947ed71738ed11 Mon Sep 17 00:00:00 2001 From: Jagadeesan Date: Mon, 3 Oct 2016 10:46:38 +0100 Subject: [PATCH 67/96] =?UTF-8?q?[SPARK-17736][DOCUMENTATION][SPARKR]=20Up?= =?UTF-8?q?date=20R=20README=20for=20rmarkdown,=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? To build R docs (which are built when R tests are run), users need to install pandoc and rmarkdown. This was done for Jenkins in ~~[SPARK-17420](https://issues.apache.org/jira/browse/SPARK-17420)~~ … pandoc] Author: Jagadeesan Closes #15309 from jagadeesanas2/SPARK-17736. --- docs/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/README.md b/docs/README.md index 8b515e187379c..ffd3b5712b618 100644 --- a/docs/README.md +++ b/docs/README.md @@ -19,8 +19,8 @@ installed. Also install the following libraries: $ sudo gem install jekyll jekyll-redirect-from pygments.rb $ sudo pip install Pygments # Following is needed only for generating API docs - $ sudo pip install sphinx - $ sudo Rscript -e 'install.packages(c("knitr", "devtools", "roxygen2", "testthat"), repos="http://cran.stat.ucla.edu/")' + $ sudo pip install sphinx pypandoc + $ sudo Rscript -e 'install.packages(c("knitr", "devtools", "roxygen2", "testthat", "rmarkdown"), repos="http://cran.stat.ucla.edu/")' ``` (Note: If you are on a system with both Ruby 1.9 and Ruby 2.0 you may need to replace gem with gem2.0) From 7bf92127643570e4eb3610fa3ffd36839eba2718 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Mon, 3 Oct 2016 10:12:02 -0700 Subject: [PATCH 68/96] [SPARK-17073][SQL] generate column-level statistics MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Generate basic column statistics for all the atomic types: - numeric types: max, min, num of nulls, ndv (number of distinct values) - date/timestamp types: they are also represented as numbers internally, so they have the same stats as above. - string: avg length, max length, num of nulls, ndv - binary: avg length, max length, num of nulls - boolean: num of nulls, num of trues, num of falsies Also support storing and loading these statistics. One thing to notice: We support analyzing columns independently, e.g.: sql1: `ANALYZE TABLE src COMPUTE STATISTICS FOR COLUMNS key;` sql2: `ANALYZE TABLE src COMPUTE STATISTICS FOR COLUMNS value;` when running sql2 to collect column stats for `value`, we don’t remove stats of columns `key` which are analyzed in sql1 and not in sql2. As a result, **users need to guarantee consistency** between sql1 and sql2. If the table has been changed before sql2, users should re-analyze column `key` when they want to analyze column `value`: `ANALYZE TABLE src COMPUTE STATISTICS FOR COLUMNS key, value;` ## How was this patch tested? add unit tests Author: Zhenhua Wang Closes #15090 from wzhfy/colStats. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 2 +- .../catalyst/plans/logical/Statistics.scala | 69 +++- .../spark/sql/execution/SparkSqlParser.scala | 18 +- .../command/AnalyzeColumnCommand.scala | 175 +++++++++ .../command/AnalyzeTableCommand.scala | 112 +++--- .../apache/spark/sql/internal/SQLConf.scala | 9 + .../spark/sql/internal/SessionState.scala | 8 +- .../spark/sql/StatisticsColumnSuite.scala | 334 ++++++++++++++++++ .../apache/spark/sql/StatisticsSuite.scala | 16 +- .../org/apache/spark/sql/StatisticsTest.scala | 129 +++++++ .../spark/sql/hive/HiveExternalCatalog.scala | 28 +- .../spark/sql/hive/StatisticsSuite.scala | 119 +++++-- .../sql/hive/execution/SQLViewSuite.scala | 1 + 13 files changed, 906 insertions(+), 114 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index de2f9ee6bc7a2..1284681fe80b4 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -86,7 +86,7 @@ statement | CREATE TABLE (IF NOT EXISTS)? target=tableIdentifier LIKE source=tableIdentifier #createTableLike | ANALYZE TABLE tableIdentifier partitionSpec? COMPUTE STATISTICS - (identifier | FOR COLUMNS identifierSeq?)? #analyze + (identifier | FOR COLUMNS identifierSeq)? #analyze | ALTER (TABLE | VIEW) from=tableIdentifier RENAME TO to=tableIdentifier #renameTable | ALTER (TABLE | VIEW) tableIdentifier diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala index 3cf20385dd712..43455c989c0f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -17,6 +17,12 @@ package org.apache.spark.sql.catalyst.plans.logical +import org.apache.commons.codec.binary.Base64 + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.types._ + /** * Estimates of various statistics. The default estimation logic simply lazily multiplies the * corresponding statistic produced by the children. To override this behavior, override @@ -32,12 +38,15 @@ package org.apache.spark.sql.catalyst.plans.logical * @param sizeInBytes Physical size in bytes. For leaf operators this defaults to 1, otherwise it * defaults to the product of children's `sizeInBytes`. * @param rowCount Estimated number of rows. + * @param colStats Column-level statistics. * @param isBroadcastable If true, output is small enough to be used in a broadcast join. */ case class Statistics( sizeInBytes: BigInt, rowCount: Option[BigInt] = None, + colStats: Map[String, ColumnStat] = Map.empty, isBroadcastable: Boolean = false) { + override def toString: String = "Statistics(" + simpleString + ")" /** Readable string representation for the Statistics. */ @@ -45,6 +54,64 @@ case class Statistics( Seq(s"sizeInBytes=$sizeInBytes", if (rowCount.isDefined) s"rowCount=${rowCount.get}" else "", s"isBroadcastable=$isBroadcastable" - ).filter(_.nonEmpty).mkString("", ", ", "") + ).filter(_.nonEmpty).mkString(", ") + } +} + +/** + * Statistics for a column. + */ +case class ColumnStat(statRow: InternalRow) { + + def forNumeric[T <: AtomicType](dataType: T): NumericColumnStat[T] = { + NumericColumnStat(statRow, dataType) + } + def forString: StringColumnStat = StringColumnStat(statRow) + def forBinary: BinaryColumnStat = BinaryColumnStat(statRow) + def forBoolean: BooleanColumnStat = BooleanColumnStat(statRow) + + override def toString: String = { + // use Base64 for encoding + Base64.encodeBase64String(statRow.asInstanceOf[UnsafeRow].getBytes) } } + +object ColumnStat { + def apply(numFields: Int, str: String): ColumnStat = { + // use Base64 for decoding + val bytes = Base64.decodeBase64(str) + val unsafeRow = new UnsafeRow(numFields) + unsafeRow.pointTo(bytes, bytes.length) + ColumnStat(unsafeRow) + } +} + +case class NumericColumnStat[T <: AtomicType](statRow: InternalRow, dataType: T) { + // The indices here must be consistent with `ColumnStatStruct.numericColumnStat`. + val numNulls: Long = statRow.getLong(0) + val max: T#InternalType = statRow.get(1, dataType).asInstanceOf[T#InternalType] + val min: T#InternalType = statRow.get(2, dataType).asInstanceOf[T#InternalType] + val ndv: Long = statRow.getLong(3) +} + +case class StringColumnStat(statRow: InternalRow) { + // The indices here must be consistent with `ColumnStatStruct.stringColumnStat`. + val numNulls: Long = statRow.getLong(0) + val avgColLen: Double = statRow.getDouble(1) + val maxColLen: Long = statRow.getLong(2) + val ndv: Long = statRow.getLong(3) +} + +case class BinaryColumnStat(statRow: InternalRow) { + // The indices here must be consistent with `ColumnStatStruct.binaryColumnStat`. + val numNulls: Long = statRow.getLong(0) + val avgColLen: Double = statRow.getDouble(1) + val maxColLen: Long = statRow.getLong(2) +} + +case class BooleanColumnStat(statRow: InternalRow) { + // The indices here must be consistent with `ColumnStatStruct.booleanColumnStat`. + val numNulls: Long = statRow.getLong(0) + val numTrues: Long = statRow.getLong(1) + val numFalses: Long = statRow.getLong(2) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 3f34d0f25393d..7f1e23e665eb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -87,19 +87,27 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { } /** - * Create an [[AnalyzeTableCommand]] command. This currently only implements the NOSCAN - * option (other options are passed on to Hive) e.g.: + * Create an [[AnalyzeTableCommand]] command or an [[AnalyzeColumnCommand]] command. + * Example SQL for analyzing table : * {{{ - * ANALYZE TABLE table COMPUTE STATISTICS NOSCAN; + * ANALYZE TABLE table COMPUTE STATISTICS [NOSCAN]; + * }}} + * Example SQL for analyzing columns : + * {{{ + * ANALYZE TABLE table COMPUTE STATISTICS FOR COLUMNS column1, column2; * }}} */ override def visitAnalyze(ctx: AnalyzeContext): LogicalPlan = withOrigin(ctx) { if (ctx.partitionSpec == null && ctx.identifier != null && ctx.identifier.getText.toLowerCase == "noscan") { - AnalyzeTableCommand(visitTableIdentifier(ctx.tableIdentifier).toString) + AnalyzeTableCommand(visitTableIdentifier(ctx.tableIdentifier)) + } else if (ctx.identifierSeq() == null) { + AnalyzeTableCommand(visitTableIdentifier(ctx.tableIdentifier), noscan = false) } else { - AnalyzeTableCommand(visitTableIdentifier(ctx.tableIdentifier).toString, noscan = false) + AnalyzeColumnCommand( + visitTableIdentifier(ctx.tableIdentifier), + visitIdentifierSeq(ctx.identifierSeq())) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala new file mode 100644 index 0000000000000..7066378279971 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import scala.collection.mutable + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases +import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTable} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, ColumnStat, LogicalPlan, Statistics} +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.types._ + + +/** + * Analyzes the given columns of the given table to generate statistics, which will be used in + * query optimizations. + */ +case class AnalyzeColumnCommand( + tableIdent: TableIdentifier, + columnNames: Seq[String]) extends RunnableCommand { + + override def run(sparkSession: SparkSession): Seq[Row] = { + val sessionState = sparkSession.sessionState + val db = tableIdent.database.getOrElse(sessionState.catalog.getCurrentDatabase) + val tableIdentWithDB = TableIdentifier(tableIdent.table, Some(db)) + val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdentWithDB)) + + relation match { + case catalogRel: CatalogRelation => + updateStats(catalogRel.catalogTable, + AnalyzeTableCommand.calculateTotalSize(sessionState, catalogRel.catalogTable)) + + case logicalRel: LogicalRelation if logicalRel.catalogTable.isDefined => + updateStats(logicalRel.catalogTable.get, logicalRel.relation.sizeInBytes) + + case otherRelation => + throw new AnalysisException("ANALYZE TABLE is not supported for " + + s"${otherRelation.nodeName}.") + } + + def updateStats(catalogTable: CatalogTable, newTotalSize: Long): Unit = { + val (rowCount, columnStats) = computeColStats(sparkSession, relation) + val statistics = Statistics( + sizeInBytes = newTotalSize, + rowCount = Some(rowCount), + colStats = columnStats ++ catalogTable.stats.map(_.colStats).getOrElse(Map())) + sessionState.catalog.alterTable(catalogTable.copy(stats = Some(statistics))) + // Refresh the cached data source table in the catalog. + sessionState.catalog.refreshTable(tableIdentWithDB) + } + + Seq.empty[Row] + } + + def computeColStats( + sparkSession: SparkSession, + relation: LogicalPlan): (Long, Map[String, ColumnStat]) = { + + // check correctness of column names + val attributesToAnalyze = mutable.MutableList[Attribute]() + val duplicatedColumns = mutable.MutableList[String]() + val resolver = sparkSession.sessionState.conf.resolver + columnNames.foreach { col => + val exprOption = relation.output.find(attr => resolver(attr.name, col)) + val expr = exprOption.getOrElse(throw new AnalysisException(s"Invalid column name: $col.")) + // do deduplication + if (!attributesToAnalyze.contains(expr)) { + attributesToAnalyze += expr + } else { + duplicatedColumns += col + } + } + if (duplicatedColumns.nonEmpty) { + logWarning(s"Duplicated columns ${duplicatedColumns.mkString("(", ", ", ")")} detected " + + s"when analyzing columns ${columnNames.mkString("(", ", ", ")")}, ignoring them.") + } + + // Collect statistics per column. + // The first element in the result will be the overall row count, the following elements + // will be structs containing all column stats. + // The layout of each struct follows the layout of the ColumnStats. + val ndvMaxErr = sparkSession.sessionState.conf.ndvMaxError + val expressions = Count(Literal(1)).toAggregateExpression() +: + attributesToAnalyze.map(ColumnStatStruct(_, ndvMaxErr)) + val namedExpressions = expressions.map(e => Alias(e, e.toString)()) + val statsRow = Dataset.ofRows(sparkSession, Aggregate(Nil, namedExpressions, relation)) + .queryExecution.toRdd.collect().head + + // unwrap the result + val rowCount = statsRow.getLong(0) + val columnStats = attributesToAnalyze.zipWithIndex.map { case (expr, i) => + val numFields = ColumnStatStruct.numStatFields(expr.dataType) + (expr.name, ColumnStat(statsRow.getStruct(i + 1, numFields))) + }.toMap + (rowCount, columnStats) + } +} + +object ColumnStatStruct { + val zero = Literal(0, LongType) + val one = Literal(1, LongType) + + def numNulls(e: Expression): Expression = if (e.nullable) Sum(If(IsNull(e), one, zero)) else zero + def max(e: Expression): Expression = Max(e) + def min(e: Expression): Expression = Min(e) + def ndv(e: Expression, relativeSD: Double): Expression = { + // the approximate ndv should never be larger than the number of rows + Least(Seq(HyperLogLogPlusPlus(e, relativeSD), Count(one))) + } + def avgLength(e: Expression): Expression = Average(Length(e)) + def maxLength(e: Expression): Expression = Max(Length(e)) + def numTrues(e: Expression): Expression = Sum(If(e, one, zero)) + def numFalses(e: Expression): Expression = Sum(If(Not(e), one, zero)) + + def getStruct(exprs: Seq[Expression]): CreateStruct = { + CreateStruct(exprs.map { expr: Expression => + expr.transformUp { + case af: AggregateFunction => af.toAggregateExpression() + } + }) + } + + def numericColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = { + Seq(numNulls(e), max(e), min(e), ndv(e, relativeSD)) + } + + def stringColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = { + Seq(numNulls(e), avgLength(e), maxLength(e), ndv(e, relativeSD)) + } + + def binaryColumnStat(e: Expression): Seq[Expression] = { + Seq(numNulls(e), avgLength(e), maxLength(e)) + } + + def booleanColumnStat(e: Expression): Seq[Expression] = { + Seq(numNulls(e), numTrues(e), numFalses(e)) + } + + def numStatFields(dataType: DataType): Int = { + dataType match { + case BinaryType | BooleanType => 3 + case _ => 4 + } + } + + def apply(e: Attribute, relativeSD: Double): CreateStruct = e.dataType match { + // Use aggregate functions to compute statistics we need. + case _: NumericType | TimestampType | DateType => getStruct(numericColumnStat(e, relativeSD)) + case StringType => getStruct(stringColumnStat(e, relativeSD)) + case BinaryType => getStruct(binaryColumnStat(e)) + case BooleanType => getStruct(booleanColumnStat(e)) + case otherType => + throw new AnalysisException("Analyzing columns is not supported for column " + + s"${e.name} of data type: ${e.dataType}.") + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index 40aecafecf5bb..7b0e49b665f42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -21,81 +21,40 @@ import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, Dataset, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTable} import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.internal.SessionState /** - * Analyzes the given table in the current database to generate statistics, which will be - * used in query optimizations. + * Analyzes the given table to generate statistics, which will be used in query optimizations. */ -case class AnalyzeTableCommand(tableName: String, noscan: Boolean = true) extends RunnableCommand { +case class AnalyzeTableCommand( + tableIdent: TableIdentifier, + noscan: Boolean = true) extends RunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val sessionState = sparkSession.sessionState - val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName) val db = tableIdent.database.getOrElse(sessionState.catalog.getCurrentDatabase) - val tableIdentwithDB = TableIdentifier(tableIdent.table, Some(db)) - val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdentwithDB)) + val tableIdentWithDB = TableIdentifier(tableIdent.table, Some(db)) + val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdentWithDB)) relation match { case relation: CatalogRelation => - val catalogTable: CatalogTable = relation.catalogTable - // This method is mainly based on - // org.apache.hadoop.hive.ql.stats.StatsUtils.getFileSizeForTable(HiveConf, Table) - // in Hive 0.13 (except that we do not use fs.getContentSummary). - // TODO: Generalize statistics collection. - // TODO: Why fs.getContentSummary returns wrong size on Jenkins? - // Can we use fs.getContentSummary in future? - // Seems fs.getContentSummary returns wrong table size on Jenkins. So we use - // countFileSize to count the table size. - val stagingDir = sessionState.conf.getConfString("hive.exec.stagingdir", ".hive-staging") - - def calculateTableSize(fs: FileSystem, path: Path): Long = { - val fileStatus = fs.getFileStatus(path) - val size = if (fileStatus.isDirectory) { - fs.listStatus(path) - .map { status => - if (!status.getPath.getName.startsWith(stagingDir)) { - calculateTableSize(fs, status.getPath) - } else { - 0L - } - }.sum - } else { - fileStatus.getLen - } - - size - } - - val newTotalSize = - catalogTable.storage.locationUri.map { p => - val path = new Path(p) - try { - val fs = path.getFileSystem(sparkSession.sessionState.newHadoopConf()) - calculateTableSize(fs, path) - } catch { - case NonFatal(e) => - logWarning( - s"Failed to get the size of table ${catalogTable.identifier.table} in the " + - s"database ${catalogTable.identifier.database} because of ${e.toString}", e) - 0L - } - }.getOrElse(0L) - - updateTableStats(catalogTable, newTotalSize) + updateTableStats(relation.catalogTable, + AnalyzeTableCommand.calculateTotalSize(sessionState, relation.catalogTable)) // data source tables have been converted into LogicalRelations case logicalRel: LogicalRelation if logicalRel.catalogTable.isDefined => updateTableStats(logicalRel.catalogTable.get, logicalRel.relation.sizeInBytes) case otherRelation => - throw new AnalysisException(s"ANALYZE TABLE is not supported for " + + throw new AnalysisException("ANALYZE TABLE is not supported for " + s"${otherRelation.nodeName}.") } @@ -125,10 +84,57 @@ case class AnalyzeTableCommand(tableName: String, noscan: Boolean = true) extend if (newStats.isDefined) { sessionState.catalog.alterTable(catalogTable.copy(stats = newStats)) // Refresh the cached data source table in the catalog. - sessionState.catalog.refreshTable(tableIdent) + sessionState.catalog.refreshTable(tableIdentWithDB) } } Seq.empty[Row] } } + +object AnalyzeTableCommand extends Logging { + + def calculateTotalSize(sessionState: SessionState, catalogTable: CatalogTable): Long = { + // This method is mainly based on + // org.apache.hadoop.hive.ql.stats.StatsUtils.getFileSizeForTable(HiveConf, Table) + // in Hive 0.13 (except that we do not use fs.getContentSummary). + // TODO: Generalize statistics collection. + // TODO: Why fs.getContentSummary returns wrong size on Jenkins? + // Can we use fs.getContentSummary in future? + // Seems fs.getContentSummary returns wrong table size on Jenkins. So we use + // countFileSize to count the table size. + val stagingDir = sessionState.conf.getConfString("hive.exec.stagingdir", ".hive-staging") + + def calculateTableSize(fs: FileSystem, path: Path): Long = { + val fileStatus = fs.getFileStatus(path) + val size = if (fileStatus.isDirectory) { + fs.listStatus(path) + .map { status => + if (!status.getPath.getName.startsWith(stagingDir)) { + calculateTableSize(fs, status.getPath) + } else { + 0L + } + }.sum + } else { + fileStatus.getLen + } + + size + } + + catalogTable.storage.locationUri.map { p => + val path = new Path(p) + try { + val fs = path.getFileSystem(sessionState.newHadoopConf()) + calculateTableSize(fs, path) + } catch { + case NonFatal(e) => + logWarning( + s"Failed to get the size of table ${catalogTable.identifier.table} in the " + + s"database ${catalogTable.identifier.database} because of ${e.toString}", e) + 0L + } + }.getOrElse(0L) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e67140fefef9a..fecdf792fd14a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -581,6 +581,13 @@ object SQLConf { .timeConf(TimeUnit.MILLISECONDS) .createWithDefault(10L) + val NDV_MAX_ERROR = + SQLConfigBuilder("spark.sql.statistics.ndv.maxError") + .internal() + .doc("The maximum estimation error allowed in HyperLogLog++ algorithm.") + .doubleConf + .createWithDefault(0.05) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -757,6 +764,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { override def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL) override def crossJoinEnabled: Boolean = getConf(SQLConf.CROSS_JOINS_ENABLED) + + def ndvMaxError: Double = getConf(NDV_MAX_ERROR) /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index c899773b6b36f..9f7d0019c6b92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -23,6 +23,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.optimizer.Optimizer @@ -188,11 +189,8 @@ private[sql] class SessionState(sparkSession: SparkSession) { /** * Analyzes the given table in the current database to generate statistics, which will be * used in query optimizations. - * - * Right now, it only supports catalog tables and it only updates the size of a catalog table - * in the external catalog. */ - def analyze(tableName: String, noscan: Boolean = true): Unit = { - AnalyzeTableCommand(tableName, noscan).run(sparkSession) + def analyze(tableIdent: TableIdentifier, noscan: Boolean = true): Unit = { + AnalyzeTableCommand(tableIdent, noscan).run(sparkSession) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala new file mode 100644 index 0000000000000..0ee0547c45591 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala @@ -0,0 +1,334 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.sql.{Date, Timestamp} + +import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.catalyst.plans.logical.ColumnStat +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.command.AnalyzeColumnCommand +import org.apache.spark.sql.test.SQLTestData.ArrayData +import org.apache.spark.sql.types._ + +class StatisticsColumnSuite extends StatisticsTest { + import testImplicits._ + + test("parse analyze column commands") { + val tableName = "tbl" + + // we need to specify column names + intercept[ParseException] { + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS") + } + + val analyzeSql = s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS key, value" + val parsed = spark.sessionState.sqlParser.parsePlan(analyzeSql) + val expected = AnalyzeColumnCommand(TableIdentifier(tableName), Seq("key", "value")) + comparePlans(parsed, expected) + } + + test("analyzing columns of non-atomic types is not supported") { + val tableName = "tbl" + withTable(tableName) { + Seq(ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3)))).toDF().write.saveAsTable(tableName) + val err = intercept[AnalysisException] { + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS data") + } + assert(err.message.contains("Analyzing columns is not supported")) + } + } + + test("check correctness of columns") { + val table = "tbl" + val colName1 = "abc" + val colName2 = "x.yz" + withTable(table) { + sql(s"CREATE TABLE $table ($colName1 int, `$colName2` string) USING PARQUET") + + val invalidColError = intercept[AnalysisException] { + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS key") + } + assert(invalidColError.message == "Invalid column name: key.") + + withSQLConf("spark.sql.caseSensitive" -> "true") { + val invalidErr = intercept[AnalysisException] { + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS ${colName1.toUpperCase}") + } + assert(invalidErr.message == s"Invalid column name: ${colName1.toUpperCase}.") + } + + withSQLConf("spark.sql.caseSensitive" -> "false") { + val columnsToAnalyze = Seq(colName2.toUpperCase, colName1, colName2) + val tableIdent = TableIdentifier(table, Some("default")) + val relation = spark.sessionState.catalog.lookupRelation(tableIdent) + val (_, columnStats) = + AnalyzeColumnCommand(tableIdent, columnsToAnalyze).computeColStats(spark, relation) + assert(columnStats.contains(colName1)) + assert(columnStats.contains(colName2)) + // check deduplication + assert(columnStats.size == 2) + assert(!columnStats.contains(colName2.toUpperCase)) + } + } + } + + private def getNonNullValues[T](values: Seq[Option[T]]): Seq[T] = { + values.filter(_.isDefined).map(_.get) + } + + test("column-level statistics for integral type columns") { + val values = (0 to 5).map { i => + if (i % 2 == 0) None else Some(i) + } + val data = values.map { i => + (i.map(_.toByte), i.map(_.toShort), i.map(_.toInt), i.map(_.toLong)) + } + + val df = data.toDF("c1", "c2", "c3", "c4") + val nonNullValues = getNonNullValues[Int](values) + val expectedColStatsSeq = df.schema.map { f => + val colStat = ColumnStat(InternalRow( + values.count(_.isEmpty).toLong, + nonNullValues.max, + nonNullValues.min, + nonNullValues.distinct.length.toLong)) + (f, colStat) + } + checkColStats(df, expectedColStatsSeq) + } + + test("column-level statistics for fractional type columns") { + val values: Seq[Option[Decimal]] = (0 to 5).map { i => + if (i == 0) None else Some(Decimal(i + i * 0.01)) + } + val data = values.map { i => + (i.map(_.toFloat), i.map(_.toDouble), i) + } + + val df = data.toDF("c1", "c2", "c3") + val nonNullValues = getNonNullValues[Decimal](values) + val numNulls = values.count(_.isEmpty).toLong + val ndv = nonNullValues.distinct.length.toLong + val expectedColStatsSeq = df.schema.map { f => + val colStat = f.dataType match { + case floatType: FloatType => + ColumnStat(InternalRow(numNulls, nonNullValues.max.toFloat, nonNullValues.min.toFloat, + ndv)) + case doubleType: DoubleType => + ColumnStat(InternalRow(numNulls, nonNullValues.max.toDouble, nonNullValues.min.toDouble, + ndv)) + case decimalType: DecimalType => + ColumnStat(InternalRow(numNulls, nonNullValues.max, nonNullValues.min, ndv)) + } + (f, colStat) + } + checkColStats(df, expectedColStatsSeq) + } + + test("column-level statistics for string column") { + val values = Seq(None, Some("a"), Some("bbbb"), Some("cccc"), Some("")) + val df = values.toDF("c1") + val nonNullValues = getNonNullValues[String](values) + val expectedColStatsSeq = df.schema.map { f => + val colStat = ColumnStat(InternalRow( + values.count(_.isEmpty).toLong, + nonNullValues.map(_.length).sum / nonNullValues.length.toDouble, + nonNullValues.map(_.length).max.toLong, + nonNullValues.distinct.length.toLong)) + (f, colStat) + } + checkColStats(df, expectedColStatsSeq) + } + + test("column-level statistics for binary column") { + val values = Seq(None, Some("a"), Some("bbbb"), Some("cccc"), Some("")).map(_.map(_.getBytes)) + val df = values.toDF("c1") + val nonNullValues = getNonNullValues[Array[Byte]](values) + val expectedColStatsSeq = df.schema.map { f => + val colStat = ColumnStat(InternalRow( + values.count(_.isEmpty).toLong, + nonNullValues.map(_.length).sum / nonNullValues.length.toDouble, + nonNullValues.map(_.length).max.toLong)) + (f, colStat) + } + checkColStats(df, expectedColStatsSeq) + } + + test("column-level statistics for boolean column") { + val values = Seq(None, Some(true), Some(false), Some(true)) + val df = values.toDF("c1") + val nonNullValues = getNonNullValues[Boolean](values) + val expectedColStatsSeq = df.schema.map { f => + val colStat = ColumnStat(InternalRow( + values.count(_.isEmpty).toLong, + nonNullValues.count(_.equals(true)).toLong, + nonNullValues.count(_.equals(false)).toLong)) + (f, colStat) + } + checkColStats(df, expectedColStatsSeq) + } + + test("column-level statistics for date column") { + val values = Seq(None, Some("1970-01-01"), Some("1970-02-02")).map(_.map(Date.valueOf)) + val df = values.toDF("c1") + val nonNullValues = getNonNullValues[Date](values) + val expectedColStatsSeq = df.schema.map { f => + val colStat = ColumnStat(InternalRow( + values.count(_.isEmpty).toLong, + // Internally, DateType is represented as the number of days from 1970-01-01. + nonNullValues.map(DateTimeUtils.fromJavaDate).max, + nonNullValues.map(DateTimeUtils.fromJavaDate).min, + nonNullValues.distinct.length.toLong)) + (f, colStat) + } + checkColStats(df, expectedColStatsSeq) + } + + test("column-level statistics for timestamp column") { + val values = Seq(None, Some("1970-01-01 00:00:00"), Some("1970-01-01 00:00:05")).map { i => + i.map(Timestamp.valueOf) + } + val df = values.toDF("c1") + val nonNullValues = getNonNullValues[Timestamp](values) + val expectedColStatsSeq = df.schema.map { f => + val colStat = ColumnStat(InternalRow( + values.count(_.isEmpty).toLong, + // Internally, TimestampType is represented as the number of days from 1970-01-01 + nonNullValues.map(DateTimeUtils.fromJavaTimestamp).max, + nonNullValues.map(DateTimeUtils.fromJavaTimestamp).min, + nonNullValues.distinct.length.toLong)) + (f, colStat) + } + checkColStats(df, expectedColStatsSeq) + } + + test("column-level statistics for null columns") { + val values = Seq(None, None) + val data = values.map { i => + (i.map(_.toString), i.map(_.toString.toInt)) + } + val df = data.toDF("c1", "c2") + val expectedColStatsSeq = df.schema.map { f => + (f, ColumnStat(InternalRow(values.count(_.isEmpty).toLong, null, null, 0L))) + } + checkColStats(df, expectedColStatsSeq) + } + + test("column-level statistics for columns with different types") { + val intSeq = Seq(1, 2) + val doubleSeq = Seq(1.01d, 2.02d) + val stringSeq = Seq("a", "bb") + val binarySeq = Seq("a", "bb").map(_.getBytes) + val booleanSeq = Seq(true, false) + val dateSeq = Seq("1970-01-01", "1970-02-02").map(Date.valueOf) + val timestampSeq = Seq("1970-01-01 00:00:00", "1970-01-01 00:00:05").map(Timestamp.valueOf) + val longSeq = Seq(5L, 4L) + + val data = intSeq.indices.map { i => + (intSeq(i), doubleSeq(i), stringSeq(i), binarySeq(i), booleanSeq(i), dateSeq(i), + timestampSeq(i), longSeq(i)) + } + val df = data.toDF("c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8") + val expectedColStatsSeq = df.schema.map { f => + val colStat = f.dataType match { + case IntegerType => + ColumnStat(InternalRow(0L, intSeq.max, intSeq.min, intSeq.distinct.length.toLong)) + case DoubleType => + ColumnStat(InternalRow(0L, doubleSeq.max, doubleSeq.min, + doubleSeq.distinct.length.toLong)) + case StringType => + ColumnStat(InternalRow(0L, stringSeq.map(_.length).sum / stringSeq.length.toDouble, + stringSeq.map(_.length).max.toLong, stringSeq.distinct.length.toLong)) + case BinaryType => + ColumnStat(InternalRow(0L, binarySeq.map(_.length).sum / binarySeq.length.toDouble, + binarySeq.map(_.length).max.toLong)) + case BooleanType => + ColumnStat(InternalRow(0L, booleanSeq.count(_.equals(true)).toLong, + booleanSeq.count(_.equals(false)).toLong)) + case DateType => + ColumnStat(InternalRow(0L, dateSeq.map(DateTimeUtils.fromJavaDate).max, + dateSeq.map(DateTimeUtils.fromJavaDate).min, dateSeq.distinct.length.toLong)) + case TimestampType => + ColumnStat(InternalRow(0L, timestampSeq.map(DateTimeUtils.fromJavaTimestamp).max, + timestampSeq.map(DateTimeUtils.fromJavaTimestamp).min, + timestampSeq.distinct.length.toLong)) + case LongType => + ColumnStat(InternalRow(0L, longSeq.max, longSeq.min, longSeq.distinct.length.toLong)) + } + (f, colStat) + } + checkColStats(df, expectedColStatsSeq) + } + + test("update table-level stats while collecting column-level stats") { + val table = "tbl" + withTable(table) { + sql(s"CREATE TABLE $table (c1 int) USING PARQUET") + sql(s"INSERT INTO $table SELECT 1") + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS") + checkTableStats(tableName = table, expectedRowCount = Some(1)) + + // update table-level stats between analyze table and analyze column commands + sql(s"INSERT INTO $table SELECT 1") + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1") + val fetchedStats = checkTableStats(tableName = table, expectedRowCount = Some(2)) + + val colStat = fetchedStats.get.colStats("c1") + StatisticsTest.checkColStat( + dataType = IntegerType, + colStat = colStat, + expectedColStat = ColumnStat(InternalRow(0L, 1, 1, 1L)), + rsd = spark.sessionState.conf.ndvMaxError) + } + } + + test("analyze column stats independently") { + val table = "tbl" + withTable(table) { + sql(s"CREATE TABLE $table (c1 int, c2 long) USING PARQUET") + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1") + val fetchedStats1 = checkTableStats(tableName = table, expectedRowCount = Some(0)) + assert(fetchedStats1.get.colStats.size == 1) + val expected1 = ColumnStat(InternalRow(0L, null, null, 0L)) + val rsd = spark.sessionState.conf.ndvMaxError + StatisticsTest.checkColStat( + dataType = IntegerType, + colStat = fetchedStats1.get.colStats("c1"), + expectedColStat = expected1, + rsd = rsd) + + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c2") + val fetchedStats2 = checkTableStats(tableName = table, expectedRowCount = Some(0)) + // column c1 is kept in the stats + assert(fetchedStats2.get.colStats.size == 2) + StatisticsTest.checkColStat( + dataType = IntegerType, + colStat = fetchedStats2.get.colStats("c1"), + expectedColStat = expected1, + rsd = rsd) + val expected2 = ColumnStat(InternalRow(0L, null, null, 0L)) + StatisticsTest.checkColStat( + dataType = LongType, + colStat = fetchedStats2.get.colStats("c2"), + expectedColStat = expected2, + rsd = rsd) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala index 264a2ffbebebd..8cf42e9248c2a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala @@ -18,11 +18,9 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.plans.logical.{GlobalLimit, Join, LocalLimit} -import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -class StatisticsSuite extends QueryTest with SharedSQLContext { +class StatisticsSuite extends StatisticsTest { import testImplicits._ test("SPARK-15392: DataFrame created from RDD should not be broadcasted") { @@ -77,20 +75,10 @@ class StatisticsSuite extends QueryTest with SharedSQLContext { } test("test table-level statistics for data source table created in InMemoryCatalog") { - def checkTableStats(tableName: String, expectedRowCount: Option[BigInt]): Unit = { - val df = sql(s"SELECT * FROM $tableName") - val relations = df.queryExecution.analyzed.collect { case rel: LogicalRelation => - assert(rel.catalogTable.isDefined) - assert(rel.catalogTable.get.stats.flatMap(_.rowCount) === expectedRowCount) - rel - } - assert(relations.size === 1) - } - val tableName = "tbl" withTable(tableName) { sql(s"CREATE TABLE $tableName(i INT, j STRING) USING parquet") - Seq(1 -> "a", 2 -> "b").toDF("i", "j").write.mode("overwrite").insertInto("tbl") + Seq(1 -> "a", 2 -> "b").toDF("i", "j").write.mode("overwrite").insertInto(tableName) // noscan won't count the number of rows sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS noscan") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala new file mode 100644 index 0000000000000..5134ac0e7e5b3 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics} +import org.apache.spark.sql.execution.command.{AnalyzeColumnCommand, ColumnStatStruct} +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ + +trait StatisticsTest extends QueryTest with SharedSQLContext { + + def checkColStats( + df: DataFrame, + expectedColStatsSeq: Seq[(StructField, ColumnStat)]): Unit = { + val table = "tbl" + withTable(table) { + df.write.format("json").saveAsTable(table) + val columns = expectedColStatsSeq.map(_._1) + val tableIdent = TableIdentifier(table, Some("default")) + val relation = spark.sessionState.catalog.lookupRelation(tableIdent) + val (_, columnStats) = + AnalyzeColumnCommand(tableIdent, columns.map(_.name)).computeColStats(spark, relation) + expectedColStatsSeq.foreach { case (field, expectedColStat) => + assert(columnStats.contains(field.name)) + val colStat = columnStats(field.name) + StatisticsTest.checkColStat( + dataType = field.dataType, + colStat = colStat, + expectedColStat = expectedColStat, + rsd = spark.sessionState.conf.ndvMaxError) + + // check if we get the same colStat after encoding and decoding + val encodedCS = colStat.toString + val numFields = ColumnStatStruct.numStatFields(field.dataType) + val decodedCS = ColumnStat(numFields, encodedCS) + StatisticsTest.checkColStat( + dataType = field.dataType, + colStat = decodedCS, + expectedColStat = expectedColStat, + rsd = spark.sessionState.conf.ndvMaxError) + } + } + } + + def checkTableStats(tableName: String, expectedRowCount: Option[Int]): Option[Statistics] = { + val df = spark.table(tableName) + val stats = df.queryExecution.analyzed.collect { case rel: LogicalRelation => + assert(rel.catalogTable.get.stats.flatMap(_.rowCount) === expectedRowCount) + rel.catalogTable.get.stats + } + assert(stats.size == 1) + stats.head + } +} + +object StatisticsTest { + def checkColStat( + dataType: DataType, + colStat: ColumnStat, + expectedColStat: ColumnStat, + rsd: Double): Unit = { + dataType match { + case StringType => + val cs = colStat.forString + val expectedCS = expectedColStat.forString + assert(cs.numNulls == expectedCS.numNulls) + assert(cs.avgColLen == expectedCS.avgColLen) + assert(cs.maxColLen == expectedCS.maxColLen) + checkNdv(ndv = cs.ndv, expectedNdv = expectedCS.ndv, rsd = rsd) + case BinaryType => + val cs = colStat.forBinary + val expectedCS = expectedColStat.forBinary + assert(cs.numNulls == expectedCS.numNulls) + assert(cs.avgColLen == expectedCS.avgColLen) + assert(cs.maxColLen == expectedCS.maxColLen) + case BooleanType => + val cs = colStat.forBoolean + val expectedCS = expectedColStat.forBoolean + assert(cs.numNulls == expectedCS.numNulls) + assert(cs.numTrues == expectedCS.numTrues) + assert(cs.numFalses == expectedCS.numFalses) + case atomicType: AtomicType => + checkNumericColStats( + dataType = atomicType, colStat = colStat, expectedColStat = expectedColStat, rsd = rsd) + } + } + + private def checkNumericColStats( + dataType: AtomicType, + colStat: ColumnStat, + expectedColStat: ColumnStat, + rsd: Double): Unit = { + val cs = colStat.forNumeric(dataType) + val expectedCS = expectedColStat.forNumeric(dataType) + assert(cs.numNulls == expectedCS.numNulls) + assert(cs.max == expectedCS.max) + assert(cs.min == expectedCS.min) + checkNdv(ndv = cs.ndv, expectedNdv = expectedCS.ndv, rsd = rsd) + } + + private def checkNdv(ndv: Long, expectedNdv: Long, rsd: Double): Unit = { + // ndv is an approximate value, so we make sure we have the value, and it should be + // within 3*SD's of the given rsd. + if (expectedNdv == 0) { + assert(ndv == 0) + } else if (expectedNdv > 0) { + assert(ndv > 0) + val error = math.abs((ndv / expectedNdv.toDouble) - 1.0d) + assert(error <= rsd * 3.0d, "Error should be within 3 std. errors.") + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index d35a681b67e38..261cc6feff090 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -32,8 +32,8 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.plans.logical.Statistics -import org.apache.spark.sql.execution.command.DDLUtils +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics} +import org.apache.spark.sql.execution.command.{ColumnStatStruct, DDLUtils} import org.apache.spark.sql.execution.datasources.CaseInsensitiveMap import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.HiveSerDe @@ -401,7 +401,10 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat var statsProperties: Map[String, String] = Map(STATISTICS_TOTAL_SIZE -> stats.sizeInBytes.toString()) if (stats.rowCount.isDefined) { - statsProperties += (STATISTICS_NUM_ROWS -> stats.rowCount.get.toString()) + statsProperties += STATISTICS_NUM_ROWS -> stats.rowCount.get.toString() + } + stats.colStats.foreach { case (colName, colStat) => + statsProperties += (STATISTICS_COL_STATS_PREFIX + colName) -> colStat.toString } tableDefinition.copy(properties = tableDefinition.properties ++ statsProperties) } else { @@ -473,15 +476,21 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } } // construct Spark's statistics from information in Hive metastore - if (catalogTable.properties.contains(STATISTICS_TOTAL_SIZE)) { - val totalSize = BigInt(catalogTable.properties.get(STATISTICS_TOTAL_SIZE).get) - // TODO: we will compute "estimatedSize" when we have column stats: - // average size of row * number of rows + val statsProps = catalogTable.properties.filterKeys(_.startsWith(STATISTICS_PREFIX)) + if (statsProps.nonEmpty) { + val colStatsProps = statsProps.filterKeys(_.startsWith(STATISTICS_COL_STATS_PREFIX)) + .map { case (k, v) => (k.drop(STATISTICS_COL_STATS_PREFIX.length), v) } + val colStats: Map[String, ColumnStat] = catalogTable.schema.collect { + case f if colStatsProps.contains(f.name) => + val numFields = ColumnStatStruct.numStatFields(f.dataType) + (f.name, ColumnStat(numFields, colStatsProps(f.name))) + }.toMap catalogTable.copy( properties = removeStatsProperties(catalogTable), stats = Some(Statistics( - sizeInBytes = totalSize, - rowCount = catalogTable.properties.get(STATISTICS_NUM_ROWS).map(BigInt(_))))) + sizeInBytes = BigInt(catalogTable.properties(STATISTICS_TOTAL_SIZE)), + rowCount = catalogTable.properties.get(STATISTICS_NUM_ROWS).map(BigInt(_)), + colStats = colStats))) } else { catalogTable } @@ -693,6 +702,7 @@ object HiveExternalCatalog { val STATISTICS_PREFIX = "spark.sql.statistics." val STATISTICS_TOTAL_SIZE = STATISTICS_PREFIX + "totalSize" val STATISTICS_NUM_ROWS = STATISTICS_PREFIX + "numRows" + val STATISTICS_COL_STATS_PREFIX = STATISTICS_PREFIX + "colStats." def removeStatsProperties(metadata: CatalogTable): Map[String, String] = { metadata.properties.filterNot { case (key, _) => key.startsWith(STATISTICS_PREFIX) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 9956706929cd1..99dd080683d40 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -21,16 +21,16 @@ import java.io.{File, PrintWriter} import scala.reflect.ClassTag -import org.apache.spark.sql.{AnalysisException, QueryTest, Row} -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.{AnalysisException, QueryTest, Row, StatisticsTest} +import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics} import org.apache.spark.sql.execution.command.{AnalyzeTableCommand, DDLUtils} import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types._ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { @@ -171,7 +171,27 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils TableIdentifier("tempTable"), ignoreIfNotExists = true, purge = false) } - private def checkStats( + test("analyzing views is not supported") { + def assertAnalyzeUnsupported(analyzeCommand: String): Unit = { + val err = intercept[AnalysisException] { + sql(analyzeCommand) + } + assert(err.message.contains("ANALYZE TABLE is not supported")) + } + + val tableName = "tbl" + withTable(tableName) { + spark.range(10).write.saveAsTable(tableName) + val viewName = "view" + withView(viewName) { + sql(s"CREATE VIEW $viewName AS SELECT * FROM $tableName") + assertAnalyzeUnsupported(s"ANALYZE TABLE $viewName COMPUTE STATISTICS") + assertAnalyzeUnsupported(s"ANALYZE TABLE $viewName COMPUTE STATISTICS FOR COLUMNS id") + } + } + } + + private def checkTableStats( stats: Option[Statistics], hasSizeInBytes: Boolean, expectedRowCounts: Option[Int]): Unit = { @@ -184,7 +204,7 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils } } - private def checkStats( + private def checkTableStats( tableName: String, isDataSourceTable: Boolean, hasSizeInBytes: Boolean, @@ -192,12 +212,12 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils val df = sql(s"SELECT * FROM $tableName") val stats = df.queryExecution.analyzed.collect { case rel: MetastoreRelation => - checkStats(rel.catalogTable.stats, hasSizeInBytes, expectedRowCounts) - assert(!isDataSourceTable, "Expected a data source table, but got a Hive serde table") + checkTableStats(rel.catalogTable.stats, hasSizeInBytes, expectedRowCounts) + assert(!isDataSourceTable, "Expected a Hive serde table, but got a data source table") rel.catalogTable.stats case rel: LogicalRelation => - checkStats(rel.catalogTable.get.stats, hasSizeInBytes, expectedRowCounts) - assert(isDataSourceTable, "Expected a Hive serde table, but got a data source table") + checkTableStats(rel.catalogTable.get.stats, hasSizeInBytes, expectedRowCounts) + assert(isDataSourceTable, "Expected a data source table, but got a Hive serde table") rel.catalogTable.get.stats } assert(stats.size == 1) @@ -210,13 +230,13 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils // Currently Spark's statistics are self-contained, we don't have statistics until we use // the `ANALYZE TABLE` command. sql(s"CREATE TABLE $textTable (key STRING, value STRING) STORED AS TEXTFILE") - checkStats( + checkTableStats( textTable, isDataSourceTable = false, hasSizeInBytes = false, expectedRowCounts = None) sql(s"INSERT INTO TABLE $textTable SELECT * FROM src") - checkStats( + checkTableStats( textTable, isDataSourceTable = false, hasSizeInBytes = false, @@ -224,12 +244,12 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils // noscan won't count the number of rows sql(s"ANALYZE TABLE $textTable COMPUTE STATISTICS noscan") - val fetchedStats1 = checkStats( + val fetchedStats1 = checkTableStats( textTable, isDataSourceTable = false, hasSizeInBytes = true, expectedRowCounts = None) // without noscan, we count the number of rows sql(s"ANALYZE TABLE $textTable COMPUTE STATISTICS") - val fetchedStats2 = checkStats( + val fetchedStats2 = checkTableStats( textTable, isDataSourceTable = false, hasSizeInBytes = true, expectedRowCounts = Some(500)) assert(fetchedStats1.get.sizeInBytes == fetchedStats2.get.sizeInBytes) } @@ -241,19 +261,19 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils sql(s"CREATE TABLE $textTable (key STRING, value STRING) STORED AS TEXTFILE") sql(s"INSERT INTO TABLE $textTable SELECT * FROM src") sql(s"ANALYZE TABLE $textTable COMPUTE STATISTICS") - val fetchedStats1 = checkStats( + val fetchedStats1 = checkTableStats( textTable, isDataSourceTable = false, hasSizeInBytes = true, expectedRowCounts = Some(500)) sql(s"ANALYZE TABLE $textTable COMPUTE STATISTICS noscan") // when the total size is not changed, the old row count is kept - val fetchedStats2 = checkStats( + val fetchedStats2 = checkTableStats( textTable, isDataSourceTable = false, hasSizeInBytes = true, expectedRowCounts = Some(500)) assert(fetchedStats1 == fetchedStats2) sql(s"INSERT INTO TABLE $textTable SELECT * FROM src") sql(s"ANALYZE TABLE $textTable COMPUTE STATISTICS noscan") // update total size and remove the old and invalid row count - val fetchedStats3 = checkStats( + val fetchedStats3 = checkTableStats( textTable, isDataSourceTable = false, hasSizeInBytes = true, expectedRowCounts = None) assert(fetchedStats3.get.sizeInBytes > fetchedStats2.get.sizeInBytes) } @@ -271,20 +291,20 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils // the default value for `spark.sql.hive.convertMetastoreParquet` is true, here we just set it // for robustness withSQLConf("spark.sql.hive.convertMetastoreParquet" -> "true") { - checkStats( + checkTableStats( parquetTable, isDataSourceTable = true, hasSizeInBytes = false, expectedRowCounts = None) sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS") - checkStats( + checkTableStats( parquetTable, isDataSourceTable = true, hasSizeInBytes = true, expectedRowCounts = Some(500)) } withSQLConf("spark.sql.hive.convertMetastoreOrc" -> "true") { - checkStats( + checkTableStats( orcTable, isDataSourceTable = true, hasSizeInBytes = false, expectedRowCounts = None) sql(s"ANALYZE TABLE $orcTable COMPUTE STATISTICS") - checkStats( + checkTableStats( orcTable, isDataSourceTable = true, hasSizeInBytes = true, expectedRowCounts = Some(500)) } } @@ -298,23 +318,23 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils assert(DDLUtils.isDatasourceTable(catalogTable)) sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src") - checkStats( + checkTableStats( parquetTable, isDataSourceTable = true, hasSizeInBytes = false, expectedRowCounts = None) // noscan won't count the number of rows sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS noscan") - val fetchedStats1 = checkStats( + val fetchedStats1 = checkTableStats( parquetTable, isDataSourceTable = true, hasSizeInBytes = true, expectedRowCounts = None) sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src") sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS noscan") - val fetchedStats2 = checkStats( + val fetchedStats2 = checkTableStats( parquetTable, isDataSourceTable = true, hasSizeInBytes = true, expectedRowCounts = None) assert(fetchedStats2.get.sizeInBytes > fetchedStats1.get.sizeInBytes) // without noscan, we count the number of rows sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS") - val fetchedStats3 = checkStats( + val fetchedStats3 = checkTableStats( parquetTable, isDataSourceTable = true, hasSizeInBytes = true, @@ -330,7 +350,7 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils val dfNoCols = spark.createDataFrame(rddNoCols, StructType(Seq.empty)) dfNoCols.write.format("json").saveAsTable(table_no_cols) sql(s"ANALYZE TABLE $table_no_cols COMPUTE STATISTICS") - checkStats( + checkTableStats( table_no_cols, isDataSourceTable = true, hasSizeInBytes = true, @@ -338,6 +358,53 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils } } + test("generate column-level statistics and load them from hive metastore") { + import testImplicits._ + + val intSeq = Seq(1, 2) + val stringSeq = Seq("a", "bb") + val booleanSeq = Seq(true, false) + + val data = intSeq.indices.map { i => + (intSeq(i), stringSeq(i), booleanSeq(i)) + } + val tableName = "table" + withTable(tableName) { + val df = data.toDF("c1", "c2", "c3") + df.write.format("parquet").saveAsTable(tableName) + val expectedColStatsSeq = df.schema.map { f => + val colStat = f.dataType match { + case IntegerType => + ColumnStat(InternalRow(0L, intSeq.max, intSeq.min, intSeq.distinct.length.toLong)) + case StringType => + ColumnStat(InternalRow(0L, stringSeq.map(_.length).sum / stringSeq.length.toDouble, + stringSeq.map(_.length).max.toLong, stringSeq.distinct.length.toLong)) + case BooleanType => + ColumnStat(InternalRow(0L, booleanSeq.count(_.equals(true)).toLong, + booleanSeq.count(_.equals(false)).toLong)) + } + (f, colStat) + } + + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS c1, c2, c3") + val readback = spark.table(tableName) + val relations = readback.queryExecution.analyzed.collect { case rel: LogicalRelation => + val columnStats = rel.catalogTable.get.stats.get.colStats + expectedColStatsSeq.foreach { case (field, expectedColStat) => + assert(columnStats.contains(field.name)) + val colStat = columnStats(field.name) + StatisticsTest.checkColStat( + dataType = field.dataType, + colStat = colStat, + expectedColStat = expectedColStat, + rsd = spark.sessionState.conf.ndvMaxError) + } + rel + } + assert(relations.size == 1) + } + } + test("estimates the size of a test MetastoreRelation") { val df = sql("""SELECT * FROM src""") val sizes = df.queryExecution.analyzed.collect { case mr: MetastoreRelation => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLViewSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLViewSuite.scala index a215c70da0c52..f5c605fe5e2fa 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLViewSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLViewSuite.scala @@ -123,6 +123,7 @@ class SQLViewSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { assertNoSuchTable(s"SHOW CREATE TABLE $viewName") assertNoSuchTable(s"SHOW PARTITIONS $viewName") assertNoSuchTable(s"ANALYZE TABLE $viewName COMPUTE STATISTICS") + assertNoSuchTable(s"ANALYZE TABLE $viewName COMPUTE STATISTICS FOR COLUMNS id") } } From 1dd68d3827133d203e85294405400b04904879e0 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 3 Oct 2016 18:09:36 +0000 Subject: [PATCH 69/96] [SPARK-17718][DOCS][MLLIB] Make loss function formulation label note clearer in MLlib docs ## What changes were proposed in this pull request? Move note about labels being +1/-1 in formulation only to be just under the table of formulations. ## How was this patch tested? Doc build Author: Sean Owen Closes #15330 from srowen/SPARK-17718. --- docs/mllib-linear-methods.md | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 6fcd3ae85700c..816bdf1317000 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -78,6 +78,11 @@ methods `spark.mllib` supports: +Note that, in the mathematical formulation above, a binary label $y$ is denoted as either +$+1$ (positive) or $-1$ (negative), which is convenient for the formulation. +*However*, the negative label is represented by $0$ in `spark.mllib` instead of $-1$, to be consistent with +multiclass labeling. + ### Regularizers The purpose of the @@ -136,10 +141,6 @@ multiclass classification problems. For both methods, `spark.mllib` supports L1 and L2 regularized variants. The training data set is represented by an RDD of [LabeledPoint](mllib-data-types.html) in MLlib, where labels are class indices starting from zero: $0, 1, 2, \ldots$. -Note that, in the mathematical formulation in this guide, a binary label $y$ is denoted as either -$+1$ (positive) or $-1$ (negative), which is convenient for the formulation. -*However*, the negative label is represented by $0$ in `spark.mllib` instead of $-1$, to be consistent with -multiclass labeling. ### Linear Support Vector Machines (SVMs) From 1f31bdaef670dd43999613deae3620f4ddcd1fbf Mon Sep 17 00:00:00 2001 From: Jason White Date: Mon, 3 Oct 2016 14:12:03 -0700 Subject: [PATCH 70/96] [SPARK-17679] [PYSPARK] remove unnecessary Py4J ListConverter patch ## What changes were proposed in this pull request? This PR removes a patch on ListConverter from https://github.com/apache/spark/pull/5570, as it is no longer necessary. The underlying issue in Py4J https://github.com/bartdag/py4j/issues/160 was patched in https://github.com/bartdag/py4j/commit/224b94b6665e56a93a064073886e1d803a4969d2 and is present in 0.10.3, the version currently in use in Spark. ## How was this patch tested? The original test added in https://github.com/apache/spark/pull/5570 remains. Author: Jason White Closes #15254 from JasonMWhite/remove_listconverter_patch. --- python/pyspark/java_gateway.py | 9 --------- python/pyspark/ml/common.py | 4 ++-- python/pyspark/mllib/common.py | 4 ++-- python/pyspark/rdd.py | 13 ++----------- 4 files changed, 6 insertions(+), 24 deletions(-) diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 527ca82d31f1b..f76cadcf62438 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -29,18 +29,9 @@ xrange = range from py4j.java_gateway import java_import, JavaGateway, GatewayClient -from py4j.java_collections import ListConverter - from pyspark.serializers import read_int -# patching ListConverter, or it will convert bytearray into Java ArrayList -def can_convert_list(self, obj): - return isinstance(obj, (list, tuple, xrange)) - -ListConverter.can_convert = can_convert_list - - def launch_gateway(): if "PYSPARK_GATEWAY_PORT" in os.environ: gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"]) diff --git a/python/pyspark/ml/common.py b/python/pyspark/ml/common.py index aec860fca7057..387c5d7309dea 100644 --- a/python/pyspark/ml/common.py +++ b/python/pyspark/ml/common.py @@ -23,7 +23,7 @@ import py4j.protocol from py4j.protocol import Py4JJavaError from py4j.java_gateway import JavaObject -from py4j.java_collections import ListConverter, JavaArray, JavaList +from py4j.java_collections import JavaArray, JavaList from pyspark import RDD, SparkContext from pyspark.serializers import PickleSerializer, AutoBatchedSerializer @@ -76,7 +76,7 @@ def _py2java(sc, obj): elif isinstance(obj, SparkContext): obj = obj._jsc elif isinstance(obj, list): - obj = ListConverter().convert([_py2java(sc, x) for x in obj], sc._gateway._gateway_client) + obj = [_py2java(sc, x) for x in obj] elif isinstance(obj, JavaObject): pass elif isinstance(obj, (int, long, float, bool, bytes, unicode)): diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py index 21f0e09ea7742..bac8f350563ec 100644 --- a/python/pyspark/mllib/common.py +++ b/python/pyspark/mllib/common.py @@ -23,7 +23,7 @@ import py4j.protocol from py4j.protocol import Py4JJavaError from py4j.java_gateway import JavaObject -from py4j.java_collections import ListConverter, JavaArray, JavaList +from py4j.java_collections import JavaArray, JavaList from pyspark import RDD, SparkContext from pyspark.serializers import PickleSerializer, AutoBatchedSerializer @@ -78,7 +78,7 @@ def _py2java(sc, obj): elif isinstance(obj, SparkContext): obj = obj._jsc elif isinstance(obj, list): - obj = ListConverter().convert([_py2java(sc, x) for x in obj], sc._gateway._gateway_client) + obj = [_py2java(sc, x) for x in obj] elif isinstance(obj, JavaObject): pass elif isinstance(obj, (int, long, float, bool, bytes, unicode)): diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 5fb10f86f4692..ed81eb16df3cd 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -52,8 +52,6 @@ get_used_memory, ExternalSorter, ExternalGroupBy from pyspark.traceback_utils import SCCallSiteSync -from py4j.java_collections import ListConverter, MapConverter - __all__ = ["RDD"] @@ -2317,16 +2315,9 @@ def _prepare_for_python_RDD(sc, command): # The broadcast will have same life cycle as created PythonRDD broadcast = sc.broadcast(pickled_command) pickled_command = ser.dumps(broadcast) - # There is a bug in py4j.java_gateway.JavaClass with auto_convert - # https://github.com/bartdag/py4j/issues/161 - # TODO: use auto_convert once py4j fix the bug - broadcast_vars = ListConverter().convert( - [x._jbroadcast for x in sc._pickled_broadcast_vars], - sc._gateway._gateway_client) + broadcast_vars = [x._jbroadcast for x in sc._pickled_broadcast_vars] sc._pickled_broadcast_vars.clear() - env = MapConverter().convert(sc.environment, sc._gateway._gateway_client) - includes = ListConverter().convert(sc._python_includes, sc._gateway._gateway_client) - return pickled_command, broadcast_vars, env, includes + return pickled_command, broadcast_vars, sc.environment, sc._python_includes def _wrap_function(sc, func, deserializer, serializer, profiler=None): From d8399b600cef706c22d381b01fab19c610db439a Mon Sep 17 00:00:00 2001 From: zero323 Date: Mon, 3 Oct 2016 17:57:54 -0700 Subject: [PATCH 71/96] [SPARK-17587][PYTHON][MLLIB] SparseVector __getitem__ should follow __getitem__ contract ## What changes were proposed in this pull request? Replaces` ValueError` with `IndexError` when index passed to `ml` / `mllib` `SparseVector.__getitem__` is out of range. This ensures correct iteration behavior. Replaces `ValueError` with `IndexError` for `DenseMatrix` and `SparkMatrix` in `ml` / `mllib`. ## How was this patch tested? PySpark `ml` / `mllib` unit tests. Additional unit tests to prove that the problem has been resolved. Author: zero323 Closes #15144 from zero323/SPARK-17587. --- python/pyspark/ml/linalg/__init__.py | 10 +++++----- python/pyspark/ml/tests.py | 16 +++++++++++++--- python/pyspark/mllib/linalg/__init__.py | 10 +++++----- python/pyspark/mllib/tests.py | 16 +++++++++++++--- 4 files changed, 36 insertions(+), 16 deletions(-) diff --git a/python/pyspark/ml/linalg/__init__.py b/python/pyspark/ml/linalg/__init__.py index 05c0ac862fb7f..a5df727fdb418 100644 --- a/python/pyspark/ml/linalg/__init__.py +++ b/python/pyspark/ml/linalg/__init__.py @@ -713,7 +713,7 @@ def __getitem__(self, index): "Indices must be of type integer, got type %s" % type(index)) if index >= self.size or index < -self.size: - raise ValueError("Index %d out of bounds." % index) + raise IndexError("Index %d out of bounds." % index) if index < 0: index += self.size @@ -960,10 +960,10 @@ def toSparse(self): def __getitem__(self, indices): i, j = indices if i < 0 or i >= self.numRows: - raise ValueError("Row index %d is out of range [0, %d)" + raise IndexError("Row index %d is out of range [0, %d)" % (i, self.numRows)) if j >= self.numCols or j < 0: - raise ValueError("Column index %d is out of range [0, %d)" + raise IndexError("Column index %d is out of range [0, %d)" % (j, self.numCols)) if self.isTransposed: @@ -1090,10 +1090,10 @@ def __reduce__(self): def __getitem__(self, indices): i, j = indices if i < 0 or i >= self.numRows: - raise ValueError("Row index %d is out of range [0, %d)" + raise IndexError("Row index %d is out of range [0, %d)" % (i, self.numRows)) if j < 0 or j >= self.numCols: - raise ValueError("Column index %d is out of range [0, %d)" + raise IndexError("Column index %d is out of range [0, %d)" % (j, self.numCols)) # If a CSR matrix is given, then the row index should be searched diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 6886ed321ee82..e233549850888 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1316,7 +1316,7 @@ def test_sparse_vector_indexing(self): self.assertEqual(sv[-3], 0.) self.assertEqual(sv[-5], 0.) for ind in [5, -6]: - self.assertRaises(ValueError, sv.__getitem__, ind) + self.assertRaises(IndexError, sv.__getitem__, ind) for ind in [7.8, '1']: self.assertRaises(TypeError, sv.__getitem__, ind) @@ -1324,11 +1324,15 @@ def test_sparse_vector_indexing(self): self.assertEqual(zeros[0], 0.0) self.assertEqual(zeros[3], 0.0) for ind in [4, -5]: - self.assertRaises(ValueError, zeros.__getitem__, ind) + self.assertRaises(IndexError, zeros.__getitem__, ind) empty = SparseVector(0, {}) for ind in [-1, 0, 1]: - self.assertRaises(ValueError, empty.__getitem__, ind) + self.assertRaises(IndexError, empty.__getitem__, ind) + + def test_sparse_vector_iteration(self): + self.assertListEqual(list(SparseVector(3, [], [])), [0.0, 0.0, 0.0]) + self.assertListEqual(list(SparseVector(5, [0, 3], [1.0, 2.0])), [1.0, 0.0, 0.0, 2.0, 0.0]) def test_matrix_indexing(self): mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) @@ -1337,6 +1341,9 @@ def test_matrix_indexing(self): for j in range(2): self.assertEqual(mat[i, j], expected[i][j]) + for i, j in [(-1, 0), (4, 1), (3, 4)]: + self.assertRaises(IndexError, mat.__getitem__, (i, j)) + def test_repr_dense_matrix(self): mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) self.assertTrue( @@ -1408,6 +1415,9 @@ def test_sparse_matrix(self): self.assertEqual(expected[i][j], sm1[i, j]) self.assertTrue(array_equal(sm1.toArray(), expected)) + for i, j in [(-1, 1), (4, 3), (3, 5)]: + self.assertRaises(IndexError, sm1.__getitem__, (i, j)) + # Test conversion to dense and sparse. smnew = sm1.toDense().toSparse() self.assertEqual(sm1.numRows, smnew.numRows) diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index 9672dbde823f2..d37e715c8d8ec 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -802,7 +802,7 @@ def __getitem__(self, index): "Indices must be of type integer, got type %s" % type(index)) if index >= self.size or index < -self.size: - raise ValueError("Index %d out of bounds." % index) + raise IndexError("Index %d out of bounds." % index) if index < 0: index += self.size @@ -1115,10 +1115,10 @@ def asML(self): def __getitem__(self, indices): i, j = indices if i < 0 or i >= self.numRows: - raise ValueError("Row index %d is out of range [0, %d)" + raise IndexError("Row index %d is out of range [0, %d)" % (i, self.numRows)) if j >= self.numCols or j < 0: - raise ValueError("Column index %d is out of range [0, %d)" + raise IndexError("Column index %d is out of range [0, %d)" % (j, self.numCols)) if self.isTransposed: @@ -1245,10 +1245,10 @@ def __reduce__(self): def __getitem__(self, indices): i, j = indices if i < 0 or i >= self.numRows: - raise ValueError("Row index %d is out of range [0, %d)" + raise IndexError("Row index %d is out of range [0, %d)" % (i, self.numRows)) if j < 0 or j >= self.numCols: - raise ValueError("Column index %d is out of range [0, %d)" + raise IndexError("Column index %d is out of range [0, %d)" % (j, self.numCols)) # If a CSR matrix is given, then the row index should be searched diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 3f3dfd186c10d..c519883cdd73b 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -260,7 +260,7 @@ def test_sparse_vector_indexing(self): self.assertEqual(sv[-3], 0.) self.assertEqual(sv[-5], 0.) for ind in [5, -6]: - self.assertRaises(ValueError, sv.__getitem__, ind) + self.assertRaises(IndexError, sv.__getitem__, ind) for ind in [7.8, '1']: self.assertRaises(TypeError, sv.__getitem__, ind) @@ -268,11 +268,15 @@ def test_sparse_vector_indexing(self): self.assertEqual(zeros[0], 0.0) self.assertEqual(zeros[3], 0.0) for ind in [4, -5]: - self.assertRaises(ValueError, zeros.__getitem__, ind) + self.assertRaises(IndexError, zeros.__getitem__, ind) empty = SparseVector(0, {}) for ind in [-1, 0, 1]: - self.assertRaises(ValueError, empty.__getitem__, ind) + self.assertRaises(IndexError, empty.__getitem__, ind) + + def test_sparse_vector_iteration(self): + self.assertListEqual(list(SparseVector(3, [], [])), [0.0, 0.0, 0.0]) + self.assertListEqual(list(SparseVector(5, [0, 3], [1.0, 2.0])), [1.0, 0.0, 0.0, 2.0, 0.0]) def test_matrix_indexing(self): mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) @@ -281,6 +285,9 @@ def test_matrix_indexing(self): for j in range(2): self.assertEqual(mat[i, j], expected[i][j]) + for i, j in [(-1, 0), (4, 1), (3, 4)]: + self.assertRaises(IndexError, mat.__getitem__, (i, j)) + def test_repr_dense_matrix(self): mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) self.assertTrue( @@ -352,6 +359,9 @@ def test_sparse_matrix(self): self.assertEqual(expected[i][j], sm1[i, j]) self.assertTrue(array_equal(sm1.toArray(), expected)) + for i, j in [(-1, 1), (4, 3), (3, 5)]: + self.assertRaises(IndexError, sm1.__getitem__, (i, j)) + # Test conversion to dense and sparse. smnew = sm1.toDense().toSparse() self.assertEqual(sm1.numRows, smnew.numRows) From 2bbecdec2023143fd144e4242ff70822e0823986 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Mon, 3 Oct 2016 19:32:59 -0700 Subject: [PATCH 72/96] [SPARK-17753][SQL] Allow a complex expression as the input a value based case statement ## What changes were proposed in this pull request? We currently only allow relatively simple expressions as the input for a value based case statement. Expressions like `case (a > 1) or (b = 2) when true then 1 when false then 0 end` currently fail. This PR adds support for such expressions. ## How was this patch tested? Added a test to the ExpressionParserSuite. Author: Herman van Hovell Closes #15322 from hvanhovell/SPARK-17753. --- .../org/apache/spark/sql/catalyst/parser/SqlBase.g4 | 12 ++++++------ .../spark/sql/catalyst/parser/AstBuilder.scala | 2 +- .../sql/catalyst/parser/ExpressionParserSuite.scala | 4 ++++ 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 1284681fe80b4..c336a0c8eab7a 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -527,16 +527,16 @@ valueExpression ; primaryExpression - : constant #constantDefault - | name=(CURRENT_DATE | CURRENT_TIMESTAMP) #timeFunctionCall + : name=(CURRENT_DATE | CURRENT_TIMESTAMP) #timeFunctionCall + | CASE value=expression whenClause+ (ELSE elseExpression=expression)? END #simpleCase + | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase + | CAST '(' expression AS dataType ')' #cast + | constant #constantDefault | ASTERISK #star | qualifiedName '.' ASTERISK #star | '(' expression (',' expression)+ ')' #rowConstructor - | qualifiedName '(' (setQuantifier? expression (',' expression)*)? ')' (OVER windowSpec)? #functionCall | '(' query ')' #subqueryExpression - | CASE valueExpression whenClause+ (ELSE elseExpression=expression)? END #simpleCase - | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase - | CAST '(' expression AS dataType ')' #cast + | qualifiedName '(' (setQuantifier? expression (',' expression)*)? ')' (OVER windowSpec)? #functionCall | value=primaryExpression '[' index=valueExpression ']' #subscript | identifier #columnReference | base=primaryExpression '.' fieldName=identifier #dereference diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 12a70b7769ef6..cd0c70a49150d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1138,7 +1138,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * }}} */ override def visitSimpleCase(ctx: SimpleCaseContext): Expression = withOrigin(ctx) { - val e = expression(ctx.valueExpression) + val e = expression(ctx.value) val branches = ctx.whenClause.asScala.map { wCtx => (EqualTo(e, expression(wCtx.condition)), expression(wCtx.result)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index f319215f05681..3718ac5f1e77b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -292,6 +292,10 @@ class ExpressionParserSuite extends PlanTest { test("case when") { assertEqual("case a when 1 then b when 2 then c else d end", CaseKeyWhen('a, Seq(1, 'b, 2, 'c, 'd))) + assertEqual("case (a or b) when true then c when false then d else e end", + CaseKeyWhen('a || 'b, Seq(true, 'c, false, 'd, 'e))) + assertEqual("case 'a'='a' when true then 1 end", + CaseKeyWhen("a" === "a", Seq(true, 1))) assertEqual("case when a = 1 then b when a = 2 then c else d end", CaseWhen(Seq(('a === 1, 'b.expr), ('a === 2, 'c.expr)), 'd)) } From c571cfb2d0e1e224107fc3f0c672730cae9804cb Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 3 Oct 2016 21:28:16 -0700 Subject: [PATCH 73/96] [SPARK-17112][SQL] "select null" via JDBC triggers IllegalArgumentException in Thriftserver ## What changes were proposed in this pull request? Currently, Spark Thrift Server raises `IllegalArgumentException` for queries whose column types are `NullType`, e.g., `SELECT null` or `SELECT if(true,null,null)`. This PR fixes that by returning `void` like Hive 1.2. **Before** ```sql $ bin/beeline -u jdbc:hive2://localhost:10000 -e "select null" Connecting to jdbc:hive2://localhost:10000 Connected to: Spark SQL (version 2.1.0-SNAPSHOT) Driver: Hive JDBC (version 1.2.1.spark2) Transaction isolation: TRANSACTION_REPEATABLE_READ Error: java.lang.IllegalArgumentException: Unrecognized type name: null (state=,code=0) Closing: 0: jdbc:hive2://localhost:10000 $ bin/beeline -u jdbc:hive2://localhost:10000 -e "select if(true,null,null)" Connecting to jdbc:hive2://localhost:10000 Connected to: Spark SQL (version 2.1.0-SNAPSHOT) Driver: Hive JDBC (version 1.2.1.spark2) Transaction isolation: TRANSACTION_REPEATABLE_READ Error: java.lang.IllegalArgumentException: Unrecognized type name: null (state=,code=0) Closing: 0: jdbc:hive2://localhost:10000 ``` **After** ```sql $ bin/beeline -u jdbc:hive2://localhost:10000 -e "select null" Connecting to jdbc:hive2://localhost:10000 Connected to: Spark SQL (version 2.1.0-SNAPSHOT) Driver: Hive JDBC (version 1.2.1.spark2) Transaction isolation: TRANSACTION_REPEATABLE_READ +-------+--+ | NULL | +-------+--+ | NULL | +-------+--+ 1 row selected (3.242 seconds) Beeline version 1.2.1.spark2 by Apache Hive Closing: 0: jdbc:hive2://localhost:10000 $ bin/beeline -u jdbc:hive2://localhost:10000 -e "select if(true,null,null)" Connecting to jdbc:hive2://localhost:10000 Connected to: Spark SQL (version 2.1.0-SNAPSHOT) Driver: Hive JDBC (version 1.2.1.spark2) Transaction isolation: TRANSACTION_REPEATABLE_READ +-------------------------+--+ | (IF(true, NULL, NULL)) | +-------------------------+--+ | NULL | +-------------------------+--+ 1 row selected (0.201 seconds) Beeline version 1.2.1.spark2 by Apache Hive Closing: 0: jdbc:hive2://localhost:10000 ``` ## How was this patch tested? * Pass the Jenkins test with a new testsuite. * Also, Manually, after starting Spark Thrift Server, run the following command. ```sql $ bin/beeline -u jdbc:hive2://localhost:10000 -e "select null" $ bin/beeline -u jdbc:hive2://localhost:10000 -e "select if(true,null,null)" ``` **Hive 1.2** ```sql hive> create table null_table as select null; hive> desc null_table; OK _c0 void ``` Author: Dongjoon Hyun Closes #15325 from dongjoon-hyun/SPARK-17112. --- .../SparkExecuteStatementOperation.scala | 19 +++++++---- .../SparkExecuteStatementOperationSuite.scala | 33 +++++++++++++++++++ 2 files changed, 46 insertions(+), 6 deletions(-) create mode 100644 sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index e555ebd623f72..aeabd6a15881d 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -56,14 +56,11 @@ private[hive] class SparkExecuteStatementOperation( private var statementId: String = _ private lazy val resultSchema: TableSchema = { - if (result == null || result.queryExecution.analyzed.output.size == 0) { + if (result == null || result.schema.isEmpty) { new TableSchema(Arrays.asList(new FieldSchema("Result", "string", ""))) } else { - logInfo(s"Result Schema: ${result.queryExecution.analyzed.output}") - val schema = result.queryExecution.analyzed.output.map { attr => - new FieldSchema(attr.name, attr.dataType.catalogString, "") - } - new TableSchema(schema.asJava) + logInfo(s"Result Schema: ${result.schema}") + SparkExecuteStatementOperation.getTableSchema(result.schema) } } @@ -282,3 +279,13 @@ private[hive] class SparkExecuteStatementOperation( } } } + +object SparkExecuteStatementOperation { + def getTableSchema(structType: StructType): TableSchema = { + val schema = structType.map { field => + val attrTypeString = if (field.dataType == NullType) "void" else field.dataType.catalogString + new FieldSchema(field.name, attrTypeString, "") + } + new TableSchema(schema.asJava) + } +} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala new file mode 100644 index 0000000000000..32ded0d254ef8 --- /dev/null +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types.{NullType, StructField, StructType} + +class SparkExecuteStatementOperationSuite extends SparkFunSuite { + test("SPARK-17112 `select null` via JDBC triggers IllegalArgumentException in ThriftServer") { + val field1 = StructField("NULL", NullType) + val field2 = StructField("(IF(true, NULL, NULL))", NullType) + val tableSchema = StructType(Seq(field1, field2)) + val columns = SparkExecuteStatementOperation.getTableSchema(tableSchema).getColumnDescriptors() + assert(columns.size() == 2) + assert(columns.get(0).getType() == org.apache.hive.service.cli.Type.NULL_TYPE) + assert(columns.get(1).getType() == org.apache.hive.service.cli.Type.NULL_TYPE) + } +} From b1b47274bfeba17a9e4e9acebd7385289f31f6c8 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 3 Oct 2016 21:48:58 -0700 Subject: [PATCH 74/96] [SPARK-17702][SQL] Code generation including too many mutable states exceeds JVM size limit. ## What changes were proposed in this pull request? Code generation including too many mutable states exceeds JVM size limit to extract values from `references` into fields in the constructor. We should split the generated extractions in the constructor into smaller functions. ## How was this patch tested? I added some tests to check if the generated codes for the expressions exceed or not. Author: Takuya UESHIN Closes #15275 from ueshin/issues/SPARK-17702. --- .../expressions/codegen/CodeGenerator.scala | 18 +++++++++++----- .../codegen/GenerateMutableProjection.scala | 3 ++- .../codegen/GenerateOrdering.scala | 3 ++- .../codegen/GeneratePredicate.scala | 4 +++- .../codegen/GenerateSafeProjection.scala | 4 +++- .../codegen/GenerateUnsafeProjection.scala | 3 ++- .../expressions/CodeGenerationSuite.scala | 21 ++++++++++++++++++- .../sql/execution/WholeStageCodegenExec.scala | 4 +++- 8 files changed, 48 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index cb808e375a35f..574943d3d21f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -178,7 +178,10 @@ class CodegenContext { def initMutableStates(): String = { // It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in // `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones. - mutableStates.distinct.map(_._3).mkString("\n") + val initCodes = mutableStates.distinct.map(_._3 + "\n") + // The generated initialization code may exceed 64kb function size limit in JVM if there are too + // many mutable states, so split it into multiple functions. + splitExpressions(initCodes, "init", Nil) } /** @@ -604,6 +607,11 @@ class CodegenContext { // Cannot split these expressions because they are not created from a row object. return expressions.mkString("\n") } + splitExpressions(expressions, "apply", ("InternalRow", row) :: Nil) + } + + private def splitExpressions( + expressions: Seq[String], funcName: String, arguments: Seq[(String, String)]): String = { val blocks = new ArrayBuffer[String]() val blockBuilder = new StringBuilder() for (code <- expressions) { @@ -623,11 +631,11 @@ class CodegenContext { // inline execution if only one block blocks.head } else { - val apply = freshName("apply") + val func = freshName(funcName) val functions = blocks.zipWithIndex.map { case (body, i) => - val name = s"${apply}_$i" + val name = s"${func}_$i" val code = s""" - |private void $name(InternalRow $row) { + |private void $name(${arguments.map { case (t, name) => s"$t $name" }.mkString(", ")}) { | $body |} """.stripMargin @@ -635,7 +643,7 @@ class CodegenContext { name } - functions.map(name => s"$name($row);").mkString("\n") + functions.map(name => s"$name(${arguments.map(_._2).mkString(", ")});").mkString("\n") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 0f82d2e613c73..13d61af1c9b40 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -104,7 +104,6 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP private Object[] references; private MutableRow mutableRow; ${ctx.declareMutableStates()} - ${ctx.declareAddedFunctions()} public SpecificMutableProjection(Object[] references) { this.references = references; @@ -112,6 +111,8 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP ${ctx.initMutableStates()} } + ${ctx.declareAddedFunctions()} + public ${classOf[BaseMutableProjection].getName} target(MutableRow row) { mutableRow = row; return this; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index f1c30ef6c7fb8..1cef95654a17b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -133,13 +133,14 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR private Object[] references; ${ctx.declareMutableStates()} - ${ctx.declareAddedFunctions()} public SpecificOrdering(Object[] references) { this.references = references; ${ctx.initMutableStates()} } + ${ctx.declareAddedFunctions()} + public int compare(InternalRow a, InternalRow b) { InternalRow ${ctx.INPUT_ROW} = null; // Holds current row being evaluated. $comparisons diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index 106bb27964cab..39aa7b17de6c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -40,6 +40,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool protected def create(predicate: Expression): ((InternalRow) => Boolean) = { val ctx = newCodeGenContext() val eval = predicate.genCode(ctx) + val codeBody = s""" public SpecificPredicate generate(Object[] references) { return new SpecificPredicate(references); @@ -48,13 +49,14 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool class SpecificPredicate extends ${classOf[Predicate].getName} { private final Object[] references; ${ctx.declareMutableStates()} - ${ctx.declareAddedFunctions()} public SpecificPredicate(Object[] references) { this.references = references; ${ctx.initMutableStates()} } + ${ctx.declareAddedFunctions()} + public boolean eval(InternalRow ${ctx.INPUT_ROW}) { ${eval.code} return !${eval.isNull} && ${eval.value}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index b891f94673752..1c98c9ed10705 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -155,6 +155,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] """ } val allExpressions = ctx.splitExpressions(ctx.INPUT_ROW, expressionCodes) + val codeBody = s""" public java.lang.Object generate(Object[] references) { return new SpecificSafeProjection(references); @@ -165,7 +166,6 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] private Object[] references; private MutableRow mutableRow; ${ctx.declareMutableStates()} - ${ctx.declareAddedFunctions()} public SpecificSafeProjection(Object[] references) { this.references = references; @@ -173,6 +173,8 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] ${ctx.initMutableStates()} } + ${ctx.declareAddedFunctions()} + public java.lang.Object apply(java.lang.Object _i) { InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i; $allExpressions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 75bb6936b49e0..7cc45372daa5a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -374,13 +374,14 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro private Object[] references; ${ctx.declareMutableStates()} - ${ctx.declareAddedFunctions()} public SpecificUnsafeProjection(Object[] references) { this.references = references; ${ctx.initMutableStates()} } + ${ctx.declareAddedFunctions()} + // Scala.Function1 need this public java.lang.Object apply(java.lang.Object row) { return apply((InternalRow) row); diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 45dcfcaf23132..5588b4429164c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.sql.Timestamp + import org.apache.spark.SparkFunSuite import org.apache.spark.metrics.source.CodegenMetrics import org.apache.spark.sql.Row @@ -24,7 +26,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.objects.{CreateExternalRow, GetExternalRowField, ValidateExternalType} -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.ThreadUtils @@ -164,6 +166,23 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("SPARK-17702: split wide constructor into blocks due to JVM code size limit") { + val length = 5000 + val expressions = Seq.fill(length) { + ToUTCTimestamp( + Literal.create(Timestamp.valueOf("2015-07-24 00:00:00"), TimestampType), + Literal.create("PST", StringType)) + } + val plan = GenerateMutableProjection.generate(expressions) + val actual = plan(new GenericMutableRow(length)).toSeq(expressions.map(_.dataType)) + val expected = Seq.fill(length)( + DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-07-24 07:00:00"))) + + if (!checkResult(actual, expected)) { + fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") + } + } + test("test generated safe and unsafe projection") { val schema = new StructType(Array( StructField("a", StringType, true), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index fb57ed7692de4..62bf6f4a81eec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -316,14 +316,16 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { private Object[] references; + private scala.collection.Iterator[] inputs; ${ctx.declareMutableStates()} public GeneratedIterator(Object[] references) { this.references = references; } - public void init(int index, scala.collection.Iterator inputs[]) { + public void init(int index, scala.collection.Iterator[] inputs) { partitionIndex = index; + this.inputs = inputs; ${ctx.initMutableStates()} } From d2dc8c4a162834818190ffd82894522c524ca3e5 Mon Sep 17 00:00:00 2001 From: Ergin Seyfe Date: Mon, 3 Oct 2016 23:28:39 -0700 Subject: [PATCH 75/96] [SPARK-17773] Input/Output] Add VoidObjectInspector ## What changes were proposed in this pull request? Added VoidObjectInspector to the list of PrimitiveObjectInspectors ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) Executing following query was failing. select SOME_UDAF*(a.arr) from ( select Array(null) as arr from dim_one_row ) a After the fix, I am getting the correct output: res0: Array[org.apache.spark.sql.Row] = Array([null]) Author: Ergin Seyfe Closes #15337 from seyfe/add_void_object_inspector. --- .../main/scala/org/apache/spark/sql/hive/HiveInspectors.scala | 2 ++ .../scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala | 1 + 2 files changed, 3 insertions(+) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index c3c4351cf58a9..fe34caa0a3e48 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -319,6 +319,8 @@ private[hive] trait HiveInspectors { withNullSafe(o => getTimestampWritable(o)) case _: TimestampObjectInspector => withNullSafe(o => DateTimeUtils.toJavaTimestamp(o.asInstanceOf[Long])) + case _: VoidObjectInspector => + (_: Any) => null // always be null for void object inspector } case soi: StandardStructObjectInspector => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index bc51bcb07ec2a..3de1f4aeb74dc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -81,6 +81,7 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { val data = Literal(true) :: + Literal(null) :: Literal(0.asInstanceOf[Byte]) :: Literal(0.asInstanceOf[Short]) :: Literal(0) :: From 126baa8d32bc0e7bf8b43f9efa84f2728f02347d Mon Sep 17 00:00:00 2001 From: ding Date: Tue, 4 Oct 2016 00:00:10 -0700 Subject: [PATCH 76/96] [SPARK-17559][MLLIB] persist edges if their storage level is non in PeriodicGraphCheckpointer ## What changes were proposed in this pull request? When use PeriodicGraphCheckpointer to persist graph, sometimes the edges isn't persisted. As currently only when vertices's storage level is none, graph is persisted. However there is a chance vertices's storage level is not none while edges's is none. Eg. graph created by a outerJoinVertices operation, vertices is automatically cached while edges is not. In this way, edges will not be persisted if we use PeriodicGraphCheckpointer do persist. We need separately check edges's storage level and persisted it if it's none. ## How was this patch tested? manual tests Author: ding Closes #15124 from dding3/spark-persisitEdge. --- .../apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala index 20db6084d0e0d..80074897567eb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala @@ -87,7 +87,10 @@ private[mllib] class PeriodicGraphCheckpointer[VD, ED]( override protected def persist(data: Graph[VD, ED]): Unit = { if (data.vertices.getStorageLevel == StorageLevel.NONE) { - data.persist() + data.vertices.persist() + } + if (data.edges.getStorageLevel == StorageLevel.NONE) { + data.edges.persist() } } From 8e8de0073d71bb00baeb24c612d7841b6274f652 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 4 Oct 2016 10:29:22 +0100 Subject: [PATCH 77/96] [SPARK-17671][WEBUI] Spark 2.0 history server summary page is slow even set spark.history.ui.maxApplications ## What changes were proposed in this pull request? Return Iterator of applications internally in history server, for consistency and performance. See https://github.com/apache/spark/pull/15248 for some back-story. The code called by and calling HistoryServer.getApplicationList wants an Iterator, but this method materializes an Iterable, which potentially causes a performance problem. It's simpler too to make this internal method also pass through an Iterator. ## How was this patch tested? Existing tests. Author: Sean Owen Closes #15321 from srowen/SPARK-17671. --- .../history/ApplicationHistoryProvider.scala | 2 +- .../deploy/history/FsHistoryProvider.scala | 2 +- .../spark/deploy/history/HistoryPage.scala | 5 +-- .../spark/deploy/history/HistoryServer.scala | 4 +- .../api/v1/ApplicationListResource.scala | 38 +++++++------------ .../deploy/history/HistoryServerSuite.scala | 4 +- project/MimaExcludes.scala | 2 + 7 files changed, 22 insertions(+), 35 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala index ba42b4862aa90..ad7a0972ef9d1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala @@ -77,7 +77,7 @@ private[history] abstract class ApplicationHistoryProvider { * * @return List of all know applications. */ - def getListing(): Iterable[ApplicationHistoryInfo] + def getListing(): Iterator[ApplicationHistoryInfo] /** * Returns the Spark UI for a specific application. diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index c5740e4737094..3c2d169f3270e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -222,7 +222,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } - override def getListing(): Iterable[FsApplicationHistoryInfo] = applications.values + override def getListing(): Iterator[FsApplicationHistoryInfo] = applications.values.iterator override def getApplicationInfo(appId: String): Option[FsApplicationHistoryInfo] = { applications.get(appId) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index b4f5a6114f3de..95b72224e0f94 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -29,10 +29,7 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") val requestedIncomplete = Option(request.getParameter("showIncomplete")).getOrElse("false").toBoolean - val allApps = parent.getApplicationList() - .filter(_.completed != requestedIncomplete) - val allAppsSize = allApps.size - + val allAppsSize = parent.getApplicationList().count(_.completed != requestedIncomplete) val providerConfig = parent.getProviderConfig() val content =
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 735aa43cfc994..087c69e6489dd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -174,12 +174,12 @@ class HistoryServer( * * @return List of all known applications. */ - def getApplicationList(): Iterable[ApplicationHistoryInfo] = { + def getApplicationList(): Iterator[ApplicationHistoryInfo] = { provider.getListing() } def getApplicationInfoList: Iterator[ApplicationInfo] = { - getApplicationList().iterator.map(ApplicationsListResource.appHistoryInfoToPublicAppInfo) + getApplicationList().map(ApplicationsListResource.appHistoryInfoToPublicAppInfo) } def getApplicationInfo(appId: String): Option[ApplicationInfo] = { diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala index 075b9ba37dc84..76779290d45e6 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.status.api.v1 -import java.util.{Arrays, Date, List => JList} +import java.util.{Date, List => JList} import javax.ws.rs.{DefaultValue, GET, Produces, QueryParam} import javax.ws.rs.core.MediaType @@ -32,33 +32,21 @@ private[v1] class ApplicationListResource(uiRoot: UIRoot) { @DefaultValue("3000-01-01") @QueryParam("maxDate") maxDate: SimpleDateParam, @QueryParam("limit") limit: Integer) : Iterator[ApplicationInfo] = { - val allApps = uiRoot.getApplicationInfoList - val adjStatus = { - if (status.isEmpty) { - Arrays.asList(ApplicationStatus.values(): _*) - } else { - status - } - } - val includeCompleted = adjStatus.contains(ApplicationStatus.COMPLETED) - val includeRunning = adjStatus.contains(ApplicationStatus.RUNNING) - val appList = allApps.filter { app => + + val numApps = Option(limit).map(_.toInt).getOrElse(Integer.MAX_VALUE) + val includeCompleted = status.isEmpty || status.contains(ApplicationStatus.COMPLETED) + val includeRunning = status.isEmpty || status.contains(ApplicationStatus.RUNNING) + + uiRoot.getApplicationInfoList.filter { app => val anyRunning = app.attempts.exists(!_.completed) - // if any attempt is still running, we consider the app to also still be running - val statusOk = (!anyRunning && includeCompleted) || - (anyRunning && includeRunning) + // if any attempt is still running, we consider the app to also still be running; // keep the app if *any* attempts fall in the right time window - val dateOk = app.attempts.exists { attempt => - attempt.startTime.getTime >= minDate.timestamp && - attempt.startTime.getTime <= maxDate.timestamp + ((!anyRunning && includeCompleted) || (anyRunning && includeRunning)) && + app.attempts.exists { attempt => + val start = attempt.startTime.getTime + start >= minDate.timestamp && start <= maxDate.timestamp } - statusOk && dateOk - } - if (limit != null) { - appList.take(limit) - } else { - appList - } + }.take(numApps) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index ae3f5d9c012ea..5b316b2f6b4b7 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -447,7 +447,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers assert(4 === getNumJobsRestful(), s"two jobs back-to-back not updated, server=$server\n") } val jobcount = getNumJobs("/jobs") - assert(!provider.getListing().head.completed) + assert(!provider.getListing().next.completed) listApplications(false) should contain(appId) @@ -455,7 +455,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers resetSparkContext() // check the app is now found as completed eventually(stdTimeout, stdInterval) { - assert(provider.getListing().head.completed, + assert(provider.getListing().next.completed, s"application never completed, server=$server\n") } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 7362041428b1f..163e3f2fdea40 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -37,6 +37,8 @@ object MimaExcludes { // Exclude rules for 2.1.x lazy val v21excludes = v20excludes ++ { Seq( + // [SPARK-17671] Spark 2.0 history server summary page is slow even set spark.history.ui.maxApplications + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.deploy.history.HistoryServer.getApplicationList"), // [SPARK-14743] Improve delegation token handling in secure cluster ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.getTimeFromNowToRenewal"), // [SPARK-16199][SQL] Add a method to list the referenced columns in data source Filter From 7d5160883542f3d9dcb3babda92880985398e9af Mon Sep 17 00:00:00 2001 From: sumansomasundar Date: Tue, 4 Oct 2016 10:31:56 +0100 Subject: [PATCH 78/96] [SPARK-16962][CORE][SQL] Fix misaligned record accesses for SPARC architectures ## What changes were proposed in this pull request? Made changes to record length offsets to make them uniform throughout various areas of Spark core and unsafe ## How was this patch tested? This change affects only SPARC architectures and was tested on X86 architectures as well for regression. Author: sumansomasundar Closes #14762 from sumansomasundar/master. --- .../spark/unsafe/UnsafeAlignedOffset.java | 58 +++++++++++++++++++ .../spark/unsafe/array/ByteArrayMethods.java | 31 +++++++--- .../spark/unsafe/map/BytesToBytesMap.java | 57 +++++++++--------- .../unsafe/sort/UnsafeExternalSorter.java | 19 +++--- .../unsafe/sort/UnsafeInMemorySorter.java | 14 +++-- .../CompressibleColumnBuilder.scala | 11 +++- 6 files changed, 144 insertions(+), 46 deletions(-) create mode 100644 common/unsafe/src/main/java/org/apache/spark/unsafe/UnsafeAlignedOffset.java diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/UnsafeAlignedOffset.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/UnsafeAlignedOffset.java new file mode 100644 index 0000000000000..be62e40412f83 --- /dev/null +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/UnsafeAlignedOffset.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe; + +/** + * Class to make changes to record length offsets uniform through out + * various areas of Apache Spark core and unsafe. The SPARC platform + * requires this because using a 4 byte Int for record lengths causes + * the entire record of 8 byte Items to become misaligned by 4 bytes. + * Using a 8 byte long for record length keeps things 8 byte aligned. + */ +public class UnsafeAlignedOffset { + + private static final int UAO_SIZE = Platform.unaligned() ? 4 : 8; + + public static int getUaoSize() { + return UAO_SIZE; + } + + public static int getSize(Object object, long offset) { + switch (UAO_SIZE) { + case 4: + return Platform.getInt(object, offset); + case 8: + return (int)Platform.getLong(object, offset); + default: + throw new AssertionError("Illegal UAO_SIZE"); + } + } + + public static void putSize(Object object, long offset, int value) { + switch (UAO_SIZE) { + case 4: + Platform.putInt(object, offset, value); + break; + case 8: + Platform.putLong(object, offset, value); + break; + default: + throw new AssertionError("Illegal UAO_SIZE"); + } + } +} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java index cf42877bf9fd4..9c551ab19e9aa 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java @@ -40,6 +40,7 @@ public static int roundNumberOfBytesToNearestWord(int numBytes) { } } + private static final boolean unaligned = Platform.unaligned(); /** * Optimized byte array equality check for byte arrays. * @return true if the arrays are equal, false otherwise @@ -47,17 +48,33 @@ public static int roundNumberOfBytesToNearestWord(int numBytes) { public static boolean arrayEquals( Object leftBase, long leftOffset, Object rightBase, long rightOffset, final long length) { int i = 0; - while (i <= length - 8) { - if (Platform.getLong(leftBase, leftOffset + i) != - Platform.getLong(rightBase, rightOffset + i)) { - return false; + + // check if stars align and we can get both offsets to be aligned + if ((leftOffset % 8) == (rightOffset % 8)) { + while ((leftOffset + i) % 8 != 0 && i < length) { + if (Platform.getByte(leftBase, leftOffset + i) != + Platform.getByte(rightBase, rightOffset + i)) { + return false; + } + i += 1; + } + } + // for architectures that suport unaligned accesses, chew it up 8 bytes at a time + if (unaligned || (((leftOffset + i) % 8 == 0) && ((rightOffset + i) % 8 == 0))) { + while (i <= length - 8) { + if (Platform.getLong(leftBase, leftOffset + i) != + Platform.getLong(rightBase, rightOffset + i)) { + return false; + } + i += 8; } - i += 8; } + // this will finish off the unaligned comparisons, or do the entire aligned + // comparison whichever is needed. while (i < length) { if (Platform.getByte(leftBase, leftOffset + i) != - Platform.getByte(rightBase, rightOffset + i)) { - return false; + Platform.getByte(rightBase, rightOffset + i)) { + return false; } i += 1; } diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index e4289818f1e75..d2fcdea4f2cee 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -35,6 +35,7 @@ import org.apache.spark.serializer.SerializerManager; import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.UnsafeAlignedOffset; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.hash.Murmur3_x86_32; @@ -273,8 +274,8 @@ private void advanceToNextPage() { currentPage = dataPages.get(nextIdx); pageBaseObject = currentPage.getBaseObject(); offsetInPage = currentPage.getBaseOffset(); - recordsInPage = Platform.getInt(pageBaseObject, offsetInPage); - offsetInPage += 4; + recordsInPage = UnsafeAlignedOffset.getSize(pageBaseObject, offsetInPage); + offsetInPage += UnsafeAlignedOffset.getUaoSize(); } else { currentPage = null; if (reader != null) { @@ -321,10 +322,10 @@ public Location next() { } numRecords--; if (currentPage != null) { - int totalLength = Platform.getInt(pageBaseObject, offsetInPage); + int totalLength = UnsafeAlignedOffset.getSize(pageBaseObject, offsetInPage); loc.with(currentPage, offsetInPage); // [total size] [key size] [key] [value] [pointer to next] - offsetInPage += 4 + totalLength + 8; + offsetInPage += UnsafeAlignedOffset.getUaoSize() + totalLength + 8; recordsInPage --; return loc; } else { @@ -367,14 +368,15 @@ public long spill(long numBytes) throws IOException { Object base = block.getBaseObject(); long offset = block.getBaseOffset(); - int numRecords = Platform.getInt(base, offset); - offset += 4; + int numRecords = UnsafeAlignedOffset.getSize(base, offset); + int uaoSize = UnsafeAlignedOffset.getUaoSize(); + offset += uaoSize; final UnsafeSorterSpillWriter writer = new UnsafeSorterSpillWriter(blockManager, 32 * 1024, writeMetrics, numRecords); while (numRecords > 0) { - int length = Platform.getInt(base, offset); - writer.write(base, offset + 4, length, 0); - offset += 4 + length + 8; + int length = UnsafeAlignedOffset.getSize(base, offset); + writer.write(base, offset + uaoSize, length, 0); + offset += uaoSize + length + 8; numRecords--; } writer.close(); @@ -530,13 +532,14 @@ private void updateAddressesAndSizes(long fullKeyAddress) { private void updateAddressesAndSizes(final Object base, long offset) { baseObject = base; - final int totalLength = Platform.getInt(base, offset); - offset += 4; - keyLength = Platform.getInt(base, offset); - offset += 4; + final int totalLength = UnsafeAlignedOffset.getSize(base, offset); + int uaoSize = UnsafeAlignedOffset.getUaoSize(); + offset += uaoSize; + keyLength = UnsafeAlignedOffset.getSize(base, offset); + offset += uaoSize; keyOffset = offset; valueOffset = offset + keyLength; - valueLength = totalLength - keyLength - 4; + valueLength = totalLength - keyLength - uaoSize; } private Location with(int pos, int keyHashcode, boolean isDefined) { @@ -565,10 +568,11 @@ private Location with(Object base, long offset, int length) { this.isDefined = true; this.memoryPage = null; baseObject = base; - keyOffset = offset + 4; - keyLength = Platform.getInt(base, offset); - valueOffset = offset + 4 + keyLength; - valueLength = length - 4 - keyLength; + int uaoSize = UnsafeAlignedOffset.getUaoSize(); + keyOffset = offset + uaoSize; + keyLength = UnsafeAlignedOffset.getSize(base, offset); + valueOffset = offset + uaoSize + keyLength; + valueLength = length - uaoSize - keyLength; return this; } @@ -699,9 +703,10 @@ public boolean append(Object kbase, long koff, int klen, Object vbase, long voff // the key address instead of storing the absolute address of the value, the key and value // must be stored in the same memory page. // (8 byte key length) (key) (value) (8 byte pointer to next value) - final long recordLength = 8 + klen + vlen + 8; + int uaoSize = UnsafeAlignedOffset.getUaoSize(); + final long recordLength = (2 * uaoSize) + klen + vlen + 8; if (currentPage == null || currentPage.size() - pageCursor < recordLength) { - if (!acquireNewPage(recordLength + 4L)) { + if (!acquireNewPage(recordLength + uaoSize)) { return false; } } @@ -710,9 +715,9 @@ public boolean append(Object kbase, long koff, int klen, Object vbase, long voff final Object base = currentPage.getBaseObject(); long offset = currentPage.getBaseOffset() + pageCursor; final long recordOffset = offset; - Platform.putInt(base, offset, klen + vlen + 4); - Platform.putInt(base, offset + 4, klen); - offset += 8; + UnsafeAlignedOffset.putSize(base, offset, klen + vlen + uaoSize); + UnsafeAlignedOffset.putSize(base, offset + uaoSize, klen); + offset += (2 * uaoSize); Platform.copyMemory(kbase, koff, base, offset, klen); offset += klen; Platform.copyMemory(vbase, voff, base, offset, vlen); @@ -722,7 +727,7 @@ public boolean append(Object kbase, long koff, int klen, Object vbase, long voff // --- Update bookkeeping data structures ---------------------------------------------------- offset = currentPage.getBaseOffset(); - Platform.putInt(base, offset, Platform.getInt(base, offset) + 1); + UnsafeAlignedOffset.putSize(base, offset, UnsafeAlignedOffset.getSize(base, offset) + 1); pageCursor += recordLength; final long storedKeyAddress = taskMemoryManager.encodePageNumberAndOffset( currentPage, recordOffset); @@ -757,8 +762,8 @@ private boolean acquireNewPage(long required) { return false; } dataPages.add(currentPage); - Platform.putInt(currentPage.getBaseObject(), currentPage.getBaseOffset(), 0); - pageCursor = 4; + UnsafeAlignedOffset.putSize(currentPage.getBaseObject(), currentPage.getBaseOffset(), 0); + pageCursor = UnsafeAlignedOffset.getUaoSize(); return true; } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 8ca29a58f8f64..428ff72e71a43 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -34,6 +34,7 @@ import org.apache.spark.serializer.SerializerManager; import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.UnsafeAlignedOffset; import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.TaskCompletionListener; @@ -392,14 +393,15 @@ public void insertRecord( } growPointerArrayIfNecessary(); + int uaoSize = UnsafeAlignedOffset.getUaoSize(); // Need 4 bytes to store the record length. - final int required = length + 4; + final int required = length + uaoSize; acquireNewPageIfNecessary(required); final Object base = currentPage.getBaseObject(); final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor); - Platform.putInt(base, pageCursor, length); - pageCursor += 4; + UnsafeAlignedOffset.putSize(base, pageCursor, length); + pageCursor += uaoSize; Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length); pageCursor += length; inMemSorter.insertRecord(recordAddress, prefix, prefixIsNull); @@ -418,15 +420,16 @@ public void insertKVRecord(Object keyBase, long keyOffset, int keyLen, throws IOException { growPointerArrayIfNecessary(); - final int required = keyLen + valueLen + 4 + 4; + int uaoSize = UnsafeAlignedOffset.getUaoSize(); + final int required = keyLen + valueLen + (2 * uaoSize); acquireNewPageIfNecessary(required); final Object base = currentPage.getBaseObject(); final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor); - Platform.putInt(base, pageCursor, keyLen + valueLen + 4); - pageCursor += 4; - Platform.putInt(base, pageCursor, keyLen); - pageCursor += 4; + UnsafeAlignedOffset.putSize(base, pageCursor, keyLen + valueLen + uaoSize); + pageCursor += uaoSize; + UnsafeAlignedOffset.putSize(base, pageCursor, keyLen); + pageCursor += uaoSize; Platform.copyMemory(keyBase, keyOffset, base, pageCursor, keyLen); pageCursor += keyLen; Platform.copyMemory(valueBase, valueOffset, base, pageCursor, valueLen); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 8ecd20910ab73..2a71e68adafad 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -25,6 +25,7 @@ import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.UnsafeAlignedOffset; import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.collection.Sorter; @@ -56,11 +57,14 @@ private static final class SortComparator implements Comparator Date: Tue, 4 Oct 2016 06:54:48 -0700 Subject: [PATCH 79/96] [SPARK-17744][ML] Parity check between the ml and mllib test suites for NB ## What changes were proposed in this pull request? 1,parity check and add missing test suites for ml's NB 2,remove some unused imports ## How was this patch tested? manual tests in spark-shell Author: Zheng RuiFeng Closes #15312 from zhengruifeng/nb_test_parity. --- .../spark/ml/feature/LabeledPoint.scala | 2 +- .../ml/feature/QuantileDiscretizer.scala | 2 +- .../org/apache/spark/ml/python/MLSerDe.scala | 5 -- .../spark/ml/regression/GBTRegressor.scala | 2 +- .../ml/regression/LinearRegression.scala | 1 - .../ml/classification/NaiveBayesSuite.scala | 69 ++++++++++++++++++- python/pyspark/ml/classification.py | 1 - 7 files changed, 70 insertions(+), 12 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala index 6cefa7086c881..7d8e4adcc2259 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala @@ -25,7 +25,7 @@ import org.apache.spark.ml.linalg.Vector /** * :: Experimental :: * - * Class that represents the features and labels of a data point. + * Class that represents the features and label of a data point. * * @param label Label for this data point. * @param features List of features for this data point. diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 1e59d71a70955..05e034d90f6a3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -25,7 +25,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ import org.apache.spark.sql.Dataset -import org.apache.spark.sql.types.{DoubleType, StructType} +import org.apache.spark.sql.types.StructType /** * Params for [[QuantileDiscretizer]]. diff --git a/mllib/src/main/scala/org/apache/spark/ml/python/MLSerDe.scala b/mllib/src/main/scala/org/apache/spark/ml/python/MLSerDe.scala index 4b805e145482a..da62f8518e363 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/python/MLSerDe.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/python/MLSerDe.scala @@ -19,17 +19,12 @@ package org.apache.spark.ml.python import java.io.OutputStream import java.nio.{ByteBuffer, ByteOrder} -import java.util.{ArrayList => JArrayList} - -import scala.collection.JavaConverters._ import net.razorvine.pickle._ -import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.python.SerDeUtil import org.apache.spark.ml.linalg._ import org.apache.spark.mllib.api.python.SerDeBase -import org.apache.spark.rdd.RDD /** * SerDe utility functions for pyspark.ml. diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index ce355938ec1c7..bb01f9d5a364c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -21,7 +21,7 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.json4s.{DefaultFormats, JObject} import org.json4s.JsonDSL._ -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.feature.LabeledPoint diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 7fddfd9b10f84..536c58f998080 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -37,7 +37,6 @@ import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.evaluation.RegressionMetrics -import org.apache.spark.mllib.linalg.{Vectors => OldVectors} import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 597428d036c7a..e934e5ea42b16 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -22,10 +22,10 @@ import scala.util.Random import breeze.linalg.{DenseVector => BDV, Vector => BV} import breeze.stats.distributions.{Multinomial => BrzMultinomial} -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.classification.NaiveBayes.{Bernoulli, Multinomial} import org.apache.spark.ml.classification.NaiveBayesSuite._ -import org.apache.spark.ml.feature.{Instance, LabeledPoint} +import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} @@ -106,6 +106,11 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa } } + test("model types") { + assert(Multinomial === "multinomial") + assert(Bernoulli === "bernoulli") + } + test("params") { ParamsSuite.checkParams(new NaiveBayes) val model = new NaiveBayesModel("nb", pi = Vectors.dense(Array(0.2, 0.8)), @@ -228,6 +233,66 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa validateProbabilities(featureAndProbabilities, model, "bernoulli") } + test("detect negative values") { + val dense = spark.createDataFrame(Seq( + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(-1.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(1.0, Vectors.dense(0.0)))) + intercept[SparkException] { + new NaiveBayes().fit(dense) + } + val sparse = spark.createDataFrame(Seq( + LabeledPoint(1.0, Vectors.sparse(1, Array(0), Array(1.0))), + LabeledPoint(0.0, Vectors.sparse(1, Array(0), Array(-1.0))), + LabeledPoint(1.0, Vectors.sparse(1, Array(0), Array(1.0))), + LabeledPoint(1.0, Vectors.sparse(1, Array.empty, Array.empty)))) + intercept[SparkException] { + new NaiveBayes().fit(sparse) + } + val nan = spark.createDataFrame(Seq( + LabeledPoint(1.0, Vectors.sparse(1, Array(0), Array(1.0))), + LabeledPoint(0.0, Vectors.sparse(1, Array(0), Array(Double.NaN))), + LabeledPoint(1.0, Vectors.sparse(1, Array(0), Array(1.0))), + LabeledPoint(1.0, Vectors.sparse(1, Array.empty, Array.empty)))) + intercept[SparkException] { + new NaiveBayes().fit(nan) + } + } + + test("detect non zero or one values in Bernoulli") { + val badTrain = spark.createDataFrame(Seq( + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(2.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(1.0, Vectors.dense(0.0)))) + + intercept[SparkException] { + new NaiveBayes().setModelType(Bernoulli).setSmoothing(1.0).fit(badTrain) + } + + val okTrain = spark.createDataFrame(Seq( + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(1.0, Vectors.dense(1.0)))) + + val model = new NaiveBayes().setModelType(Bernoulli).setSmoothing(1.0).fit(okTrain) + + val badPredict = spark.createDataFrame(Seq( + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(1.0, Vectors.dense(2.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(1.0, Vectors.dense(0.0)))) + + intercept[SparkException] { + model.transform(badPredict).collect() + } + } + test("read/write") { def checkModelData(model: NaiveBayesModel, model2: NaiveBayesModel): Unit = { assert(model.pi === model2.pi) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 505e7bffd1763..ea60fab029582 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -16,7 +16,6 @@ # import operator -import warnings from pyspark import since, keyword_only from pyspark.ml import Estimator, Model From 068c198e956346b90968a4d74edb7bc820c4be28 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Tue, 4 Oct 2016 09:22:26 -0700 Subject: [PATCH 80/96] [SPARKR][DOC] minor formatting and output cleanup for R vignettes ## What changes were proposed in this pull request? Clean up output, format table, truncate long example output, hide warnings (new - Left; existing - Right) ![image](https://cloud.githubusercontent.com/assets/8969467/19064018/5dcde4d0-89bc-11e6-857b-052df3f52a4e.png) ![image](https://cloud.githubusercontent.com/assets/8969467/19064034/6db09956-89bc-11e6-8e43-232d5c3fe5e6.png) ![image](https://cloud.githubusercontent.com/assets/8969467/19064058/88f09590-89bc-11e6-9993-61639e29dfdd.png) ![image](https://cloud.githubusercontent.com/assets/8969467/19064066/95ccbf64-89bc-11e6-877f-45af03ddcadc.png) ![image](https://cloud.githubusercontent.com/assets/8969467/19064082/a8445404-89bc-11e6-8532-26d8bc9b206f.png) ## How was this patch tested? Run create-doc.sh manually Author: Felix Cheung Closes #15340 from felixcheung/vignettes. --- R/pkg/vignettes/sparkr-vignettes.Rmd | 31 ++++++++++++++++++---------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index aea52db8b8556..80e876027bddb 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -26,7 +26,7 @@ library(SparkR) We use default settings in which it runs in local mode. It auto downloads Spark package in the background if no previous installation is found. For more details about setup, see [Spark Session](#SetupSparkSession). -```{r, message=FALSE} +```{r, message=FALSE, results="hide"} sparkR.session() ``` @@ -114,10 +114,12 @@ In particular, the following Spark driver properties can be set in `sparkConfig` Property Name | Property group | spark-submit equivalent ---------------- | ------------------ | ---------------------- -spark.driver.memory | Application Properties | --driver-memory -spark.driver.extraClassPath | Runtime Environment | --driver-class-path -spark.driver.extraJavaOptions | Runtime Environment | --driver-java-options -spark.driver.extraLibraryPath | Runtime Environment | --driver-library-path +`spark.driver.memory` | Application Properties | `--driver-memory` +`spark.driver.extraClassPath` | Runtime Environment | `--driver-class-path` +`spark.driver.extraJavaOptions` | Runtime Environment | `--driver-java-options` +`spark.driver.extraLibraryPath` | Runtime Environment | `--driver-library-path` +`spark.yarn.keytab` | Application Properties | `--keytab` +`spark.yarn.principal` | Application Properties | `--principal` **For Windows users**: Due to different file prefixes across operating systems, to avoid the issue of potential wrong prefix, a current workaround is to specify `spark.sql.warehouse.dir` when starting the `SparkSession`. @@ -161,7 +163,7 @@ head(df) ### Data Sources SparkR supports operating on a variety of data sources through the `SparkDataFrame` interface. You can check the Spark SQL programming guide for more [specific options](https://spark.apache.org/docs/latest/sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources. -The general method for creating `SparkDataFrame` from data sources is `read.df`. This method takes in the path for the file to load and the type of data source, and the currently active Spark Session will be used automatically. SparkR supports reading CSV, JSON and Parquet files natively and through Spark Packages you can find data source connectors for popular file formats like Avro. These packages can be added with `sparkPackages` parameter when initializing SparkSession using `sparkR.session'.` +The general method for creating `SparkDataFrame` from data sources is `read.df`. This method takes in the path for the file to load and the type of data source, and the currently active Spark Session will be used automatically. SparkR supports reading CSV, JSON and Parquet files natively and through Spark Packages you can find data source connectors for popular file formats like Avro. These packages can be added with `sparkPackages` parameter when initializing SparkSession using `sparkR.session`. ```{r, eval=FALSE} sparkR.session(sparkPackages = "com.databricks:spark-avro_2.11:3.0.0") @@ -406,10 +408,17 @@ class(model.summaries) ``` -To avoid lengthy display, we only present the result of the second fitted model. You are free to inspect other models as well. +To avoid lengthy display, we only present the partial result of the second fitted model. You are free to inspect other models as well. +```{r, include=FALSE} +ops <- options() +options(max.print=40) +``` ```{r} print(model.summaries[[2]]) ``` +```{r, include=FALSE} +options(ops) +``` ### SQL Queries @@ -544,7 +553,7 @@ head(select(kmeansPredictions, "model", "mpg", "hp", "wt", "prediction"), n = 20 Survival analysis studies the expected duration of time until an event happens, and often the relationship with risk factors or treatment taken on the subject. In contrast to standard regression analysis, survival modeling has to deal with special characteristics in the data including non-negative survival time and censoring. Accelerated Failure Time (AFT) model is a parametric survival model for censored data that assumes the effect of a covariate is to accelerate or decelerate the life course of an event by some constant. For more information, refer to the Wikipedia page [AFT Model](https://en.wikipedia.org/wiki/Accelerated_failure_time_model) and the references there. Different from a [Proportional Hazards Model](https://en.wikipedia.org/wiki/Proportional_hazards_model) designed for the same purpose, the AFT model is easier to parallelize because each instance contributes to the objective function independently. -```{r} +```{r, warning=FALSE} library(survival) ovarianDF <- createDataFrame(ovarian) aftModel <- spark.survreg(ovarianDF, Surv(futime, fustat) ~ ecog_ps + rx) @@ -678,7 +687,7 @@ MLPC employs backpropagation for learning the model. We use the logistic loss fu * `tol`: convergence tolerance of iterations. -* `stepSize`: step size for `"gd"`. +* `stepSize`: step size for `"gd"`. * `seed`: seed parameter for weights initialization. @@ -763,8 +772,8 @@ We also expect Decision Tree, Random Forest, Kolmogorov-Smirnov Test coming in t ### Model Persistence The following example shows how to save/load an ML model by SparkR. -```{r} -irisDF <- suppressWarnings(createDataFrame(iris)) +```{r, warning=FALSE} +irisDF <- createDataFrame(iris) gaussianGLM <- spark.glm(irisDF, Sepal_Length ~ Sepal_Width + Species, family = "gaussian") # Save and then load a fitted MLlib model From 8d969a2125d915da1506c17833aa98da614a257f Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 4 Oct 2016 09:38:44 -0700 Subject: [PATCH 81/96] [SPARK-17549][SQL] Only collect table size stat in driver for cached relation. This reverts commit 9ac68dbc5720026ea92acc61d295ca64d0d3d132. Turns out the original fix was correct. Original change description: The existing code caches all stats for all columns for each partition in the driver; for a large relation, this causes extreme memory usage, which leads to gc hell and application failures. It seems that only the size in bytes of the data is actually used in the driver, so instead just colllect that. In executors, the full stats are still kept, but that's not a big problem; we expect the data to be distributed and thus not really incur in too much memory pressure in each individual executor. There are also potential improvements on the executor side, since the data being stored currently is very wasteful (e.g. storing boxed types vs. primitive types for stats). But that's a separate issue. Author: Marcelo Vanzin Closes #15304 from vanzin/SPARK-17549.2. --- .../execution/columnar/InMemoryRelation.scala | 24 +++++-------------- .../columnar/InMemoryColumnarQuerySuite.scala | 14 +++++++++++ 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 479934a7afc75..56bd5c1891e8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution.columnar -import scala.collection.JavaConverters._ - import org.apache.commons.lang3.StringUtils import org.apache.spark.network.util.JavaUtils @@ -31,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.CollectionAccumulator +import org.apache.spark.util.LongAccumulator object InMemoryRelation { @@ -63,8 +61,7 @@ case class InMemoryRelation( @transient child: SparkPlan, tableName: Option[String])( @transient var _cachedColumnBuffers: RDD[CachedBatch] = null, - val batchStats: CollectionAccumulator[InternalRow] = - child.sqlContext.sparkContext.collectionAccumulator[InternalRow]) + val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator) extends logical.LeafNode with MultiInstanceRelation { override protected def innerChildren: Seq[QueryPlan[_]] = Seq(child) @@ -74,21 +71,12 @@ case class InMemoryRelation( @transient val partitionStatistics = new PartitionStatistics(output) override lazy val statistics: Statistics = { - if (batchStats.value.isEmpty) { + if (batchStats.value == 0L) { // Underlying columnar RDD hasn't been materialized, no useful statistics information // available, return the default statistics. Statistics(sizeInBytes = child.sqlContext.conf.defaultSizeInBytes) } else { - // Underlying columnar RDD has been materialized, required information has also been - // collected via the `batchStats` accumulator. - val sizeOfRow: Expression = - BindReferences.bindReference( - output.map(a => partitionStatistics.forAttribute(a).sizeInBytes).reduce(Add), - partitionStatistics.schema) - - val sizeInBytes = - batchStats.value.asScala.map(row => sizeOfRow.eval(row).asInstanceOf[Long]).sum - Statistics(sizeInBytes = sizeInBytes) + Statistics(sizeInBytes = batchStats.value.longValue) } } @@ -139,10 +127,10 @@ case class InMemoryRelation( rowCount += 1 } + batchStats.add(totalSize) + val stats = InternalRow.fromSeq(columnBuilders.map(_.columnStats.collectedStatistics) .flatMap(_.values)) - - batchStats.add(stats) CachedBatch(rowCount, columnBuilders.map { builder => JavaUtils.bufferToArray(builder.build()) }, stats) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 937839644ad5f..0daa29b666f62 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -232,4 +232,18 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { val columnTypes2 = List.fill(length2)(IntegerType) val columnarIterator2 = GenerateColumnAccessor.generate(columnTypes2) } + + test("SPARK-17549: cached table size should be correctly calculated") { + val data = spark.sparkContext.parallelize(1 to 10, 5).toDF() + val plan = spark.sessionState.executePlan(data.logicalPlan).sparkPlan + val cached = InMemoryRelation(true, 5, MEMORY_ONLY, plan, None) + + // Materialize the data. + val expectedAnswer = data.collect() + checkAnswer(cached, expectedAnswer) + + // Check that the right size was calculated. + assert(cached.batchStats.value === expectedAnswer.size * INT.defaultSize) + } + } From a99743d053e84f695dc3034550939555297b0a05 Mon Sep 17 00:00:00 2001 From: Tejas Patil Date: Tue, 4 Oct 2016 18:59:31 -0700 Subject: [PATCH 82/96] [SPARK-17495][SQL] Add Hash capability semantically equivalent to Hive's ## What changes were proposed in this pull request? Jira : https://issues.apache.org/jira/browse/SPARK-17495 Spark internally uses Murmur3Hash for partitioning. This is different from the one used by Hive. For queries which use bucketing this leads to different results if one tries the same query on both engines. For us, we want users to have backward compatibility to that one can switch parts of applications across the engines without observing regressions. This PR includes `HiveHash`, `HiveHashFunction`, `HiveHasher` which mimics Hive's hashing at https://github.com/apache/hive/blob/master/serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/ObjectInspectorUtils.java#L638 I am intentionally not introducing any usages of this hash function in rest of the code to keep this PR small. My eventual goal is to have Hive bucketing support in Spark. Once this PR gets in, I will make hash function pluggable in relevant areas (eg. `HashPartitioning`'s `partitionIdExpression` has Murmur3 hardcoded : https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala#L265) ## How was this patch tested? Added `HiveHashSuite` Author: Tejas Patil Closes #15047 from tejasapatil/SPARK-17495_hive_hash. --- .../sql/catalyst/expressions/HiveHasher.java | 49 +++ .../spark/sql/catalyst/expressions/misc.scala | 391 +++++++++++++++--- .../catalyst/expressions/HiveHasherSuite.java | 128 ++++++ .../org/apache/spark/sql/HashBenchmark.scala | 93 +++-- .../spark/sql/HashByteArrayBenchmark.scala | 118 +++--- .../expressions/MiscFunctionsSuite.scala | 3 +- 6 files changed, 631 insertions(+), 151 deletions(-) create mode 100644 common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java create mode 100644 sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java new file mode 100644 index 0000000000000..c7ea9085eba66 --- /dev/null +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +import org.apache.spark.unsafe.Platform; + +/** + * Simulates Hive's hashing function at + * org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#hashcode() + */ +public class HiveHasher { + + @Override + public String toString() { + return HiveHasher.class.getSimpleName(); + } + + public static int hashInt(int input) { + return input; + } + + public static int hashLong(long input) { + return (int) ((input >>> 32) ^ input); + } + + public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes) { + assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; + int result = 0; + for (int i = 0; i < lengthInBytes; i++) { + result = (result * 31) + (int) Platform.getByte(base, offset + i); + } + return result; + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index dbb52a4bb18de..138ef2a1dcc01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -259,7 +259,7 @@ abstract class HashExpression[E] extends Expression { $childrenHash""") } - private def nullSafeElementHash( + protected def nullSafeElementHash( input: String, index: String, nullable: Boolean, @@ -276,76 +276,127 @@ abstract class HashExpression[E] extends Expression { } } - @tailrec - private def computeHash( + protected def genHashInt(i: String, result: String): String = + s"$result = $hasherClassName.hashInt($i, $result);" + + protected def genHashLong(l: String, result: String): String = + s"$result = $hasherClassName.hashLong($l, $result);" + + protected def genHashBytes(b: String, result: String): String = { + val offset = "Platform.BYTE_ARRAY_OFFSET" + s"$result = $hasherClassName.hashUnsafeBytes($b, $offset, $b.length, $result);" + } + + protected def genHashBoolean(input: String, result: String): String = + genHashInt(s"$input ? 1 : 0", result) + + protected def genHashFloat(input: String, result: String): String = + genHashInt(s"Float.floatToIntBits($input)", result) + + protected def genHashDouble(input: String, result: String): String = + genHashLong(s"Double.doubleToLongBits($input)", result) + + protected def genHashDecimal( + ctx: CodegenContext, + d: DecimalType, input: String, - dataType: DataType, - result: String, - ctx: CodegenContext): String = { - val hasher = hasherClassName - - def hashInt(i: String): String = s"$result = $hasher.hashInt($i, $result);" - def hashLong(l: String): String = s"$result = $hasher.hashLong($l, $result);" - def hashBytes(b: String): String = - s"$result = $hasher.hashUnsafeBytes($b, Platform.BYTE_ARRAY_OFFSET, $b.length, $result);" - - dataType match { - case NullType => "" - case BooleanType => hashInt(s"$input ? 1 : 0") - case ByteType | ShortType | IntegerType | DateType => hashInt(input) - case LongType | TimestampType => hashLong(input) - case FloatType => hashInt(s"Float.floatToIntBits($input)") - case DoubleType => hashLong(s"Double.doubleToLongBits($input)") - case d: DecimalType => - if (d.precision <= Decimal.MAX_LONG_DIGITS) { - hashLong(s"$input.toUnscaledLong()") - } else { - val bytes = ctx.freshName("bytes") - s""" + result: String): String = { + if (d.precision <= Decimal.MAX_LONG_DIGITS) { + genHashLong(s"$input.toUnscaledLong()", result) + } else { + val bytes = ctx.freshName("bytes") + s""" final byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray(); - ${hashBytes(bytes)} + ${genHashBytes(bytes, result)} """ + } + } + + protected def genHashCalendarInterval(input: String, result: String): String = { + val microsecondsHash = s"$hasherClassName.hashLong($input.microseconds, $result)" + s"$result = $hasherClassName.hashInt($input.months, $microsecondsHash);" + } + + protected def genHashString(input: String, result: String): String = { + val baseObject = s"$input.getBaseObject()" + val baseOffset = s"$input.getBaseOffset()" + val numBytes = s"$input.numBytes()" + s"$result = $hasherClassName.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $result);" + } + + protected def genHashForMap( + ctx: CodegenContext, + input: String, + result: String, + keyType: DataType, + valueType: DataType, + valueContainsNull: Boolean): String = { + val index = ctx.freshName("index") + val keys = ctx.freshName("keys") + val values = ctx.freshName("values") + s""" + final ArrayData $keys = $input.keyArray(); + final ArrayData $values = $input.valueArray(); + for (int $index = 0; $index < $input.numElements(); $index++) { + ${nullSafeElementHash(keys, index, false, keyType, result, ctx)} + ${nullSafeElementHash(values, index, valueContainsNull, valueType, result, ctx)} } - case CalendarIntervalType => - val microsecondsHash = s"$hasher.hashLong($input.microseconds, $result)" - s"$result = $hasher.hashInt($input.months, $microsecondsHash);" - case BinaryType => hashBytes(input) - case StringType => - val baseObject = s"$input.getBaseObject()" - val baseOffset = s"$input.getBaseOffset()" - val numBytes = s"$input.numBytes()" - s"$result = $hasher.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $result);" - - case ArrayType(et, containsNull) => - val index = ctx.freshName("index") - s""" - for (int $index = 0; $index < $input.numElements(); $index++) { - ${nullSafeElementHash(input, index, containsNull, et, result, ctx)} - } - """ - - case MapType(kt, vt, valueContainsNull) => - val index = ctx.freshName("index") - val keys = ctx.freshName("keys") - val values = ctx.freshName("values") - s""" - final ArrayData $keys = $input.keyArray(); - final ArrayData $values = $input.valueArray(); - for (int $index = 0; $index < $input.numElements(); $index++) { - ${nullSafeElementHash(keys, index, false, kt, result, ctx)} - ${nullSafeElementHash(values, index, valueContainsNull, vt, result, ctx)} - } - """ + """ + } + + protected def genHashForArray( + ctx: CodegenContext, + input: String, + result: String, + elementType: DataType, + containsNull: Boolean): String = { + val index = ctx.freshName("index") + s""" + for (int $index = 0; $index < $input.numElements(); $index++) { + ${nullSafeElementHash(input, index, containsNull, elementType, result, ctx)} + } + """ + } - case StructType(fields) => - fields.zipWithIndex.map { case (field, index) => - nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx) - }.mkString("\n") + protected def genHashForStruct( + ctx: CodegenContext, + input: String, + result: String, + fields: Array[StructField]): String = { + fields.zipWithIndex.map { case (field, index) => + nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx) + }.mkString("\n") + } - case udt: UserDefinedType[_] => computeHash(input, udt.sqlType, result, ctx) - } + @tailrec + private def computeHashWithTailRec( + input: String, + dataType: DataType, + result: String, + ctx: CodegenContext): String = dataType match { + case NullType => "" + case BooleanType => genHashBoolean(input, result) + case ByteType | ShortType | IntegerType | DateType => genHashInt(input, result) + case LongType | TimestampType => genHashLong(input, result) + case FloatType => genHashFloat(input, result) + case DoubleType => genHashDouble(input, result) + case d: DecimalType => genHashDecimal(ctx, d, input, result) + case CalendarIntervalType => genHashCalendarInterval(input, result) + case BinaryType => genHashBytes(input, result) + case StringType => genHashString(input, result) + case ArrayType(et, containsNull) => genHashForArray(ctx, input, result, et, containsNull) + case MapType(kt, vt, valueContainsNull) => + genHashForMap(ctx, input, result, kt, vt, valueContainsNull) + case StructType(fields) => genHashForStruct(ctx, input, result, fields) + case udt: UserDefinedType[_] => computeHashWithTailRec(input, udt.sqlType, result, ctx) } + protected def computeHash( + input: String, + dataType: DataType, + result: String, + ctx: CodegenContext): String = computeHashWithTailRec(input, dataType, result, ctx) + protected def hasherClassName: String } @@ -568,3 +619,217 @@ case class CurrentDatabase() extends LeafExpression with Unevaluable { override def foldable: Boolean = true override def nullable: Boolean = false } + +/** + * Simulates Hive's hashing function at + * org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#hashcode() in Hive + * + * We should use this hash function for both shuffle and bucket of Hive tables, so that + * we can guarantee shuffle and bucketing have same data distribution + * + * TODO: Support Decimal and date related types + */ +@ExpressionDescription( + usage = "_FUNC_(a1, a2, ...) - Returns a hash value of the arguments.") +case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { + override val seed = 0 + + override def dataType: DataType = IntegerType + + override def prettyName: String = "hive-hash" + + override protected def hasherClassName: String = classOf[HiveHasher].getName + + override protected def computeHash(value: Any, dataType: DataType, seed: Int): Int = { + HiveHashFunction.hash(value, dataType, seed).toInt + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + ev.isNull = "false" + val childHash = ctx.freshName("childHash") + val childrenHash = children.map { child => + val childGen = child.genCode(ctx) + childGen.code + ctx.nullSafeExec(child.nullable, childGen.isNull) { + computeHash(childGen.value, child.dataType, childHash, ctx) + } + s"${ev.value} = (31 * ${ev.value}) + $childHash;" + }.mkString(s"int $childHash = 0;", s"\n$childHash = 0;\n", "") + + ev.copy(code = s""" + ${ctx.javaType(dataType)} ${ev.value} = $seed; + $childrenHash""") + } + + override def eval(input: InternalRow): Int = { + var hash = seed + var i = 0 + val len = children.length + while (i < len) { + hash = (31 * hash) + computeHash(children(i).eval(input), children(i).dataType, hash) + i += 1 + } + hash + } + + override protected def genHashInt(i: String, result: String): String = + s"$result = $hasherClassName.hashInt($i);" + + override protected def genHashLong(l: String, result: String): String = + s"$result = $hasherClassName.hashLong($l);" + + override protected def genHashBytes(b: String, result: String): String = + s"$result = $hasherClassName.hashUnsafeBytes($b, Platform.BYTE_ARRAY_OFFSET, $b.length);" + + override protected def genHashCalendarInterval(input: String, result: String): String = { + s""" + $result = (31 * $hasherClassName.hashInt($input.months)) + + $hasherClassName.hashLong($input.microseconds);" + """ + } + + override protected def genHashString(input: String, result: String): String = { + val baseObject = s"$input.getBaseObject()" + val baseOffset = s"$input.getBaseOffset()" + val numBytes = s"$input.numBytes()" + s"$result = $hasherClassName.hashUnsafeBytes($baseObject, $baseOffset, $numBytes);" + } + + override protected def genHashForArray( + ctx: CodegenContext, + input: String, + result: String, + elementType: DataType, + containsNull: Boolean): String = { + val index = ctx.freshName("index") + val childResult = ctx.freshName("childResult") + s""" + int $childResult = 0; + for (int $index = 0; $index < $input.numElements(); $index++) { + $childResult = 0; + ${nullSafeElementHash(input, index, containsNull, elementType, childResult, ctx)}; + $result = (31 * $result) + $childResult; + } + """ + } + + override protected def genHashForMap( + ctx: CodegenContext, + input: String, + result: String, + keyType: DataType, + valueType: DataType, + valueContainsNull: Boolean): String = { + val index = ctx.freshName("index") + val keys = ctx.freshName("keys") + val values = ctx.freshName("values") + val keyResult = ctx.freshName("keyResult") + val valueResult = ctx.freshName("valueResult") + s""" + final ArrayData $keys = $input.keyArray(); + final ArrayData $values = $input.valueArray(); + int $keyResult = 0; + int $valueResult = 0; + for (int $index = 0; $index < $input.numElements(); $index++) { + $keyResult = 0; + ${nullSafeElementHash(keys, index, false, keyType, keyResult, ctx)} + $valueResult = 0; + ${nullSafeElementHash(values, index, valueContainsNull, valueType, valueResult, ctx)} + $result += $keyResult ^ $valueResult; + } + """ + } + + override protected def genHashForStruct( + ctx: CodegenContext, + input: String, + result: String, + fields: Array[StructField]): String = { + val localResult = ctx.freshName("localResult") + val childResult = ctx.freshName("childResult") + fields.zipWithIndex.map { case (field, index) => + s""" + $childResult = 0; + ${nullSafeElementHash(input, index.toString, field.nullable, field.dataType, + childResult, ctx)} + $localResult = (31 * $localResult) + $childResult; + """ + }.mkString( + s""" + int $localResult = 0; + int $childResult = 0; + """, + "", + s"$result = (31 * $result) + $localResult;" + ) + } +} + +object HiveHashFunction extends InterpretedHashFunction { + override protected def hashInt(i: Int, seed: Long): Long = { + HiveHasher.hashInt(i) + } + + override protected def hashLong(l: Long, seed: Long): Long = { + HiveHasher.hashLong(l) + } + + override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = { + HiveHasher.hashUnsafeBytes(base, offset, len) + } + + override def hash(value: Any, dataType: DataType, seed: Long): Long = { + value match { + case null => 0 + case array: ArrayData => + val elementType = dataType match { + case udt: UserDefinedType[_] => udt.sqlType.asInstanceOf[ArrayType].elementType + case ArrayType(et, _) => et + } + + var result = 0 + var i = 0 + val length = array.numElements() + while (i < length) { + result = (31 * result) + hash(array.get(i, elementType), elementType, 0).toInt + i += 1 + } + result + + case map: MapData => + val (kt, vt) = dataType match { + case udt: UserDefinedType[_] => + val mapType = udt.sqlType.asInstanceOf[MapType] + mapType.keyType -> mapType.valueType + case MapType(_kt, _vt, _) => _kt -> _vt + } + val keys = map.keyArray() + val values = map.valueArray() + + var result = 0 + var i = 0 + val length = map.numElements() + while (i < length) { + result += hash(keys.get(i, kt), kt, 0).toInt ^ hash(values.get(i, vt), vt, 0).toInt + i += 1 + } + result + + case struct: InternalRow => + val types: Array[DataType] = dataType match { + case udt: UserDefinedType[_] => + udt.sqlType.asInstanceOf[StructType].map(_.dataType).toArray + case StructType(fields) => fields.map(_.dataType) + } + + var result = 0 + var i = 0 + val length = struct.numFields + while (i < length) { + result = (31 * result) + hash(struct.get(i, types(i)), types(i), seed + 1).toInt + i += 1 + } + result + + case _ => super.hash(value, dataType, seed) + } + } +} diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java new file mode 100644 index 0000000000000..67a5eb0c7fe8f --- /dev/null +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.types.UTF8String; +import org.junit.Assert; +import org.junit.Test; + +import java.nio.charset.StandardCharsets; +import java.util.HashSet; +import java.util.Random; +import java.util.Set; + +public class HiveHasherSuite { + private final static HiveHasher hasher = new HiveHasher(); + + @Test + public void testKnownIntegerInputs() { + int[] inputs = {0, Integer.MIN_VALUE, Integer.MAX_VALUE, 593689054, -189366624}; + for (int input : inputs) { + Assert.assertEquals(input, HiveHasher.hashInt(input)); + } + } + + @Test + public void testKnownLongInputs() { + Assert.assertEquals(0, HiveHasher.hashLong(0L)); + Assert.assertEquals(41, HiveHasher.hashLong(-42L)); + Assert.assertEquals(42, HiveHasher.hashLong(42L)); + Assert.assertEquals(-2147483648, HiveHasher.hashLong(Long.MIN_VALUE)); + Assert.assertEquals(-2147483648, HiveHasher.hashLong(Long.MAX_VALUE)); + } + + @Test + public void testKnownStringAndIntInputs() { + int[] inputs = {84, 19, 8}; + int[] expected = {-823832826, -823835053, 111972242}; + + for (int i = 0; i < inputs.length; i++) { + UTF8String s = UTF8String.fromString("val_" + inputs[i]); + int hash = HiveHasher.hashUnsafeBytes(s.getBaseObject(), s.getBaseOffset(), s.numBytes()); + Assert.assertEquals(expected[i], ((31 * inputs[i]) + hash)); + } + } + + @Test + public void randomizedStressTest() { + int size = 65536; + Random rand = new Random(); + + // A set used to track collision rate. + Set hashcodes = new HashSet<>(); + for (int i = 0; i < size; i++) { + int vint = rand.nextInt(); + long lint = rand.nextLong(); + Assert.assertEquals(HiveHasher.hashInt(vint), HiveHasher.hashInt(vint)); + Assert.assertEquals(HiveHasher.hashLong(lint), HiveHasher.hashLong(lint)); + + hashcodes.add(HiveHasher.hashLong(lint)); + } + + // A very loose bound. + Assert.assertTrue(hashcodes.size() > size * 0.95); + } + + @Test + public void randomizedStressTestBytes() { + int size = 65536; + Random rand = new Random(); + + // A set used to track collision rate. + Set hashcodes = new HashSet<>(); + for (int i = 0; i < size; i++) { + int byteArrSize = rand.nextInt(100) * 8; + byte[] bytes = new byte[byteArrSize]; + rand.nextBytes(bytes); + + Assert.assertEquals( + HiveHasher.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), + HiveHasher.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + + hashcodes.add(HiveHasher.hashUnsafeBytes( + bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + } + + // A very loose bound. + Assert.assertTrue(hashcodes.size() > size * 0.95); + } + + @Test + public void randomizedStressTestPaddedStrings() { + int size = 64000; + // A set used to track collision rate. + Set hashcodes = new HashSet<>(); + for (int i = 0; i < size; i++) { + int byteArrSize = 8; + byte[] strBytes = String.valueOf(i).getBytes(StandardCharsets.UTF_8); + byte[] paddedBytes = new byte[byteArrSize]; + System.arraycopy(strBytes, 0, paddedBytes, 0, strBytes.length); + + Assert.assertEquals( + HiveHasher.hashUnsafeBytes(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), + HiveHasher.hashUnsafeBytes(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + + hashcodes.add(HiveHasher.hashUnsafeBytes( + paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + } + + // A very loose bound. + Assert.assertTrue(hashcodes.size() > size * 0.95); + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala index c6a1a2be0d071..2d94b66a1e122 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala @@ -42,8 +42,8 @@ object HashBenchmark { val benchmark = new Benchmark("Hash For " + name, iters * numRows) benchmark.addCase("interpreted version") { _: Int => + var sum = 0 for (_ <- 0L until iters) { - var sum = 0 var i = 0 while (i < numRows) { sum += rows(i).hashCode() @@ -54,8 +54,8 @@ object HashBenchmark { val getHashCode = UnsafeProjection.create(new Murmur3Hash(attrs) :: Nil, attrs) benchmark.addCase("codegen version") { _: Int => + var sum = 0 for (_ <- 0L until iters) { - var sum = 0 var i = 0 while (i < numRows) { sum += getHashCode(rows(i)).getInt(0) @@ -66,8 +66,8 @@ object HashBenchmark { val getHashCode64b = UnsafeProjection.create(new XxHash64(attrs) :: Nil, attrs) benchmark.addCase("codegen version 64-bit") { _: Int => + var sum = 0 for (_ <- 0L until iters) { - var sum = 0 var i = 0 while (i < numRows) { sum += getHashCode64b(rows(i)).getInt(0) @@ -76,30 +76,44 @@ object HashBenchmark { } } + val getHiveHashCode = UnsafeProjection.create(new HiveHash(attrs) :: Nil, attrs) + benchmark.addCase("codegen HiveHash version") { _: Int => + var sum = 0 + for (_ <- 0L until iters) { + var i = 0 + while (i < numRows) { + sum += getHiveHashCode(rows(i)).getInt(0) + i += 1 + } + } + } + benchmark.run() } def main(args: Array[String]): Unit = { val singleInt = new StructType().add("i", IntegerType) /* - Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz - Hash For single ints: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - interpreted version 1006 / 1011 133.4 7.5 1.0X - codegen version 1835 / 1839 73.1 13.7 0.5X - codegen version 64-bit 1627 / 1628 82.5 12.1 0.6X - */ + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Hash For single ints: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + interpreted version 3262 / 3267 164.6 6.1 1.0X + codegen version 6448 / 6718 83.3 12.0 0.5X + codegen version 64-bit 6088 / 6154 88.2 11.3 0.5X + codegen HiveHash version 4732 / 4745 113.5 8.8 0.7X + */ test("single ints", singleInt, 1 << 15, 1 << 14) val singleLong = new StructType().add("i", LongType) /* - Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz - Hash For single longs: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - interpreted version 1196 / 1209 112.2 8.9 1.0X - codegen version 2178 / 2181 61.6 16.2 0.5X - codegen version 64-bit 1752 / 1753 76.6 13.1 0.7X - */ + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Hash For single longs: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + interpreted version 3716 / 3726 144.5 6.9 1.0X + codegen version 7706 / 7732 69.7 14.4 0.5X + codegen version 64-bit 6370 / 6399 84.3 11.9 0.6X + codegen HiveHash version 4924 / 5026 109.0 9.2 0.8X + */ test("single longs", singleLong, 1 << 15, 1 << 14) val normal = new StructType() @@ -118,13 +132,14 @@ object HashBenchmark { .add("date", DateType) .add("timestamp", TimestampType) /* - Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz - Hash For normal: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - interpreted version 2713 / 2715 0.8 1293.5 1.0X - codegen version 2015 / 2018 1.0 960.9 1.3X - codegen version 64-bit 735 / 738 2.9 350.7 3.7X - */ + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Hash For normal: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + interpreted version 2985 / 3013 0.7 1423.4 1.0X + codegen version 2422 / 2434 0.9 1155.1 1.2X + codegen version 64-bit 856 / 920 2.5 408.0 3.5X + codegen HiveHash version 4501 / 4979 0.5 2146.4 0.7X + */ test("normal", normal, 1 << 10, 1 << 11) val arrayOfInt = ArrayType(IntegerType) @@ -132,13 +147,14 @@ object HashBenchmark { .add("array", arrayOfInt) .add("arrayOfArray", ArrayType(arrayOfInt)) /* - Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz - Hash For array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - interpreted version 1498 / 1499 0.1 11432.1 1.0X - codegen version 2642 / 2643 0.0 20158.4 0.6X - codegen version 64-bit 2421 / 2424 0.1 18472.5 0.6X - */ + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Hash For array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + interpreted version 3100 / 3555 0.0 23651.8 1.0X + codegen version 5779 / 5865 0.0 44088.4 0.5X + codegen version 64-bit 4738 / 4821 0.0 36151.7 0.7X + codegen HiveHash version 2200 / 2246 0.1 16785.9 1.4X + */ test("array", array, 1 << 8, 1 << 9) val mapOfInt = MapType(IntegerType, IntegerType) @@ -146,13 +162,14 @@ object HashBenchmark { .add("map", mapOfInt) .add("mapOfMap", MapType(IntegerType, mapOfInt)) /* - Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz - Hash For map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - interpreted version 1612 / 1618 0.0 393553.4 1.0X - codegen version 149 / 150 0.0 36381.2 10.8X - codegen version 64-bit 144 / 145 0.0 35122.1 11.2X - */ + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Hash For map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + interpreted version 0 / 0 48.1 20.8 1.0X + codegen version 257 / 275 0.0 62768.7 0.0X + codegen version 64-bit 226 / 240 0.0 55224.5 0.0X + codegen HiveHash version 89 / 96 0.0 21708.8 0.0X + */ test("map", map, 1 << 6, 1 << 6) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala index 53f21a8442429..2a753a0c84ed5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import java.util.Random -import org.apache.spark.sql.catalyst.expressions.XXH64 +import org.apache.spark.sql.catalyst.expressions.{HiveHasher, XXH64} import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.hash.Murmur3_x86_32 import org.apache.spark.util.Benchmark @@ -38,8 +38,8 @@ object HashByteArrayBenchmark { val benchmark = new Benchmark("Hash byte arrays with length " + length, iters * numArrays) benchmark.addCase("Murmur3_x86_32") { _: Int => + var sum = 0L for (_ <- 0L until iters) { - var sum = 0 var i = 0 while (i < numArrays) { sum += Murmur3_x86_32.hashUnsafeBytes(arrays(i), Platform.BYTE_ARRAY_OFFSET, length, 42) @@ -49,8 +49,8 @@ object HashByteArrayBenchmark { } benchmark.addCase("xxHash 64-bit") { _: Int => + var sum = 0L for (_ <- 0L until iters) { - var sum = 0L var i = 0 while (i < numArrays) { sum += XXH64.hashUnsafeBytes(arrays(i), Platform.BYTE_ARRAY_OFFSET, length, 42) @@ -59,90 +59,110 @@ object HashByteArrayBenchmark { } } + benchmark.addCase("HiveHasher") { _: Int => + var sum = 0L + for (_ <- 0L until iters) { + var i = 0 + while (i < numArrays) { + sum += HiveHasher.hashUnsafeBytes(arrays(i), Platform.BYTE_ARRAY_OFFSET, length) + i += 1 + } + } + } + benchmark.run() } def main(args: Array[String]): Unit = { /* - Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz - Hash byte arrays with length 8: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - Murmur3_x86_32 11 / 12 185.1 5.4 1.0X - xxHash 64-bit 17 / 18 120.0 8.3 0.6X + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Hash byte arrays with length 8: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Murmur3_x86_32 12 / 16 174.3 5.7 1.0X + xxHash 64-bit 17 / 22 120.0 8.3 0.7X + HiveHasher 13 / 15 162.1 6.2 0.9X */ test(8, 42L, 1 << 10, 1 << 11) /* - Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz - Hash byte arrays with length 16: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - Murmur3_x86_32 18 / 18 118.6 8.4 1.0X - xxHash 64-bit 20 / 21 102.5 9.8 0.9X + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Hash byte arrays with length 16: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Murmur3_x86_32 19 / 22 107.6 9.3 1.0X + xxHash 64-bit 20 / 24 104.6 9.6 1.0X + HiveHasher 24 / 28 87.0 11.5 0.8X */ test(16, 42L, 1 << 10, 1 << 11) /* - Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz - Hash byte arrays with length 24: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - Murmur3_x86_32 24 / 24 86.6 11.5 1.0X - xxHash 64-bit 23 / 23 93.2 10.7 1.1X + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Hash byte arrays with length 24: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Murmur3_x86_32 28 / 32 74.8 13.4 1.0X + xxHash 64-bit 24 / 29 87.3 11.5 1.2X + HiveHasher 36 / 41 57.7 17.3 0.8X */ test(24, 42L, 1 << 10, 1 << 11) // Add 31 to all arrays to create worse case alignment for xxHash. /* - Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz - Hash byte arrays with length 31: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - Murmur3_x86_32 38 / 39 54.7 18.3 1.0X - xxHash 64-bit 33 / 33 64.4 15.5 1.2X + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Hash byte arrays with length 31: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Murmur3_x86_32 41 / 45 51.1 19.6 1.0X + xxHash 64-bit 36 / 44 58.8 17.0 1.2X + HiveHasher 49 / 54 42.6 23.5 0.8X */ test(31, 42L, 1 << 10, 1 << 11) /* - Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz - Hash byte arrays with length 95: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - Murmur3_x86_32 91 / 94 22.9 43.6 1.0X - xxHash 64-bit 68 / 69 30.6 32.7 1.3X + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Hash byte arrays with length 95: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Murmur3_x86_32 100 / 110 21.0 47.7 1.0X + xxHash 64-bit 74 / 78 28.2 35.5 1.3X + HiveHasher 189 / 196 11.1 90.3 0.5X */ test(64 + 31, 42L, 1 << 10, 1 << 11) /* - Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz - Hash byte arrays with length 287: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - Murmur3_x86_32 268 / 268 7.8 127.6 1.0X - xxHash 64-bit 108 / 109 19.4 51.6 2.5X + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Hash byte arrays with length 287: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Murmur3_x86_32 299 / 311 7.0 142.4 1.0X + xxHash 64-bit 113 / 122 18.5 54.1 2.6X + HiveHasher 620 / 624 3.4 295.5 0.5X */ test(256 + 31, 42L, 1 << 10, 1 << 11) /* - Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz - Hash byte arrays with length 1055: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - Murmur3_x86_32 942 / 945 2.2 449.4 1.0X - xxHash 64-bit 276 / 276 7.6 131.4 3.4X + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Hash byte arrays with length 1055: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Murmur3_x86_32 1068 / 1070 2.0 509.1 1.0X + xxHash 64-bit 306 / 315 6.9 145.9 3.5X + HiveHasher 2316 / 2369 0.9 1104.3 0.5X */ test(1024 + 31, 42L, 1 << 10, 1 << 11) /* - Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz - Hash byte arrays with length 2079: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - Murmur3_x86_32 1839 / 1843 1.1 876.8 1.0X - xxHash 64-bit 445 / 448 4.7 212.1 4.1X + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Hash byte arrays with length 2079: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Murmur3_x86_32 2252 / 2274 0.9 1074.1 1.0X + xxHash 64-bit 534 / 580 3.9 254.6 4.2X + HiveHasher 4739 / 4786 0.4 2259.8 0.5X */ test(2048 + 31, 42L, 1 << 10, 1 << 11) /* - Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz - Hash byte arrays with length 8223: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - Murmur3_x86_32 7307 / 7310 0.3 3484.4 1.0X - xxHash 64-bit 1487 / 1488 1.4 709.1 4.9X - */ + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Hash byte arrays with length 8223: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Murmur3_x86_32 9249 / 9586 0.2 4410.5 1.0X + xxHash 64-bit 2897 / 3241 0.7 1381.6 3.2X + HiveHasher 19392 / 20211 0.1 9246.6 0.5X + */ test(8192 + 31, 42L, 1 << 10, 1 << 11) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala index 33916c0891866..13ce588462028 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala @@ -145,7 +145,7 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { val inputGenerator = RandomDataGenerator.forType(inputSchema, nullable = false).get val encoder = RowEncoder(inputSchema) val seed = scala.util.Random.nextInt() - test(s"murmur3/xxHash64 hash: ${inputSchema.simpleString}") { + test(s"murmur3/xxHash64/hive hash: ${inputSchema.simpleString}") { for (_ <- 1 to 10) { val input = encoder.toRow(inputGenerator.apply().asInstanceOf[Row]).asInstanceOf[UnsafeRow] val literals = input.toSeq(inputSchema).zip(inputSchema.map(_.dataType)).map { @@ -154,6 +154,7 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { // Only test the interpreted version has same result with codegen version. checkEvaluation(Murmur3Hash(literals, seed), Murmur3Hash(literals, seed).eval()) checkEvaluation(XxHash64(literals, seed), XxHash64(literals, seed).eval()) + checkEvaluation(HiveHash(literals), HiveHash(literals).eval()) } } } From c9fe10d4ed8df5ac4bd0f1eb8c9cd19244e27736 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 4 Oct 2016 22:58:43 -0700 Subject: [PATCH 83/96] [SPARK-17658][SPARKR] read.df/write.df API taking path optionally in SparkR MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? `write.df`/`read.df` API require path which is not actually always necessary in Spark. Currently, it only affects the datasources implementing `CreatableRelationProvider`. Currently, Spark currently does not have internal data sources implementing this but it'd affect other external datasources. In addition we'd be able to use this way in Spark's JDBC datasource after https://github.com/apache/spark/pull/12601 is merged. **Before** - `read.df` ```r > read.df(source = "json") Error in dispatchFunc("read.df(path = NULL, source = NULL, schema = NULL, ...)", : argument "x" is missing with no default ``` ```r > read.df(path = c(1, 2)) Error in dispatchFunc("read.df(path = NULL, source = NULL, schema = NULL, ...)", : argument "x" is missing with no default ``` ```r > read.df(c(1, 2)) Error in invokeJava(isStatic = TRUE, className, methodName, ...) : java.lang.ClassCastException: java.lang.Double cannot be cast to java.lang.String at org.apache.spark.sql.execution.datasources.DataSource.hasMetadata(DataSource.scala:300) at ... In if (is.na(object)) { : ... ``` - `write.df` ```r > write.df(df, source = "json") Error in (function (classes, fdef, mtable) : unable to find an inherited method for function ‘write.df’ for signature ‘"function", "missing"’ ``` ```r > write.df(df, source = c(1, 2)) Error in (function (classes, fdef, mtable) : unable to find an inherited method for function ‘write.df’ for signature ‘"SparkDataFrame", "missing"’ ``` ```r > write.df(df, mode = TRUE) Error in (function (classes, fdef, mtable) : unable to find an inherited method for function ‘write.df’ for signature ‘"SparkDataFrame", "missing"’ ``` **After** - `read.df` ```r > read.df(source = "json") Error in loadDF : analysis error - Unable to infer schema for JSON at . It must be specified manually; ``` ```r > read.df(path = c(1, 2)) Error in f(x, ...) : path should be charactor, null or omitted. ``` ```r > read.df(c(1, 2)) Error in f(x, ...) : path should be charactor, null or omitted. ``` - `write.df` ```r > write.df(df, source = "json") Error in save : illegal argument - 'path' is not specified ``` ```r > write.df(df, source = c(1, 2)) Error in .local(df, path, ...) : source should be charactor, null or omitted. It is 'parquet' by default. ``` ```r > write.df(df, mode = TRUE) Error in .local(df, path, ...) : mode should be charactor or omitted. It is 'error' by default. ``` ## How was this patch tested? Unit tests in `test_sparkSQL.R` Author: hyukjinkwon Closes #15231 from HyukjinKwon/write-default-r. --- R/pkg/R/DataFrame.R | 20 ++++++--- R/pkg/R/SQLContext.R | 19 ++++++--- R/pkg/R/generics.R | 4 +- R/pkg/R/utils.R | 52 +++++++++++++++++++++++ R/pkg/inst/tests/testthat/test_sparkSQL.R | 35 +++++++++++++++ R/pkg/inst/tests/testthat/test_utils.R | 10 +++++ 6 files changed, 127 insertions(+), 13 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 40f1f0f4429e0..75861d5de7092 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2608,7 +2608,7 @@ setMethod("except", #' @param ... additional argument(s) passed to the method. #' #' @family SparkDataFrame functions -#' @aliases write.df,SparkDataFrame,character-method +#' @aliases write.df,SparkDataFrame-method #' @rdname write.df #' @name write.df #' @export @@ -2622,21 +2622,31 @@ setMethod("except", #' } #' @note write.df since 1.4.0 setMethod("write.df", - signature(df = "SparkDataFrame", path = "character"), - function(df, path, source = NULL, mode = "error", ...) { + signature(df = "SparkDataFrame"), + function(df, path = NULL, source = NULL, mode = "error", ...) { + if (!is.null(path) && !is.character(path)) { + stop("path should be charactor, NULL or omitted.") + } + if (!is.null(source) && !is.character(source)) { + stop("source should be character, NULL or omitted. It is the datasource specified ", + "in 'spark.sql.sources.default' configuration by default.") + } + if (!is.character(mode)) { + stop("mode should be charactor or omitted. It is 'error' by default.") + } if (is.null(source)) { source <- getDefaultSqlSource() } jmode <- convertToJSaveMode(mode) options <- varargsToEnv(...) if (!is.null(path)) { - options[["path"]] <- path + options[["path"]] <- path } write <- callJMethod(df@sdf, "write") write <- callJMethod(write, "format", source) write <- callJMethod(write, "mode", jmode) write <- callJMethod(write, "options", options) - write <- callJMethod(write, "save", path) + write <- handledCallJMethod(write, "save") }) #' @rdname write.df diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index ce531c3f88863..baa87824beb91 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -771,6 +771,13 @@ dropTempView <- function(viewName) { #' @method read.df default #' @note read.df since 1.4.0 read.df.default <- function(path = NULL, source = NULL, schema = NULL, na.strings = "NA", ...) { + if (!is.null(path) && !is.character(path)) { + stop("path should be charactor, NULL or omitted.") + } + if (!is.null(source) && !is.character(source)) { + stop("source should be character, NULL or omitted. It is the datasource specified ", + "in 'spark.sql.sources.default' configuration by default.") + } sparkSession <- getSparkSession() options <- varargsToEnv(...) if (!is.null(path)) { @@ -784,16 +791,16 @@ read.df.default <- function(path = NULL, source = NULL, schema = NULL, na.string } if (!is.null(schema)) { stopifnot(class(schema) == "structType") - sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sparkSession, source, - schema$jobj, options) + sdf <- handledCallJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sparkSession, + source, schema$jobj, options) } else { - sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", - "loadDF", sparkSession, source, options) + sdf <- handledCallJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sparkSession, + source, options) } dataFrame(sdf) } -read.df <- function(x, ...) { +read.df <- function(x = NULL, ...) { dispatchFunc("read.df(path = NULL, source = NULL, schema = NULL, ...)", x, ...) } @@ -805,7 +812,7 @@ loadDF.default <- function(path = NULL, source = NULL, schema = NULL, ...) { read.df(path, source, schema, ...) } -loadDF <- function(x, ...) { +loadDF <- function(x = NULL, ...) { dispatchFunc("loadDF(path = NULL, source = NULL, schema = NULL, ...)", x, ...) } diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 67a999da9bc26..90a02e2778310 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -633,7 +633,7 @@ setGeneric("transform", function(`_data`, ...) {standardGeneric("transform") }) #' @rdname write.df #' @export -setGeneric("write.df", function(df, path, source = NULL, mode = "error", ...) { +setGeneric("write.df", function(df, path = NULL, source = NULL, mode = "error", ...) { standardGeneric("write.df") }) @@ -732,7 +732,7 @@ setGeneric("withColumnRenamed", #' @rdname write.df #' @export -setGeneric("write.df", function(df, path, ...) { standardGeneric("write.df") }) +setGeneric("write.df", function(df, path = NULL, ...) { standardGeneric("write.df") }) #' @rdname randomSplit #' @export diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 248c57532b6cf..e69666453480c 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -698,6 +698,58 @@ isSparkRShell <- function() { grepl(".*shell\\.R$", Sys.getenv("R_PROFILE_USER"), perl = TRUE) } +# Works identically with `callJStatic(...)` but throws a pretty formatted exception. +handledCallJStatic <- function(cls, method, ...) { + result <- tryCatch(callJStatic(cls, method, ...), + error = function(e) { + captureJVMException(e, method) + }) + result +} + +# Works identically with `callJMethod(...)` but throws a pretty formatted exception. +handledCallJMethod <- function(obj, method, ...) { + result <- tryCatch(callJMethod(obj, method, ...), + error = function(e) { + captureJVMException(e, method) + }) + result +} + +captureJVMException <- function(e, method) { + rawmsg <- as.character(e) + if (any(grep("^Error in .*?: ", rawmsg))) { + # If the exception message starts with "Error in ...", this is possibly + # "Error in invokeJava(...)". Here, it replaces the characters to + # `paste("Error in", method, ":")` in order to identify which function + # was called in JVM side. + stacktrace <- strsplit(rawmsg, "Error in .*?: ")[[1]] + rmsg <- paste("Error in", method, ":") + stacktrace <- paste(rmsg[1], stacktrace[2]) + } else { + # Otherwise, do not convert the error message just in case. + stacktrace <- rawmsg + } + + if (any(grep("java.lang.IllegalArgumentException: ", stacktrace))) { + msg <- strsplit(stacktrace, "java.lang.IllegalArgumentException: ", fixed = TRUE)[[1]] + # Extract "Error in ..." message. + rmsg <- msg[1] + # Extract the first message of JVM exception. + first <- strsplit(msg[2], "\r?\n\tat")[[1]][1] + stop(paste0(rmsg, "illegal argument - ", first), call. = FALSE) + } else if (any(grep("org.apache.spark.sql.AnalysisException: ", stacktrace))) { + msg <- strsplit(stacktrace, "org.apache.spark.sql.AnalysisException: ", fixed = TRUE)[[1]] + # Extract "Error in ..." message. + rmsg <- msg[1] + # Extract the first message of JVM exception. + first <- strsplit(msg[2], "\r?\n\tat")[[1]][1] + stop(paste0(rmsg, "analysis error - ", first), call. = FALSE) + } else { + stop(stacktrace, call. = FALSE) + } +} + # rbind a list of rows with raw (binary) columns # # @param inputData a list of rows, with each row a list diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 9d874a0988716..f5ab601f274fe 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -2544,6 +2544,41 @@ test_that("Spark version from SparkSession", { expect_equal(ver, version) }) +test_that("Call DataFrameWriter.save() API in Java without path and check argument types", { + df <- read.df(jsonPath, "json") + # This tests if the exception is thrown from JVM not from SparkR side. + # It makes sure that we can omit path argument in write.df API and then it calls + # DataFrameWriter.save() without path. + expect_error(write.df(df, source = "csv"), + "Error in save : illegal argument - 'path' is not specified") + + # Arguments checking in R side. + expect_error(write.df(df, "data.tmp", source = c(1, 2)), + paste("source should be character, NULL or omitted. It is the datasource specified", + "in 'spark.sql.sources.default' configuration by default.")) + expect_error(write.df(df, path = c(3)), + "path should be charactor, NULL or omitted.") + expect_error(write.df(df, mode = TRUE), + "mode should be charactor or omitted. It is 'error' by default.") +}) + +test_that("Call DataFrameWriter.load() API in Java without path and check argument types", { + # This tests if the exception is thrown from JVM not from SparkR side. + # It makes sure that we can omit path argument in read.df API and then it calls + # DataFrameWriter.load() without path. + expect_error(read.df(source = "json"), + paste("Error in loadDF : analysis error - Unable to infer schema for JSON at .", + "It must be specified manually")) + expect_error(read.df("arbitrary_path"), "Error in loadDF : analysis error - Path does not exist") + + # Arguments checking in R side. + expect_error(read.df(path = c(3)), + "path should be charactor, NULL or omitted.") + expect_error(read.df(jsonPath, source = c(1, 2)), + paste("source should be character, NULL or omitted. It is the datasource specified", + "in 'spark.sql.sources.default' configuration by default.")) +}) + unlink(parquetPath) unlink(orcPath) unlink(jsonPath) diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R index 77f25292f3f29..69ed5549168b1 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/inst/tests/testthat/test_utils.R @@ -166,6 +166,16 @@ test_that("convertToJSaveMode", { 'mode should be one of "append", "overwrite", "error", "ignore"') #nolint }) +test_that("captureJVMException", { + method <- "getSQLDataType" + expect_error(tryCatch(callJStatic("org.apache.spark.sql.api.r.SQLUtils", method, + "unknown"), + error = function(e) { + captureJVMException(e, method) + }), + "Error in getSQLDataType : illegal argument - Invalid type unknown") +}) + test_that("hashCode", { expect_error(hashCode("bc53d3605e8a5b7de1e8e271c2317645"), NA) }) From 89516c1c4a167249b0c82f60a62edb45ede3bd2c Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 4 Oct 2016 23:48:26 -0700 Subject: [PATCH 84/96] [SPARK-17258][SQL] Parse scientific decimal literals as decimals ## What changes were proposed in this pull request? Currently Spark SQL parses regular decimal literals (e.g. `10.00`) as decimals and scientific decimal literals (e.g. `10.0e10`) as doubles. The difference between the two confuses most users. This PR unifies the parsing behavior and also parses scientific decimal literals as decimals. This implications in tests are limited to a single Hive compatibility test. ## How was this patch tested? Updated tests in `ExpressionParserSuite` and `SQLQueryTestSuite`. Author: Herman van Hovell Closes #14828 from hvanhovell/SPARK-17258. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 7 +----- .../sql/catalyst/parser/AstBuilder.scala | 8 ------- .../parser/ExpressionParserSuite.scala | 24 +++++++++---------- .../resources/sql-tests/inputs/literals.sql | 8 ++++--- .../sql-tests/results/arithmetic.sql.out | 2 +- .../sql-tests/results/literals.sql.out | 24 ++++++++++++------- .../execution/HiveCompatibilitySuite.scala | 4 +++- 7 files changed, 38 insertions(+), 39 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index c336a0c8eab7a..87719d9ee2bc4 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -653,7 +653,6 @@ quotedIdentifier number : MINUS? DECIMAL_VALUE #decimalLiteral - | MINUS? SCIENTIFIC_DECIMAL_VALUE #scientificDecimalLiteral | MINUS? INTEGER_VALUE #integerLiteral | MINUS? BIGINT_LITERAL #bigIntLiteral | MINUS? SMALLINT_LITERAL #smallIntLiteral @@ -944,12 +943,8 @@ INTEGER_VALUE ; DECIMAL_VALUE - : DECIMAL_DIGITS {isValidDecimal()}? - ; - -SCIENTIFIC_DECIMAL_VALUE : DIGIT+ EXPONENT - | DECIMAL_DIGITS EXPONENT {isValidDecimal()}? + | DECIMAL_DIGITS EXPONENT? {isValidDecimal()}? ; DOUBLE_LITERAL diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index cd0c70a49150d..bf3f30279a6fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1282,14 +1282,6 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } } - /** - * Create a double literal for a number denoted in scientific notation. - */ - override def visitScientificDecimalLiteral( - ctx: ScientificDecimalLiteralContext): Literal = withOrigin(ctx) { - Literal(ctx.getText.toDouble) - } - /** * Create a decimal literal for a regular decimal number. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 3718ac5f1e77b..0fb1138478a9b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -352,6 +352,10 @@ class ExpressionParserSuite extends PlanTest { } test("literals") { + def testDecimal(value: String): Unit = { + assertEqual(value, Literal(BigDecimal(value).underlying)) + } + // NULL assertEqual("null", Literal(null)) @@ -362,20 +366,18 @@ class ExpressionParserSuite extends PlanTest { // Integral should have the narrowest possible type assertEqual("787324", Literal(787324)) assertEqual("7873247234798249234", Literal(7873247234798249234L)) - assertEqual("78732472347982492793712334", - Literal(BigDecimal("78732472347982492793712334").underlying())) + testDecimal("78732472347982492793712334") // Decimal - assertEqual("7873247234798249279371.2334", - Literal(BigDecimal("7873247234798249279371.2334").underlying())) + testDecimal("7873247234798249279371.2334") // Scientific Decimal - assertEqual("9.0e1", 90d) - assertEqual(".9e+2", 90d) - assertEqual("0.9e+2", 90d) - assertEqual("900e-1", 90d) - assertEqual("900.0E-1", 90d) - assertEqual("9.e+1", 90d) + testDecimal("9.0e1") + testDecimal(".9e+2") + testDecimal("0.9e+2") + testDecimal("900e-1") + testDecimal("900.0E-1") + testDecimal("9.e+1") intercept(".e3") // Tiny Int Literal @@ -395,8 +397,6 @@ class ExpressionParserSuite extends PlanTest { assertEqual("10.0D", Literal(10.0D)) intercept("-1.8E308D", s"does not fit in range") intercept("1.8E308D", s"does not fit in range") - // TODO we need to figure out if we should throw an exception here! - assertEqual("1E309", Literal(Double.PositiveInfinity)) // BigDecimal Literal assertEqual("90912830918230182310293801923652346786BD", diff --git a/sql/core/src/test/resources/sql-tests/inputs/literals.sql b/sql/core/src/test/resources/sql-tests/inputs/literals.sql index 40dceb19cfc5b..37b4b7606d12b 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/literals.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/literals.sql @@ -50,14 +50,14 @@ select 1D, 1.2D, 1e10, 1.5e5, .10D, 0.10D, .1e5, .9e+2, 0.9e+2, 900e-1, 9.e+1; select -1D, -1.2D, -1e10, -1.5e5, -.10D, -0.10D, -.1e5; -- negative double select .e3; --- inf and -inf +-- very large decimals (overflowing double). select 1E309, -1E309; -- decimal parsing select 0.3, -0.8, .5, -.18, 0.1111, .1111; --- super large scientific notation numbers should still be valid doubles -select 123456789012345678901234567890123456789e10, 123456789012345678901234567890123456789.1e10; +-- super large scientific notation double literals should still be valid doubles +select 123456789012345678901234567890123456789e10d, 123456789012345678901234567890123456789.1e10d; -- string select "Hello Peter!", 'hello lee!'; @@ -103,3 +103,5 @@ select x'2379ACFe'; -- invalid hexadecimal binary literal select X'XuZ'; +-- Hive literal_double test. +SELECT 3.14, -3.14, 3.14e8, 3.14e-8, -3.14e8, -3.14e-8, 3.14e+8, 3.14E8, 3.14E-8; diff --git a/sql/core/src/test/resources/sql-tests/results/arithmetic.sql.out b/sql/core/src/test/resources/sql-tests/results/arithmetic.sql.out index 6abe048af477d..ce42c016a7100 100644 --- a/sql/core/src/test/resources/sql-tests/results/arithmetic.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/arithmetic.sql.out @@ -29,7 +29,7 @@ struct<-5.2:decimal(2,1)> -- !query 3 select +6.8e0 -- !query 3 schema -struct<6.8:double> +struct<6.8:decimal(2,1)> -- !query 3 output 6.8 diff --git a/sql/core/src/test/resources/sql-tests/results/literals.sql.out b/sql/core/src/test/resources/sql-tests/results/literals.sql.out index e2d8daef9868f..95d4413148f64 100644 --- a/sql/core/src/test/resources/sql-tests/results/literals.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/literals.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 42 +-- Number of queries: 43 -- !query 0 @@ -167,17 +167,17 @@ select 1234567890123456789012345678901234567890.0 -- !query 17 select 1D, 1.2D, 1e10, 1.5e5, .10D, 0.10D, .1e5, .9e+2, 0.9e+2, 900e-1, 9.e+1 -- !query 17 schema -struct<1.0:double,1.2:double,1.0E10:double,150000.0:double,0.1:double,0.1:double,10000.0:double,90.0:double,90.0:double,90.0:double,90.0:double> +struct<1.0:double,1.2:double,1E+10:decimal(1,-10),1.5E+5:decimal(2,-4),0.1:double,0.1:double,1E+4:decimal(1,-4),9E+1:decimal(1,-1),9E+1:decimal(1,-1),90.0:decimal(3,1),9E+1:decimal(1,-1)> -- !query 17 output -1.0 1.2 1.0E10 150000.0 0.1 0.1 10000.0 90.0 90.0 90.0 90.0 +1.0 1.2 10000000000 150000 0.1 0.1 10000 90 90 90 90 -- !query 18 select -1D, -1.2D, -1e10, -1.5e5, -.10D, -0.10D, -.1e5 -- !query 18 schema -struct<-1.0:double,-1.2:double,-1.0E10:double,-150000.0:double,-0.1:double,-0.1:double,-10000.0:double> +struct<-1.0:double,-1.2:double,-1E+10:decimal(1,-10),-1.5E+5:decimal(2,-4),-0.1:double,-0.1:double,-1E+4:decimal(1,-4)> -- !query 18 output --1.0 -1.2 -1.0E10 -150000.0 -0.1 -0.1 -10000.0 +-1.0 -1.2 -10000000000 -150000 -0.1 -0.1 -10000 -- !query 19 @@ -197,9 +197,9 @@ select .e3 -- !query 20 select 1E309, -1E309 -- !query 20 schema -struct +struct<1E+309:decimal(1,-309),-1E+309:decimal(1,-309)> -- !query 20 output -Infinity -Infinity +1000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000 -1000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000 -- !query 21 @@ -211,7 +211,7 @@ struct<0.3:decimal(1,1),-0.8:decimal(1,1),0.5:decimal(1,1),-0.18:decimal(2,2),0. -- !query 22 -select 123456789012345678901234567890123456789e10, 123456789012345678901234567890123456789.1e10 +select 123456789012345678901234567890123456789e10d, 123456789012345678901234567890123456789.1e10d -- !query 22 schema struct<1.2345678901234568E48:double,1.2345678901234568E48:double> -- !query 22 output @@ -408,3 +408,11 @@ contains illegal character for hexBinary: 0XuZ(line 1, pos 7) == SQL == select X'XuZ' -------^^^ + + +-- !query 42 +SELECT 3.14, -3.14, 3.14e8, 3.14e-8, -3.14e8, -3.14e-8, 3.14e+8, 3.14E8, 3.14E-8 +-- !query 42 schema +struct<3.14:decimal(3,2),-3.14:decimal(3,2),3.14E+8:decimal(3,-6),3.14E-8:decimal(10,10),-3.14E+8:decimal(3,-6),-3.14E-8:decimal(10,10),3.14E+8:decimal(3,-6),3.14E+8:decimal(3,-6),3.14E-8:decimal(10,10)> +-- !query 42 output +3.14 -3.14 314000000 0.0000000314 -314000000 -0.0000000314 314000000 314000000 0.0000000314 diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index bebcb8f8016b1..f5d10de8cd2bf 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -555,6 +555,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "varchar_2", "varchar_join1", + // This test assumes we parse scientific decimals as doubles (we parse them as decimals) + "literal_double", + // These tests are duplicates of joinXYZ "auto_join0", "auto_join1", @@ -832,7 +835,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "leftsemijoin_mr", "limit_pushdown_negative", "lineage1", - "literal_double", "literal_ints", "literal_string", "load_dyn_part1", From 6a05eb24d043aa93390f353850d56efa6124e063 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 5 Oct 2016 10:52:43 -0700 Subject: [PATCH 85/96] [SPARK-17328][SQL] Fix NPE with EXPLAIN DESCRIBE TABLE ## What changes were proposed in this pull request? This PR fixes the following NPE scenario in two ways. **Reported Error Scenario** ```scala scala> sql("EXPLAIN DESCRIBE TABLE x").show(truncate = false) INFO SparkSqlParser: Parsing command: EXPLAIN DESCRIBE TABLE x java.lang.NullPointerException ``` - **DESCRIBE**: Extend `DESCRIBE` syntax to accept `TABLE`. - **EXPLAIN**: Prevent NPE in case of the parsing failure of target statement, e.g., `EXPLAIN DESCRIBE TABLES x`. ## How was this patch tested? Pass the Jenkins test with a new test case. Author: Dongjoon Hyun Closes #15357 from dongjoon-hyun/SPARK-17328. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 2 +- .../spark/sql/execution/SparkSqlParser.scala | 4 +- .../resources/sql-tests/inputs/describe.sql | 4 ++ .../sql-tests/results/describe.sql.out | 58 ++++++++++++++----- .../sql/execution/SparkSqlParserSuite.scala | 18 +++++- 5 files changed, 68 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 87719d9ee2bc4..6a94def65f360 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -136,7 +136,7 @@ statement | SHOW CREATE TABLE tableIdentifier #showCreateTable | (DESC | DESCRIBE) FUNCTION EXTENDED? describeFuncName #describeFunction | (DESC | DESCRIBE) DATABASE EXTENDED? identifier #describeDatabase - | (DESC | DESCRIBE) option=(EXTENDED | FORMATTED)? + | (DESC | DESCRIBE) TABLE? option=(EXTENDED | FORMATTED)? tableIdentifier partitionSpec? describeColName? #describeTable | REFRESH TABLE tableIdentifier #refreshTable | REFRESH .*? #refreshResource diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 7f1e23e665eb1..085bb9fc3c6cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -265,7 +265,9 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { } val statement = plan(ctx.statement) - if (isExplainableStatement(statement)) { + if (statement == null) { + null // This is enough since ParseException will raise later. + } else if (isExplainableStatement(statement)) { ExplainCommand(statement, extended = ctx.EXTENDED != null, codegen = ctx.CODEGEN != null) } else { ExplainCommand(OneRowRelation) diff --git a/sql/core/src/test/resources/sql-tests/inputs/describe.sql b/sql/core/src/test/resources/sql-tests/inputs/describe.sql index 3f0ae902e0529..84503d0b12a8e 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/describe.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/describe.sql @@ -2,8 +2,12 @@ CREATE TABLE t (a STRING, b INT) PARTITIONED BY (c STRING, d STRING); ALTER TABLE t ADD PARTITION (c='Us', d=1); +DESCRIBE t; + DESC t; +DESC TABLE t; + -- Ignore these because there exist timestamp results, e.g., `Create Table`. -- DESC EXTENDED t; -- DESC FORMATTED t; diff --git a/sql/core/src/test/resources/sql-tests/results/describe.sql.out b/sql/core/src/test/resources/sql-tests/results/describe.sql.out index 37bf303f1bfe4..b448d60c7685d 100644 --- a/sql/core/src/test/resources/sql-tests/results/describe.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/describe.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 8 +-- Number of queries: 10 -- !query 0 @@ -19,7 +19,7 @@ struct<> -- !query 2 -DESC t +DESCRIBE t -- !query 2 schema struct -- !query 2 output @@ -34,7 +34,7 @@ d string -- !query 3 -DESC t PARTITION (c='Us', d=1) +DESC t -- !query 3 schema struct -- !query 3 output @@ -49,30 +49,60 @@ d string -- !query 4 -DESC t PARTITION (c='Us', d=2) +DESC TABLE t -- !query 4 schema -struct<> +struct -- !query 4 output +# Partition Information +# col_name data_type comment +a string +b int +c string +c string +d string +d string + + +-- !query 5 +DESC t PARTITION (c='Us', d=1) +-- !query 5 schema +struct +-- !query 5 output +# Partition Information +# col_name data_type comment +a string +b int +c string +c string +d string +d string + + +-- !query 6 +DESC t PARTITION (c='Us', d=2) +-- !query 6 schema +struct<> +-- !query 6 output org.apache.spark.sql.catalyst.analysis.NoSuchPartitionException Partition not found in table 't' database 'default': c -> Us d -> 2; --- !query 5 +-- !query 7 DESC t PARTITION (c='Us') --- !query 5 schema +-- !query 7 schema struct<> --- !query 5 output +-- !query 7 output org.apache.spark.sql.AnalysisException Partition spec is invalid. The spec (c) must match the partition spec (c, d) defined in table '`default`.`t`'; --- !query 6 +-- !query 8 DESC t PARTITION (c='Us', d) --- !query 6 schema +-- !query 8 schema struct<> --- !query 6 output +-- !query 8 output org.apache.spark.sql.catalyst.parser.ParseException PARTITION specification is incomplete: `d`(line 1, pos 0) @@ -82,9 +112,9 @@ DESC t PARTITION (c='Us', d) ^^^ --- !query 7 +-- !query 9 DROP TABLE t --- !query 7 schema +-- !query 9 schema struct<> --- !query 7 output +-- !query 9 output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index 8161c08b2cb48..6712d32924890 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -17,11 +17,12 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.command.{DescribeFunctionCommand, ShowFunctionsCommand} +import org.apache.spark.sql.execution.command.{DescribeFunctionCommand, DescribeTableCommand, + ShowFunctionsCommand} import org.apache.spark.sql.internal.SQLConf /** @@ -72,4 +73,17 @@ class SparkSqlParserSuite extends PlanTest { DescribeFunctionCommand(FunctionIdentifier("bar", database = Option("f")), isExtended = true)) } + test("SPARK-17328 Fix NPE with EXPLAIN DESCRIBE TABLE") { + assertEqual("describe table t", + DescribeTableCommand( + TableIdentifier("t"), Map.empty, isExtended = false, isFormatted = false)) + assertEqual("describe table extended t", + DescribeTableCommand( + TableIdentifier("t"), Map.empty, isExtended = true, isFormatted = false)) + assertEqual("describe table formatted t", + DescribeTableCommand( + TableIdentifier("t"), Map.empty, isExtended = false, isFormatted = true)) + + intercept("explain describe tables x", "Unsupported SQL statement") + } } From 9df54f5325c2942bb77008ff1810e2fb5f6d848b Mon Sep 17 00:00:00 2001 From: sethah Date: Wed, 5 Oct 2016 18:28:21 +0000 Subject: [PATCH 86/96] [SPARK-17239][ML][DOC] Update user guide for multiclass logistic regression ## What changes were proposed in this pull request? Updates user guide to reflect that LogisticRegression now supports multiclass. Also adds new examples to show multiclass training. ## How was this patch tested? Ran locally using spark-submit, run-example, and copy/paste from user guide into shells. Generated docs and verified correct output. Author: sethah Closes #15349 from sethah/SPARK-17239. --- docs/ml-classification-regression.md | 65 +++++++++++++++++-- ...gisticRegressionWithElasticNetExample.java | 14 ++++ ...gisticRegressionWithElasticNetExample.java | 55 ++++++++++++++++ .../logistic_regression_with_elastic_net.py | 10 +++ ...ss_logistic_regression_with_elastic_net.py | 48 ++++++++++++++ ...isticRegressionWithElasticNetExample.scala | 13 ++++ ...isticRegressionWithElasticNetExample.scala | 57 ++++++++++++++++ 7 files changed, 255 insertions(+), 7 deletions(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaMulticlassLogisticRegressionWithElasticNetExample.java create mode 100644 examples/src/main/python/ml/multiclass_logistic_regression_with_elastic_net.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md index 7c2437eacde3f..bb2e404330cc0 100644 --- a/docs/ml-classification-regression.md +++ b/docs/ml-classification-regression.md @@ -34,17 +34,22 @@ discussing specific classes of algorithms, such as linear methods, trees, and en ## Logistic regression -Logistic regression is a popular method to predict a binary response. It is a special case of [Generalized Linear models](https://en.wikipedia.org/wiki/Generalized_linear_model) that predicts the probability of the outcome. -For more background and more details about the implementation, refer to the documentation of the [logistic regression in `spark.mllib`](mllib-linear-methods.html#logistic-regression). +Logistic regression is a popular method to predict a categorical response. It is a special case of [Generalized Linear models](https://en.wikipedia.org/wiki/Generalized_linear_model) that predicts the probability of the outcomes. +In `spark.ml` logistic regression can be used to predict a binary outcome by using binomial logistic regression, or it can be used to predict a multiclass outcome by using multinomial logistic regression. Use the `family` +parameter to select between these two algorithms, or leave it unset and Spark will infer the correct variant. - > The current implementation of logistic regression in `spark.ml` only supports binary classes. Support for multiclass regression will be added in the future. + > Multinomial logistic regression can be used for binary classification by setting the `family` param to "multinomial". It will produce two sets of coefficients and two intercepts. > When fitting LogisticRegressionModel without intercept on dataset with constant nonzero column, Spark MLlib outputs zero coefficients for constant nonzero columns. This behavior is the same as R glmnet but different from LIBSVM. +### Binomial logistic regression + +For more background and more details about the implementation of binomial logistic regression, refer to the documentation of [logistic regression in `spark.mllib`](mllib-linear-methods.html#logistic-regression). + **Example** -The following example shows how to train a logistic regression model -with elastic net regularization. `elasticNetParam` corresponds to +The following example shows how to train binomial and multinomial logistic regression +models for binary classification with elastic net regularization. `elasticNetParam` corresponds to $\alpha$ and `regParam` corresponds to $\lambda$.
@@ -92,8 +97,8 @@ provides a summary for a [`LogisticRegressionModel`](api/java/org/apache/spark/ml/classification/LogisticRegressionModel.html). Currently, only binary classification is supported and the summary must be explicitly cast to -[`BinaryLogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html). -This will likely change when multiclass classification is supported. +[`BinaryLogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html). +Support for multiclass model summaries will be added in the future. Continuing the earlier example: @@ -107,6 +112,52 @@ Logistic regression model summary is not yet supported in Python.
+### Multinomial logistic regression + +Multiclass classification is supported via multinomial logistic (softmax) regression. In multinomial logistic regression, +the algorithm produces $K$ sets of coefficients, or a matrix of dimension $K \times J$ where $K$ is the number of outcome +classes and $J$ is the number of features. If the algorithm is fit with an intercept term then a length $K$ vector of +intercepts is available. + + > Multinomial coefficients are available as `coefficientMatrix` and intercepts are available as `interceptVector`. + + > `coefficients` and `intercept` methods on a logistic regression model trained with multinomial family are not supported. Use `coefficientMatrix` and `interceptVector` instead. + +The conditional probabilities of the outcome classes $k \in \{1, 2, ..., K\}$ are modeled using the softmax function. + +`\[ + P(Y=k|\mathbf{X}, \boldsymbol{\beta}_k, \beta_{0k}) = \frac{e^{\boldsymbol{\beta}_k \cdot \mathbf{X} + \beta_{0k}}}{\sum_{k'=0}^{K-1} e^{\boldsymbol{\beta}_{k'} \cdot \mathbf{X} + \beta_{0k'}}} +\]` + +We minimize the weighted negative log-likelihood, using a multinomial response model, with elastic-net penalty to control for overfitting. + +`\[ +\min_{\beta, \beta_0} -\left[\sum_{i=1}^L w_i \cdot \log P(Y = y_i|\mathbf{x}_i)\right] + \lambda \left[\frac{1}{2}\left(1 - \alpha\right)||\boldsymbol{\beta}||_2^2 + \alpha ||\boldsymbol{\beta}||_1\right] +\]` + +For a detailed derivation please see [here](https://en.wikipedia.org/wiki/Multinomial_logistic_regression#As_a_log-linear_model). + +**Example** + +The following example shows how to train a multiclass logistic regression +model with elastic net regularization. + +
+ +
+{% include_example scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala %} +
+ +
+{% include_example java/org/apache/spark/examples/ml/JavaMulticlassLogisticRegressionWithElasticNetExample.java %} +
+ +
+{% include_example python/ml/multiclass_logistic_regression_with_elastic_net.py %} +
+ +
+ ## Decision tree classifier diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java index 6101c79fb0c98..b8fb5972ea418 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java @@ -48,6 +48,20 @@ public static void main(String[] args) { // Print the coefficients and intercept for logistic regression System.out.println("Coefficients: " + lrModel.coefficients() + " Intercept: " + lrModel.intercept()); + + // We can also use the multinomial family for binary classification + LogisticRegression mlr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8) + .setFamily("multinomial"); + + // Fit the model + LogisticRegressionModel mlrModel = mlr.fit(training); + + // Print the coefficients and intercepts for logistic regression with multinomial family + System.out.println("Multinomial coefficients: " + + lrModel.coefficientMatrix() + "\nMultinomial intercepts: " + mlrModel.interceptVector()); // $example off$ spark.stop(); diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaMulticlassLogisticRegressionWithElasticNetExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaMulticlassLogisticRegressionWithElasticNetExample.java new file mode 100644 index 0000000000000..da410cba2b3f1 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaMulticlassLogisticRegressionWithElasticNetExample.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.classification.LogisticRegressionModel; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +// $example off$ + +public class JavaMulticlassLogisticRegressionWithElasticNetExample { + public static void main(String[] args) { + SparkSession spark = SparkSession + .builder() + .appName("JavaMulticlassLogisticRegressionWithElasticNetExample") + .getOrCreate(); + + // $example on$ + // Load training data + Dataset training = spark.read().format("libsvm") + .load("data/mllib/sample_multiclass_classification_data.txt"); + + LogisticRegression lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8); + + // Fit the model + LogisticRegressionModel lrModel = lr.fit(training); + + // Print the coefficients and intercept for multinomial logistic regression + System.out.println("Coefficients: \n" + + lrModel.coefficientMatrix() + " \nIntercept: " + lrModel.interceptVector()); + // $example off$ + + spark.stop(); + } +} diff --git a/examples/src/main/python/ml/logistic_regression_with_elastic_net.py b/examples/src/main/python/ml/logistic_regression_with_elastic_net.py index 33d0689f75cd5..d095fbd373408 100644 --- a/examples/src/main/python/ml/logistic_regression_with_elastic_net.py +++ b/examples/src/main/python/ml/logistic_regression_with_elastic_net.py @@ -40,6 +40,16 @@ # Print the coefficients and intercept for logistic regression print("Coefficients: " + str(lrModel.coefficients)) print("Intercept: " + str(lrModel.intercept)) + + # We can also use the multinomial family for binary classification + mlr = LogisticRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8, family="multinomial") + + # Fit the model + mlrModel = mlr.fit(training) + + # Print the coefficients and intercepts for logistic regression with multinomial family + print("Multinomial coefficients: " + str(mlrModel.coefficientMatrix)) + print("Multinomial intercepts: " + str(mlrModel.interceptVector)) # $example off$ spark.stop() diff --git a/examples/src/main/python/ml/multiclass_logistic_regression_with_elastic_net.py b/examples/src/main/python/ml/multiclass_logistic_regression_with_elastic_net.py new file mode 100644 index 0000000000000..bb9cd82d6ba27 --- /dev/null +++ b/examples/src/main/python/ml/multiclass_logistic_regression_with_elastic_net.py @@ -0,0 +1,48 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +# $example on$ +from pyspark.ml.classification import LogisticRegression +# $example off$ +from pyspark.sql import SparkSession + +if __name__ == "__main__": + spark = SparkSession \ + .builder \ + .appName("MulticlassLogisticRegressionWithElasticNet") \ + .getOrCreate() + + # $example on$ + # Load training data + training = spark \ + .read \ + .format("libsvm") \ + .load("data/mllib/sample_multiclass_classification_data.txt") + + lr = LogisticRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8) + + # Fit the model + lrModel = lr.fit(training) + + # Print the coefficients and intercept for multinomial logistic regression + print("Coefficients: \n" + str(lrModel.coefficientMatrix)) + print("Intercept: " + str(lrModel.interceptVector)) + # $example off$ + + spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala index 616263b8e9f48..18471049087d9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala @@ -45,6 +45,19 @@ object LogisticRegressionWithElasticNetExample { // Print the coefficients and intercept for logistic regression println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}") + + // We can also use the multinomial family for binary classification + val mlr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8) + .setFamily("multinomial") + + val mlrModel = mlr.fit(training) + + // Print the coefficients and intercepts for logistic regression with multinomial family + println(s"Multinomial coefficients: ${mlrModel.coefficientMatrix}") + println(s"Multinomial intercepts: ${mlrModel.interceptVector}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala new file mode 100644 index 0000000000000..42f0ace7a353d --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.classification.LogisticRegression +// $example off$ +import org.apache.spark.sql.SparkSession + +object MulticlassLogisticRegressionWithElasticNetExample { + + def main(args: Array[String]): Unit = { + val spark = SparkSession + .builder + .appName("MulticlassLogisticRegressionWithElasticNetExample") + .getOrCreate() + + // $example on$ + // Load training data + val training = spark + .read + .format("libsvm") + .load("data/mllib/sample_multiclass_classification_data.txt") + + val lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8) + + // Fit the model + val lrModel = lr.fit(training) + + // Print the coefficients and intercept for multinomial logistic regression + println(s"Coefficients: \n${lrModel.coefficientMatrix}") + println(s"Intercepts: ${lrModel.interceptVector}") + // $example off$ + + spark.stop() + } +} +// scalastyle:on println From 221b418b1c9db7b04c600b6300d18b034a4f444e Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 5 Oct 2016 14:54:55 -0700 Subject: [PATCH 87/96] [SPARK-17778][TESTS] Mock SparkContext to reduce memory usage of BlockManagerSuite ## What changes were proposed in this pull request? Mock SparkContext to reduce memory usage of BlockManagerSuite ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #15350 from zsxwing/SPARK-17778. --- .../scala/org/apache/spark/storage/BlockManagerSuite.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 1652fcdb964da..705c355234425 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -107,7 +107,10 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr) conf.set("spark.driver.port", rpcEnv.address.port.toString) - sc = new SparkContext("local", "test", conf) + // Mock SparkContext to reduce the memory usage of tests. It's fine since the only reason we + // need to create a SparkContext is to initialize LiveListenerBus. + sc = mock(classOf[SparkContext]) + when(sc.conf).thenReturn(conf) master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus(sc))), conf, true) From 5fd54b994e2078dbf0794932b4e0ffa9a9eda0c3 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 5 Oct 2016 16:05:30 -0700 Subject: [PATCH 88/96] [SPARK-17758][SQL] Last returns wrong result in case of empty partition ## What changes were proposed in this pull request? The result of the `Last` function can be wrong when the last partition processed is empty. It can return `null` instead of the expected value. For example, this can happen when we process partitions in the following order: ``` - Partition 1 [Row1, Row2] - Partition 2 [Row3] - Partition 3 [] ``` In this case the `Last` function will currently return a null, instead of the value of `Row3`. This PR fixes this by adding a `valueSet` flag to the `Last` function. ## How was this patch tested? We only used end to end tests for `DeclarativeAggregateFunction`s. I have added an evaluator for these functions so we can tests them in catalyst. I have added a `LastTestSuite` to test the `Last` aggregate function. Author: Herman van Hovell Closes #15348 from hvanhovell/SPARK-17758. --- .../catalyst/expressions/aggregate/Last.scala | 27 ++--- .../DeclarativeAggregateEvaluator.scala | 61 ++++++++++ .../expressions/aggregate/LastTestSuite.scala | 109 ++++++++++++++++++ 3 files changed, 184 insertions(+), 13 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DeclarativeAggregateEvaluator.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/LastTestSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala index af8840305805f..8579f7292d3ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala @@ -55,34 +55,35 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) extends Declarat private lazy val last = AttributeReference("last", child.dataType)() - override lazy val aggBufferAttributes: Seq[AttributeReference] = last :: Nil + private lazy val valueSet = AttributeReference("valueSet", BooleanType)() + + override lazy val aggBufferAttributes: Seq[AttributeReference] = last :: valueSet :: Nil override lazy val initialValues: Seq[Literal] = Seq( - /* last = */ Literal.create(null, child.dataType) + /* last = */ Literal.create(null, child.dataType), + /* valueSet = */ Literal.create(false, BooleanType) ) override lazy val updateExpressions: Seq[Expression] = { if (ignoreNulls) { Seq( - /* last = */ If(IsNull(child), last, child) + /* last = */ If(IsNull(child), last, child), + /* valueSet = */ Or(valueSet, IsNotNull(child)) ) } else { Seq( - /* last = */ child + /* last = */ child, + /* valueSet = */ Literal.create(true, BooleanType) ) } } override lazy val mergeExpressions: Seq[Expression] = { - if (ignoreNulls) { - Seq( - /* last = */ If(IsNull(last.right), last.left, last.right) - ) - } else { - Seq( - /* last = */ last.right - ) - } + // Prefer the right hand expression if it has been set. + Seq( + /* last = */ If(valueSet.right, last.right, last.left), + /* valueSet = */ Or(valueSet.right, valueSet.left) + ) } override lazy val evaluateExpression: AttributeReference = last diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DeclarativeAggregateEvaluator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DeclarativeAggregateEvaluator.scala new file mode 100644 index 0000000000000..614f24db0aafb --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DeclarativeAggregateEvaluator.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, JoinedRow} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection + +/** + * Evaluator for a [[DeclarativeAggregate]]. + */ +case class DeclarativeAggregateEvaluator(function: DeclarativeAggregate, input: Seq[Attribute]) { + + lazy val initializer = GenerateSafeProjection.generate(function.initialValues) + + lazy val updater = GenerateSafeProjection.generate( + function.updateExpressions, + function.aggBufferAttributes ++ input) + + lazy val merger = GenerateSafeProjection.generate( + function.mergeExpressions, + function.aggBufferAttributes ++ function.inputAggBufferAttributes) + + lazy val evaluator = GenerateSafeProjection.generate( + function.evaluateExpression :: Nil, + function.aggBufferAttributes) + + def initialize(): InternalRow = initializer.apply(InternalRow.empty).copy() + + def update(values: InternalRow*): InternalRow = { + val joiner = new JoinedRow + val buffer = values.foldLeft(initialize()) { (buffer, input) => + updater(joiner(buffer, input)) + } + buffer.copy() + } + + def merge(buffers: InternalRow*): InternalRow = { + val joiner = new JoinedRow + val buffer = buffers.foldLeft(initialize()) { (left, right) => + merger(joiner(left, right)) + } + buffer.copy() + } + + def eval(buffer: InternalRow): InternalRow = evaluator(buffer).copy() +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/LastTestSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/LastTestSuite.scala new file mode 100644 index 0000000000000..ba36bc074e154 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/LastTestSuite.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Literal} +import org.apache.spark.sql.types.IntegerType + +class LastTestSuite extends SparkFunSuite { + val input = AttributeReference("input", IntegerType, nullable = true)() + val evaluator = DeclarativeAggregateEvaluator(Last(input, Literal(false)), Seq(input)) + val evaluatorIgnoreNulls = DeclarativeAggregateEvaluator(Last(input, Literal(true)), Seq(input)) + + test("empty buffer") { + assert(evaluator.initialize() === InternalRow(null, false)) + } + + test("update") { + val result = evaluator.update( + InternalRow(1), + InternalRow(9), + InternalRow(-1)) + assert(result === InternalRow(-1, true)) + } + + test("update - ignore nulls") { + val result1 = evaluatorIgnoreNulls.update( + InternalRow(null), + InternalRow(9), + InternalRow(null)) + assert(result1 === InternalRow(9, true)) + + val result2 = evaluatorIgnoreNulls.update( + InternalRow(null), + InternalRow(null)) + assert(result2 === InternalRow(null, false)) + } + + test("merge") { + // Empty merge + val p0 = evaluator.initialize() + assert(evaluator.merge(p0) === InternalRow(null, false)) + + // Single merge + val p1 = evaluator.update(InternalRow(1), InternalRow(-99)) + assert(evaluator.merge(p1) === p1) + + // Multiple merges. + val p2 = evaluator.update(InternalRow(2), InternalRow(10)) + assert(evaluator.merge(p1, p2) === p2) + + // Empty partitions (p0 is empty) + assert(evaluator.merge(p1, p0, p2) === p2) + assert(evaluator.merge(p2, p1, p0) === p1) + } + + test("merge - ignore nulls") { + // Multi merges + val p1 = evaluatorIgnoreNulls.update(InternalRow(1), InternalRow(null)) + val p2 = evaluatorIgnoreNulls.update(InternalRow(null), InternalRow(null)) + assert(evaluatorIgnoreNulls.merge(p1, p2) === p1) + } + + test("eval") { + // Null Eval + assert(evaluator.eval(InternalRow(null, true)) === InternalRow(null)) + assert(evaluator.eval(InternalRow(null, false)) === InternalRow(null)) + + // Empty Eval + val p0 = evaluator.initialize() + assert(evaluator.eval(p0) === InternalRow(null)) + + // Update - Eval + val p1 = evaluator.update(InternalRow(1), InternalRow(-99)) + assert(evaluator.eval(p1) === InternalRow(-99)) + + // Update - Merge - Eval + val p2 = evaluator.update(InternalRow(2), InternalRow(10)) + val m1 = evaluator.merge(p1, p0, p2) + assert(evaluator.eval(m1) === InternalRow(10)) + + // Update - Merge - Eval (empty partition at the end) + val m2 = evaluator.merge(p2, p1, p0) + assert(evaluator.eval(m2) === InternalRow(-99)) + } + + test("eval - ignore nulls") { + // Update - Merge - Eval + val p1 = evaluatorIgnoreNulls.update(InternalRow(1), InternalRow(null)) + val p2 = evaluatorIgnoreNulls.update(InternalRow(null), InternalRow(null)) + val m1 = evaluatorIgnoreNulls.merge(p1, p2) + assert(evaluatorIgnoreNulls.eval(m1) === InternalRow(1)) + } +} From 9293734d35eb3d6e4fd4ebb86f54dd5d3a35e6db Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 5 Oct 2016 16:45:45 -0700 Subject: [PATCH 89/96] [SPARK-17346][SQL] Add Kafka source for Structured Streaming ## What changes were proposed in this pull request? This PR adds a new project ` external/kafka-0-10-sql` for Structured Streaming Kafka source. It's based on the design doc: https://docs.google.com/document/d/19t2rWe51x7tq2e5AOfrsM9qb8_m7BRuv9fel9i0PqR8/edit?usp=sharing tdas did most of work and part of them was inspired by koeninger's work. ### Introduction The Kafka source is a structured streaming data source to poll data from Kafka. The schema of reading data is as follows: Column | Type ---- | ---- key | binary value | binary topic | string partition | int offset | long timestamp | long timestampType | int The source can deal with deleting topics. However, the user should make sure there is no Spark job processing the data when deleting a topic. ### Configuration The user can use `DataStreamReader.option` to set the following configurations. Kafka Source's options | value | default | meaning ------ | ------- | ------ | ----- startingOffset | ["earliest", "latest"] | "latest" | The start point when a query is started, either "earliest" which is from the earliest offset, or "latest" which is just from the latest offset. Note: This only applies when a new Streaming query is started, and that resuming will always pick up from where the query left off. failOnDataLost | [true, false] | true | Whether to fail the query when it's possible that data is lost (e.g., topics are deleted, or offsets are out of range). This may be a false alarm. You can disable it when it doesn't work as you expected. subscribe | A comma-separated list of topics | (none) | The topic list to subscribe. Only one of "subscribe" and "subscribeParttern" options can be specified for Kafka source. subscribePattern | Java regex string | (none) | The pattern used to subscribe the topic. Only one of "subscribe" and "subscribeParttern" options can be specified for Kafka source. kafka.consumer.poll.timeoutMs | long | 512 | The timeout in milliseconds to poll data from Kafka in executors fetchOffset.numRetries | int | 3 | Number of times to retry before giving up fatch Kafka latest offsets. fetchOffset.retryIntervalMs | long | 10 | milliseconds to wait before retrying to fetch Kafka offsets Kafka's own configurations can be set via `DataStreamReader.option` with `kafka.` prefix, e.g, `stream.option("kafka.bootstrap.servers", "host:port")` ### Usage * Subscribe to 1 topic ```Scala spark .readStream .format("kafka") .option("kafka.bootstrap.servers", "host:port") .option("subscribe", "topic1") .load() ``` * Subscribe to multiple topics ```Scala spark .readStream .format("kafka") .option("kafka.bootstrap.servers", "host:port") .option("subscribe", "topic1,topic2") .load() ``` * Subscribe to a pattern ```Scala spark .readStream .format("kafka") .option("kafka.bootstrap.servers", "host:port") .option("subscribePattern", "topic.*") .load() ``` ## How was this patch tested? The new unit tests. Author: Shixiong Zhu Author: Tathagata Das Author: Shixiong Zhu Author: cody koeninger Closes #15102 from zsxwing/kafka-source. --- .../spark/util/UninterruptibleThread.scala | 7 - dev/run-tests.py | 2 +- dev/sparktestsupport/modules.py | 12 + .../structured-streaming-kafka-integration.md | 239 ++++++++++ .../structured-streaming-programming-guide.md | 7 +- external/kafka-0-10-sql/pom.xml | 82 ++++ ...pache.spark.sql.sources.DataSourceRegister | 1 + .../sql/kafka010/CachedKafkaConsumer.scala | 152 +++++++ .../spark/sql/kafka010/KafkaSource.scala | 399 ++++++++++++++++ .../sql/kafka010/KafkaSourceOffset.scala | 54 +++ .../sql/kafka010/KafkaSourceProvider.scala | 282 ++++++++++++ .../spark/sql/kafka010/KafkaSourceRDD.scala | 148 ++++++ .../spark/sql/kafka010/package-info.java | 21 + .../src/test/resources/log4j.properties | 28 ++ .../sql/kafka010/KafkaSourceOffsetSuite.scala | 39 ++ .../spark/sql/kafka010/KafkaSourceSuite.scala | 424 ++++++++++++++++++ .../spark/sql/kafka010/KafkaTestUtils.scala | 339 ++++++++++++++ pom.xml | 1 + project/SparkBuild.scala | 6 +- .../execution/streaming/StreamExecution.scala | 8 +- .../spark/sql/streaming/StreamTest.scala | 40 +- 21 files changed, 2268 insertions(+), 23 deletions(-) create mode 100644 docs/structured-streaming-kafka-integration.md create mode 100644 external/kafka-0-10-sql/pom.xml create mode 100644 external/kafka-0-10-sql/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package-info.java create mode 100644 external/kafka-0-10-sql/src/test/resources/log4j.properties create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceOffsetSuite.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala diff --git a/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala index 4dcf95177aa78..f0b68f0cb7e29 100644 --- a/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala +++ b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala @@ -89,13 +89,6 @@ private[spark] class UninterruptibleThread(name: String) extends Thread(name) { } } - /** - * Tests whether `interrupt()` has been called. - */ - override def isInterrupted: Boolean = { - super.isInterrupted || uninterruptibleLock.synchronized { shouldInterruptThread } - } - /** * Interrupt `this` thread if possible. If `this` is in the uninterruptible status, it won't be * interrupted until it enters into the interruptible status. diff --git a/dev/run-tests.py b/dev/run-tests.py index ae4b5306fc5cf..5d661f5f1a1c5 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -110,7 +110,7 @@ def determine_modules_to_test(changed_modules): ['graphx', 'examples'] >>> x = [x.name for x in determine_modules_to_test([modules.sql])] >>> x # doctest: +NORMALIZE_WHITESPACE - ['sql', 'hive', 'mllib', 'examples', 'hive-thriftserver', + ['sql', 'hive', 'mllib', 'sql-kafka-0-10', 'examples', 'hive-thriftserver', 'pyspark-sql', 'sparkr', 'pyspark-mllib', 'pyspark-ml'] """ modules_to_test = set() diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 050cdf043757f..5f14683d9a52f 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -158,6 +158,18 @@ def __hash__(self): ) +sql_kafka = Module( + name="sql-kafka-0-10", + dependencies=[sql], + source_file_regexes=[ + "external/kafka-0-10-sql", + ], + sbt_test_goals=[ + "sql-kafka-0-10/test", + ] +) + + sketch = Module( name="sketch", dependencies=[tags], diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md new file mode 100644 index 0000000000000..668489addf82c --- /dev/null +++ b/docs/structured-streaming-kafka-integration.md @@ -0,0 +1,239 @@ +--- +layout: global +title: Structured Streaming + Kafka Integration Guide (Kafka broker version 0.10.0 or higher) +--- + +Structured Streaming integration for Kafka 0.10 to poll data from Kafka. + +### Linking +For Scala/Java applications using SBT/Maven project definitions, link your application with the following artifact: + + groupId = org.apache.spark + artifactId = spark-sql-kafka-0-10_{{site.SCALA_BINARY_VERSION}} + version = {{site.SPARK_VERSION_SHORT}} + +For Python applications, you need to add this above library and its dependencies when deploying your +application. See the [Deploying](#deploying) subsection below. + +### Creating a Kafka Source Stream + +
+
+ + // Subscribe to 1 topic + val ds1 = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .load() + ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + + // Subscribe to multiple topics + val ds2 = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1,topic2") + .load() + ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + + // Subscribe to a pattern + val ds3 = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribePattern", "topic.*") + .load() + ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + +
+
+ + // Subscribe to 1 topic + Dataset ds1 = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .load() + ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + + // Subscribe to multiple topics + Dataset ds2 = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1,topic2") + .load() + ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + + // Subscribe to a pattern + Dataset ds3 = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribePattern", "topic.*") + .load() + ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + +
+
+ + # Subscribe to 1 topic + ds1 = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .load() + ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + + # Subscribe to multiple topics + ds2 = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1,topic2") + .load() + ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + + # Subscribe to a pattern + ds3 = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribePattern", "topic.*") + .load() + ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + +
+
+ +Each row in the source has the following schema: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
ColumnType
keybinary
valuebinary
topicstring
partitionint
offsetlong
timestamplong
timestampTypeint
+ +The following options must be set for the Kafka source. + + + + + + + + + + + + + + + + + + +
Optionvaluemeaning
subscribeA comma-separated list of topicsThe topic list to subscribe. Only one of "subscribe" and "subscribePattern" options can be + specified for Kafka source.
subscribePatternJava regex stringThe pattern used to subscribe the topic. Only one of "subscribe" and "subscribePattern" + options can be specified for Kafka source.
kafka.bootstrap.serversA comma-separated list of host:portThe Kafka "bootstrap.servers" configuration.
+ +The following configurations are optional: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Optionvaluedefaultmeaning
startingOffset["earliest", "latest"]"latest"The start point when a query is started, either "earliest" which is from the earliest offset, + or "latest" which is just from the latest offset. Note: This only applies when a new Streaming q + uery is started, and that resuming will always pick up from where the query left off.
failOnDataLoss[true, false]trueWhether to fail the query when it's possible that data is lost (e.g., topics are deleted, or + offsets are out of range). This may be a false alarm. You can disable it when it doesn't work + as you expected.
kafkaConsumer.pollTimeoutMslong512The timeout in milliseconds to poll data from Kafka in executors.
fetchOffset.numRetriesint3Number of times to retry before giving up fatch Kafka latest offsets.
fetchOffset.retryIntervalMslong10milliseconds to wait before retrying to fetch Kafka offsets
+ +Kafka's own configurations can be set via `DataStreamReader.option` with `kafka.` prefix, e.g, +`stream.option("kafka.bootstrap.servers", "host:port")`. For possible kafkaParams, see +[Kafka consumer config docs](http://kafka.apache.org/documentation.html#newconsumerconfigs). + +Note that the following Kafka params cannot be set and the Kafka source will throw an exception: +- **group.id**: Kafka source will create a unique group id for each query automatically. +- **auto.offset.reset**: Set the source option `startingOffset` to `earliest` or `latest` to specify + where to start instead. Structured Streaming manages which offsets are consumed internally, rather + than rely on the kafka Consumer to do it. This will ensure that no data is missed when when new + topics/partitions are dynamically subscribed. Note that `startingOffset` only applies when a new + Streaming query is started, and that resuming will always pick up from where the query left off. +- **key.deserializer**: Keys are always deserialized as byte arrays with ByteArrayDeserializer. Use + DataFrame operations to explicitly deserialize the keys. +- **value.deserializer**: Values are always deserialized as byte arrays with ByteArrayDeserializer. + Use DataFrame operations to explicitly deserialize the values. +- **enable.auto.commit**: Kafka source doesn't commit any offset. +- **interceptor.classes**: Kafka source always read keys and values as byte arrays. It's not safe to + use ConsumerInterceptor as it may break the query. + +### Deploying + +As with any Spark applications, `spark-submit` is used to launch your application. `spark-sql-kafka-0-10_{{site.SCALA_BINARY_VERSION}}` +and its dependencies can be directly added to `spark-submit` using `--packages`, such as, + + ./bin/spark-submit --packages org.apache.spark:spark-sql-kafka-0-10_{{site.SCALA_BINARY_VERSION}}:{{site.SPARK_VERSION_SHORT}} ... + +See [Application Submission Guide](submitting-applications.html) for more details about submitting +applications with external dependencies. diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 2e6df94823d38..173fd6e8c73b9 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -418,10 +418,15 @@ Since Spark 2.0, DataFrames and Datasets can represent static, bounded data, as Streaming DataFrames can be created through the `DataStreamReader` interface ([Scala](api/scala/index.html#org.apache.spark.sql.streaming.DataStreamReader)/ [Java](api/java/org/apache/spark/sql/streaming/DataStreamReader.html)/ -[Python](api/python/pyspark.sql.html#pyspark.sql.streaming.DataStreamReader) docs) returned by `SparkSession.readStream()`. Similar to the read interface for creating static DataFrame, you can specify the details of the source – data format, schema, options, etc. In Spark 2.0, there are a few built-in sources. +[Python](api/python/pyspark.sql.html#pyspark.sql.streaming.DataStreamReader) docs) returned by `SparkSession.readStream()`. Similar to the read interface for creating static DataFrame, you can specify the details of the source – data format, schema, options, etc. + +#### Data Sources +In Spark 2.0, there are a few built-in sources. - **File source** - Reads files written in a directory as a stream of data. Supported file formats are text, csv, json, parquet. See the docs of the DataStreamReader interface for a more up-to-date list, and supported options for each file format. Note that the files must be atomically placed in the given directory, which in most file systems, can be achieved by file move operations. + - **Kafka source** - Poll data from Kafka. It's compatible with Kafka broker versions 0.10.0 or higher. See the [Kafka Integration Guide](structured-streaming-kafka-integration.html) for more details. + - **Socket source (for testing)** - Reads UTF8 text data from a socket connection. The listening server socket is at the driver. Note that this should be used only for testing as this does not provide end-to-end fault-tolerance guarantees. Here are some examples. diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml new file mode 100644 index 0000000000000..b96445a11f858 --- /dev/null +++ b/external/kafka-0-10-sql/pom.xml @@ -0,0 +1,82 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.11 + 2.1.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-sql-kafka-0-10_2.11 + + sql-kafka-0-10 + + jar + Kafka 0.10 Source for Structured Streaming + http://spark.apache.org/ + + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + provided + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.kafka + kafka-clients + 0.10.0.1 + + + org.apache.kafka + kafka_${scala.binary.version} + 0.10.0.1 + test + + + net.sf.jopt-simple + jopt-simple + 3.2 + test + + + org.scalacheck + scalacheck_${scala.binary.version} + test + + + org.apache.spark + spark-tags_${scala.binary.version} + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + diff --git a/external/kafka-0-10-sql/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/external/kafka-0-10-sql/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 0000000000000..2f9e9fc0396d5 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1 @@ +org.apache.spark.sql.kafka010.KafkaSourceProvider diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala new file mode 100644 index 0000000000000..3b5a96534f9b6 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} + +import org.apache.kafka.clients.consumer.{ConsumerConfig, ConsumerRecord, KafkaConsumer} +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.{SparkEnv, SparkException, TaskContext} +import org.apache.spark.internal.Logging + + +/** + * Consumer of single topicpartition, intended for cached reuse. + * Underlying consumer is not threadsafe, so neither is this, + * but processing the same topicpartition and group id in multiple threads is usually bad anyway. + */ +private[kafka010] case class CachedKafkaConsumer private( + topicPartition: TopicPartition, + kafkaParams: ju.Map[String, Object]) extends Logging { + + private val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] + + private val consumer = { + val c = new KafkaConsumer[Array[Byte], Array[Byte]](kafkaParams) + val tps = new ju.ArrayList[TopicPartition]() + tps.add(topicPartition) + c.assign(tps) + c + } + + /** Iterator to the already fetch data */ + private var fetchedData = ju.Collections.emptyIterator[ConsumerRecord[Array[Byte], Array[Byte]]] + private var nextOffsetInFetchedData = -2L + + /** + * Get the record for the given offset, waiting up to timeout ms if IO is necessary. + * Sequential forward access will use buffers, but random access will be horribly inefficient. + */ + def get(offset: Long, pollTimeoutMs: Long): ConsumerRecord[Array[Byte], Array[Byte]] = { + logDebug(s"Get $groupId $topicPartition nextOffset $nextOffsetInFetchedData requested $offset") + if (offset != nextOffsetInFetchedData) { + logInfo(s"Initial fetch for $topicPartition $offset") + seek(offset) + poll(pollTimeoutMs) + } + + if (!fetchedData.hasNext()) { poll(pollTimeoutMs) } + assert(fetchedData.hasNext(), + s"Failed to get records for $groupId $topicPartition $offset " + + s"after polling for $pollTimeoutMs") + var record = fetchedData.next() + + if (record.offset != offset) { + logInfo(s"Buffer miss for $groupId $topicPartition $offset") + seek(offset) + poll(pollTimeoutMs) + assert(fetchedData.hasNext(), + s"Failed to get records for $groupId $topicPartition $offset " + + s"after polling for $pollTimeoutMs") + record = fetchedData.next() + assert(record.offset == offset, + s"Got wrong record for $groupId $topicPartition even after seeking to offset $offset") + } + + nextOffsetInFetchedData = offset + 1 + record + } + + private def close(): Unit = consumer.close() + + private def seek(offset: Long): Unit = { + logDebug(s"Seeking to $groupId $topicPartition $offset") + consumer.seek(topicPartition, offset) + } + + private def poll(pollTimeoutMs: Long): Unit = { + val p = consumer.poll(pollTimeoutMs) + val r = p.records(topicPartition) + logDebug(s"Polled $groupId ${p.partitions()} ${r.size}") + fetchedData = r.iterator + } +} + +private[kafka010] object CachedKafkaConsumer extends Logging { + + private case class CacheKey(groupId: String, topicPartition: TopicPartition) + + private lazy val cache = { + val conf = SparkEnv.get.conf + val capacity = conf.getInt("spark.sql.kafkaConsumerCache.capacity", 64) + new ju.LinkedHashMap[CacheKey, CachedKafkaConsumer](capacity, 0.75f, true) { + override def removeEldestEntry( + entry: ju.Map.Entry[CacheKey, CachedKafkaConsumer]): Boolean = { + if (this.size > capacity) { + logWarning(s"KafkaConsumer cache hitting max capacity of $capacity, " + + s"removing consumer for ${entry.getKey}") + try { + entry.getValue.close() + } catch { + case e: SparkException => + logError(s"Error closing earliest Kafka consumer for ${entry.getKey}", e) + } + true + } else { + false + } + } + } + } + + /** + * Get a cached consumer for groupId, assigned to topic and partition. + * If matching consumer doesn't already exist, will be created using kafkaParams. + */ + def getOrCreate( + topic: String, + partition: Int, + kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer = synchronized { + val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] + val topicPartition = new TopicPartition(topic, partition) + val key = CacheKey(groupId, topicPartition) + + // If this is reattempt at running the task, then invalidate cache and start with + // a new consumer + if (TaskContext.get != null && TaskContext.get.attemptNumber > 1) { + cache.remove(key) + new CachedKafkaConsumer(topicPartition, kafkaParams) + } else { + if (!cache.containsKey(key)) { + cache.put(key, new CachedKafkaConsumer(topicPartition, kafkaParams)) + } + cache.get(key) + } + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala new file mode 100644 index 0000000000000..1be70db87497e --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -0,0 +1,399 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} + +import scala.collection.JavaConverters._ +import scala.util.control.NonFatal + +import org.apache.kafka.clients.consumer.{Consumer, KafkaConsumer} +import org.apache.kafka.clients.consumer.internals.NoOpConsumerRebalanceListener +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.SparkContext +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.kafka010.KafkaSource._ +import org.apache.spark.sql.types._ +import org.apache.spark.util.UninterruptibleThread + +/** + * A [[Source]] that uses Kafka's own [[KafkaConsumer]] API to reads data from Kafka. The design + * for this source is as follows. + * + * - The [[KafkaSourceOffset]] is the custom [[Offset]] defined for this source that contains + * a map of TopicPartition -> offset. Note that this offset is 1 + (available offset). For + * example if the last record in a Kafka topic "t", partition 2 is offset 5, then + * KafkaSourceOffset will contain TopicPartition("t", 2) -> 6. This is done keep it consistent + * with the semantics of `KafkaConsumer.position()`. + * + * - The [[ConsumerStrategy]] class defines which Kafka topics and partitions should be read + * by this source. These strategies directly correspond to the different consumption options + * in . This class is designed to return a configured [[KafkaConsumer]] that is used by the + * [[KafkaSource]] to query for the offsets. See the docs on + * [[org.apache.spark.sql.kafka010.KafkaSource.ConsumerStrategy]] for more details. + * + * - The [[KafkaSource]] written to do the following. + * + * - As soon as the source is created, the pre-configured KafkaConsumer returned by the + * [[ConsumerStrategy]] is used to query the initial offsets that this source should + * start reading from. This used to create the first batch. + * + * - `getOffset()` uses the KafkaConsumer to query the latest available offsets, which are + * returned as a [[KafkaSourceOffset]]. + * + * - `getBatch()` returns a DF that reads from the 'start offset' until the 'end offset' in + * for each partition. The end offset is excluded to be consistent with the semantics of + * [[KafkaSourceOffset]] and `KafkaConsumer.position()`. + * + * - The DF returned is based on [[KafkaSourceRDD]] which is constructed such that the + * data from Kafka topic + partition is consistently read by the same executors across + * batches, and cached KafkaConsumers in the executors can be reused efficiently. See the + * docs on [[KafkaSourceRDD]] for more details. + * + * Zero data lost is not guaranteed when topics are deleted. If zero data lost is critical, the user + * must make sure all messages in a topic have been processed when deleting a topic. + * + * There is a known issue caused by KAFKA-1894: the query using KafkaSource maybe cannot be stopped. + * To avoid this issue, you should make sure stopping the query before stopping the Kafka brokers + * and not use wrong broker addresses. + */ +private[kafka010] case class KafkaSource( + sqlContext: SQLContext, + consumerStrategy: ConsumerStrategy, + executorKafkaParams: ju.Map[String, Object], + sourceOptions: Map[String, String], + metadataPath: String, + failOnDataLoss: Boolean) + extends Source with Logging { + + private val sc = sqlContext.sparkContext + + private val pollTimeoutMs = sourceOptions.getOrElse("kafkaConsumer.pollTimeoutMs", "512").toLong + + private val maxOffsetFetchAttempts = + sourceOptions.getOrElse("fetchOffset.numRetries", "3").toInt + + private val offsetFetchAttemptIntervalMs = + sourceOptions.getOrElse("fetchOffset.retryIntervalMs", "10").toLong + + /** + * A KafkaConsumer used in the driver to query the latest Kafka offsets. This only queries the + * offsets and never commits them. + */ + private val consumer = consumerStrategy.createConsumer() + + /** + * Lazily initialize `initialPartitionOffsets` to make sure that `KafkaConsumer.poll` is only + * called in StreamExecutionThread. Otherwise, interrupting a thread while running + * `KafkaConsumer.poll` may hang forever (KAFKA-1894). + */ + private lazy val initialPartitionOffsets = { + val metadataLog = new HDFSMetadataLog[KafkaSourceOffset](sqlContext.sparkSession, metadataPath) + metadataLog.get(0).getOrElse { + val offsets = KafkaSourceOffset(fetchPartitionOffsets(seekToEnd = false)) + metadataLog.add(0, offsets) + logInfo(s"Initial offsets: $offsets") + offsets + }.partitionToOffsets + } + + override def schema: StructType = KafkaSource.kafkaSchema + + /** Returns the maximum available offset for this source. */ + override def getOffset: Option[Offset] = { + // Make sure initialPartitionOffsets is initialized + initialPartitionOffsets + + val offset = KafkaSourceOffset(fetchPartitionOffsets(seekToEnd = true)) + logDebug(s"GetOffset: ${offset.partitionToOffsets.toSeq.map(_.toString).sorted}") + Some(offset) + } + + /** + * Returns the data that is between the offsets + * [`start.get.partitionToOffsets`, `end.partitionToOffsets`), i.e. end.partitionToOffsets is + * exclusive. + */ + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + // Make sure initialPartitionOffsets is initialized + initialPartitionOffsets + + logInfo(s"GetBatch called with start = $start, end = $end") + val untilPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(end) + val fromPartitionOffsets = start match { + case Some(prevBatchEndOffset) => + KafkaSourceOffset.getPartitionOffsets(prevBatchEndOffset) + case None => + initialPartitionOffsets + } + + // Find the new partitions, and get their earliest offsets + val newPartitions = untilPartitionOffsets.keySet.diff(fromPartitionOffsets.keySet) + val newPartitionOffsets = if (newPartitions.nonEmpty) { + fetchNewPartitionEarliestOffsets(newPartitions.toSeq) + } else { + Map.empty[TopicPartition, Long] + } + if (newPartitionOffsets.keySet != newPartitions) { + // We cannot get from offsets for some partitions. It means they got deleted. + val deletedPartitions = newPartitions.diff(newPartitionOffsets.keySet) + reportDataLoss( + s"Cannot find earliest offsets of ${deletedPartitions}. Some data may have been missed") + } + logInfo(s"Partitions added: $newPartitionOffsets") + newPartitionOffsets.filter(_._2 != 0).foreach { case (p, o) => + reportDataLoss( + s"Added partition $p starts from $o instead of 0. Some data may have been missed") + } + + val deletedPartitions = fromPartitionOffsets.keySet.diff(untilPartitionOffsets.keySet) + if (deletedPartitions.nonEmpty) { + reportDataLoss(s"$deletedPartitions are gone. Some data may have been missed") + } + + // Use the until partitions to calculate offset ranges to ignore partitions that have + // been deleted + val topicPartitions = untilPartitionOffsets.keySet.filter { tp => + // Ignore partitions that we don't know the from offsets. + newPartitionOffsets.contains(tp) || fromPartitionOffsets.contains(tp) + }.toSeq + logDebug("TopicPartitions: " + topicPartitions.mkString(", ")) + + val sortedExecutors = getSortedExecutorList(sc) + val numExecutors = sortedExecutors.length + logDebug("Sorted executors: " + sortedExecutors.mkString(", ")) + + // Calculate offset ranges + val offsetRanges = topicPartitions.map { tp => + val fromOffset = fromPartitionOffsets.get(tp).getOrElse { + newPartitionOffsets.getOrElse(tp, { + // This should not happen since newPartitionOffsets contains all partitions not in + // fromPartitionOffsets + throw new IllegalStateException(s"$tp doesn't have a from offset") + }) + } + val untilOffset = untilPartitionOffsets(tp) + val preferredLoc = if (numExecutors > 0) { + // This allows cached KafkaConsumers in the executors to be re-used to read the same + // partition in every batch. + Some(sortedExecutors(floorMod(tp.hashCode, numExecutors))) + } else None + KafkaSourceRDDOffsetRange(tp, fromOffset, untilOffset, preferredLoc) + }.filter { range => + if (range.untilOffset < range.fromOffset) { + reportDataLoss(s"Partition ${range.topicPartition}'s offset was changed from " + + s"${range.fromOffset} to ${range.untilOffset}, some data may have been missed") + false + } else { + true + } + }.toArray + + // Create a RDD that reads from Kafka and get the (key, value) pair as byte arrays. + val rdd = new KafkaSourceRDD( + sc, executorKafkaParams, offsetRanges, pollTimeoutMs).map { cr => + Row(cr.key, cr.value, cr.topic, cr.partition, cr.offset, cr.timestamp, cr.timestampType.id) + } + + logInfo("GetBatch generating RDD of offset range: " + + offsetRanges.sortBy(_.topicPartition.toString).mkString(", ")) + sqlContext.createDataFrame(rdd, schema) + } + + /** Stop this source and free any resources it has allocated. */ + override def stop(): Unit = synchronized { + consumer.close() + } + + override def toString(): String = s"KafkaSource[$consumerStrategy]" + + /** + * Fetch the offset of a partition, either seek to the latest offsets or use the current offsets + * in the consumer. + */ + private def fetchPartitionOffsets( + seekToEnd: Boolean): Map[TopicPartition, Long] = withRetriesWithoutInterrupt { + // Make sure `KafkaConsumer.poll` won't be interrupted (KAFKA-1894) + assert(Thread.currentThread().isInstanceOf[StreamExecutionThread]) + // Poll to get the latest assigned partitions + consumer.poll(0) + val partitions = consumer.assignment() + consumer.pause(partitions) + logDebug(s"Partitioned assigned to consumer: $partitions") + + // Get the current or latest offset of each partition + if (seekToEnd) { + consumer.seekToEnd(partitions) + logDebug("Seeked to the end") + } + val partitionOffsets = partitions.asScala.map(p => p -> consumer.position(p)).toMap + logDebug(s"Got offsets for partition : $partitionOffsets") + partitionOffsets + } + + /** + * Fetch the earliest offsets for newly discovered partitions. The return result may not contain + * some partitions if they are deleted. + */ + private def fetchNewPartitionEarliestOffsets( + newPartitions: Seq[TopicPartition]): Map[TopicPartition, Long] = withRetriesWithoutInterrupt { + // Make sure `KafkaConsumer.poll` won't be interrupted (KAFKA-1894) + assert(Thread.currentThread().isInstanceOf[StreamExecutionThread]) + // Poll to get the latest assigned partitions + consumer.poll(0) + val partitions = consumer.assignment() + logDebug(s"\tPartitioned assigned to consumer: $partitions") + + // Get the earliest offset of each partition + consumer.seekToBeginning(partitions) + val partitionToOffsets = newPartitions.filter { p => + // When deleting topics happen at the same time, some partitions may not be in `partitions`. + // So we need to ignore them + partitions.contains(p) + }.map(p => p -> consumer.position(p)).toMap + logDebug(s"Got offsets for new partitions: $partitionToOffsets") + partitionToOffsets + } + + /** + * Helper function that does multiple retries on the a body of code that returns offsets. + * Retries are needed to handle transient failures. For e.g. race conditions between getting + * assignment and getting position while topics/partitions are deleted can cause NPEs. + * + * This method also makes sure `body` won't be interrupted to workaround a potential issue in + * `KafkaConsumer.poll`. (KAFKA-1894) + */ + private def withRetriesWithoutInterrupt( + body: => Map[TopicPartition, Long]): Map[TopicPartition, Long] = { + synchronized { + var result: Option[Map[TopicPartition, Long]] = None + var attempt = 1 + var lastException: Throwable = null + while (result.isEmpty && attempt <= maxOffsetFetchAttempts + && !Thread.currentThread().isInterrupted) { + Thread.currentThread match { + case ut: UninterruptibleThread => + // "KafkaConsumer.poll" may hang forever if the thread is interrupted (E.g., the query + // is stopped)(KAFKA-1894). Hence, we just make sure we don't interrupt it. + // + // If the broker addresses are wrong, or Kafka cluster is down, "KafkaConsumer.poll" may + // hang forever as well. This cannot be resolved in KafkaSource until Kafka fixes the + // issue. + ut.runUninterruptibly { + try { + result = Some(body) + } catch { + case NonFatal(e) => + lastException = e + logWarning(s"Error in attempt $attempt getting Kafka offsets: ", e) + attempt += 1 + Thread.sleep(offsetFetchAttemptIntervalMs) + } + } + case _ => + throw new IllegalStateException( + "Kafka APIs must be executed on a o.a.spark.util.UninterruptibleThread") + } + } + if (Thread.interrupted()) { + throw new InterruptedException() + } + if (result.isEmpty) { + assert(attempt > maxOffsetFetchAttempts) + assert(lastException != null) + throw lastException + } + result.get + } + } + + /** + * If `failOnDataLoss` is true, this method will throw an `IllegalStateException`. + * Otherwise, just log a warning. + */ + private def reportDataLoss(message: String): Unit = { + if (failOnDataLoss) { + throw new IllegalStateException(message + + ". Set the source option 'failOnDataLoss' to 'false' if you want to ignore these checks.") + } else { + logWarning(message) + } + } +} + + +/** Companion object for the [[KafkaSource]]. */ +private[kafka010] object KafkaSource { + + def kafkaSchema: StructType = StructType(Seq( + StructField("key", BinaryType), + StructField("value", BinaryType), + StructField("topic", StringType), + StructField("partition", IntegerType), + StructField("offset", LongType), + StructField("timestamp", LongType), + StructField("timestampType", IntegerType) + )) + + sealed trait ConsumerStrategy { + def createConsumer(): Consumer[Array[Byte], Array[Byte]] + } + + case class SubscribeStrategy(topics: Seq[String], kafkaParams: ju.Map[String, Object]) + extends ConsumerStrategy { + override def createConsumer(): Consumer[Array[Byte], Array[Byte]] = { + val consumer = new KafkaConsumer[Array[Byte], Array[Byte]](kafkaParams) + consumer.subscribe(topics.asJava) + consumer + } + + override def toString: String = s"Subscribe[${topics.mkString(", ")}]" + } + + case class SubscribePatternStrategy( + topicPattern: String, kafkaParams: ju.Map[String, Object]) + extends ConsumerStrategy { + override def createConsumer(): Consumer[Array[Byte], Array[Byte]] = { + val consumer = new KafkaConsumer[Array[Byte], Array[Byte]](kafkaParams) + consumer.subscribe( + ju.regex.Pattern.compile(topicPattern), + new NoOpConsumerRebalanceListener()) + consumer + } + + override def toString: String = s"SubscribePattern[$topicPattern]" + } + + private def getSortedExecutorList(sc: SparkContext): Array[String] = { + val bm = sc.env.blockManager + bm.master.getPeers(bm.blockManagerId).toArray + .map(x => ExecutorCacheTaskLocation(x.host, x.executorId)) + .sortWith(compare) + .map(_.toString) + } + + private def compare(a: ExecutorCacheTaskLocation, b: ExecutorCacheTaskLocation): Boolean = { + if (a.host == b.host) { a.executorId > b.executorId } else { a.host > b.host } + } + + private def floorMod(a: Long, b: Int): Int = ((a % b).toInt + b) % b +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala new file mode 100644 index 0000000000000..b5ade982515f0 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.sql.execution.streaming.Offset + +/** + * An [[Offset]] for the [[KafkaSource]]. This one tracks all partitions of subscribed topics and + * their offsets. + */ +private[kafka010] +case class KafkaSourceOffset(partitionToOffsets: Map[TopicPartition, Long]) extends Offset { + override def toString(): String = { + partitionToOffsets.toSeq.sortBy(_._1.toString).mkString("[", ", ", "]") + } +} + +/** Companion object of the [[KafkaSourceOffset]] */ +private[kafka010] object KafkaSourceOffset { + + def getPartitionOffsets(offset: Offset): Map[TopicPartition, Long] = { + offset match { + case o: KafkaSourceOffset => o.partitionToOffsets + case _ => + throw new IllegalArgumentException( + s"Invalid conversion from offset of ${offset.getClass} to KafkaSourceOffset") + } + } + + /** + * Returns [[KafkaSourceOffset]] from a variable sequence of (topic, partitionId, offset) + * tuples. + */ + def apply(offsetTuples: (String, Int, Long)*): KafkaSourceOffset = { + KafkaSourceOffset(offsetTuples.map { case(t, p, o) => (new TopicPartition(t, p), o) }.toMap) + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala new file mode 100644 index 0000000000000..1b0a2fe955d03 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -0,0 +1,282 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} +import java.util.UUID + +import scala.collection.JavaConverters._ + +import org.apache.kafka.clients.consumer.ConsumerConfig +import org.apache.kafka.common.serialization.ByteArrayDeserializer + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.execution.streaming.Source +import org.apache.spark.sql.kafka010.KafkaSource._ +import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} +import org.apache.spark.sql.types.StructType + +/** + * The provider class for the [[KafkaSource]]. This provider is designed such that it throws + * IllegalArgumentException when the Kafka Dataset is created, so that it can catch + * missing options even before the query is started. + */ +private[kafka010] class KafkaSourceProvider extends StreamSourceProvider + with DataSourceRegister with Logging { + + import KafkaSourceProvider._ + + /** + * Returns the name and schema of the source. In addition, it also verifies whether the options + * are correct and sufficient to create the [[KafkaSource]] when the query is started. + */ + override def sourceSchema( + sqlContext: SQLContext, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): (String, StructType) = { + require(schema.isEmpty, "Kafka source has a fixed schema and cannot be set with a custom one") + validateOptions(parameters) + ("kafka", KafkaSource.kafkaSchema) + } + + override def createSource( + sqlContext: SQLContext, + metadataPath: String, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source = { + validateOptions(parameters) + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase, v) } + val specifiedKafkaParams = + parameters + .keySet + .filter(_.toLowerCase.startsWith("kafka.")) + .map { k => k.drop(6).toString -> parameters(k) } + .toMap + + val deserClassName = classOf[ByteArrayDeserializer].getName + // Each running query should use its own group id. Otherwise, the query may be only assigned + // partial data since Kafka will assign partitions to multiple consumers having the same group + // id. Hence, we should generate a unique id for each query. + val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}" + + val autoOffsetResetValue = caseInsensitiveParams.get(STARTING_OFFSET_OPTION_KEY) match { + case Some(value) => value.trim() // same values as those supported by auto.offset.reset + case None => "latest" + } + + val kafkaParamsForStrategy = + ConfigUpdater("source", specifiedKafkaParams) + .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName) + .set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName) + + // So that consumers in Kafka source do not mess with any existing group id + .set(ConsumerConfig.GROUP_ID_CONFIG, s"$uniqueGroupId-driver") + + // So that consumers can start from earliest or latest + .set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, autoOffsetResetValue) + + // So that consumers in the driver does not commit offsets unnecessarily + .set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false") + + // So that the driver does not pull too much data + .set(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, new java.lang.Integer(1)) + + // If buffer config is not set, set it to reasonable value to work around + // buffer issues (see KAFKA-3135) + .setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer) + .build() + + val kafkaParamsForExecutors = + ConfigUpdater("executor", specifiedKafkaParams) + .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName) + .set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName) + + // Make sure executors do only what the driver tells them. + .set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "none") + + // So that consumers in executors do not mess with any existing group id + .set(ConsumerConfig.GROUP_ID_CONFIG, s"$uniqueGroupId-executor") + + // So that consumers in executors does not commit offsets unnecessarily + .set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false") + + // If buffer config is not set, set it to reasonable value to work around + // buffer issues (see KAFKA-3135) + .setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer) + .build() + + val strategy = caseInsensitiveParams.find(x => STRATEGY_OPTION_KEYS.contains(x._1)).get match { + case ("subscribe", value) => + SubscribeStrategy( + value.split(",").map(_.trim()).filter(_.nonEmpty), + kafkaParamsForStrategy) + case ("subscribepattern", value) => + SubscribePatternStrategy( + value.trim(), + kafkaParamsForStrategy) + case _ => + // Should never reach here as we are already matching on + // matched strategy names + throw new IllegalArgumentException("Unknown option") + } + + val failOnDataLoss = + caseInsensitiveParams.getOrElse(FAIL_ON_DATA_LOSS_OPTION_KEY, "true").toBoolean + + new KafkaSource( + sqlContext, + strategy, + kafkaParamsForExecutors, + parameters, + metadataPath, + failOnDataLoss) + } + + private def validateOptions(parameters: Map[String, String]): Unit = { + + // Validate source options + + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase, v) } + val specifiedStrategies = + caseInsensitiveParams.filter { case (k, _) => STRATEGY_OPTION_KEYS.contains(k) }.toSeq + if (specifiedStrategies.isEmpty) { + throw new IllegalArgumentException( + "One of the following options must be specified for Kafka source: " + + STRATEGY_OPTION_KEYS.mkString(", ") + ". See the docs for more details.") + } else if (specifiedStrategies.size > 1) { + throw new IllegalArgumentException( + "Only one of the following options can be specified for Kafka source: " + + STRATEGY_OPTION_KEYS.mkString(", ") + ". See the docs for more details.") + } + + val strategy = caseInsensitiveParams.find(x => STRATEGY_OPTION_KEYS.contains(x._1)).get match { + case ("subscribe", value) => + val topics = value.split(",").map(_.trim).filter(_.nonEmpty) + if (topics.isEmpty) { + throw new IllegalArgumentException( + "No topics to subscribe to as specified value for option " + + s"'subscribe' is '$value'") + } + case ("subscribepattern", value) => + val pattern = caseInsensitiveParams("subscribepattern").trim() + if (pattern.isEmpty) { + throw new IllegalArgumentException( + "Pattern to subscribe is empty as specified value for option " + + s"'subscribePattern' is '$value'") + } + case _ => + // Should never reach here as we are already matching on + // matched strategy names + throw new IllegalArgumentException("Unknown option") + } + + caseInsensitiveParams.get(STARTING_OFFSET_OPTION_KEY) match { + case Some(pos) if !STARTING_OFFSET_OPTION_VALUES.contains(pos.trim.toLowerCase) => + throw new IllegalArgumentException( + s"Illegal value '$pos' for option '$STARTING_OFFSET_OPTION_KEY', " + + s"acceptable values are: ${STARTING_OFFSET_OPTION_VALUES.mkString(", ")}") + case _ => + } + + // Validate user-specified Kafka options + + if (caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.GROUP_ID_CONFIG}")) { + throw new IllegalArgumentException( + s"Kafka option '${ConsumerConfig.GROUP_ID_CONFIG}' is not supported as " + + s"user-specified consumer groups is not used to track offsets.") + } + + if (caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.AUTO_OFFSET_RESET_CONFIG}")) { + throw new IllegalArgumentException( + s""" + |Kafka option '${ConsumerConfig.AUTO_OFFSET_RESET_CONFIG}' is not supported. + |Instead set the source option '$STARTING_OFFSET_OPTION_KEY' to 'earliest' or 'latest' to + |specify where to start. Structured Streaming manages which offsets are consumed + |internally, rather than relying on the kafkaConsumer to do it. This will ensure that no + |data is missed when when new topics/partitions are dynamically subscribed. Note that + |'$STARTING_OFFSET_OPTION_KEY' only applies when a new Streaming query is started, and + |that resuming will always pick up from where the query left off. See the docs for more + |details. + """.stripMargin) + } + + if (caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG}")) { + throw new IllegalArgumentException( + s"Kafka option '${ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG}' is not supported as keys " + + "are deserialized as byte arrays with ByteArrayDeserializer. Use DataFrame operations " + + "to explicitly deserialize the keys.") + } + + if (caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG}")) + { + throw new IllegalArgumentException( + s"Kafka option '${ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG}' is not supported as " + + "value are deserialized as byte arrays with ByteArrayDeserializer. Use DataFrame " + + "operations to explicitly deserialize the values.") + } + + val otherUnsupportedConfigs = Seq( + ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, // committing correctly requires new APIs in Source + ConsumerConfig.INTERCEPTOR_CLASSES_CONFIG) // interceptors can modify payload, so not safe + + otherUnsupportedConfigs.foreach { c => + if (caseInsensitiveParams.contains(s"kafka.$c")) { + throw new IllegalArgumentException(s"Kafka option '$c' is not supported") + } + } + + if (!caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG}")) { + throw new IllegalArgumentException( + s"Option 'kafka.${ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG}' must be specified for " + + s"configuring Kafka consumer") + } + } + + override def shortName(): String = "kafka" + + /** Class to conveniently update Kafka config params, while logging the changes */ + private case class ConfigUpdater(module: String, kafkaParams: Map[String, String]) { + private val map = new ju.HashMap[String, Object](kafkaParams.asJava) + + def set(key: String, value: Object): this.type = { + map.put(key, value) + logInfo(s"$module: Set $key to $value, earlier value: ${kafkaParams.get(key).getOrElse("")}") + this + } + + def setIfUnset(key: String, value: Object): ConfigUpdater = { + if (!map.containsKey(key)) { + map.put(key, value) + logInfo(s"$module: Set $key to $value") + } + this + } + + def build(): ju.Map[String, Object] = map + } +} + +private[kafka010] object KafkaSourceProvider { + private val STRATEGY_OPTION_KEYS = Set("subscribe", "subscribepattern") + private val STARTING_OFFSET_OPTION_KEY = "startingoffset" + private val STARTING_OFFSET_OPTION_VALUES = Set("earliest", "latest") + private val FAIL_ON_DATA_LOSS_OPTION_KEY = "failondataloss" +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala new file mode 100644 index 0000000000000..496af7e39abab --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} + +import scala.collection.mutable.ArrayBuffer + +import org.apache.kafka.clients.consumer.ConsumerRecord +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.{Partition, SparkContext, TaskContext} +import org.apache.spark.partial.{BoundedDouble, PartialResult} +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel + + +/** Offset range that one partition of the KafkaSourceRDD has to read */ +private[kafka010] case class KafkaSourceRDDOffsetRange( + topicPartition: TopicPartition, + fromOffset: Long, + untilOffset: Long, + preferredLoc: Option[String]) { + def topic: String = topicPartition.topic + def partition: Int = topicPartition.partition + def size: Long = untilOffset - fromOffset +} + + +/** Partition of the KafkaSourceRDD */ +private[kafka010] case class KafkaSourceRDDPartition( + index: Int, offsetRange: KafkaSourceRDDOffsetRange) extends Partition + + +/** + * An RDD that reads data from Kafka based on offset ranges across multiple partitions. + * Additionally, it allows preferred locations to be set for each topic + partition, so that + * the [[KafkaSource]] can ensure the same executor always reads the same topic + partition + * and cached KafkaConsuemrs (see [[CachedKafkaConsumer]] can be used read data efficiently. + * + * @param sc the [[SparkContext]] + * @param executorKafkaParams Kafka configuration for creating KafkaConsumer on the executors + * @param offsetRanges Offset ranges that define the Kafka data belonging to this RDD + */ +private[kafka010] class KafkaSourceRDD( + sc: SparkContext, + executorKafkaParams: ju.Map[String, Object], + offsetRanges: Seq[KafkaSourceRDDOffsetRange], + pollTimeoutMs: Long) + extends RDD[ConsumerRecord[Array[Byte], Array[Byte]]](sc, Nil) { + + override def persist(newLevel: StorageLevel): this.type = { + logError("Kafka ConsumerRecord is not serializable. " + + "Use .map to extract fields before calling .persist or .window") + super.persist(newLevel) + } + + override def getPartitions: Array[Partition] = { + offsetRanges.zipWithIndex.map { case (o, i) => new KafkaSourceRDDPartition(i, o) }.toArray + } + + override def count(): Long = offsetRanges.map(_.size).sum + + override def countApprox(timeout: Long, confidence: Double): PartialResult[BoundedDouble] = { + val c = count + new PartialResult(new BoundedDouble(c, 1.0, c, c), true) + } + + override def isEmpty(): Boolean = count == 0L + + override def take(num: Int): Array[ConsumerRecord[Array[Byte], Array[Byte]]] = { + val nonEmptyPartitions = + this.partitions.map(_.asInstanceOf[KafkaSourceRDDPartition]).filter(_.offsetRange.size > 0) + + if (num < 1 || nonEmptyPartitions.isEmpty) { + return new Array[ConsumerRecord[Array[Byte], Array[Byte]]](0) + } + + // Determine in advance how many messages need to be taken from each partition + val parts = nonEmptyPartitions.foldLeft(Map[Int, Int]()) { (result, part) => + val remain = num - result.values.sum + if (remain > 0) { + val taken = Math.min(remain, part.offsetRange.size) + result + (part.index -> taken.toInt) + } else { + result + } + } + + val buf = new ArrayBuffer[ConsumerRecord[Array[Byte], Array[Byte]]] + val res = context.runJob( + this, + (tc: TaskContext, it: Iterator[ConsumerRecord[Array[Byte], Array[Byte]]]) => + it.take(parts(tc.partitionId)).toArray, parts.keys.toArray + ) + res.foreach(buf ++= _) + buf.toArray + } + + override def compute( + thePart: Partition, + context: TaskContext): Iterator[ConsumerRecord[Array[Byte], Array[Byte]]] = { + val range = thePart.asInstanceOf[KafkaSourceRDDPartition].offsetRange + assert( + range.fromOffset <= range.untilOffset, + s"Beginning offset ${range.fromOffset} is after the ending offset ${range.untilOffset} " + + s"for topic ${range.topic} partition ${range.partition}. " + + "You either provided an invalid fromOffset, or the Kafka topic has been damaged") + if (range.fromOffset == range.untilOffset) { + logInfo(s"Beginning offset ${range.fromOffset} is the same as ending offset " + + s"skipping ${range.topic} ${range.partition}") + Iterator.empty + + } else { + + val consumer = CachedKafkaConsumer.getOrCreate( + range.topic, range.partition, executorKafkaParams) + var requestOffset = range.fromOffset + + logDebug(s"Creating iterator for $range") + + new Iterator[ConsumerRecord[Array[Byte], Array[Byte]]]() { + override def hasNext(): Boolean = requestOffset < range.untilOffset + override def next(): ConsumerRecord[Array[Byte], Array[Byte]] = { + assert(hasNext(), "Can't call next() once untilOffset has been reached") + val r = consumer.get(requestOffset, pollTimeoutMs) + requestOffset += 1 + r + } + } + } + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package-info.java b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package-info.java new file mode 100644 index 0000000000000..596f775c56dbc --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package-info.java @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Structured Streaming Data Source for Kafka 0.10 + */ +package org.apache.spark.sql.kafka010; diff --git a/external/kafka-0-10-sql/src/test/resources/log4j.properties b/external/kafka-0-10-sql/src/test/resources/log4j.properties new file mode 100644 index 0000000000000..75e3b53a093f6 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/resources/log4j.properties @@ -0,0 +1,28 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file target/unit-tests.log +log4j.rootCategory=INFO, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=true +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Ignore messages below warning level from Jetty, because it's a bit verbose +log4j.logger.org.spark-project.jetty=WARN + diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceOffsetSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceOffsetSuite.scala new file mode 100644 index 0000000000000..7056a41b1751e --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceOffsetSuite.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import org.apache.spark.sql.streaming.OffsetSuite + +class KafkaSourceOffsetSuite extends OffsetSuite { + + compare( + one = KafkaSourceOffset(("t", 0, 1L)), + two = KafkaSourceOffset(("t", 0, 2L))) + + compare( + one = KafkaSourceOffset(("t", 0, 1L), ("t", 1, 0L)), + two = KafkaSourceOffset(("t", 0, 2L), ("t", 1, 1L))) + + compare( + one = KafkaSourceOffset(("t", 0, 1L), ("T", 0, 0L)), + two = KafkaSourceOffset(("t", 0, 2L), ("T", 0, 1L))) + + compare( + one = KafkaSourceOffset(("t", 0, 1L)), + two = KafkaSourceOffset(("t", 0, 2L), ("t", 1, 1L))) +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala new file mode 100644 index 0000000000000..64bf503058027 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -0,0 +1,424 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.util.concurrent.atomic.AtomicInteger + +import scala.util.Random + +import org.apache.kafka.clients.producer.RecordMetadata +import org.scalatest.BeforeAndAfter +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.test.SharedSQLContext + + +abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { + + protected var testUtils: KafkaTestUtils = _ + + override val streamingTimeout = 30.seconds + + override def beforeAll(): Unit = { + super.beforeAll() + testUtils = new KafkaTestUtils + testUtils.setup() + } + + override def afterAll(): Unit = { + if (testUtils != null) { + testUtils.teardown() + testUtils = null + super.afterAll() + } + } + + protected def makeSureGetOffsetCalled = AssertOnQuery { q => + // Because KafkaSource's initialPartitionOffsets is set lazily, we need to make sure + // its "getOffset" is called before pushing any data. Otherwise, because of the race contion, + // we don't know which data should be fetched when `startingOffset` is latest. + q.processAllAvailable() + true + } + + /** + * Add data to Kafka. + * + * `topicAction` can be used to run actions for each topic before inserting data. + */ + case class AddKafkaData(topics: Set[String], data: Int*) + (implicit ensureDataInMultiplePartition: Boolean = false, + concurrent: Boolean = false, + message: String = "", + topicAction: (String, Option[Int]) => Unit = (_, _) => {}) extends AddData { + + override def addData(query: Option[StreamExecution]): (Source, Offset) = { + if (query.get.isActive) { + // Make sure no Spark job is running when deleting a topic + query.get.processAllAvailable() + } + + val existingTopics = testUtils.getAllTopicsAndPartitionSize().toMap + val newTopics = topics.diff(existingTopics.keySet) + for (newTopic <- newTopics) { + topicAction(newTopic, None) + } + for (existingTopicPartitions <- existingTopics) { + topicAction(existingTopicPartitions._1, Some(existingTopicPartitions._2)) + } + + // Read all topics again in case some topics are delete. + val allTopics = testUtils.getAllTopicsAndPartitionSize().toMap.keys + require( + query.nonEmpty, + "Cannot add data when there is no query for finding the active kafka source") + + val sources = query.get.logicalPlan.collect { + case StreamingExecutionRelation(source, _) if source.isInstanceOf[KafkaSource] => + source.asInstanceOf[KafkaSource] + } + if (sources.isEmpty) { + throw new Exception( + "Could not find Kafka source in the StreamExecution logical plan to add data to") + } else if (sources.size > 1) { + throw new Exception( + "Could not select the Kafka source in the StreamExecution logical plan as there" + + "are multiple Kafka sources:\n\t" + sources.mkString("\n\t")) + } + val kafkaSource = sources.head + val topic = topics.toSeq(Random.nextInt(topics.size)) + val sentMetadata = testUtils.sendMessages(topic, data.map { _.toString }.toArray) + + def metadataToStr(m: (String, RecordMetadata)): String = { + s"Sent ${m._1} to partition ${m._2.partition()}, offset ${m._2.offset()}" + } + // Verify that the test data gets inserted into multiple partitions + if (ensureDataInMultiplePartition) { + require( + sentMetadata.groupBy(_._2.partition).size > 1, + s"Added data does not test multiple partitions: ${sentMetadata.map(metadataToStr)}") + } + + val offset = KafkaSourceOffset(testUtils.getLatestOffsets(topics)) + logInfo(s"Added data, expected offset $offset") + (kafkaSource, offset) + } + + override def toString: String = + s"AddKafkaData(topics = $topics, data = $data, message = $message)" + } +} + + +class KafkaSourceSuite extends KafkaSourceTest { + + import testImplicits._ + + private val topicId = new AtomicInteger(0) + + test("cannot stop Kafka stream") { + val topic = newTopic() + testUtils.createTopic(newTopic(), partitions = 5) + testUtils.sendMessages(topic, (101 to 105).map { _.toString }.toArray) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", s"topic-.*") + + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped = kafka.map(kv => kv._2.toInt + 1) + + testStream(mapped)( + StopStream + ) + } + + test("subscribing topic by name from latest offsets") { + val topic = newTopic() + testFromLatestOffsets(topic, "subscribe" -> topic) + } + + test("subscribing topic by name from earliest offsets") { + val topic = newTopic() + testFromEarliestOffsets(topic, "subscribe" -> topic) + } + + test("subscribing topic by pattern from latest offsets") { + val topicPrefix = newTopic() + val topic = topicPrefix + "-suffix" + testFromLatestOffsets(topic, "subscribePattern" -> s"$topicPrefix-.*") + } + + test("subscribing topic by pattern from earliest offsets") { + val topicPrefix = newTopic() + val topic = topicPrefix + "-suffix" + testFromEarliestOffsets(topic, "subscribePattern" -> s"$topicPrefix-.*") + } + + test("subscribing topic by pattern with topic deletions") { + val topicPrefix = newTopic() + val topic = topicPrefix + "-seems" + val topic2 = topicPrefix + "-bad" + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", s"$topicPrefix-.*") + .option("failOnDataLoss", "false") + + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped = kafka.map(kv => kv._2.toInt + 1) + + testStream(mapped)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + Assert { + testUtils.deleteTopic(topic) + testUtils.createTopic(topic2, partitions = 5) + true + }, + AddKafkaData(Set(topic2), 4, 5, 6), + CheckAnswer(2, 3, 4, 5, 6, 7) + ) + } + + test("bad source options") { + def testBadOptions(options: (String, String)*)(expectedMsgs: String*): Unit = { + val ex = intercept[IllegalArgumentException] { + val reader = spark + .readStream + .format("kafka") + options.foreach { case (k, v) => reader.option(k, v) } + reader.load() + } + expectedMsgs.foreach { m => + assert(ex.getMessage.toLowerCase.contains(m.toLowerCase)) + } + } + + // No strategy specified + testBadOptions()("options must be specified", "subscribe", "subscribePattern") + + // Multiple strategies specified + testBadOptions("subscribe" -> "t", "subscribePattern" -> "t.*")( + "only one", "options can be specified") + + testBadOptions("subscribe" -> "")("no topics to subscribe") + testBadOptions("subscribePattern" -> "")("pattern to subscribe is empty") + } + + test("unsupported kafka configs") { + def testUnsupportedConfig(key: String, value: String = "someValue"): Unit = { + val ex = intercept[IllegalArgumentException] { + val reader = spark + .readStream + .format("kafka") + .option("subscribe", "topic") + .option("kafka.bootstrap.servers", "somehost") + .option(s"$key", value) + reader.load() + } + assert(ex.getMessage.toLowerCase.contains("not supported")) + } + + testUnsupportedConfig("kafka.group.id") + testUnsupportedConfig("kafka.auto.offset.reset") + testUnsupportedConfig("kafka.enable.auto.commit") + testUnsupportedConfig("kafka.interceptor.classes") + testUnsupportedConfig("kafka.key.deserializer") + testUnsupportedConfig("kafka.value.deserializer") + + testUnsupportedConfig("kafka.auto.offset.reset", "none") + testUnsupportedConfig("kafka.auto.offset.reset", "someValue") + testUnsupportedConfig("kafka.auto.offset.reset", "earliest") + testUnsupportedConfig("kafka.auto.offset.reset", "latest") + } + + private def newTopic(): String = s"topic-${topicId.getAndIncrement()}" + + private def testFromLatestOffsets(topic: String, options: (String, String)*): Unit = { + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val reader = spark + .readStream + .format("kafka") + .option("startingOffset", s"latest") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + options.foreach { case (k, v) => reader.option(k, v) } + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped = kafka.map(kv => kv._2.toInt + 1) + + testStream(mapped)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + StopStream, + StartStream(), + CheckAnswer(2, 3, 4), // Should get the data back on recovery + StopStream, + AddKafkaData(Set(topic), 4, 5, 6), // Add data when stream is stopped + StartStream(), + CheckAnswer(2, 3, 4, 5, 6, 7), // Should get the added data + AddKafkaData(Set(topic), 7, 8), + CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9), + AssertOnQuery("Add partitions") { query: StreamExecution => + testUtils.addPartitions(topic, 10) + true + }, + AddKafkaData(Set(topic), 9, 10, 11, 12, 13, 14, 15, 16), + CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17) + ) + } + + private def testFromEarliestOffsets(topic: String, options: (String, String)*): Unit = { + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, (1 to 3).map { _.toString }.toArray) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val reader = spark.readStream + reader + .format(classOf[KafkaSourceProvider].getCanonicalName.stripSuffix("$")) + .option("startingOffset", s"earliest") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + options.foreach { case (k, v) => reader.option(k, v) } + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped = kafka.map(kv => kv._2.toInt + 1) + + testStream(mapped)( + AddKafkaData(Set(topic), 4, 5, 6), // Add data when stream is stopped + CheckAnswer(2, 3, 4, 5, 6, 7), + StopStream, + StartStream(), + CheckAnswer(2, 3, 4, 5, 6, 7), + StopStream, + AddKafkaData(Set(topic), 7, 8), + StartStream(), + CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9), + AssertOnQuery("Add partitions") { query: StreamExecution => + testUtils.addPartitions(topic, 10) + true + }, + AddKafkaData(Set(topic), 9, 10, 11, 12, 13, 14, 15, 16), + CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17) + ) + } +} + + +class KafkaSourceStressSuite extends KafkaSourceTest with BeforeAndAfter { + + import testImplicits._ + + val topicId = new AtomicInteger(1) + + @volatile var topics: Seq[String] = (1 to 5).map(_ => newStressTopic) + + def newStressTopic: String = s"stress${topicId.getAndIncrement()}" + + private def nextInt(start: Int, end: Int): Int = { + start + Random.nextInt(start + end - 1) + } + + after { + for (topic <- testUtils.getAllTopicsAndPartitionSize().toMap.keys) { + testUtils.deleteTopic(topic) + } + } + + test("stress test with multiple topics and partitions") { + topics.foreach { topic => + testUtils.createTopic(topic, partitions = nextInt(1, 6)) + testUtils.sendMessages(topic, (101 to 105).map { _.toString }.toArray) + } + + // Create Kafka source that reads from latest offset + val kafka = + spark.readStream + .format(classOf[KafkaSourceProvider].getCanonicalName.stripSuffix("$")) + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", "stress.*") + .option("failOnDataLoss", "false") + .load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + + val mapped = kafka.map(kv => kv._2.toInt + 1) + + runStressTest( + mapped, + Seq(makeSureGetOffsetCalled), + (d, running) => { + Random.nextInt(5) match { + case 0 => // Add a new topic + topics = topics ++ Seq(newStressTopic) + AddKafkaData(topics.toSet, d: _*)(message = s"Add topic $newStressTopic", + topicAction = (topic, partition) => { + if (partition.isEmpty) { + testUtils.createTopic(topic, partitions = nextInt(1, 6)) + } + }) + case 1 if running => + // Only delete a topic when the query is running. Otherwise, we may lost data and + // cannot check the correctness. + val deletedTopic = topics(Random.nextInt(topics.size)) + if (deletedTopic != topics.head) { + topics = topics.filterNot(_ == deletedTopic) + } + AddKafkaData(topics.toSet, d: _*)(message = s"Delete topic $deletedTopic", + topicAction = (topic, partition) => { + // Never remove the first topic to make sure we have at least one topic + if (topic == deletedTopic && deletedTopic != topics.head) { + testUtils.deleteTopic(deletedTopic) + } + }) + case 2 => // Add new partitions + AddKafkaData(topics.toSet, d: _*)(message = "Add partitiosn", + topicAction = (topic, partition) => { + testUtils.addPartitions(topic, partition.get + nextInt(1, 6)) + }) + case _ => // Just add new data + AddKafkaData(topics.toSet, d: _*) + } + }, + iterations = 50) + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala new file mode 100644 index 0000000000000..3eb8a737ba4c8 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala @@ -0,0 +1,339 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.io.File +import java.lang.{Integer => JInt} +import java.net.InetSocketAddress +import java.util.{Map => JMap, Properties} +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConverters._ +import scala.language.postfixOps +import scala.util.Random + +import kafka.admin.AdminUtils +import kafka.api.Request +import kafka.common.TopicAndPartition +import kafka.server.{KafkaConfig, KafkaServer, OffsetCheckpoint} +import kafka.utils.ZkUtils +import org.apache.kafka.clients.consumer.KafkaConsumer +import org.apache.kafka.clients.producer._ +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.serialization.{StringDeserializer, StringSerializer} +import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer} +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils +import org.apache.spark.SparkConf + +/** + * This is a helper class for Kafka test suites. This has the functionality to set up + * and tear down local Kafka servers, and to push data using Kafka producers. + * + * The reason to put Kafka test utility class in src is to test Python related Kafka APIs. + */ +class KafkaTestUtils extends Logging { + + // Zookeeper related configurations + private val zkHost = "localhost" + private var zkPort: Int = 0 + private val zkConnectionTimeout = 60000 + private val zkSessionTimeout = 6000 + + private var zookeeper: EmbeddedZookeeper = _ + + private var zkUtils: ZkUtils = _ + + // Kafka broker related configurations + private val brokerHost = "localhost" + private var brokerPort = 0 + private var brokerConf: KafkaConfig = _ + + // Kafka broker server + private var server: KafkaServer = _ + + // Kafka producer + private var producer: Producer[String, String] = _ + + // Flag to test whether the system is correctly started + private var zkReady = false + private var brokerReady = false + + def zkAddress: String = { + assert(zkReady, "Zookeeper not setup yet or already torn down, cannot get zookeeper address") + s"$zkHost:$zkPort" + } + + def brokerAddress: String = { + assert(brokerReady, "Kafka not setup yet or already torn down, cannot get broker address") + s"$brokerHost:$brokerPort" + } + + def zookeeperClient: ZkUtils = { + assert(zkReady, "Zookeeper not setup yet or already torn down, cannot get zookeeper client") + Option(zkUtils).getOrElse( + throw new IllegalStateException("Zookeeper client is not yet initialized")) + } + + // Set up the Embedded Zookeeper server and get the proper Zookeeper port + private def setupEmbeddedZookeeper(): Unit = { + // Zookeeper server startup + zookeeper = new EmbeddedZookeeper(s"$zkHost:$zkPort") + // Get the actual zookeeper binding port + zkPort = zookeeper.actualPort + zkUtils = ZkUtils(s"$zkHost:$zkPort", zkSessionTimeout, zkConnectionTimeout, false) + zkReady = true + } + + // Set up the Embedded Kafka server + private def setupEmbeddedKafkaServer(): Unit = { + assert(zkReady, "Zookeeper should be set up beforehand") + + // Kafka broker startup + Utils.startServiceOnPort(brokerPort, port => { + brokerPort = port + brokerConf = new KafkaConfig(brokerConfiguration, doLog = false) + server = new KafkaServer(brokerConf) + server.startup() + brokerPort = server.boundPort() + (server, brokerPort) + }, new SparkConf(), "KafkaBroker") + + brokerReady = true + } + + /** setup the whole embedded servers, including Zookeeper and Kafka brokers */ + def setup(): Unit = { + setupEmbeddedZookeeper() + setupEmbeddedKafkaServer() + } + + /** Teardown the whole servers, including Kafka broker and Zookeeper */ + def teardown(): Unit = { + brokerReady = false + zkReady = false + + if (producer != null) { + producer.close() + producer = null + } + + if (server != null) { + server.shutdown() + server = null + } + + brokerConf.logDirs.foreach { f => Utils.deleteRecursively(new File(f)) } + + if (zkUtils != null) { + zkUtils.close() + zkUtils = null + } + + if (zookeeper != null) { + zookeeper.shutdown() + zookeeper = null + } + } + + /** Create a Kafka topic and wait until it is propagated to the whole cluster */ + def createTopic(topic: String, partitions: Int): Unit = { + AdminUtils.createTopic(zkUtils, topic, partitions, 1) + // wait until metadata is propagated + (0 until partitions).foreach { p => + waitUntilMetadataIsPropagated(topic, p) + } + } + + def getAllTopicsAndPartitionSize(): Seq[(String, Int)] = { + zkUtils.getPartitionsForTopics(zkUtils.getAllTopics()).mapValues(_.size).toSeq + } + + /** Create a Kafka topic and wait until it is propagated to the whole cluster */ + def createTopic(topic: String): Unit = { + createTopic(topic, 1) + } + + /** Delete a Kafka topic and wait until it is propagated to the whole cluster */ + def deleteTopic(topic: String): Unit = { + val partitions = zkUtils.getPartitionsForTopics(Seq(topic))(topic).size + AdminUtils.deleteTopic(zkUtils, topic) + verifyTopicDeletion(zkUtils, topic, partitions, List(this.server)) + } + + /** Add new paritions to a Kafka topic */ + def addPartitions(topic: String, partitions: Int): Unit = { + AdminUtils.addPartitions(zkUtils, topic, partitions) + // wait until metadata is propagated + (0 until partitions).foreach { p => + waitUntilMetadataIsPropagated(topic, p) + } + } + + /** Java-friendly function for sending messages to the Kafka broker */ + def sendMessages(topic: String, messageToFreq: JMap[String, JInt]): Unit = { + sendMessages(topic, Map(messageToFreq.asScala.mapValues(_.intValue()).toSeq: _*)) + } + + /** Send the messages to the Kafka broker */ + def sendMessages(topic: String, messageToFreq: Map[String, Int]): Unit = { + val messages = messageToFreq.flatMap { case (s, freq) => Seq.fill(freq)(s) }.toArray + sendMessages(topic, messages) + } + + /** Send the array of messages to the Kafka broker */ + def sendMessages(topic: String, messages: Array[String]): Seq[(String, RecordMetadata)] = { + producer = new KafkaProducer[String, String](producerConfiguration) + val offsets = try { + messages.map { m => + val metadata = + producer.send(new ProducerRecord[String, String](topic, m)).get(10, TimeUnit.SECONDS) + logInfo(s"\tSent $m to partition ${metadata.partition}, offset ${metadata.offset}") + (m, metadata) + } + } finally { + if (producer != null) { + producer.close() + producer = null + } + } + offsets + } + + def getLatestOffsets(topics: Set[String]): Map[TopicPartition, Long] = { + val kc = new KafkaConsumer[String, String](consumerConfiguration) + logInfo("Created consumer to get latest offsets") + kc.subscribe(topics.asJavaCollection) + kc.poll(0) + val partitions = kc.assignment() + kc.pause(partitions) + kc.seekToEnd(partitions) + val offsets = partitions.asScala.map(p => p -> kc.position(p)).toMap + kc.close() + logInfo("Closed consumer to get latest offsets") + offsets + } + + private def brokerConfiguration: Properties = { + val props = new Properties() + props.put("broker.id", "0") + props.put("host.name", "localhost") + props.put("advertised.host.name", "localhost") + props.put("port", brokerPort.toString) + props.put("log.dir", Utils.createTempDir().getAbsolutePath) + props.put("zookeeper.connect", zkAddress) + props.put("log.flush.interval.messages", "1") + props.put("replica.socket.timeout.ms", "1500") + props.put("delete.topic.enable", "true") + props + } + + private def producerConfiguration: Properties = { + val props = new Properties() + props.put("bootstrap.servers", brokerAddress) + props.put("value.serializer", classOf[StringSerializer].getName) + props.put("key.serializer", classOf[StringSerializer].getName) + // wait for all in-sync replicas to ack sends + props.put("acks", "all") + props + } + + private def consumerConfiguration: Properties = { + val props = new Properties() + props.put("bootstrap.servers", brokerAddress) + props.put("group.id", "group-KafkaTestUtils-" + Random.nextInt) + props.put("value.deserializer", classOf[StringDeserializer].getName) + props.put("key.deserializer", classOf[StringDeserializer].getName) + props.put("enable.auto.commit", "false") + props + } + + private def verifyTopicDeletion( + zkUtils: ZkUtils, + topic: String, + numPartitions: Int, + servers: Seq[KafkaServer]) { + import ZkUtils._ + val topicAndPartitions = (0 until numPartitions).map(TopicAndPartition(topic, _)) + def isDeleted(): Boolean = { + // wait until admin path for delete topic is deleted, signaling completion of topic deletion + val deletePath = !zkUtils.pathExists(getDeleteTopicPath(topic)) + val topicPath = !zkUtils.pathExists(getTopicPath(topic)) + // ensure that the topic-partition has been deleted from all brokers' replica managers + val replicaManager = servers.forall(server => topicAndPartitions.forall(tp => + server.replicaManager.getPartition(tp.topic, tp.partition) == None)) + // ensure that logs from all replicas are deleted if delete topic is marked successful + val logManager = servers.forall(server => topicAndPartitions.forall(tp => + server.getLogManager().getLog(tp).isEmpty)) + // ensure that topic is removed from all cleaner offsets + val cleaner = servers.forall(server => topicAndPartitions.forall { tp => + val checkpoints = server.getLogManager().logDirs.map { logDir => + new OffsetCheckpoint(new File(logDir, "cleaner-offset-checkpoint")).read() + } + checkpoints.forall(checkpointsPerLogDir => !checkpointsPerLogDir.contains(tp)) + }) + deletePath && topicPath && replicaManager && logManager && cleaner + } + eventually(timeout(10.seconds)) { + assert(isDeleted, s"$topic not deleted after timeout") + } + } + + private def waitUntilMetadataIsPropagated(topic: String, partition: Int): Unit = { + def isPropagated = server.apis.metadataCache.getPartitionInfo(topic, partition) match { + case Some(partitionState) => + val leaderAndInSyncReplicas = partitionState.leaderIsrAndControllerEpoch.leaderAndIsr + + zkUtils.getLeaderForPartition(topic, partition).isDefined && + Request.isValidBrokerId(leaderAndInSyncReplicas.leader) && + leaderAndInSyncReplicas.isr.size >= 1 + + case _ => + false + } + eventually(timeout(10.seconds)) { + assert(isPropagated, s"Partition [$topic, $partition] metadata not propagated after timeout") + } + } + + private class EmbeddedZookeeper(val zkConnect: String) { + val snapshotDir = Utils.createTempDir() + val logDir = Utils.createTempDir() + + val zookeeper = new ZooKeeperServer(snapshotDir, logDir, 500) + val (ip, port) = { + val splits = zkConnect.split(":") + (splits(0), splits(1).toInt) + } + val factory = new NIOServerCnxnFactory() + factory.configure(new InetSocketAddress(ip, port), 16) + factory.startup(zookeeper) + + val actualPort = factory.getLocalPort + + def shutdown() { + factory.shutdown() + Utils.deleteRecursively(snapshotDir) + Utils.deleteRecursively(logDir) + } + } +} + diff --git a/pom.xml b/pom.xml index 8408f4b1fa5ed..37976b0359ad4 100644 --- a/pom.xml +++ b/pom.xml @@ -111,6 +111,7 @@ external/kafka-0-8-assembly external/kafka-0-10 external/kafka-0-10-assembly + external/kafka-0-10-sql diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 8e47e7f13d367..88d5dc9b02dd9 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -39,8 +39,8 @@ object BuildCommons { private val buildLocation = file(".").getAbsoluteFile.getParentFile - val sqlProjects@Seq(catalyst, sql, hive, hiveThriftServer) = Seq( - "catalyst", "sql", "hive", "hive-thriftserver" + val sqlProjects@Seq(catalyst, sql, hive, hiveThriftServer, sqlKafka010) = Seq( + "catalyst", "sql", "hive", "hive-thriftserver", "sql-kafka-0-10" ).map(ProjectRef(buildLocation, _)) val streamingProjects@Seq( @@ -353,7 +353,7 @@ object SparkBuild extends PomBuild { val mimaProjects = allProjects.filterNot { x => Seq( spark, hive, hiveThriftServer, catalyst, repl, networkCommon, networkShuffle, networkYarn, - unsafe, tags + unsafe, tags, sqlKafka010 ).contains(x) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 9825f19b86a55..b3a0d6ad0bd4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -116,7 +116,7 @@ class StreamExecution( * [[HDFSMetadataLog]]. See SPARK-14131 for more details. */ val microBatchThread = - new UninterruptibleThread(s"stream execution thread for $name") { + new StreamExecutionThread(s"stream execution thread for $name") { override def run(): Unit = { // To fix call site like "run at :0", we bridge the call site from the caller // thread to this micro batch thread @@ -530,3 +530,9 @@ object StreamExecution { def nextId: Long = _nextId.getAndIncrement() } + +/** + * A special thread to run the stream query. Some codes require to run in the StreamExecutionThread + * and will use `classOf[StreamExecutionThread]` to check. + */ +abstract class StreamExecutionThread(name: String) extends UninterruptibleThread(name) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index aa6515bc7a909..09140a1d6e76b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -50,11 +50,11 @@ import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils} * * {{{ * val inputData = MemoryStream[Int] - val mapped = inputData.toDS().map(_ + 1) - - testStream(mapped)( - AddData(inputData, 1, 2, 3), - CheckAnswer(2, 3, 4)) + * val mapped = inputData.toDS().map(_ + 1) + * + * testStream(mapped)( + * AddData(inputData, 1, 2, 3), + * CheckAnswer(2, 3, 4)) * }}} * * Note that while we do sleep to allow the other thread to progress without spinning, @@ -477,21 +477,41 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { } } + + /** + * Creates a stress test that randomly starts/stops/adds data/checks the result. + * + * @param ds a dataframe that executes + 1 on a stream of integers, returning the result + * @param addData an add data action that adds the given numbers to the stream, encoding them + * as needed + * @param iterations the iteration number + */ + def runStressTest( + ds: Dataset[Int], + addData: Seq[Int] => StreamAction, + iterations: Int = 100): Unit = { + runStressTest(ds, Seq.empty, (data, running) => addData(data), iterations) + } + /** * Creates a stress test that randomly starts/stops/adds data/checks the result. * - * @param ds a dataframe that executes + 1 on a stream of integers, returning the result. - * @param addData and add data action that adds the given numbers to the stream, encoding them + * @param ds a dataframe that executes + 1 on a stream of integers, returning the result + * @param prepareActions actions need to run before starting the stress test. + * @param addData an add data action that adds the given numbers to the stream, encoding them * as needed + * @param iterations the iteration number */ def runStressTest( ds: Dataset[Int], - addData: Seq[Int] => StreamAction, - iterations: Int = 100): Unit = { + prepareActions: Seq[StreamAction], + addData: (Seq[Int], Boolean) => StreamAction, + iterations: Int): Unit = { implicit val intEncoder = ExpressionEncoder[Int]() var dataPos = 0 var running = true val actions = new ArrayBuffer[StreamAction]() + actions ++= prepareActions def addCheck() = { actions += CheckAnswer(1 to dataPos: _*) } @@ -499,7 +519,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { val numItems = Random.nextInt(10) val data = dataPos until (dataPos + numItems) dataPos += numItems - actions += addData(data) + actions += addData(data, running) } (1 to iterations).foreach { i => From b678e465afa417780b54db0fbbaa311621311f15 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 5 Oct 2016 18:11:31 -0700 Subject: [PATCH 90/96] [SPARK-17346][SQL][TEST-MAVEN] Generate the sql test jar to fix the maven build ## What changes were proposed in this pull request? Generate the sql test jar to fix the maven build ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #15368 from zsxwing/sql-test-jar. --- external/kafka-0-10-sql/pom.xml | 14 ++++++++++ .../spark/sql/kafka010/KafkaSourceSuite.scala | 1 + sql/core/pom.xml | 27 +++++++++++++++++++ 3 files changed, 42 insertions(+) diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml index b96445a11f858..ebff5fd07a9b9 100644 --- a/external/kafka-0-10-sql/pom.xml +++ b/external/kafka-0-10-sql/pom.xml @@ -41,6 +41,20 @@ ${project.version} provided + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-catalyst_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-sql_${scala.binary.version} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 64bf503058027..6c03070398fca 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -151,6 +151,7 @@ class KafkaSourceSuite extends KafkaSourceTest { val mapped = kafka.map(kv => kv._2.toInt + 1) testStream(mapped)( + makeSureGetOffsetCalled, StopStream ) } diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 84de1d4a6e2d1..7da77158ff07e 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -132,6 +132,33 @@ target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes + + + org.apache.maven.plugins + maven-jar-plugin + + + + test-jar + + + + test-jar-on-test-compile + test-compile + + test-jar + + + + org.codehaus.mojo build-helper-maven-plugin From 7aeb20be7e999523784aca7be1a7c9c99dec125e Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 5 Oct 2016 23:03:09 -0700 Subject: [PATCH 91/96] [MINOR][ML] Avoid 2D array flatten in NB training. ## What changes were proposed in this pull request? Avoid 2D array flatten in ```NaiveBayes``` training, since flatten method might be expensive (It will create another array and copy data there). ## How was this patch tested? Existing tests. Author: Yanbo Liang Closes #15359 from yanboliang/nb-theta. --- .../org/apache/spark/ml/classification/NaiveBayes.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 6775745167b08..e565a6fd3ece2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -176,8 +176,8 @@ class NaiveBayes @Since("1.5.0") ( val numLabels = aggregated.length val numDocuments = aggregated.map(_._2._1).sum - val piArray = Array.fill[Double](numLabels)(0.0) - val thetaArrays = Array.fill[Double](numLabels, numFeatures)(0.0) + val piArray = new Array[Double](numLabels) + val thetaArray = new Array[Double](numLabels * numFeatures) val lambda = $(smoothing) val piLogDenom = math.log(numDocuments + numLabels * lambda) @@ -193,14 +193,14 @@ class NaiveBayes @Since("1.5.0") ( } var j = 0 while (j < numFeatures) { - thetaArrays(i)(j) = math.log(sumTermFreqs(j) + lambda) - thetaLogDenom + thetaArray(i * numFeatures + j) = math.log(sumTermFreqs(j) + lambda) - thetaLogDenom j += 1 } i += 1 } val pi = Vectors.dense(piArray) - val theta = new DenseMatrix(numLabels, thetaArrays(0).length, thetaArrays.flatten, true) + val theta = new DenseMatrix(numLabels, numFeatures, thetaArray, true) new NaiveBayesModel(uid, pi, theta) } From 5e9f32dd87e58e909a579eaa310e67d31c3b6573 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 6 Oct 2016 09:58:58 +0100 Subject: [PATCH 92/96] [BUILD] Closing some stale PRs ## What changes were proposed in this pull request? This PR proposes to close some stale PRs and ones suggested to be closed by committer(s) or obviously inappropriate PRs (e.g. branch to branch). Closes #13458 Closes #15278 Closes #15294 Closes #15339 Closes #15283 ## How was this patch tested? N/A Author: hyukjinkwon Closes #15356 from HyukjinKwon/closing-prs. From 92b7e5728025b1bb6ed3aab5f1753c946a73568c Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 6 Oct 2016 09:42:30 -0700 Subject: [PATCH 93/96] [SPARK-17750][SQL] Fix CREATE VIEW with INTERVAL arithmetic. ## What changes were proposed in this pull request? Currently, Spark raises `RuntimeException` when creating a view with timestamp with INTERVAL arithmetic like the following. The root cause is the arithmetic expression, `TimeAdd`, was transformed into `timeadd` function as a VIEW definition. This PR fixes the SQL definition of `TimeAdd` and `TimeSub` expressions. ```scala scala> sql("CREATE TABLE dates (ts TIMESTAMP)") scala> sql("CREATE VIEW view1 AS SELECT ts + INTERVAL 1 DAY FROM dates") java.lang.RuntimeException: Failed to analyze the canonicalized SQL: ... ``` ## How was this patch tested? Pass Jenkins with a new testcase. Author: Dongjoon Hyun Closes #15318 from dongjoon-hyun/SPARK-17750. --- .../expressions/datetimeExpressions.scala | 2 ++ .../resources/sqlgen/interval_arithmetic.sql | 8 ++++++++ .../catalyst/ExpressionSQLBuilderSuite.scala | 18 +++++++++++++++++- .../sql/catalyst/LogicalPlanToSQLSuite.scala | 16 ++++++++++++++++ 4 files changed, 43 insertions(+), 1 deletion(-) create mode 100644 sql/hive/src/test/resources/sqlgen/interval_arithmetic.sql diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 04c17bdaf2989..7ab68a13e09cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -682,6 +682,7 @@ case class TimeAdd(start: Expression, interval: Expression) override def right: Expression = interval override def toString: String = s"$left + $right" + override def sql: String = s"${left.sql} + ${right.sql}" override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, CalendarIntervalType) override def dataType: DataType = TimestampType @@ -762,6 +763,7 @@ case class TimeSub(start: Expression, interval: Expression) override def right: Expression = interval override def toString: String = s"$left - $right" + override def sql: String = s"${left.sql} - ${right.sql}" override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, CalendarIntervalType) override def dataType: DataType = TimestampType diff --git a/sql/hive/src/test/resources/sqlgen/interval_arithmetic.sql b/sql/hive/src/test/resources/sqlgen/interval_arithmetic.sql new file mode 100644 index 0000000000000..31d00348769f5 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/interval_arithmetic.sql @@ -0,0 +1,8 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +select ts + interval 1 day, ts + interval 2 days, + ts - interval 1 day, ts - interval 2 days, + ts + interval '1' day, ts + interval '2' days, + ts - interval '1' day, ts - interval '2' days +from dates +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `CAST(ts + interval 1 days AS TIMESTAMP)`, `gen_attr_2` AS `CAST(ts + interval 2 days AS TIMESTAMP)`, `gen_attr_3` AS `CAST(ts - interval 1 days AS TIMESTAMP)`, `gen_attr_4` AS `CAST(ts - interval 2 days AS TIMESTAMP)`, `gen_attr_5` AS `CAST(ts + interval 1 days AS TIMESTAMP)`, `gen_attr_6` AS `CAST(ts + interval 2 days AS TIMESTAMP)`, `gen_attr_7` AS `CAST(ts - interval 1 days AS TIMESTAMP)`, `gen_attr_8` AS `CAST(ts - interval 2 days AS TIMESTAMP)` FROM (SELECT CAST(`gen_attr_1` + interval 1 days AS TIMESTAMP) AS `gen_attr_0`, CAST(`gen_attr_1` + interval 2 days AS TIMESTAMP) AS `gen_attr_2`, CAST(`gen_attr_1` - interval 1 days AS TIMESTAMP) AS `gen_attr_3`, CAST(`gen_attr_1` - interval 2 days AS TIMESTAMP) AS `gen_attr_4`, CAST(`gen_attr_1` + interval 1 days AS TIMESTAMP) AS `gen_attr_5`, CAST(`gen_attr_1` + interval 2 days AS TIMESTAMP) AS `gen_attr_6`, CAST(`gen_attr_1` - interval 1 days AS TIMESTAMP) AS `gen_attr_7`, CAST(`gen_attr_1` - interval 2 days AS TIMESTAMP) AS `gen_attr_8` FROM (SELECT `ts` AS `gen_attr_1` FROM `default`.`dates`) AS gen_subquery_0) AS gen_subquery_1 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala index ce5efe853ca4f..149ce1e195111 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala @@ -20,7 +20,9 @@ package org.apache.spark.sql.catalyst import java.sql.Timestamp import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{If, Literal, SpecifiedWindowFrame, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.expressions.{If, Literal, SpecifiedWindowFrame, TimeAdd, + TimeSub, WindowSpecDefinition} +import org.apache.spark.unsafe.types.CalendarInterval class ExpressionSQLBuilderSuite extends SQLBuilderTest { test("literal") { @@ -119,4 +121,18 @@ class ExpressionSQLBuilderSuite extends SQLBuilderTest { s"(PARTITION BY `a`, `b` ORDER BY `c` ASC NULLS FIRST, `d` DESC NULLS LAST $frame)" ) } + + test("interval arithmetic") { + val interval = Literal(new CalendarInterval(0, CalendarInterval.MICROS_PER_DAY)) + + checkSQL( + TimeAdd('a, interval), + "`a` + interval 1 days" + ) + + checkSQL( + TimeSub('a, interval), + "`a` - interval 1 days" + ) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala index 7fa5c29dc5b8f..9ac1e86fc82cb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala @@ -1145,4 +1145,20 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { """.stripMargin, "inline_tables") } + + test("SPARK-17750 - interval arithmetic") { + withTable("dates") { + sql("create table dates (ts timestamp)") + checkSQL( + """ + |select ts + interval 1 day, ts + interval 2 days, + | ts - interval 1 day, ts - interval 2 days, + | ts + interval '1' day, ts + interval '2' days, + | ts - interval '1' day, ts - interval '2' days + |from dates + """.stripMargin, + "interval_arithmetic" + ) + } + } } From 79accf45ace5549caa0cbab02f94fc87bedb5587 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 6 Oct 2016 10:33:45 -0700 Subject: [PATCH 94/96] [SPARK-17798][SQL] Remove redundant Experimental annotations in sql.streaming ## What changes were proposed in this pull request? I was looking through API annotations to catch mislabeled APIs, and realized DataStreamReader and DataStreamWriter classes are already annotated as Experimental, and as a result there is no need to annotate each method within them. ## How was this patch tested? N/A Author: Reynold Xin Closes #15373 from rxin/SPARK-17798. --- .../sql/streaming/DataStreamReader.scala | 28 ------------------ .../sql/streaming/DataStreamWriter.scala | 29 ------------------- .../streaming/StreamingQueryListener.scala | 4 +-- 3 files changed, 1 insertion(+), 60 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index d437c16a25b01..864a9cd3eb89d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -35,89 +35,73 @@ import org.apache.spark.sql.types.StructType @Experimental final class DataStreamReader private[sql](sparkSession: SparkSession) extends Logging { /** - * :: Experimental :: * Specifies the input data source format. * * @since 2.0.0 */ - @Experimental def format(source: String): DataStreamReader = { this.source = source this } /** - * :: Experimental :: * Specifies the input schema. Some data sources (e.g. JSON) can infer the input schema * automatically from data. By specifying the schema here, the underlying data source can * skip the schema inference step, and thus speed up data loading. * * @since 2.0.0 */ - @Experimental def schema(schema: StructType): DataStreamReader = { this.userSpecifiedSchema = Option(schema) this } /** - * :: Experimental :: * Adds an input option for the underlying data source. * * @since 2.0.0 */ - @Experimental def option(key: String, value: String): DataStreamReader = { this.extraOptions += (key -> value) this } /** - * :: Experimental :: * Adds an input option for the underlying data source. * * @since 2.0.0 */ - @Experimental def option(key: String, value: Boolean): DataStreamReader = option(key, value.toString) /** - * :: Experimental :: * Adds an input option for the underlying data source. * * @since 2.0.0 */ - @Experimental def option(key: String, value: Long): DataStreamReader = option(key, value.toString) /** - * :: Experimental :: * Adds an input option for the underlying data source. * * @since 2.0.0 */ - @Experimental def option(key: String, value: Double): DataStreamReader = option(key, value.toString) /** - * :: Experimental :: * (Scala-specific) Adds input options for the underlying data source. * * @since 2.0.0 */ - @Experimental def options(options: scala.collection.Map[String, String]): DataStreamReader = { this.extraOptions ++= options this } /** - * :: Experimental :: * Adds input options for the underlying data source. * * @since 2.0.0 */ - @Experimental def options(options: java.util.Map[String, String]): DataStreamReader = { this.options(options.asScala) this @@ -125,13 +109,11 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo /** - * :: Experimental :: * Loads input data stream in as a [[DataFrame]], for data streams that don't require a path * (e.g. external key-value stores). * * @since 2.0.0 */ - @Experimental def load(): DataFrame = { val dataSource = DataSource( @@ -143,18 +125,15 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo } /** - * :: Experimental :: * Loads input in as a [[DataFrame]], for data streams that read from some path. * * @since 2.0.0 */ - @Experimental def load(path: String): DataFrame = { option("path", path).load() } /** - * :: Experimental :: * Loads a JSON file stream (one object per line) and returns the result as a [[DataFrame]]. * * This function goes through the input once to determine the input schema. If you know the @@ -198,11 +177,9 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * * @since 2.0.0 */ - @Experimental def json(path: String): DataFrame = format("json").load(path) /** - * :: Experimental :: * Loads a CSV file stream and returns the result as a [[DataFrame]]. * * This function will go through the input once to determine the input schema if `inferSchema` @@ -262,11 +239,9 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * * @since 2.0.0 */ - @Experimental def csv(path: String): DataFrame = format("csv").load(path) /** - * :: Experimental :: * Loads a Parquet file stream, returning the result as a [[DataFrame]]. * * You can set the following Parquet-specific option(s) for reading Parquet files: @@ -281,13 +256,11 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * * @since 2.0.0 */ - @Experimental def parquet(path: String): DataFrame = { format("parquet").load(path) } /** - * :: Experimental :: * Loads text files and returns a [[DataFrame]] whose schema starts with a string column named * "value", and followed by partitioned columns if there are any. * @@ -308,7 +281,6 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * * @since 2.0.0 */ - @Experimental def text(path: String): DataFrame = format("text").load(path) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index f70c7d08a691c..b959444b49298 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -37,7 +37,6 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { private val df = ds.toDF() /** - * :: Experimental :: * Specifies how data of a streaming DataFrame/Dataset is written to a streaming sink. * - `OutputMode.Append()`: only the new rows in the streaming DataFrame/Dataset will be * written to the sink @@ -46,15 +45,12 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { * * @since 2.0.0 */ - @Experimental def outputMode(outputMode: OutputMode): DataStreamWriter[T] = { this.outputMode = outputMode this } - /** - * :: Experimental :: * Specifies how data of a streaming DataFrame/Dataset is written to a streaming sink. * - `append`: only the new rows in the streaming DataFrame/Dataset will be written to * the sink @@ -63,7 +59,6 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { * * @since 2.0.0 */ - @Experimental def outputMode(outputMode: String): DataStreamWriter[T] = { this.outputMode = outputMode.toLowerCase match { case "append" => @@ -78,7 +73,6 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { } /** - * :: Experimental :: * Set the trigger for the stream query. The default value is `ProcessingTime(0)` and it will run * the query as fast as possible. * @@ -100,7 +94,6 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { * * @since 2.0.0 */ - @Experimental def trigger(trigger: Trigger): DataStreamWriter[T] = { this.trigger = trigger this @@ -108,25 +101,21 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { /** - * :: Experimental :: * Specifies the name of the [[StreamingQuery]] that can be started with `start()`. * This name must be unique among all the currently active queries in the associated SQLContext. * * @since 2.0.0 */ - @Experimental def queryName(queryName: String): DataStreamWriter[T] = { this.extraOptions += ("queryName" -> queryName) this } /** - * :: Experimental :: * Specifies the underlying output data source. Built-in options include "parquet" for now. * * @since 2.0.0 */ - @Experimental def format(source: String): DataStreamWriter[T] = { this.source = source this @@ -156,90 +145,74 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { } /** - * :: Experimental :: * Adds an output option for the underlying data source. * * @since 2.0.0 */ - @Experimental def option(key: String, value: String): DataStreamWriter[T] = { this.extraOptions += (key -> value) this } /** - * :: Experimental :: * Adds an output option for the underlying data source. * * @since 2.0.0 */ - @Experimental def option(key: String, value: Boolean): DataStreamWriter[T] = option(key, value.toString) /** - * :: Experimental :: * Adds an output option for the underlying data source. * * @since 2.0.0 */ - @Experimental def option(key: String, value: Long): DataStreamWriter[T] = option(key, value.toString) /** - * :: Experimental :: * Adds an output option for the underlying data source. * * @since 2.0.0 */ - @Experimental def option(key: String, value: Double): DataStreamWriter[T] = option(key, value.toString) /** - * :: Experimental :: * (Scala-specific) Adds output options for the underlying data source. * * @since 2.0.0 */ - @Experimental def options(options: scala.collection.Map[String, String]): DataStreamWriter[T] = { this.extraOptions ++= options this } /** - * :: Experimental :: * Adds output options for the underlying data source. * * @since 2.0.0 */ - @Experimental def options(options: java.util.Map[String, String]): DataStreamWriter[T] = { this.options(options.asScala) this } /** - * :: Experimental :: * Starts the execution of the streaming query, which will continually output results to the given * path as new data arrives. The returned [[StreamingQuery]] object can be used to interact with * the stream. * * @since 2.0.0 */ - @Experimental def start(path: String): StreamingQuery = { option("path", path).start() } /** - * :: Experimental :: * Starts the execution of the streaming query, which will continually output results to the given * path as new data arrives. The returned [[StreamingQuery]] object can be used to interact with * the stream. * * @since 2.0.0 */ - @Experimental def start(): StreamingQuery = { if (source == "memory") { assertNotPartitioned("memory") @@ -297,7 +270,6 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { } /** - * :: Experimental :: * Starts the execution of the streaming query, which will continually send results to the given * [[ForeachWriter]] as as new data arrives. The [[ForeachWriter]] can be used to send the data * generated by the [[DataFrame]]/[[Dataset]] to an external system. @@ -343,7 +315,6 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { * * @since 2.0.0 */ - @Experimental def foreach(writer: ForeachWriter[T]): DataStreamWriter[T] = { this.source = "foreach" this.foreachWriter = if (writer != null) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala index db606abb8ce43..8a8855d85a4c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala @@ -35,7 +35,7 @@ abstract class StreamingQueryListener { /** * Called when a query is started. * @note This is called synchronously with - * [[org.apache.spark.sql.DataStreamWriter `DataStreamWriter.start()`]], + * [[org.apache.spark.sql.streaming.DataStreamWriter `DataStreamWriter.start()`]], * that is, `onQueryStart` will be called on all listeners before * `DataStreamWriter.start()` returns the corresponding [[StreamingQuery]]. Please * don't block this method as it will block your query. @@ -101,8 +101,6 @@ object StreamingQueryListener { * @param queryInfo Information about the status of the query. * @param exception The exception message of the [[StreamingQuery]] if the query was terminated * with an exception. Otherwise, it will be `None`. - * @param stackTrace The stack trace of the exception if the query was terminated with an - * exception. It will be empty if there was no error. * @since 2.0.0 */ @Experimental From 9a48e60e6319d85f2c3be3a3c608dab135e18a73 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 6 Oct 2016 12:51:12 -0700 Subject: [PATCH 95/96] [SPARK-17780][SQL] Report Throwable to user in StreamExecution ## What changes were proposed in this pull request? When using an incompatible source for structured streaming, it may throw NoClassDefFoundError. It's better to just catch Throwable and report it to the user since the streaming thread is dying. ## How was this patch tested? `test("NoClassDefFoundError from an incompatible source")` Author: Shixiong Zhu Closes #15352 from zsxwing/SPARK-17780. --- .../execution/streaming/StreamExecution.scala | 7 ++++- .../spark/sql/streaming/StreamSuite.scala | 31 ++++++++++++++++++- .../spark/sql/streaming/StreamTest.scala | 3 +- 3 files changed, 37 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index b3a0d6ad0bd4c..333239f875bd3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -207,13 +207,18 @@ class StreamExecution( }) } catch { case _: InterruptedException if state == TERMINATED => // interrupted by stop() - case NonFatal(e) => + case e: Throwable => streamDeathCause = new StreamingQueryException( this, s"Query $name terminated with exception: ${e.getMessage}", e, Some(committedOffsets.toCompositeOffset(sources))) logError(s"Query $name terminated with error", e) + // Rethrow the fatal errors to allow the user using `Thread.UncaughtExceptionHandler` to + // handle them + if (!NonFatal(e)) { + throw e + } } finally { state = TERMINATED sparkSession.streams.notifyQueryTermination(StreamExecution.this) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 1caafb9d74440..cdbad901dba8e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.streaming +import scala.reflect.ClassTag +import scala.util.control.ControlThrowable + import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.sources.StreamSourceProvider -import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.ManualClock @@ -236,6 +238,33 @@ class StreamSuite extends StreamTest { } } + testQuietly("fatal errors from a source should be sent to the user") { + for (e <- Seq( + new VirtualMachineError {}, + new ThreadDeath, + new LinkageError, + new ControlThrowable {} + )) { + val source = new Source { + override def getOffset: Option[Offset] = { + throw e + } + + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + throw e + } + + override def schema: StructType = StructType(Array(StructField("value", IntegerType))) + + override def stop(): Unit = {} + } + val df = Dataset[Int](sqlContext.sparkSession, StreamingExecutionRelation(source)) + testStream(df)( + ExpectFailure()(ClassTag(e.getClass)) + ) + } + } + test("output mode API in Scala") { val o1 = OutputMode.Append assert(o1 === InternalOutputModes.Append) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 09140a1d6e76b..fa13d385cce75 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -167,7 +167,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { /** Signals that a failure is expected and should not kill the test. */ case class ExpectFailure[T <: Throwable : ClassTag]() extends StreamAction { val causeClass: Class[T] = implicitly[ClassTag[T]].runtimeClass.asInstanceOf[Class[T]] - override def toString(): String = s"ExpectFailure[${causeClass.getCanonicalName}]" + override def toString(): String = s"ExpectFailure[${causeClass.getName}]" } /** Assert that a body is true */ @@ -322,7 +322,6 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { new UncaughtExceptionHandler { override def uncaughtException(t: Thread, e: Throwable): Unit = { streamDeathCause = e - testThread.interrupt() } }) From 49d11d49983fbe270f4df4fb1e34b5fbe854c5ec Mon Sep 17 00:00:00 2001 From: Christian Kadner Date: Thu, 6 Oct 2016 14:28:49 -0700 Subject: [PATCH 96/96] [SPARK-17803][TESTS] Upgrade docker-client dependency [SPARK-17803: Docker integration tests don't run with "Docker for Mac"](https://issues.apache.org/jira/browse/SPARK-17803) ## What changes were proposed in this pull request? This PR upgrades the [docker-client](https://mvnrepository.com/artifact/com.spotify/docker-client) dependency from [3.6.6](https://mvnrepository.com/artifact/com.spotify/docker-client/3.6.6) to [5.0.2](https://mvnrepository.com/artifact/com.spotify/docker-client/5.0.2) to enable _Docker for Mac_ users to run the `docker-integration-tests` out of the box. The very latest docker-client version is [6.0.0](https://mvnrepository.com/artifact/com.spotify/docker-client/6.0.0) but that has one additional dependency and no usage yet. ## How was this patch tested? The code change was tested on Mac OS X Yosemite with both _Docker Toolbox_ as well as _Docker for Mac_ and on Linux Ubuntu 14.04. ``` $ build/mvn -Pyarn -Phadoop-2.6 -Dhadoop.version=2.6.0 -Phive -Phive-thriftserver -DskipTests clean package $ build/mvn -Pdocker-integration-tests -Pscala-2.11 -pl :spark-docker-integration-tests_2.11 clean compile test ``` Author: Christian Kadner Closes #15378 from ckadner/SPARK-17803_Docker_for_Mac. --- .../org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala | 1 + pom.xml | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala index c36f4d5f95482..609696bc8a2c7 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._ import scala.util.control.NonFatal import com.spotify.docker.client._ +import com.spotify.docker.client.exceptions.ImageNotFoundException import com.spotify.docker.client.messages.{ContainerConfig, HostConfig, PortBinding} import org.scalatest.BeforeAndAfterAll import org.scalatest.concurrent.Eventually diff --git a/pom.xml b/pom.xml index 37976b0359ad4..7d13c51b2a596 100644 --- a/pom.xml +++ b/pom.xml @@ -744,7 +744,7 @@ com.spotify docker-client - 3.6.6 + 5.0.2 test