From cb3cfb4d4d12709eb7b16a9a38619720a5524efe Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Thu, 14 Mar 2019 13:41:34 -0700 Subject: [PATCH 01/17] SPARK-25299: add CI infrastructure and SortShuffleWriterBenchmark (#498) * add initial bypass merge sort shuffle writer benchmarks * dd unsafe shuffle writer benchmarks * changes in bypassmergesort benchmarks * cleanup * add circle script * add this branch for testing * fix circle attempt 1 * checkout code * add some caches? * why is it not pull caches... * save as artifact instead of publishing * mkdir * typo * try uploading artifacts again * try print per iteration to avoid circle erroring out on idle * blah (#495) * make a PR comment * actually delete files * run benchmarks on test build branch * oops forgot to enable upload * add sort shuffle writer benchmarks * add stdev * cleanup sort a bit * fix stdev text * fix sort shuffle * initial code for read side * format * use times and sample stdev * add assert for at least one iteration * cleanup shuffle write to use fewer mocks and single base interface * shuffle read works with transport client... needs lots of cleaning * test running in cicle * scalastyle * dont publish results yet * cleanup writer code * get only git message * fix command to get PR number * add SortshuffleWriterBenchmark * writer code * cleanup * fix benchmark script * use ArgumentMatchers * also in shufflewriterbenchmarkbase * scalastyle * add apache license * fix some scale stuff * fix up tests * only copy benchmarks we care about * increase size for reader again * delete two writers and reader for PR * SPARK-25299: Add shuffle reader benchmarks (#506) * Revert "SPARK-25299: Add shuffle reader benchmarks (#506)" This reverts commit 9d46fae9a36c6229a888bb647b1b63f51d83b407. * add -e to bash script * blah * enable upload as a PR comment and prevent running benchmarks on this branch * Revert "enable upload as a PR comment and prevent running benchmarks on this branch" This reverts commit 13703fa476f11955631dd1f2e73be2bf69bbd253. * try machine execution * try uploading benchmarks (#498) * only upload results when merging into the feature branch * lock down machine image * don't write input data to disk * run benchmark test * stop creating file cleanup threads for every block manager * use alphanumeric again * use a new random everytime * close the writers -__________- * delete branch and publish results as comment * close in finally --- .../sort/ShuffleWriterBenchmarkBase.scala | 158 ++++++++++++++++ .../sort/SortShuffleWriterBenchmark.scala | 172 ++++++++++++++++++ dev/run-spark-25299-benchmarks.sh | 88 +++++++++ 3 files changed, 418 insertions(+) create mode 100644 core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleWriterBenchmarkBase.scala create mode 100644 core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala create mode 100755 dev/run-spark-25299-benchmarks.sh diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleWriterBenchmarkBase.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleWriterBenchmarkBase.scala new file mode 100644 index 000000000000..8e6a69fb7080 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleWriterBenchmarkBase.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort + +import java.io.{BufferedInputStream, Closeable, File, FileInputStream, FileOutputStream} +import java.util.UUID + +import org.apache.commons.io.FileUtils +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Answers.RETURNS_SMART_NULLS +import org.mockito.ArgumentMatchers.any +import org.mockito.Mockito.when +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.util.Random + +import org.apache.spark.{HashPartitioner, ShuffleDependency, SparkConf, TaskContext} +import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.memory.{MemoryManager, TaskMemoryManager, TestMemoryManager} +import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv} +import org.apache.spark.serializer.{KryoSerializer, Serializer, SerializerManager} +import org.apache.spark.shuffle.IndexShuffleBlockResolver +import org.apache.spark.storage.{BlockManager, DiskBlockManager, TempShuffleBlockId} +import org.apache.spark.util.Utils + +abstract class ShuffleWriterBenchmarkBase extends BenchmarkBase { + + protected val DEFAULT_DATA_STRING_SIZE = 5 + + // This is only used in the writer constructors, so it's ok to mock + @Mock(answer = RETURNS_SMART_NULLS) protected var dependency: + ShuffleDependency[String, String, String] = _ + // This is only used in the stop() function, so we can safely mock this without affecting perf + @Mock(answer = RETURNS_SMART_NULLS) protected var taskContext: TaskContext = _ + @Mock(answer = RETURNS_SMART_NULLS) protected var rpcEnv: RpcEnv = _ + @Mock(answer = RETURNS_SMART_NULLS) protected var rpcEndpointRef: RpcEndpointRef = _ + + protected val defaultConf: SparkConf = new SparkConf(loadDefaults = false) + protected val serializer: Serializer = new KryoSerializer(defaultConf) + protected val partitioner: HashPartitioner = new HashPartitioner(10) + protected val serializerManager: SerializerManager = + new SerializerManager(serializer, defaultConf) + protected val shuffleMetrics: TaskMetrics = new TaskMetrics + + protected val tempFilesCreated: ArrayBuffer[File] = new ArrayBuffer[File] + protected val filenameToFile: mutable.Map[String, File] = new mutable.HashMap[String, File] + + class TestDiskBlockManager(tempDir: File) extends DiskBlockManager(defaultConf, false) { + override def getFile(filename: String): File = { + if (filenameToFile.contains(filename)) { + filenameToFile(filename) + } else { + val outputFile = File.createTempFile("shuffle", null, tempDir) + filenameToFile(filename) = outputFile + outputFile + } + } + + override def createTempShuffleBlock(): (TempShuffleBlockId, File) = { + var blockId = new TempShuffleBlockId(UUID.randomUUID()) + val file = getFile(blockId) + tempFilesCreated += file + (blockId, file) + } + } + + class TestBlockManager(tempDir: File, memoryManager: MemoryManager) extends BlockManager("0", + rpcEnv, + null, + serializerManager, + defaultConf, + memoryManager, + null, + null, + null, + null, + 1) { + override val diskBlockManager = new TestDiskBlockManager(tempDir) + override val remoteBlockTempFileManager = null + } + + protected var tempDir: File = _ + + protected var blockManager: BlockManager = _ + protected var blockResolver: IndexShuffleBlockResolver = _ + + protected var memoryManager: TestMemoryManager = _ + protected var taskMemoryManager: TaskMemoryManager = _ + + MockitoAnnotations.initMocks(this) + when(dependency.partitioner).thenReturn(partitioner) + when(dependency.serializer).thenReturn(serializer) + when(dependency.shuffleId).thenReturn(0) + when(taskContext.taskMetrics()).thenReturn(shuffleMetrics) + when(rpcEnv.setupEndpoint(any[String], any[RpcEndpoint])).thenReturn(rpcEndpointRef) + + def setup(): Unit = { + memoryManager = new TestMemoryManager(defaultConf) + memoryManager.limit(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES) + taskMemoryManager = new TaskMemoryManager(memoryManager, 0) + tempDir = Utils.createTempDir() + blockManager = new TestBlockManager(tempDir, memoryManager) + blockResolver = new IndexShuffleBlockResolver( + defaultConf, + blockManager) + } + + def addBenchmarkCase(benchmark: Benchmark, name: String)(func: Benchmark.Timer => Unit): Unit = { + benchmark.addTimerCase(name) { timer => + setup() + func(timer) + teardown() + } + } + + def teardown(): Unit = { + FileUtils.deleteDirectory(tempDir) + tempFilesCreated.clear() + filenameToFile.clear() + } + + protected class DataIterator (size: Int) + extends Iterator[Product2[String, String]] { + val random = new Random(123) + var count = 0 + override def hasNext: Boolean = { + count < size + } + + override def next(): Product2[String, String] = { + count+=1 + val string = random.alphanumeric.take(5).mkString + (string, string) + } + } + + + def createDataIterator(size: Int): DataIterator = { + new DataIterator(size) + } + +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala new file mode 100644 index 000000000000..317cd23279ed --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort + +import org.mockito.Mockito.when + +import org.apache.spark.{Aggregator, SparkEnv} +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.shuffle.BaseShuffleHandle +import org.apache.spark.util.Utils + +/** + * Benchmark to measure performance for aggregate primitives. + * {{{ + * To run this benchmark: + * 1. without sbt: bin/spark-submit --class + * 2. build/sbt "sql/test:runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " + * Results will be written to "benchmarks/-results.txt". + * }}} + */ +object SortShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase { + + private val shuffleHandle: BaseShuffleHandle[String, String, String] = + new BaseShuffleHandle( + shuffleId = 0, + numMaps = 1, + dependency = dependency) + + private val MIN_NUM_ITERS = 10 + private val DATA_SIZE_SMALL = 1000 + private val DATA_SIZE_LARGE = + PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES/4/DEFAULT_DATA_STRING_SIZE + + def getWriter(aggregator: Option[Aggregator[String, String, String]], + sorter: Option[Ordering[String]]): SortShuffleWriter[String, String, String] = { + // we need this since SortShuffleWriter uses SparkEnv to get lots of its private vars + SparkEnv.set(new SparkEnv( + "0", + null, + serializer, + null, + serializerManager, + null, + null, + null, + blockManager, + null, + null, + null, + null, + defaultConf + )) + + if (aggregator.isEmpty && sorter.isEmpty) { + when(dependency.mapSideCombine).thenReturn(false) + } else { + when(dependency.mapSideCombine).thenReturn(false) + when(dependency.aggregator).thenReturn(aggregator) + when(dependency.keyOrdering).thenReturn(sorter) + } + + when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager) + + val shuffleWriter = new SortShuffleWriter[String, String, String]( + blockResolver, + shuffleHandle, + 0, + taskContext + ) + shuffleWriter + } + + def writeBenchmarkWithSmallDataset(): Unit = { + val size = DATA_SIZE_SMALL + val benchmark = new Benchmark("SortShuffleWriter without spills", + size, + minNumIters = MIN_NUM_ITERS, + output = output) + addBenchmarkCase(benchmark, "small dataset without spills") { timer => + val shuffleWriter = getWriter(Option.empty, Option.empty) + val dataIterator = createDataIterator(size) + try { + timer.startTiming() + shuffleWriter.write(dataIterator) + timer.stopTiming() + assert(tempFilesCreated.isEmpty) + } finally { + shuffleWriter.stop(true) + } + } + benchmark.run() + } + + def writeBenchmarkWithSpill(): Unit = { + val size = DATA_SIZE_LARGE + + val benchmark = new Benchmark("SortShuffleWriter with spills", + size, + minNumIters = MIN_NUM_ITERS, + output = output, + outputPerIteration = true) + addBenchmarkCase(benchmark, "no map side combine") { timer => + val shuffleWriter = getWriter(Option.empty, Option.empty) + val dataIterator = createDataIterator(size) + try { + timer.startTiming() + shuffleWriter.write(dataIterator) + timer.stopTiming() + assert(tempFilesCreated.length == 7) + } finally { + shuffleWriter.stop(true) + } + } + + def createCombiner(i: String): String = i + def mergeValue(i: String, j: String): String = if (Ordering.String.compare(i, j) > 0) i else j + def mergeCombiners(i: String, j: String): String = + if (Ordering.String.compare(i, j) > 0) i else j + val aggregator = + new Aggregator[String, String, String](createCombiner, mergeValue, mergeCombiners) + addBenchmarkCase(benchmark, "with map side aggregation") { timer => + val shuffleWriter = getWriter(Some(aggregator), Option.empty) + val dataIterator = createDataIterator(size) + try { + timer.startTiming() + shuffleWriter.write(dataIterator) + timer.stopTiming() + assert(tempFilesCreated.length == 7) + } finally { + shuffleWriter.stop(true) + } + } + + val sorter = Ordering.String + addBenchmarkCase(benchmark, "with map side sort") { timer => + val shuffleWriter = getWriter(Option.empty, Some(sorter)) + val dataIterator = createDataIterator(size) + try { + timer.startTiming() + shuffleWriter.write(dataIterator) + timer.stopTiming() + assert(tempFilesCreated.length == 7) + } finally { + shuffleWriter.stop(true) + } + } + benchmark.run() + } + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + runBenchmark("SortShuffleWriter writer") { + writeBenchmarkWithSmallDataset() + writeBenchmarkWithSpill() + } + } +} diff --git a/dev/run-spark-25299-benchmarks.sh b/dev/run-spark-25299-benchmarks.sh new file mode 100755 index 000000000000..2a0fe2088f21 --- /dev/null +++ b/dev/run-spark-25299-benchmarks.sh @@ -0,0 +1,88 @@ +#!/usr/bin/env bash +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# +# Script to create a binary distribution for easy deploys of Spark. +# The distribution directory defaults to dist/ but can be overridden below. +# The distribution contains fat (assembly) jars that include the Scala library, +# so it is completely self contained. +# It does not contain source or *.class files. + +set -oue pipefail + + +function usage { + echo "Usage: $(basename $0) [-h] [-u]" + echo "" + echo "Runs the perf tests and optionally uploads the results as a comment to a PR" + echo "" + echo " -h help" + echo " -u Upload the perf results as a comment" + # Exit as error for nesting scripts + exit 1 +} + +UPLOAD=false +while getopts "hu" opt; do + case $opt in + h) + usage + exit 0;; + u) + UPLOAD=true;; + esac +done + +echo "Running SPARK-25299 benchmarks" + +SPARK_GENERATE_BENCHMARK_FILES=1 ./build/sbt "sql/test:runMain org.apache.spark.shuffle.sort.SortShuffleWriterBenchmark" + +SPARK_DIR=`pwd` + +mkdir -p /tmp/artifacts +cp $SPARK_DIR/sql/core/benchmarks/SortShuffleWriterBenchmark-results.txt /tmp/artifacts/ + +if [ "$UPLOAD" = false ]; then + exit 0 +fi + +IFS= +RESULTS="" +for benchmark_file in /tmp/artifacts/*.txt; do + RESULTS+=$(cat $benchmark_file) + RESULTS+=$'\n\n' +done + +echo $RESULTS +# Get last git message, filter out empty lines, get the last number of the first line. This is the PR number +PULL_REQUEST_NUM=$(git log -1 --pretty=%B | awk NF | awk '{print $NF}' | head -1 | sed 's/(//g' | sed 's/)//g' | sed 's/#//g') + + +USERNAME=svc-spark-25299 +PASSWORD=$SVC_SPARK_25299_PASSWORD +message='{"body": "```' +message+=$'\n' +message+=$RESULTS +message+=$'\n' +json_message=$(echo $message | awk '{printf "%s\\n", $0}') +json_message+='```", "event":"COMMENT"}' +echo "$json_message" > benchmark_results.json + +echo "Sending benchmark requests to PR $PULL_REQUEST_NUM" +curl -XPOST https://${USERNAME}:${PASSWORD}@api.github.com/repos/palantir/spark/pulls/${PULL_REQUEST_NUM}/reviews -d @benchmark_results.json +rm benchmark_results.json From ef563da97f41aef083e1cc16918d8166480481f6 Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Mon, 18 Mar 2019 12:06:51 -0700 Subject: [PATCH 02/17] SPARK-25299: Add rest of shuffle writer benchmarks (#507) --- ...ypassMergeSortShuffleWriterBenchmark.scala | 93 +++++++++++++++++++ .../sort/ShuffleWriterBenchmarkBase.scala | 24 ++++- .../sort/SortShuffleWriterBenchmark.scala | 69 ++++---------- .../sort/UnsafeShuffleWriterBenchmark.scala | 91 ++++++++++++++++++ dev/run-spark-25299-benchmarks.sh | 4 + 5 files changed, 228 insertions(+), 53 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala create mode 100644 core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala new file mode 100644 index 000000000000..7f67affe5636 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort + +import org.apache.spark.SparkConf +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.util.Utils + +/** + * Benchmark to measure performance for aggregate primitives. + * {{{ + * To run this benchmark: + * 1. without sbt: bin/spark-submit --class + * 2. build/sbt "sql/test:runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " + * Results will be written to "benchmarks/-results.txt". + * }}} + */ +object BypassMergeSortShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase { + + private val shuffleHandle: BypassMergeSortShuffleHandle[String, String] = + new BypassMergeSortShuffleHandle[String, String]( + shuffleId = 0, + numMaps = 1, + dependency) + + private val MIN_NUM_ITERS = 10 + private val DATA_SIZE_SMALL = 1000 + private val DATA_SIZE_LARGE = + PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES/4/DEFAULT_DATA_STRING_SIZE + + def getWriter(transferTo: Boolean): BypassMergeSortShuffleWriter[String, String] = { + val conf = new SparkConf(loadDefaults = false) + conf.set("spark.file.transferTo", String.valueOf(transferTo)) + conf.set("spark.shuffle.file.buffer", "32k") + + val shuffleWriter = new BypassMergeSortShuffleWriter[String, String]( + blockManager, + blockResolver, + shuffleHandle, + 0, + conf, + taskContext.taskMetrics().shuffleWriteMetrics + ) + + shuffleWriter + } + + def writeBenchmarkWithLargeDataset(): Unit = { + val size = DATA_SIZE_LARGE + val benchmark = new Benchmark( + "BypassMergeSortShuffleWrite with spill", + size, + minNumIters = MIN_NUM_ITERS, + output = output) + + addBenchmarkCase(benchmark, "without transferTo", size, () => getWriter(false)) + addBenchmarkCase(benchmark, "with transferTo", size, () => getWriter(true)) + benchmark.run() + } + + def writeBenchmarkWithSmallDataset(): Unit = { + val size = DATA_SIZE_SMALL + val benchmark = new Benchmark("BypassMergeSortShuffleWrite without spill", + size, + minNumIters = MIN_NUM_ITERS, + output = output) + addBenchmarkCase(benchmark, "small dataset without disk spill", size, () => getWriter(false)) + benchmark.run() + } + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + runBenchmark("BypassMergeSortShuffleWriter write") { + writeBenchmarkWithSmallDataset() + writeBenchmarkWithLargeDataset() + } + } +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleWriterBenchmarkBase.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleWriterBenchmarkBase.scala index 8e6a69fb7080..8cca6c331ff2 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleWriterBenchmarkBase.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleWriterBenchmarkBase.scala @@ -17,7 +17,7 @@ package org.apache.spark.shuffle.sort -import java.io.{BufferedInputStream, Closeable, File, FileInputStream, FileOutputStream} +import java.io.File import java.util.UUID import org.apache.commons.io.FileUtils @@ -35,7 +35,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.{MemoryManager, TaskMemoryManager, TestMemoryManager} import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.serializer.{KryoSerializer, Serializer, SerializerManager} -import org.apache.spark.shuffle.IndexShuffleBlockResolver +import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter} import org.apache.spark.storage.{BlockManager, DiskBlockManager, TempShuffleBlockId} import org.apache.spark.util.Utils @@ -121,10 +121,26 @@ abstract class ShuffleWriterBenchmarkBase extends BenchmarkBase { blockManager) } - def addBenchmarkCase(benchmark: Benchmark, name: String)(func: Benchmark.Timer => Unit): Unit = { + def addBenchmarkCase( + benchmark: Benchmark, + name: String, + size: Int, + writerSupplier: () => ShuffleWriter[String, String], + numSpillFiles: Option[Int] = Option.empty): Unit = { benchmark.addTimerCase(name) { timer => setup() - func(timer) + val writer = writerSupplier() + val dataIterator = createDataIterator(size) + try { + timer.startTiming() + writer.write(dataIterator) + timer.stopTiming() + if (numSpillFiles.isDefined) { + assert(tempFilesCreated.length == numSpillFiles.get) + } + } finally { + writer.stop(true) + } teardown() } } diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala index 317cd23279ed..62cc13fa107f 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala @@ -92,41 +92,26 @@ object SortShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase { size, minNumIters = MIN_NUM_ITERS, output = output) - addBenchmarkCase(benchmark, "small dataset without spills") { timer => - val shuffleWriter = getWriter(Option.empty, Option.empty) - val dataIterator = createDataIterator(size) - try { - timer.startTiming() - shuffleWriter.write(dataIterator) - timer.stopTiming() - assert(tempFilesCreated.isEmpty) - } finally { - shuffleWriter.stop(true) - } - } + addBenchmarkCase(benchmark, + "small dataset without spills", + size, + () => getWriter(Option.empty, Option.empty), + Some(0)) benchmark.run() } def writeBenchmarkWithSpill(): Unit = { val size = DATA_SIZE_LARGE - val benchmark = new Benchmark("SortShuffleWriter with spills", size, minNumIters = MIN_NUM_ITERS, output = output, outputPerIteration = true) - addBenchmarkCase(benchmark, "no map side combine") { timer => - val shuffleWriter = getWriter(Option.empty, Option.empty) - val dataIterator = createDataIterator(size) - try { - timer.startTiming() - shuffleWriter.write(dataIterator) - timer.stopTiming() - assert(tempFilesCreated.length == 7) - } finally { - shuffleWriter.stop(true) - } - } + addBenchmarkCase(benchmark, + "no map side combine", + size, + () => getWriter(Option.empty, Option.empty), + Some(7)) def createCombiner(i: String): String = i def mergeValue(i: String, j: String): String = if (Ordering.String.compare(i, j) > 0) i else j @@ -134,32 +119,18 @@ object SortShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase { if (Ordering.String.compare(i, j) > 0) i else j val aggregator = new Aggregator[String, String, String](createCombiner, mergeValue, mergeCombiners) - addBenchmarkCase(benchmark, "with map side aggregation") { timer => - val shuffleWriter = getWriter(Some(aggregator), Option.empty) - val dataIterator = createDataIterator(size) - try { - timer.startTiming() - shuffleWriter.write(dataIterator) - timer.stopTiming() - assert(tempFilesCreated.length == 7) - } finally { - shuffleWriter.stop(true) - } - } + addBenchmarkCase(benchmark, + "with map side aggregation", + size, + () => getWriter(Some(aggregator), Option.empty), + Some(7)) val sorter = Ordering.String - addBenchmarkCase(benchmark, "with map side sort") { timer => - val shuffleWriter = getWriter(Option.empty, Some(sorter)) - val dataIterator = createDataIterator(size) - try { - timer.startTiming() - shuffleWriter.write(dataIterator) - timer.stopTiming() - assert(tempFilesCreated.length == 7) - } finally { - shuffleWriter.stop(true) - } - } + addBenchmarkCase(benchmark, + "with map side sort", + size, + () => getWriter(Option.empty, Some(sorter)), + Some(7)) benchmark.run() } diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala new file mode 100644 index 000000000000..ac62b496406f --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle.sort + +import org.apache.spark.SparkConf +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.util.Utils + +/** + * Benchmark to measure performance for aggregate primitives. + * {{{ + * To run this benchmark: + * 1. without sbt: bin/spark-submit --class + * 2. build/sbt "sql/test:runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " + * Results will be written to "benchmarks/-results.txt". + * }}} + */ +object UnsafeShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase { + + private val shuffleHandle: SerializedShuffleHandle[String, String] = + new SerializedShuffleHandle[String, String](0, 0, this.dependency) + + private val MIN_NUM_ITERS = 10 + private val DATA_SIZE_SMALL = 1000 + private val DATA_SIZE_LARGE = + PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES/2/DEFAULT_DATA_STRING_SIZE + + def getWriter(transferTo: Boolean): UnsafeShuffleWriter[String, String] = { + val conf = new SparkConf(loadDefaults = false) + conf.set("spark.file.transferTo", String.valueOf(transferTo)) + + new UnsafeShuffleWriter[String, String]( + blockManager, + blockResolver, + taskMemoryManager, + shuffleHandle, + 0, + taskContext, + conf, + taskContext.taskMetrics().shuffleWriteMetrics + ) + } + + def writeBenchmarkWithSmallDataset(): Unit = { + val size = DATA_SIZE_SMALL + val benchmark = new Benchmark("UnsafeShuffleWriter without spills", + size, + minNumIters = MIN_NUM_ITERS, + output = output) + addBenchmarkCase(benchmark, + "small dataset without spills", + size, + () => getWriter(false), + Some(1)) // The single temp file is for the temp index file + benchmark.run() + } + + def writeBenchmarkWithSpill(): Unit = { + val size = DATA_SIZE_LARGE + val benchmark = new Benchmark("UnsafeShuffleWriter with spills", + size, + minNumIters = MIN_NUM_ITERS, + output = output, + outputPerIteration = true) + addBenchmarkCase(benchmark, "without transferTo", size, () => getWriter(false), Some(7)) + addBenchmarkCase(benchmark, "with transferTo", size, () => getWriter(true), Some(7)) + benchmark.run() + } + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + runBenchmark("UnsafeShuffleWriter write") { + writeBenchmarkWithSmallDataset() + writeBenchmarkWithSpill() + } + } +} diff --git a/dev/run-spark-25299-benchmarks.sh b/dev/run-spark-25299-benchmarks.sh index 2a0fe2088f21..d11060a50d11 100755 --- a/dev/run-spark-25299-benchmarks.sh +++ b/dev/run-spark-25299-benchmarks.sh @@ -50,12 +50,16 @@ done echo "Running SPARK-25299 benchmarks" +SPARK_GENERATE_BENCHMARK_FILES=1 ./build/sbt "sql/test:runMain org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriterBenchmark" SPARK_GENERATE_BENCHMARK_FILES=1 ./build/sbt "sql/test:runMain org.apache.spark.shuffle.sort.SortShuffleWriterBenchmark" +SPARK_GENERATE_BENCHMARK_FILES=1 ./build/sbt "sql/test:runMain org.apache.spark.shuffle.sort.UnsafeShuffleWriterBenchmark" SPARK_DIR=`pwd` mkdir -p /tmp/artifacts +cp $SPARK_DIR/sql/core/benchmarks/BypassMergeSortShuffleWriterBenchmark-results.txt /tmp/artifacts/ cp $SPARK_DIR/sql/core/benchmarks/SortShuffleWriterBenchmark-results.txt /tmp/artifacts/ +cp $SPARK_DIR/sql/core/benchmarks/UnsafeShuffleWriterBenchmark-results.txt /tmp/artifacts/ if [ "$UPLOAD" = false ]; then exit 0 From e37f5ac7c790fc40e7e40c752b887c25591a96c9 Mon Sep 17 00:00:00 2001 From: mccheah Date: Wed, 20 Mar 2019 15:14:04 -0700 Subject: [PATCH 03/17] [SPARK-25299] Introduce the new shuffle writer API (#5) (#520) Introduces the new Shuffle Writer API. Ported from https://github.com/bloomberg/apache-spark-on-k8s/pull/5. --- .../spark/api/shuffle/ShuffleDataIO.java | 31 ++++++++++++++ .../shuffle/ShuffleExecutorComponents.java | 33 +++++++++++++++ .../api/shuffle/ShuffleMapOutputWriter.java | 37 ++++++++++++++++ .../api/shuffle/ShufflePartitionWriter.java | 42 +++++++++++++++++++ .../api/shuffle/ShuffleWriteSupport.java | 37 ++++++++++++++++ 5 files changed, 180 insertions(+) create mode 100644 core/src/main/java/org/apache/spark/api/shuffle/ShuffleDataIO.java create mode 100644 core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java create mode 100644 core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java create mode 100644 core/src/main/java/org/apache/spark/api/shuffle/ShufflePartitionWriter.java create mode 100644 core/src/main/java/org/apache/spark/api/shuffle/ShuffleWriteSupport.java diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDataIO.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDataIO.java new file mode 100644 index 000000000000..4cb40f6dd00b --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDataIO.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.shuffle; + +import org.apache.spark.annotation.Experimental; + +/** + * :: Experimental :: + * An interface for launching Shuffle related components + * + * @since 3.0.0 + */ +@Experimental +public interface ShuffleDataIO { + ShuffleExecutorComponents executor(); +} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java new file mode 100644 index 000000000000..1edf044225cc --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.shuffle; + +import org.apache.spark.annotation.Experimental; + +/** + * :: Experimental :: + * An interface for building shuffle support for Executors + * + * @since 3.0.0 + */ +@Experimental +public interface ShuffleExecutorComponents { + void intitializeExecutor(String appId, String execId); + + ShuffleWriteSupport writes(); +} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java new file mode 100644 index 000000000000..5119e34803a8 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.shuffle; + +import java.io.IOException; + +import org.apache.spark.annotation.Experimental; + +/** + * :: Experimental :: + * An interface for creating and managing shuffle partition writers + * + * @since 3.0.0 + */ +@Experimental +public interface ShuffleMapOutputWriter { + ShufflePartitionWriter getNextPartitionWriter() throws IOException; + + void commitAllPartitions() throws IOException; + + void abort(Throwable error) throws IOException; +} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/api/shuffle/ShufflePartitionWriter.java new file mode 100644 index 000000000000..c043a6b3a499 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShufflePartitionWriter.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.shuffle; + +import java.io.IOException; +import java.io.OutputStream; +import java.nio.channels.Channels; +import java.nio.channels.WritableByteChannel; + +import org.apache.http.annotation.Experimental; + +/** + * :: Experimental :: + * An interface for giving streams / channels for shuffle writes + * + * @since 3.0.0 + */ +@Experimental +public interface ShufflePartitionWriter { + OutputStream openStream() throws IOException; + + long getLength(); + + default WritableByteChannel openChannel() throws IOException { + return Channels.newChannel(openStream()); + } +} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleWriteSupport.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleWriteSupport.java new file mode 100644 index 000000000000..5ba5564bb46d --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleWriteSupport.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.shuffle; + +import java.io.IOException; + +import org.apache.http.annotation.Experimental; + +/** + * :: Experimental :: + * An interface for deploying a shuffle map output writer + * + * @since 3.0.0 + */ +@Experimental +public interface ShuffleWriteSupport { + ShuffleMapOutputWriter createMapOutputWriter( + String appId, + int shuffleId, + int mapId, + int numPartitions) throws IOException; +} From 90297fc040d21320b504a7034cf93fb83d83c973 Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Wed, 27 Mar 2019 17:57:48 -0700 Subject: [PATCH 04/17] SPARK-25299: Add Reader Benchmarks (#508) --- .../BlockStoreShuffleReaderBenchmark.scala | 439 ++++++++++++++++++ .../sort/ShuffleWriterBenchmarkBase.scala | 2 +- dev/run-spark-25299-benchmarks.sh | 2 + 3 files changed, 442 insertions(+), 1 deletion(-) create mode 100644 core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala new file mode 100644 index 000000000000..2690f1a515fc --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala @@ -0,0 +1,439 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle.sort + +import java.io.{File, FileOutputStream} + +import com.google.common.io.CountingOutputStream +import org.apache.commons.io.FileUtils +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Answers.RETURNS_SMART_NULLS +import org.mockito.ArgumentMatchers.any +import org.mockito.Mockito.when +import scala.util.Random + +import org.apache.spark.{Aggregator, MapOutputTracker, ShuffleDependency, SparkConf, SparkEnv, TaskContext} +import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} +import org.apache.spark.metrics.source.Source +import org.apache.spark.network.BlockTransferService +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.netty.{NettyBlockTransferService, SparkTransportConf} +import org.apache.spark.network.util.TransportConf +import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv} +import org.apache.spark.serializer.{KryoSerializer, SerializerManager} +import org.apache.spark.shuffle.{BaseShuffleHandle, BlockStoreShuffleReader, FetchFailedException} +import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, BlockManagerMaster, ShuffleBlockId} +import org.apache.spark.util.{AccumulatorV2, TaskCompletionListener, TaskFailureListener, Utils} + +/** + * Benchmark to measure performance for aggregate primitives. + * {{{ + * To run this benchmark: + * 1. without sbt: bin/spark-submit --class + * 2. build/sbt "sql/test:runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " + * Results will be written to "benchmarks/-results.txt". + * }}} + */ +object BlockStoreShuffleReaderBenchmark extends BenchmarkBase { + + // this is only used to retrieve the aggregator/sorters/serializers, + // so it shouldn't affect the performance significantly + @Mock(answer = RETURNS_SMART_NULLS) private var dependency: + ShuffleDependency[String, String, String] = _ + // only used to retrieve info about the maps at the beginning, doesn't affect perf + @Mock(answer = RETURNS_SMART_NULLS) private var mapOutputTracker: MapOutputTracker = _ + // this is only used when initializing the BlockManager, so doesn't affect perf + @Mock(answer = RETURNS_SMART_NULLS) private var blockManagerMaster: BlockManagerMaster = _ + // this is only used when initiating the BlockManager, for comms between master and executor + @Mock(answer = RETURNS_SMART_NULLS) private var rpcEnv: RpcEnv = _ + @Mock(answer = RETURNS_SMART_NULLS) protected var rpcEndpointRef: RpcEndpointRef = _ + + private var tempDir: File = _ + + private val NUM_MAPS = 5 + private val DEFAULT_DATA_STRING_SIZE = 5 + private val TEST_DATA_SIZE = 10000000 + private val SMALLER_DATA_SIZE = 2000000 + private val MIN_NUM_ITERS = 10 + + private val executorId = "0" + private val localPort = 17000 + private val remotePort = 17002 + + private val defaultConf = new SparkConf() + .set("spark.shuffle.compress", "false") + .set("spark.shuffle.spill.compress", "false") + .set("spark.authenticate", "false") + .set("spark.app.id", "test-app") + private val serializer = new KryoSerializer(defaultConf) + private val serializerManager = new SerializerManager(serializer, defaultConf) + private val execBlockManagerId = BlockManagerId(executorId, "localhost", localPort) + private val remoteBlockManagerId = BlockManagerId(executorId, "localhost", remotePort) + private val transportConf = SparkTransportConf.fromSparkConf(defaultConf, "shuffle") + private val securityManager = new org.apache.spark.SecurityManager(defaultConf) + protected val memoryManager = new TestMemoryManager(defaultConf) + + class TestBlockManager(transferService: BlockTransferService, + blockManagerMaster: BlockManagerMaster, + dataFile: File, + fileLength: Long, + offset: Long) extends BlockManager( + executorId, + rpcEnv, + blockManagerMaster, + serializerManager, + defaultConf, + memoryManager, + null, + null, + transferService, + null, + 1) { + blockManagerId = execBlockManagerId + + override def getBlockData(blockId: BlockId): ManagedBuffer = { + new FileSegmentManagedBuffer( + transportConf, + dataFile, + offset, + fileLength + ) + } + } + + private var blockManager : BlockManager = _ + private var externalBlockManager: BlockManager = _ + + def getTestBlockManager( + port: Int, + dataFile: File, + dataFileLength: Long, + offset: Long): TestBlockManager = { + val shuffleClient = new NettyBlockTransferService( + defaultConf, + securityManager, + "localhost", + "localhost", + port, + 1 + ) + new TestBlockManager(shuffleClient, + blockManagerMaster, + dataFile, + dataFileLength, + offset) + } + + def initializeServers(dataFile: File, dataFileLength: Long, readOffset: Long = 0): Unit = { + MockitoAnnotations.initMocks(this) + when(blockManagerMaster.registerBlockManager( + any[BlockManagerId], any[Long], any[Long], any[RpcEndpointRef])).thenReturn(null) + when(rpcEnv.setupEndpoint(any[String], any[RpcEndpoint])).thenReturn(rpcEndpointRef) + blockManager = getTestBlockManager(localPort, dataFile, dataFileLength, readOffset) + blockManager.initialize(defaultConf.getAppId) + externalBlockManager = getTestBlockManager(remotePort, dataFile, dataFileLength, readOffset) + externalBlockManager.initialize(defaultConf.getAppId) + } + + def stopServers(): Unit = { + blockManager.stop() + externalBlockManager.stop() + } + + def setupReader( + dataFile: File, + dataFileLength: Long, + fetchLocal: Boolean, + aggregator: Option[Aggregator[String, String, String]] = None, + sorter: Option[Ordering[String]] = None): BlockStoreShuffleReader[String, String] = { + SparkEnv.set(new SparkEnv( + "0", + null, + serializer, + null, + serializerManager, + mapOutputTracker, + null, + null, + blockManager, + null, + null, + null, + null, + defaultConf + )) + + val shuffleHandle = new BaseShuffleHandle( + shuffleId = 0, + numMaps = NUM_MAPS, + dependency = dependency) + + val taskContext = new TestTaskContext + TaskContext.setTaskContext(taskContext) + + var dataBlockId = execBlockManagerId + if (!fetchLocal) { + dataBlockId = remoteBlockManagerId + } + + when(mapOutputTracker.getMapSizesByExecutorId(0, 0, 1)) + .thenReturn { + val shuffleBlockIdsAndSizes = (0 until NUM_MAPS).map { mapId => + val shuffleBlockId = ShuffleBlockId(0, mapId, 0) + (shuffleBlockId, dataFileLength) + } + Seq((dataBlockId, shuffleBlockIdsAndSizes)).toIterator + } + + when(dependency.serializer).thenReturn(serializer) + when(dependency.aggregator).thenReturn(aggregator) + when(dependency.keyOrdering).thenReturn(sorter) + + new BlockStoreShuffleReader[String, String]( + shuffleHandle, + 0, + 1, + taskContext, + taskContext.taskMetrics().createTempShuffleReadMetrics(), + serializerManager, + blockManager, + mapOutputTracker + ) + } + + def generateDataOnDisk(size: Int, file: File, recordOffset: Int): (Long, Long) = { + // scalastyle:off println + println("Generating test data with num records: " + size) + + val dataOutput = new ManualCloseFileOutputStream(file) + val random = new Random(123) + val serializerInstance = serializer.newInstance() + + var countingOutput = new CountingOutputStream(dataOutput) + var serializedOutput = serializerInstance.serializeStream(countingOutput) + var readOffset = 0L + try { + (1 to size).foreach { i => { + if (i % 1000000 == 0) { + println("Wrote " + i + " test data points") + } + if (i == recordOffset) { + serializedOutput.close() + readOffset = countingOutput.getCount + countingOutput = new CountingOutputStream(dataOutput) + serializedOutput = serializerInstance.serializeStream(countingOutput) + } + val x = random.alphanumeric.take(DEFAULT_DATA_STRING_SIZE).mkString + serializedOutput.writeKey(x) + serializedOutput.writeValue(x) + }} + } finally { + serializedOutput.close() + dataOutput.manualClose() + } + (countingOutput.getCount, readOffset) + // scalastyle:off println + } + + class TestDataFile(file: File, length: Long, offset: Long) { + def getFile(): File = file + def getLength(): Long = length + def getOffset(): Long = offset + } + + def runWithTestDataFile(size: Int, readOffset: Int = 0)(func: TestDataFile => Unit): Unit = { + val tempDataFile = File.createTempFile("test-data", "", tempDir) + val dataFileLengthAndOffset = generateDataOnDisk(size, tempDataFile, readOffset) + initializeServers(tempDataFile, dataFileLengthAndOffset._1, dataFileLengthAndOffset._2) + func(new TestDataFile(tempDataFile, dataFileLengthAndOffset._1, dataFileLengthAndOffset._2)) + tempDataFile.delete() + stopServers() + } + + def addBenchmarkCase( + benchmark: Benchmark, + name: String, + shuffleReaderSupplier: => BlockStoreShuffleReader[String, String], + assertSize: Option[Int] = None): Unit = { + benchmark.addTimerCase(name) { timer => + val reader = shuffleReaderSupplier + timer.startTiming() + val numRead = reader.read().length + timer.stopTiming() + assertSize.foreach(size => assert(numRead == size)) + } + } + + def runLargeDatasetTests(): Unit = { + runWithTestDataFile(TEST_DATA_SIZE) { testDataFile => + val baseBenchmark = + new Benchmark("no aggregation or sorting", + TEST_DATA_SIZE, + minNumIters = MIN_NUM_ITERS, + output = output, + outputPerIteration = true) + addBenchmarkCase( + baseBenchmark, + "local fetch", + setupReader(testDataFile.getFile(), testDataFile.getLength(), fetchLocal = true), + assertSize = Option.apply(TEST_DATA_SIZE * NUM_MAPS)) + addBenchmarkCase( + baseBenchmark, + "remote rpc fetch", + setupReader(testDataFile.getFile(), testDataFile.getLength(), fetchLocal = false), + assertSize = Option.apply(TEST_DATA_SIZE * NUM_MAPS)) + baseBenchmark.run() + } + } + + def runSmallDatasetTests(): Unit = { + runWithTestDataFile(SMALLER_DATA_SIZE) { testDataFile => + def createCombiner(i: String): String = i + def mergeValue(i: String, j: String): String = if (Ordering.String.compare(i, j) > 0) i else j + def mergeCombiners(i: String, j: String): String = + if (Ordering.String.compare(i, j) > 0) i else j + val aggregator = + new Aggregator[String, String, String](createCombiner, mergeValue, mergeCombiners) + val aggregationBenchmark = + new Benchmark("with aggregation", + SMALLER_DATA_SIZE, + minNumIters = MIN_NUM_ITERS, + output = output, + outputPerIteration = true) + addBenchmarkCase( + aggregationBenchmark, + "local fetch", + setupReader( + testDataFile.getFile(), + testDataFile.getLength(), + fetchLocal = true, + aggregator = Some(aggregator))) + addBenchmarkCase( + aggregationBenchmark, + "remote rpc fetch", + setupReader( + testDataFile.getFile(), + testDataFile.getLength(), + fetchLocal = false, + aggregator = Some(aggregator))) + aggregationBenchmark.run() + + + val sortingBenchmark = + new Benchmark("with sorting", + SMALLER_DATA_SIZE, + minNumIters = MIN_NUM_ITERS, + output = output, + outputPerIteration = true) + addBenchmarkCase( + sortingBenchmark, + "local fetch", + setupReader( + testDataFile.getFile(), + testDataFile.getLength(), + fetchLocal = true, + sorter = Some(Ordering.String)), + assertSize = Option.apply(SMALLER_DATA_SIZE * NUM_MAPS)) + addBenchmarkCase( + sortingBenchmark, + "remote rpc fetch", + setupReader( + testDataFile.getFile(), + testDataFile.getLength(), + fetchLocal = false, + sorter = Some(Ordering.String)), + assertSize = Option.apply(SMALLER_DATA_SIZE * NUM_MAPS)) + sortingBenchmark.run() + } + } + + def runSeekTests(): Unit = { + runWithTestDataFile(SMALLER_DATA_SIZE, readOffset = SMALLER_DATA_SIZE) { testDataFile => + val seekBenchmark = + new Benchmark("with seek", + SMALLER_DATA_SIZE, + minNumIters = MIN_NUM_ITERS, + output = output) + + addBenchmarkCase( + seekBenchmark, + "seek to last record", + setupReader(testDataFile.getFile(), testDataFile.getLength(), fetchLocal = false), + Option.apply(NUM_MAPS)) + seekBenchmark.run() + } + } + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + tempDir = Utils.createTempDir(null, "shuffle") + + runBenchmark("BlockStoreShuffleReader reader") { + runLargeDatasetTests() + runSmallDatasetTests() + runSeekTests() + } + + FileUtils.deleteDirectory(tempDir) + } + + // We cannot mock the TaskContext because it taskMetrics() gets called at every next() + // call on the reader, and Mockito will try to log all calls to taskMetrics(), thus OOM-ing + // the test + class TestTaskContext extends TaskContext { + private val metrics: TaskMetrics = new TaskMetrics + private val testMemManager = new TestMemoryManager(defaultConf) + private val taskMemManager = new TaskMemoryManager(testMemManager, 0) + testMemManager.limit(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES) + override def isCompleted(): Boolean = false + override def isInterrupted(): Boolean = false + override def addTaskCompletionListener(listener: TaskCompletionListener): + TaskContext = { null } + override def addTaskFailureListener(listener: TaskFailureListener): TaskContext = { null } + override def stageId(): Int = 0 + override def stageAttemptNumber(): Int = 0 + override def partitionId(): Int = 0 + override def attemptNumber(): Int = 0 + override def taskAttemptId(): Long = 0 + override def getLocalProperty(key: String): String = "" + override def taskMetrics(): TaskMetrics = metrics + override def getMetricsSources(sourceName: String): Seq[Source] = Seq.empty + override private[spark] def killTaskIfInterrupted(): Unit = {} + override private[spark] def getKillReason() = None + override private[spark] def taskMemoryManager() = taskMemManager + override private[spark] def registerAccumulator(a: AccumulatorV2[_, _]): Unit = {} + override private[spark] def setFetchFailed(fetchFailed: FetchFailedException): Unit = {} + override private[spark] def markInterrupted(reason: String): Unit = {} + override private[spark] def markTaskFailed(error: Throwable): Unit = {} + override private[spark] def markTaskCompleted(error: Option[Throwable]): Unit = {} + override private[spark] def fetchFailed = None + override private[spark] def getLocalProperties = { null } + } + + class ManualCloseFileOutputStream(file: File) extends FileOutputStream(file, true) { + override def close(): Unit = { + flush() + } + + def manualClose(): Unit = { + flush() + super.close() + } + } +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleWriterBenchmarkBase.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleWriterBenchmarkBase.scala index 8cca6c331ff2..eceb207219c2 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleWriterBenchmarkBase.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleWriterBenchmarkBase.scala @@ -161,7 +161,7 @@ abstract class ShuffleWriterBenchmarkBase extends BenchmarkBase { override def next(): Product2[String, String] = { count+=1 - val string = random.alphanumeric.take(5).mkString + val string = random.alphanumeric.take(DEFAULT_DATA_STRING_SIZE).mkString (string, string) } } diff --git a/dev/run-spark-25299-benchmarks.sh b/dev/run-spark-25299-benchmarks.sh index d11060a50d11..2d60f9d5a06e 100755 --- a/dev/run-spark-25299-benchmarks.sh +++ b/dev/run-spark-25299-benchmarks.sh @@ -50,6 +50,7 @@ done echo "Running SPARK-25299 benchmarks" +SPARK_GENERATE_BENCHMARK_FILES=1 ./build/sbt "sql/test:runMain org.apache.spark.shuffle.sort.BlockStoreShuffleReaderBenchmark" SPARK_GENERATE_BENCHMARK_FILES=1 ./build/sbt "sql/test:runMain org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriterBenchmark" SPARK_GENERATE_BENCHMARK_FILES=1 ./build/sbt "sql/test:runMain org.apache.spark.shuffle.sort.SortShuffleWriterBenchmark" SPARK_GENERATE_BENCHMARK_FILES=1 ./build/sbt "sql/test:runMain org.apache.spark.shuffle.sort.UnsafeShuffleWriterBenchmark" @@ -57,6 +58,7 @@ SPARK_GENERATE_BENCHMARK_FILES=1 ./build/sbt "sql/test:runMain org.apache.spark. SPARK_DIR=`pwd` mkdir -p /tmp/artifacts +cp $SPARK_DIR/sql/core/benchmarks/BlockStoreShuffleReaderBenchmark-results.txt /tmp/artifacts/ cp $SPARK_DIR/sql/core/benchmarks/BypassMergeSortShuffleWriterBenchmark-results.txt /tmp/artifacts/ cp $SPARK_DIR/sql/core/benchmarks/SortShuffleWriterBenchmark-results.txt /tmp/artifacts/ cp $SPARK_DIR/sql/core/benchmarks/UnsafeShuffleWriterBenchmark-results.txt /tmp/artifacts/ From 4dd59550224d95e773d27d34390163b1dbe59ba2 Mon Sep 17 00:00:00 2001 From: mccheah Date: Wed, 3 Apr 2019 16:34:46 -0700 Subject: [PATCH 05/17] [SPARK-25299] Local shuffle implementation of the shuffle writer API (#524) Implements the shuffle writer API by writing shuffle files to local disk and using the index block resolver to commit data and write index files. The logic in `BypassMergeSortShuffleWriter` has been refactored to use the base implementation of the plugin instead. APIs have been slightly renamed to clarify semantics after considering nuances in how these are to be implemented by other developers. Follow-up commits are to come for `SortShuffleWriter` and `UnsafeShuffleWriter`. Ported from https://github.com/bloomberg/apache-spark-on-k8s/pull/6, credits to @ifilonenko. --- .../shuffle/ShuffleExecutorComponents.java | 2 +- .../api/shuffle/ShufflePartitionWriter.java | 44 +++- .../api/shuffle/ShuffleWriteSupport.java | 1 - .../sort/BypassMergeSortShuffleWriter.java | 146 ++++++----- .../shuffle/sort/io/DefaultShuffleDataIO.java | 36 +++ .../io/DefaultShuffleExecutorComponents.java | 51 ++++ .../io/DefaultShuffleMapOutputWriter.java | 243 ++++++++++++++++++ .../sort/io/DefaultShuffleWriteSupport.java | 47 ++++ .../spark/internal/config/package.scala | 7 + .../shuffle/sort/SortShuffleManager.scala | 21 +- .../scala/org/apache/spark/util/Utils.scala | 30 ++- .../scala/org/apache/spark/ShuffleSuite.scala | 14 +- ...ypassMergeSortShuffleWriterBenchmark.scala | 6 +- .../BypassMergeSortShuffleWriterSuite.scala | 68 ++++- .../sort/SortShuffleWriterBenchmark.scala | 2 +- .../DefaultShuffleMapOutputWriterSuite.scala | 216 ++++++++++++++++ 16 files changed, 838 insertions(+), 96 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleDataIO.java create mode 100644 core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java create mode 100644 core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java create mode 100644 core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleWriteSupport.java create mode 100644 core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java index 1edf044225cc..4fc20bad9938 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java @@ -27,7 +27,7 @@ */ @Experimental public interface ShuffleExecutorComponents { - void intitializeExecutor(String appId, String execId); + void initializeExecutor(String appId, String execId); ShuffleWriteSupport writes(); } diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/api/shuffle/ShufflePartitionWriter.java index c043a6b3a499..6a53803e5d11 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShufflePartitionWriter.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShufflePartitionWriter.java @@ -17,6 +17,7 @@ package org.apache.spark.api.shuffle; +import java.io.Closeable; import java.io.IOException; import java.io.OutputStream; import java.nio.channels.Channels; @@ -26,17 +27,48 @@ /** * :: Experimental :: - * An interface for giving streams / channels for shuffle writes + * An interface for giving streams / channels for shuffle writes. * * @since 3.0.0 */ @Experimental -public interface ShufflePartitionWriter { - OutputStream openStream() throws IOException; +public interface ShufflePartitionWriter extends Closeable { - long getLength(); + /** + * Returns an underlying {@link OutputStream} that can write bytes to the underlying data store. + *

+ * Note that this stream itself is not closed by the caller; close the stream in the + * implementation of this interface's {@link #close()}. + */ + OutputStream toStream() throws IOException; - default WritableByteChannel openChannel() throws IOException { - return Channels.newChannel(openStream()); + /** + * Returns an underlying {@link WritableByteChannel} that can write bytes to the underlying data + * store. + *

+ * Note that this channel itself is not closed by the caller; close the channel in the + * implementation of this interface's {@link #close()}. + */ + default WritableByteChannel toChannel() throws IOException { + return Channels.newChannel(toStream()); } + + /** + * Get the number of bytes written by this writer's stream returned by {@link #toStream()} or + * the channel returned by {@link #toChannel()}. + */ + long getNumBytesWritten(); + + /** + * Close all resources created by this ShufflePartitionWriter, via calls to {@link #toStream()} + * or {@link #toChannel()}. + *

+ * This must always close any stream returned by {@link #toStream()}. + *

+ * Note that the default version of {@link #toChannel()} returns a {@link WritableByteChannel} + * that does not itself need to be closed up front; only the underlying output stream given by + * {@link #toStream()} must be closed. + */ + @Override + void close() throws IOException; } diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleWriteSupport.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleWriteSupport.java index 5ba5564bb46d..6c69d5db9fd0 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleWriteSupport.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleWriteSupport.java @@ -30,7 +30,6 @@ @Experimental public interface ShuffleWriteSupport { ShuffleMapOutputWriter createMapOutputWriter( - String appId, int shuffleId, int mapId, int numPartitions) throws IOException; diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 32b446785a9f..aef133fe7d46 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -19,8 +19,10 @@ import java.io.File; import java.io.FileInputStream; -import java.io.FileOutputStream; import java.io.IOException; +import java.io.OutputStream; +import java.nio.channels.FileChannel; +import java.nio.channels.WritableByteChannel; import javax.annotation.Nullable; import scala.None$; @@ -34,6 +36,9 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.spark.api.shuffle.ShuffleMapOutputWriter; +import org.apache.spark.api.shuffle.ShufflePartitionWriter; +import org.apache.spark.api.shuffle.ShuffleWriteSupport; import org.apache.spark.internal.config.package$; import org.apache.spark.Partitioner; import org.apache.spark.ShuffleDependency; @@ -82,6 +87,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { private final int shuffleId; private final int mapId; private final Serializer serializer; + private final ShuffleWriteSupport shuffleWriteSupport; private final IndexShuffleBlockResolver shuffleBlockResolver; /** Array of file writers, one for each partition */ @@ -103,7 +109,8 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { BypassMergeSortShuffleHandle handle, int mapId, SparkConf conf, - ShuffleWriteMetricsReporter writeMetrics) { + ShuffleWriteMetricsReporter writeMetrics, + ShuffleWriteSupport shuffleWriteSupport) { // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided this.fileBufferSize = (int) (long) conf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024; this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true); @@ -116,57 +123,61 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { this.writeMetrics = writeMetrics; this.serializer = dep.serializer(); this.shuffleBlockResolver = shuffleBlockResolver; + this.shuffleWriteSupport = shuffleWriteSupport; } @Override public void write(Iterator> records) throws IOException { assert (partitionWriters == null); - if (!records.hasNext()) { - partitionLengths = new long[numPartitions]; - shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, null); - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); - return; - } - final SerializerInstance serInstance = serializer.newInstance(); - final long openStartTime = System.nanoTime(); - partitionWriters = new DiskBlockObjectWriter[numPartitions]; - partitionWriterSegments = new FileSegment[numPartitions]; - for (int i = 0; i < numPartitions; i++) { - final Tuple2 tempShuffleBlockIdPlusFile = - blockManager.diskBlockManager().createTempShuffleBlock(); - final File file = tempShuffleBlockIdPlusFile._2(); - final BlockId blockId = tempShuffleBlockIdPlusFile._1(); - partitionWriters[i] = - blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics); - } - // Creating the file to write to and creating a disk writer both involve interacting with - // the disk, and can take a long time in aggregate when we open many files, so should be - // included in the shuffle write time. - writeMetrics.incWriteTime(System.nanoTime() - openStartTime); - - while (records.hasNext()) { - final Product2 record = records.next(); - final K key = record._1(); - partitionWriters[partitioner.getPartition(key)].write(key, record._2()); - } + ShuffleMapOutputWriter mapOutputWriter = shuffleWriteSupport + .createMapOutputWriter(shuffleId, mapId, numPartitions); + try { + if (!records.hasNext()) { + partitionLengths = new long[numPartitions]; + mapOutputWriter.commitAllPartitions(); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + return; + } + final SerializerInstance serInstance = serializer.newInstance(); + final long openStartTime = System.nanoTime(); + partitionWriters = new DiskBlockObjectWriter[numPartitions]; + partitionWriterSegments = new FileSegment[numPartitions]; + for (int i = 0; i < numPartitions; i++) { + final Tuple2 tempShuffleBlockIdPlusFile = + blockManager.diskBlockManager().createTempShuffleBlock(); + final File file = tempShuffleBlockIdPlusFile._2(); + final BlockId blockId = tempShuffleBlockIdPlusFile._1(); + partitionWriters[i] = + blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics); + } + // Creating the file to write to and creating a disk writer both involve interacting with + // the disk, and can take a long time in aggregate when we open many files, so should be + // included in the shuffle write time. + writeMetrics.incWriteTime(System.nanoTime() - openStartTime); - for (int i = 0; i < numPartitions; i++) { - try (DiskBlockObjectWriter writer = partitionWriters[i]) { - partitionWriterSegments[i] = writer.commitAndGet(); + while (records.hasNext()) { + final Product2 record = records.next(); + final K key = record._1(); + partitionWriters[partitioner.getPartition(key)].write(key, record._2()); } - } - File output = shuffleBlockResolver.getDataFile(shuffleId, mapId); - File tmp = Utils.tempFileWith(output); - try { - partitionLengths = writePartitionedFile(tmp); - shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp); - } finally { - if (tmp.exists() && !tmp.delete()) { - logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); + for (int i = 0; i < numPartitions; i++) { + try (DiskBlockObjectWriter writer = partitionWriters[i]) { + partitionWriterSegments[i] = writer.commitAndGet(); + } } + + partitionLengths = writePartitionedData(mapOutputWriter); + mapOutputWriter.commitAllPartitions(); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + } catch (Exception e) { + try { + mapOutputWriter.abort(e); + } catch (Exception e2) { + logger.error("Failed to abort the writer after failing to write map output.", e2); + } + throw e; } - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); } @VisibleForTesting @@ -179,37 +190,54 @@ long[] getPartitionLengths() { * * @return array of lengths, in bytes, of each partition of the file (used by map output tracker). */ - private long[] writePartitionedFile(File outputFile) throws IOException { + private long[] writePartitionedData(ShuffleMapOutputWriter mapOutputWriter) throws IOException { // Track location of the partition starts in the output file final long[] lengths = new long[numPartitions]; if (partitionWriters == null) { // We were passed an empty iterator return lengths; } - - final FileOutputStream out = new FileOutputStream(outputFile, true); final long writeStartTime = System.nanoTime(); - boolean threwException = true; try { for (int i = 0; i < numPartitions; i++) { final File file = partitionWriterSegments[i].file(); - if (file.exists()) { - final FileInputStream in = new FileInputStream(file); - boolean copyThrewException = true; - try { - lengths[i] = Utils.copyStream(in, out, false, transferToEnabled); + boolean copyThrewException = true; + ShufflePartitionWriter writer = null; + try { + writer = mapOutputWriter.getNextPartitionWriter(); + if (!file.exists()) { copyThrewException = false; - } finally { - Closeables.close(in, copyThrewException); - } - if (!file.delete()) { - logger.error("Unable to delete file for partition {}", i); + } else { + if (transferToEnabled) { + WritableByteChannel outputChannel = writer.toChannel(); + FileInputStream in = new FileInputStream(file); + try (FileChannel inputChannel = in.getChannel()) { + Utils.copyFileStreamNIO(inputChannel, outputChannel, 0, inputChannel.size()); + copyThrewException = false; + } finally { + Closeables.close(in, copyThrewException); + } + } else { + OutputStream tempOutputStream = writer.toStream(); + FileInputStream in = new FileInputStream(file); + try { + Utils.copyStream(in, tempOutputStream, false, false); + copyThrewException = false; + } finally { + Closeables.close(in, copyThrewException); + } + } + if (!file.delete()) { + logger.error("Unable to delete file for partition {}", i); + } } + } finally { + Closeables.close(writer, copyThrewException); } + + lengths[i] = writer.getNumBytesWritten(); } - threwException = false; } finally { - Closeables.close(out, threwException); writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); } partitionWriters = null; diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleDataIO.java new file mode 100644 index 000000000000..906600c0f15f --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleDataIO.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort.io; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.shuffle.ShuffleExecutorComponents; +import org.apache.spark.api.shuffle.ShuffleDataIO; + +public class DefaultShuffleDataIO implements ShuffleDataIO { + + private final SparkConf sparkConf; + + public DefaultShuffleDataIO(SparkConf sparkConf) { + this.sparkConf = sparkConf; + } + + @Override + public ShuffleExecutorComponents executor() { + return new DefaultShuffleExecutorComponents(sparkConf); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java new file mode 100644 index 000000000000..76e87a674025 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort.io; + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkEnv; +import org.apache.spark.api.shuffle.ShuffleExecutorComponents; +import org.apache.spark.api.shuffle.ShuffleWriteSupport; +import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.storage.BlockManager; + +public class DefaultShuffleExecutorComponents implements ShuffleExecutorComponents { + + private final SparkConf sparkConf; + private BlockManager blockManager; + private IndexShuffleBlockResolver blockResolver; + + public DefaultShuffleExecutorComponents(SparkConf sparkConf) { + this.sparkConf = sparkConf; + } + + @Override + public void initializeExecutor(String appId, String execId) { + blockManager = SparkEnv.get().blockManager(); + blockResolver = new IndexShuffleBlockResolver(sparkConf, blockManager); + } + + @Override + public ShuffleWriteSupport writes() { + if (blockResolver == null) { + throw new IllegalStateException( + "Executor components must be initialized before getting writers."); + } + return new DefaultShuffleWriteSupport(sparkConf, blockResolver); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java new file mode 100644 index 000000000000..0f7e5ed66bb7 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort.io; + +import java.io.BufferedOutputStream; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.nio.channels.FileChannel; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.shuffle.ShuffleMapOutputWriter; +import org.apache.spark.api.shuffle.ShufflePartitionWriter; +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; +import org.apache.spark.internal.config.package$; +import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.storage.TimeTrackingOutputStream; +import org.apache.spark.util.Utils; + +public class DefaultShuffleMapOutputWriter implements ShuffleMapOutputWriter { + + private static final Logger log = + LoggerFactory.getLogger(DefaultShuffleMapOutputWriter.class); + + private final int shuffleId; + private final int mapId; + private final ShuffleWriteMetricsReporter metrics; + private final IndexShuffleBlockResolver blockResolver; + private final long[] partitionLengths; + private final int bufferSize; + private int currPartitionId = 0; + private long currChannelPosition; + + private final File outputFile; + private File outputTempFile; + private FileOutputStream outputFileStream; + private FileChannel outputFileChannel; + private TimeTrackingOutputStream ts; + private BufferedOutputStream outputBufferedFileStream; + + public DefaultShuffleMapOutputWriter( + int shuffleId, + int mapId, + int numPartitions, + ShuffleWriteMetricsReporter metrics, + IndexShuffleBlockResolver blockResolver, + SparkConf sparkConf) { + this.shuffleId = shuffleId; + this.mapId = mapId; + this.metrics = metrics; + this.blockResolver = blockResolver; + this.bufferSize = + (int) (long) sparkConf.get( + package$.MODULE$.SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE()) * 1024; + this.partitionLengths = new long[numPartitions]; + this.outputFile = blockResolver.getDataFile(shuffleId, mapId); + this.outputTempFile = null; + } + + @Override + public ShufflePartitionWriter getNextPartitionWriter() throws IOException { + if (outputTempFile == null) { + outputTempFile = Utils.tempFileWith(outputFile); + } + if (outputFileChannel != null) { + currChannelPosition = outputFileChannel.position(); + } else { + currChannelPosition = 0L; + } + return new DefaultShufflePartitionWriter(currPartitionId++); + } + + @Override + public void commitAllPartitions() throws IOException { + cleanUp(); + blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, outputTempFile); + } + + @Override + public void abort(Throwable error) { + try { + cleanUp(); + } catch (Exception e) { + log.error("Unable to close appropriate underlying file stream", e); + } + if (outputTempFile != null && outputTempFile.exists() && !outputTempFile.delete()) { + log.warn("Failed to delete temporary shuffle file at {}", outputTempFile.getAbsolutePath()); + } + } + + private void cleanUp() throws IOException { + if (outputBufferedFileStream != null) { + outputBufferedFileStream.close(); + } + + if (outputFileChannel != null) { + outputFileChannel.close(); + } + + if (outputFileStream != null) { + outputFileStream.close(); + } + } + + private void initStream() throws IOException { + if (outputFileStream == null) { + outputFileStream = new FileOutputStream(outputTempFile, true); + ts = new TimeTrackingOutputStream(metrics, outputFileStream); + } + if (outputBufferedFileStream == null) { + outputBufferedFileStream = new BufferedOutputStream(ts, bufferSize); + } + } + + private void initChannel() throws IOException { + if (outputFileStream == null) { + outputFileStream = new FileOutputStream(outputTempFile, true); + } + if (outputFileChannel == null) { + outputFileChannel = outputFileStream.getChannel(); + } + } + + private class DefaultShufflePartitionWriter implements ShufflePartitionWriter { + + private final int partitionId; + private PartitionWriterStream stream = null; + + private DefaultShufflePartitionWriter(int partitionId) { + this.partitionId = partitionId; + } + + @Override + public OutputStream toStream() throws IOException { + if (outputFileChannel != null) { + throw new IllegalStateException("Requested an output channel for a previous write but" + + " now an output stream has been requested. Should not be using both channels" + + " and streams to write."); + } + initStream(); + stream = new PartitionWriterStream(); + return stream; + } + + @Override + public FileChannel toChannel() throws IOException { + if (stream != null) { + throw new IllegalStateException("Requested an output stream for a previous write but" + + " now an output channel has been requested. Should not be using both channels" + + " and streams to write."); + } + initChannel(); + return outputFileChannel; + } + + @Override + public long getNumBytesWritten() { + if (outputFileChannel != null && stream == null) { + try { + long newPosition = outputFileChannel.position(); + return newPosition - currChannelPosition; + } catch (Exception e) { + log.error("The partition which failed is: {}", partitionId, e); + throw new IllegalStateException("Failed to calculate position of file channel", e); + } + } else if (stream != null) { + return stream.getCount(); + } else { + // Assume an empty partition if stream and channel are never created + return 0; + } + } + + @Override + public void close() throws IOException { + if (stream != null) { + stream.close(); + } + partitionLengths[partitionId] = getNumBytesWritten(); + } + } + + private class PartitionWriterStream extends OutputStream { + private int count = 0; + private boolean isClosed = false; + + public int getCount() { + return count; + } + + @Override + public void write(int b) throws IOException { + verifyNotClosed(); + outputBufferedFileStream.write(b); + count++; + } + + @Override + public void write(byte[] buf, int pos, int length) throws IOException { + verifyNotClosed(); + outputBufferedFileStream.write(buf, pos, length); + count += length; + } + + @Override + public void close() throws IOException { + flush(); + isClosed = true; + } + + @Override + public void flush() throws IOException { + if (!isClosed) { + outputBufferedFileStream.flush(); + } + } + + private void verifyNotClosed() { + if (isClosed) { + throw new IllegalStateException("Attempting to write to a closed block output stream."); + } + } + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleWriteSupport.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleWriteSupport.java new file mode 100644 index 000000000000..f8fadd0ecfa6 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleWriteSupport.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort.io; + +import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; +import org.apache.spark.api.shuffle.ShuffleMapOutputWriter; +import org.apache.spark.api.shuffle.ShuffleWriteSupport; +import org.apache.spark.shuffle.IndexShuffleBlockResolver; + +public class DefaultShuffleWriteSupport implements ShuffleWriteSupport { + + private final SparkConf sparkConf; + private final IndexShuffleBlockResolver blockResolver; + + public DefaultShuffleWriteSupport( + SparkConf sparkConf, + IndexShuffleBlockResolver blockResolver) { + this.sparkConf = sparkConf; + this.blockResolver = blockResolver; + } + + @Override + public ShuffleMapOutputWriter createMapOutputWriter( + int shuffleId, + int mapId, + int numPartitions) { + return new DefaultShuffleMapOutputWriter( + shuffleId, mapId, numPartitions, + TaskContext.get().taskMetrics().shuffleWriteMetrics(), blockResolver, sparkConf); + } +} diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 0bd46bef35d2..6e996b4c1936 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -23,6 +23,7 @@ import org.apache.spark.launcher.SparkLauncher import org.apache.spark.metrics.GarbageCollectionMetrics import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.{EventLoggingListener, SchedulingMode} +import org.apache.spark.shuffle.sort.io.DefaultShuffleDataIO import org.apache.spark.storage.{DefaultTopologyMapper, RandomBlockReplicationPolicy} import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.util.Utils @@ -766,6 +767,12 @@ package object config { .booleanConf .createWithDefault(false) + private[spark] val SHUFFLE_IO_PLUGIN_CLASS = + ConfigBuilder("spark.shuffle.io.plugin.class") + .doc("Name of the class to use for shuffle IO.") + .stringConf + .createWithDefault(classOf[DefaultShuffleDataIO].getName) + private[spark] val SHUFFLE_FILE_BUFFER_SIZE = ConfigBuilder("spark.shuffle.file.buffer") .doc("Size of the in-memory buffer for each shuffle file output stream, in KiB unless " + diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index b59fa8e8a3cc..5da7b5cb35e6 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -20,8 +20,10 @@ package org.apache.spark.shuffle.sort import java.util.concurrent.ConcurrentHashMap import org.apache.spark._ -import org.apache.spark.internal.Logging +import org.apache.spark.api.shuffle.{ShuffleDataIO, ShuffleExecutorComponents} +import org.apache.spark.internal.{config, Logging} import org.apache.spark.shuffle._ +import org.apache.spark.util.Utils /** * In sort-based shuffle, incoming records are sorted according to their target partition ids, then @@ -68,6 +70,8 @@ import org.apache.spark.shuffle._ */ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { + import SortShuffleManager._ + if (!conf.getBoolean("spark.shuffle.spill", true)) { logWarning( "spark.shuffle.spill was set to false, but this configuration is ignored as of Spark 1.6+." + @@ -79,6 +83,8 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager */ private[this] val numMapsForShuffle = new ConcurrentHashMap[Int, Int]() + private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents(conf) + override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) /** @@ -148,7 +154,8 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager bypassMergeSortHandle, mapId, env.conf, - metrics) + metrics, + shuffleExecutorComponents.writes()) case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => new SortShuffleWriter(shuffleBlockResolver, other, mapId, context) } @@ -205,6 +212,16 @@ private[spark] object SortShuffleManager extends Logging { true } } + + private def loadShuffleExecutorComponents(conf: SparkConf): ShuffleExecutorComponents = { + val configuredPluginClasses = conf.get(config.SHUFFLE_IO_PLUGIN_CLASS) + val maybeIO = Utils.loadExtensions( + classOf[ShuffleDataIO], Seq(configuredPluginClasses), conf) + require(maybeIO.size == 1, s"Failed to load plugins of type $configuredPluginClasses") + val executorComponents = maybeIO.head.executor() + executorComponents.initializeExecutor(conf.getAppId, SparkEnv.get.executorId) + executorComponents + } } /** diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 7c8648d61bfb..057af76d72d5 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -24,7 +24,7 @@ import java.lang.reflect.InvocationTargetException import java.math.{MathContext, RoundingMode} import java.net._ import java.nio.ByteBuffer -import java.nio.channels.{Channels, FileChannel} +import java.nio.channels.{Channels, FileChannel, WritableByteChannel} import java.nio.charset.StandardCharsets import java.nio.file.Files import java.security.SecureRandom @@ -393,10 +393,14 @@ private[spark] object Utils extends Logging { def copyFileStreamNIO( input: FileChannel, - output: FileChannel, + output: WritableByteChannel, startPosition: Long, bytesToCopy: Long): Unit = { - val initialPos = output.position() + val outputInitialState = output match { + case outputFileChannel: FileChannel => + Some((outputFileChannel.position(), outputFileChannel)) + case _ => None + } var count = 0L // In case transferTo method transferred less data than we have required. while (count < bytesToCopy) { @@ -411,15 +415,17 @@ private[spark] object Utils extends Logging { // kernel version 2.6.32, this issue can be seen in // https://bugs.openjdk.java.net/browse/JDK-7052359 // This will lead to stream corruption issue when using sort-based shuffle (SPARK-3948). - val finalPos = output.position() - val expectedPos = initialPos + bytesToCopy - assert(finalPos == expectedPos, - s""" - |Current position $finalPos do not equal to expected position $expectedPos - |after transferTo, please check your kernel version to see if it is 2.6.32, - |this is a kernel bug which will lead to unexpected behavior when using transferTo. - |You can set spark.file.transferTo = false to disable this NIO feature. - """.stripMargin) + outputInitialState.foreach { case (initialPos, outputFileChannel) => + val finalPos = outputFileChannel.position() + val expectedPos = initialPos + bytesToCopy + assert(finalPos == expectedPos, + s""" + |Current position $finalPos do not equal to expected position $expectedPos + |after transferTo, please check your kernel version to see if it is 2.6.32, + |this is a kernel bug which will lead to unexpected behavior when using transferTo. + |You can set spark.file.transferTo = false to disable this NIO feature. + """.stripMargin) + } } /** diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 8b1084a8edc7..90c790cefcca 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.scheduler.{MapStatus, MyRDD, SparkListener, SparkListene import org.apache.spark.serializer.KryoSerializer import org.apache.spark.shuffle.ShuffleWriter import org.apache.spark.storage.{ShuffleBlockId, ShuffleDataBlockId, ShuffleIndexBlockId} -import org.apache.spark.util.{MutablePair, Utils} +import org.apache.spark.util.MutablePair abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkContext { @@ -368,7 +368,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC new TaskContextImpl(0, 0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem) val writer1 = manager.getWriter[Int, Int]( shuffleHandle, 0, context1, context1.taskMetrics.shuffleWriteMetrics) - val data1 = (1 to 10).map { x => x -> x} + val data1 = (1 to 10).map { x => x -> x } // second attempt -- also successful. We'll write out different data, // just to simulate the fact that the records may get written differently @@ -383,13 +383,17 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC // simultaneously, and everything is still OK def writeAndClose( - writer: ShuffleWriter[Int, Int])( + writer: ShuffleWriter[Int, Int], + taskContext: TaskContext)( iter: Iterator[(Int, Int)]): Option[MapStatus] = { + TaskContext.setTaskContext(taskContext) val files = writer.write(iter) - writer.stop(true) + val status = writer.stop(true) + TaskContext.unset + status } val interleaver = new InterleaveIterators( - data1, writeAndClose(writer1), data2, writeAndClose(writer2)) + data1, writeAndClose(writer1, context1), data2, writeAndClose(writer2, context2)) val (mapOutput1, mapOutput2) = interleaver.run() // check that we can read the map output and it has the right data diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala index 7f67affe5636..7eb867fc29fd 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala @@ -19,7 +19,7 @@ package org.apache.spark.shuffle.sort import org.apache.spark.SparkConf import org.apache.spark.benchmark.Benchmark -import org.apache.spark.util.Utils +import org.apache.spark.shuffle.sort.io.{DefaultShuffleWriteSupport} /** * Benchmark to measure performance for aggregate primitives. @@ -46,6 +46,7 @@ object BypassMergeSortShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase def getWriter(transferTo: Boolean): BypassMergeSortShuffleWriter[String, String] = { val conf = new SparkConf(loadDefaults = false) + val shuffleWriteSupport = new DefaultShuffleWriteSupport(conf, blockResolver) conf.set("spark.file.transferTo", String.valueOf(transferTo)) conf.set("spark.shuffle.file.buffer", "32k") @@ -55,7 +56,8 @@ object BypassMergeSortShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase shuffleHandle, 0, conf, - taskContext.taskMetrics().shuffleWriteMetrics + taskContext.taskMetrics().shuffleWriteMetrics, + shuffleWriteSupport ) shuffleWriter diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index 72a1a4fb500f..6683858830cd 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -18,23 +18,27 @@ package org.apache.spark.shuffle.sort import java.io.File -import java.util.UUID +import java.util.{Properties, UUID} import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.mockito.{Mock, MockitoAnnotations} import org.mockito.Answers.RETURNS_SMART_NULLS -import org.mockito.ArgumentMatchers.{any, anyInt} +import org.mockito.ArgumentMatchers.{any, anyInt, anyString} import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatest.BeforeAndAfterEach +import scala.util.Random import org.apache.spark._ +import org.apache.spark.api.shuffle.ShuffleWriteSupport import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics} +import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} import org.apache.spark.serializer.{JavaSerializer, SerializerInstance, SerializerManager} import org.apache.spark.shuffle.IndexShuffleBlockResolver +import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport import org.apache.spark.storage._ import org.apache.spark.util.Utils @@ -49,7 +53,9 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte private var taskMetrics: TaskMetrics = _ private var tempDir: File = _ private var outputFile: File = _ + private var writeSupport: ShuffleWriteSupport = _ private val conf: SparkConf = new SparkConf(loadDefaults = false) + .set("spark.app.id", "sampleApp") private val temporaryFilesCreated: mutable.Buffer[File] = new ArrayBuffer[File]() private val blockIdToFileMap: mutable.Map[BlockId, File] = new mutable.HashMap[BlockId, File] private var shuffleHandle: BypassMergeSortShuffleHandle[Int, Int] = _ @@ -118,9 +124,27 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte blockIdToFileMap(invocation.getArguments.head.asInstanceOf[BlockId]) } }) + + val memoryManager = new TestMemoryManager(conf) + val taskMemoryManager = new TaskMemoryManager(memoryManager, 0) + when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager) + + TaskContext.setTaskContext(new TaskContextImpl( + stageId = 0, + stageAttemptNumber = 0, + partitionId = 0, + taskAttemptId = Random.nextInt(10000), + attemptNumber = 0, + taskMemoryManager = taskMemoryManager, + localProperties = new Properties, + metricsSystem = null, + taskMetrics = taskMetrics)) + + writeSupport = new DefaultShuffleWriteSupport(conf, blockResolver) } override def afterEach(): Unit = { + TaskContext.unset() try { Utils.deleteRecursively(tempDir) blockIdToFileMap.clear() @@ -137,7 +161,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte shuffleHandle, 0, // MapId conf, - taskContext.taskMetrics().shuffleWriteMetrics + taskContext.taskMetrics().shuffleWriteMetrics, + writeSupport ) writer.write(Iterator.empty) writer.stop( /* success = */ true) @@ -153,6 +178,33 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte } test("write with some empty partitions") { + val transferConf = conf.clone.set("spark.file.transferTo", "false") + def records: Iterator[(Int, Int)] = + Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) + val writer = new BypassMergeSortShuffleWriter[Int, Int]( + blockManager, + blockResolver, + shuffleHandle, + 0, // MapId + transferConf, + taskContext.taskMetrics().shuffleWriteMetrics, + writeSupport + ) + writer.write(records) + writer.stop( /* success = */ true) + assert(temporaryFilesCreated.nonEmpty) + assert(writer.getPartitionLengths.sum === outputFile.length()) + assert(writer.getPartitionLengths.count(_ == 0L) === 4) // should be 4 zero length files + assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temporary files were deleted + val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics + assert(shuffleWriteMetrics.bytesWritten === outputFile.length()) + assert(shuffleWriteMetrics.recordsWritten === records.length) + assert(taskMetrics.diskBytesSpilled === 0) + assert(taskMetrics.memoryBytesSpilled === 0) + } + + // TODO(ifilonenko): MAKE THIS PASS + test("write with some empty partitions with transferTo") { def records: Iterator[(Int, Int)] = Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) val writer = new BypassMergeSortShuffleWriter[Int, Int]( @@ -161,7 +213,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte shuffleHandle, 0, // MapId conf, - taskContext.taskMetrics().shuffleWriteMetrics + taskContext.taskMetrics().shuffleWriteMetrics, + writeSupport ) writer.write(records) writer.stop( /* success = */ true) @@ -196,7 +249,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte shuffleHandle, 0, // MapId conf, - taskContext.taskMetrics().shuffleWriteMetrics + taskContext.taskMetrics().shuffleWriteMetrics, + writeSupport ) intercept[SparkException] { @@ -218,7 +272,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte shuffleHandle, 0, // MapId conf, - taskContext.taskMetrics().shuffleWriteMetrics + taskContext.taskMetrics().shuffleWriteMetrics, + writeSupport ) intercept[SparkException] { writer.write((0 until 100000).iterator.map(i => { @@ -232,5 +287,4 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte writer.stop( /* success = */ false) assert(temporaryFilesCreated.count(_.exists()) === 0) } - } diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala index 62cc13fa107f..ce1abde421fc 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala @@ -22,7 +22,7 @@ import org.mockito.Mockito.when import org.apache.spark.{Aggregator, SparkEnv} import org.apache.spark.benchmark.Benchmark import org.apache.spark.shuffle.BaseShuffleHandle -import org.apache.spark.util.Utils +import org.apache.spark.shuffle.sort.io.DefaultShuffleDataIO /** * Benchmark to measure performance for aggregate primitives. diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala new file mode 100644 index 000000000000..22d52924a7c7 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort.io + +import java.io.{ByteArrayInputStream, File, FileInputStream, FileOutputStream} +import java.math.BigInteger +import java.nio.ByteBuffer + +import org.mockito.Answers.RETURNS_SMART_NULLS +import org.mockito.ArgumentMatchers.{any, anyInt, anyLong} +import org.mockito.Mock +import org.mockito.Mockito.{doAnswer, doNothing, when} +import org.mockito.MockitoAnnotations +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.network.util.LimitedInputStream +import org.apache.spark.shuffle.IndexShuffleBlockResolver +import org.apache.spark.util.Utils + +class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAfterEach { + + @Mock(answer = RETURNS_SMART_NULLS) private var blockResolver: IndexShuffleBlockResolver = _ + @Mock(answer = RETURNS_SMART_NULLS) private var shuffleWriteMetrics: ShuffleWriteMetrics = _ + + private val NUM_PARTITIONS = 4 + private val D_LEN = 10 + private val data: Array[Array[Int]] = (0 until NUM_PARTITIONS).map { + p => (1 to D_LEN).map(_ + p).toArray }.toArray + + private var tempFile: File = _ + private var mergedOutputFile: File = _ + private var tempDir: File = _ + private var partitionSizesInMergedFile: Array[Long] = _ + private var conf: SparkConf = _ + private var mapOutputWriter: DefaultShuffleMapOutputWriter = _ + + override def afterEach(): Unit = { + try { + Utils.deleteRecursively(tempDir) + } finally { + super.afterEach() + } + } + + override def beforeEach(): Unit = { + MockitoAnnotations.initMocks(this) + tempDir = Utils.createTempDir(null, "test") + mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir) + tempFile = File.createTempFile("tempfile", "", tempDir) + partitionSizesInMergedFile = null + conf = new SparkConf() + .set("spark.app.id", "example.spark.app") + .set("spark.shuffle.unsafe.file.output.buffer", "16k") + when(blockResolver.getDataFile(anyInt, anyInt)).thenReturn(mergedOutputFile) + + doNothing().when(shuffleWriteMetrics).incWriteTime(anyLong) + + doAnswer(new Answer[Void] { + def answer(invocationOnMock: InvocationOnMock): Void = { + partitionSizesInMergedFile = invocationOnMock.getArguments()(2).asInstanceOf[Array[Long]] + val tmp: File = invocationOnMock.getArguments()(3).asInstanceOf[File] + if (tmp != null) { + mergedOutputFile.delete + tmp.renameTo(mergedOutputFile) + } + null + } + }).when(blockResolver) + .writeIndexFileAndCommit(anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File])) + mapOutputWriter = new DefaultShuffleMapOutputWriter( + 0, 0, NUM_PARTITIONS, shuffleWriteMetrics, blockResolver, conf) + } + + private def readRecordsFromFile(fromByte: Boolean): Array[Array[Int]] = { + var startOffset = 0L + val result = new Array[Array[Int]](NUM_PARTITIONS) + (0 until NUM_PARTITIONS).foreach { p => + val partitionSize = partitionSizesInMergedFile(p).toInt + lazy val inner = new Array[Int](partitionSize) + lazy val innerBytebuffer = ByteBuffer.allocate(partitionSize) + if (partitionSize > 0) { + val in = new FileInputStream(mergedOutputFile) + in.getChannel.position(startOffset) + val lin = new LimitedInputStream(in, partitionSize) + var nonEmpty = true + var count = 0 + while (nonEmpty) { + try { + val readBit = lin.read() + if (fromByte) { + innerBytebuffer.put(readBit.toByte) + } else { + inner(count) = readBit + } + count += 1 + } catch { + case _: Exception => + nonEmpty = false + } + } + in.close() + } + if (fromByte) { + result(p) = innerBytebuffer.array().sliding(4, 4).map { b => + new BigInteger(b).intValue() + }.toArray + } else { + result(p) = inner + } + startOffset += partitionSize + } + result + } + + test("writing to an outputstream") { + (0 until NUM_PARTITIONS).foreach{ p => + val writer = mapOutputWriter.getNextPartitionWriter + val stream = writer.toStream() + data(p).foreach { i => stream.write(i)} + stream.close() + intercept[IllegalStateException] { + stream.write(p) + } + assert(writer.getNumBytesWritten() == D_LEN) + writer.close + } + mapOutputWriter.commitAllPartitions() + val partitionLengths = (0 until NUM_PARTITIONS).map { _ => D_LEN.toDouble}.toArray + assert(partitionSizesInMergedFile === partitionLengths) + assert(mergedOutputFile.length() === partitionLengths.sum) + assert(data === readRecordsFromFile(false)) + } + + test("writing to a channel") { + (0 until NUM_PARTITIONS).foreach{ p => + val writer = mapOutputWriter.getNextPartitionWriter + val channel = writer.toChannel() + val byteBuffer = ByteBuffer.allocate(D_LEN * 4) + val intBuffer = byteBuffer.asIntBuffer() + intBuffer.put(data(p)) + assert(channel.isOpen) + channel.write(byteBuffer) + // Bytes require * 4 + assert(writer.getNumBytesWritten == D_LEN * 4) + writer.close + } + mapOutputWriter.commitAllPartitions() + val partitionLengths = (0 until NUM_PARTITIONS).map { _ => (D_LEN * 4).toDouble}.toArray + assert(partitionSizesInMergedFile === partitionLengths) + assert(mergedOutputFile.length() === partitionLengths.sum) + assert(data === readRecordsFromFile(true)) + } + + test("copyStreams with an outputstream") { + (0 until NUM_PARTITIONS).foreach{ p => + val writer = mapOutputWriter.getNextPartitionWriter + val stream = writer.toStream() + val byteBuffer = ByteBuffer.allocate(D_LEN * 4) + val intBuffer = byteBuffer.asIntBuffer() + intBuffer.put(data(p)) + val in = new ByteArrayInputStream(byteBuffer.array()) + Utils.copyStream(in, stream, false, false) + in.close() + stream.close() + assert(writer.getNumBytesWritten == D_LEN * 4) + writer.close + } + mapOutputWriter.commitAllPartitions() + val partitionLengths = (0 until NUM_PARTITIONS).map { _ => (D_LEN * 4).toDouble}.toArray + assert(partitionSizesInMergedFile === partitionLengths) + assert(mergedOutputFile.length() === partitionLengths.sum) + assert(data === readRecordsFromFile(true)) + } + + test("copyStreamsWithNIO with a channel") { + (0 until NUM_PARTITIONS).foreach{ p => + val writer = mapOutputWriter.getNextPartitionWriter + val channel = writer.toChannel() + val byteBuffer = ByteBuffer.allocate(D_LEN * 4) + val intBuffer = byteBuffer.asIntBuffer() + intBuffer.put(data(p)) + val out = new FileOutputStream(tempFile) + out.write(byteBuffer.array()) + out.close() + val in = new FileInputStream(tempFile) + Utils.copyFileStreamNIO(in.getChannel, channel, 0, D_LEN * 4) + in.close() + assert(writer.getNumBytesWritten == D_LEN * 4) + writer.close + } + mapOutputWriter.commitAllPartitions() + val partitionLengths = (0 until NUM_PARTITIONS).map { _ => (D_LEN * 4).toDouble}.toArray + assert(partitionSizesInMergedFile === partitionLengths) + assert(mergedOutputFile.length() === partitionLengths.sum) + assert(data === readRecordsFromFile(true)) + } +} From 83bcfcffb22b427693fc73c04973f0709136e864 Mon Sep 17 00:00:00 2001 From: mccheah Date: Thu, 4 Apr 2019 19:55:47 -0700 Subject: [PATCH 06/17] Set the task context in writer benchmarks (#529) --- .../spark/shuffle/sort/ShuffleWriterBenchmarkBase.scala | 1 + .../spark/shuffle/sort/SortShuffleWriterBenchmark.scala | 4 ++-- .../spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala | 4 ++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleWriterBenchmarkBase.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleWriterBenchmarkBase.scala index eceb207219c2..26b92e5203b5 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleWriterBenchmarkBase.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleWriterBenchmarkBase.scala @@ -111,6 +111,7 @@ abstract class ShuffleWriterBenchmarkBase extends BenchmarkBase { when(rpcEnv.setupEndpoint(any[String], any[RpcEndpoint])).thenReturn(rpcEndpointRef) def setup(): Unit = { + TaskContext.setTaskContext(taskContext) memoryManager = new TestMemoryManager(defaultConf) memoryManager.limit(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES) taskMemoryManager = new TaskMemoryManager(memoryManager, 0) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala index ce1abde421fc..110ff1b03d51 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala @@ -19,10 +19,9 @@ package org.apache.spark.shuffle.sort import org.mockito.Mockito.when -import org.apache.spark.{Aggregator, SparkEnv} +import org.apache.spark.{Aggregator, SparkEnv, TaskContext} import org.apache.spark.benchmark.Benchmark import org.apache.spark.shuffle.BaseShuffleHandle -import org.apache.spark.shuffle.sort.io.DefaultShuffleDataIO /** * Benchmark to measure performance for aggregate primitives. @@ -76,6 +75,7 @@ object SortShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase { } when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager) + TaskContext.setTaskContext(taskContext) val shuffleWriter = new SortShuffleWriter[String, String, String]( blockResolver, diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala index ac62b496406f..15a08111f6d5 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala @@ -16,9 +16,8 @@ */ package org.apache.spark.shuffle.sort -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, TaskContext} import org.apache.spark.benchmark.Benchmark -import org.apache.spark.util.Utils /** * Benchmark to measure performance for aggregate primitives. @@ -44,6 +43,7 @@ object UnsafeShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase { val conf = new SparkConf(loadDefaults = false) conf.set("spark.file.transferTo", String.valueOf(transferTo)) + TaskContext.setTaskContext(taskContext) new UnsafeShuffleWriter[String, String]( blockManager, blockResolver, From 04d30a097beb81bc2037ebdd84d723e92cd738d2 Mon Sep 17 00:00:00 2001 From: mccheah Date: Mon, 15 Apr 2019 11:35:03 -0700 Subject: [PATCH 07/17] [SPARK-25299] Use the shuffle writer plugin for the SortShuffleWriter. (#532) * [SPARK-25299] Use the shuffle writer plugin for the SortShuffleWriter. * Remove unused * Handle empty partitions properly. * Adjust formatting * Don't close streams twice. Because compressed output streams don't like it. * Clarify comment --- .../shuffle/sort/SortShuffleManager.scala | 3 +- .../shuffle/sort/SortShuffleWriter.scala | 23 ++-- .../spark/storage/DiskBlockObjectWriter.scala | 4 +- .../util/collection/ExternalSorter.scala | 130 +++++++++++++++++- .../spark/util/collection/PairsWriter.scala | 23 ++++ .../ShufflePartitionPairsWriter.scala | 105 ++++++++++++++ .../WritablePartitionedPairCollection.scala | 4 +- .../sort/SortShuffleWriterBenchmark.scala | 6 +- 8 files changed, 270 insertions(+), 28 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/util/collection/PairsWriter.scala create mode 100644 core/src/main/scala/org/apache/spark/util/collection/ShufflePartitionPairsWriter.scala diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 5da7b5cb35e6..42a249564cd0 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -157,7 +157,8 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager metrics, shuffleExecutorComponents.writes()) case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => - new SortShuffleWriter(shuffleBlockResolver, other, mapId, context) + new SortShuffleWriter( + shuffleBlockResolver, other, mapId, context, shuffleExecutorComponents.writes()) } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 16058de8bf3f..62316f384b64 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -18,18 +18,18 @@ package org.apache.spark.shuffle.sort import org.apache.spark._ +import org.apache.spark.api.shuffle.ShuffleWriteSupport import org.apache.spark.internal.{config, Logging} import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver, ShuffleWriter} -import org.apache.spark.storage.ShuffleBlockId -import org.apache.spark.util.Utils import org.apache.spark.util.collection.ExternalSorter private[spark] class SortShuffleWriter[K, V, C]( shuffleBlockResolver: IndexShuffleBlockResolver, handle: BaseShuffleHandle[K, V, C], mapId: Int, - context: TaskContext) + context: TaskContext, + writeSupport: ShuffleWriteSupport) extends ShuffleWriter[K, V] with Logging { private val dep = handle.dependency @@ -64,18 +64,11 @@ private[spark] class SortShuffleWriter[K, V, C]( // Don't bother including the time to open the merged output file in the shuffle write time, // because it just opens a single file, so is typically too fast to measure accurately // (see SPARK-3570). - val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId) - val tmp = Utils.tempFileWith(output) - try { - val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID) - val partitionLengths = sorter.writePartitionedFile(blockId, tmp) - shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp) - mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) - } finally { - if (tmp.exists() && !tmp.delete()) { - logError(s"Error while deleting temp file ${tmp.getAbsolutePath}") - } - } + val mapOutputWriter = writeSupport.createMapOutputWriter( + dep.shuffleId, mapId, dep.partitioner.numPartitions) + val partitionLengths = sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter) + mapOutputWriter.commitAllPartitions() + mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) } /** Close this writer, passing along whether the map completed */ diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index 17390f9c60e7..f9f4e3594e4f 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -24,6 +24,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager} import org.apache.spark.shuffle.ShuffleWriteMetricsReporter import org.apache.spark.util.Utils +import org.apache.spark.util.collection.PairsWriter /** * A class for writing JVM objects directly to a file on disk. This class allows data to be appended @@ -46,7 +47,8 @@ private[spark] class DiskBlockObjectWriter( writeMetrics: ShuffleWriteMetricsReporter, val blockId: BlockId = null) extends OutputStream - with Logging { + with Logging + with PairsWriter { /** * Guards against close calls, e.g. from a wrapping stream. diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 4806c1396725..8ccc1dfc9b3f 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -26,10 +26,11 @@ import scala.collection.mutable.ArrayBuffer import com.google.common.io.ByteStreams import org.apache.spark._ +import org.apache.spark.api.shuffle.{ShuffleMapOutputWriter, ShufflePartitionWriter} import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.internal.{config, Logging} import org.apache.spark.serializer._ -import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter} +import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, ShuffleBlockId} /** * Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner @@ -674,11 +675,9 @@ private[spark] class ExternalSorter[K, V, C]( } /** - * Write all the data added into this ExternalSorter into a file in the disk store. This is - * called by the SortShuffleWriter. - * - * @param blockId block ID to write to. The index file will be blockId.name + ".index". - * @return array of lengths, in bytes, of each partition of the file (used by map output tracker) + * TODO remove this, as this is only used by UnsafeRowSerializerSuite in the SQL project. + * We should figure out an alternative way to test that so that we can remove this otherwise + * unused code path. */ def writePartitionedFile( blockId: BlockId, @@ -722,6 +721,123 @@ private[spark] class ExternalSorter[K, V, C]( lengths } + private def writeEmptyPartition(mapOutputWriter: ShuffleMapOutputWriter): Unit = { + var partitionWriter: ShufflePartitionWriter = null + try { + partitionWriter = mapOutputWriter.getNextPartitionWriter + } finally { + if (partitionWriter != null) { + partitionWriter.close() + } + } + } + + /** + * Write all the data added into this ExternalSorter into a map output writer that pushes bytes + * to some arbitrary backing store. This is called by the SortShuffleWriter. + * + * @return array of lengths, in bytes, of each partition of the file (used by map output tracker) + */ + def writePartitionedMapOutput( + shuffleId: Int, mapId: Int, mapOutputWriter: ShuffleMapOutputWriter): Array[Long] = { + // Track location of each range in the map output + val lengths = new Array[Long](numPartitions) + var nextPartitionId = 0 + if (spills.isEmpty) { + // Case where we only have in-memory data + val collection = if (aggregator.isDefined) map else buffer + val it = collection.destructiveSortedWritablePartitionedIterator(comparator) + while (it.hasNext()) { + val partitionId = it.nextPartition() + // The contract for the plugin is that we will ask for a writer for every partition + // even if it's empty. However, the external sorter will return non-contiguous + // partition ids. So this loop "backfills" the empty partitions that form the gaps. + + // The algorithm as a whole is correct because the partition ids are returned by the + // iterator in ascending order. + for (emptyPartition <- nextPartitionId until partitionId) { + writeEmptyPartition(mapOutputWriter) + } + var partitionWriter: ShufflePartitionWriter = null + var partitionPairsWriter: ShufflePartitionPairsWriter = null + try { + partitionWriter = mapOutputWriter.getNextPartitionWriter + val blockId = ShuffleBlockId(shuffleId, mapId, partitionId) + partitionPairsWriter = new ShufflePartitionPairsWriter( + partitionWriter, + serializerManager, + serInstance, + blockId, + context.taskMetrics().shuffleWriteMetrics) + while (it.hasNext && it.nextPartition() == partitionId) { + it.writeNext(partitionPairsWriter) + } + } finally { + if (partitionPairsWriter != null) { + partitionPairsWriter.close() + } + if (partitionWriter != null) { + partitionWriter.close() + } + } + if (partitionWriter != null) { + lengths(partitionId) = partitionWriter.getNumBytesWritten + } + nextPartitionId = partitionId + 1 + } + } else { + // We must perform merge-sort; get an iterator by partition and write everything directly. + for ((id, elements) <- this.partitionedIterator) { + // The contract for the plugin is that we will ask for a writer for every partition + // even if it's empty. However, the external sorter will return non-contiguous + // partition ids. So this loop "backfills" the empty partitions that form the gaps. + + // The algorithm as a whole is correct because the partition ids are returned by the + // iterator in ascending order. + for (emptyPartition <- nextPartitionId until id) { + writeEmptyPartition(mapOutputWriter) + } + val blockId = ShuffleBlockId(shuffleId, mapId, id) + var partitionWriter: ShufflePartitionWriter = null + var partitionPairsWriter: ShufflePartitionPairsWriter = null + try { + partitionWriter = mapOutputWriter.getNextPartitionWriter + partitionPairsWriter = new ShufflePartitionPairsWriter( + partitionWriter, + serializerManager, + serInstance, + blockId, + context.taskMetrics().shuffleWriteMetrics) + if (elements.hasNext) { + for (elem <- elements) { + partitionPairsWriter.write(elem._1, elem._2) + } + } + } finally { + if (partitionPairsWriter!= null) { + partitionPairsWriter.close() + } + } + if (partitionWriter != null) { + lengths(id) = partitionWriter.getNumBytesWritten + } + nextPartitionId = id + 1 + } + } + + // The iterator may have stopped short of opening a writer for every partition. So fill in the + // remaining empty partitions. + for (emptyPartition <- nextPartitionId until numPartitions) { + writeEmptyPartition(mapOutputWriter) + } + + context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled) + context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled) + context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes) + + lengths + } + def stop(): Unit = { spills.foreach(s => s.file.delete()) spills.clear() @@ -785,7 +901,7 @@ private[spark] class ExternalSorter[K, V, C]( val inMemoryIterator = new WritablePartitionedIterator { private[this] var cur = if (upstream.hasNext) upstream.next() else null - def writeNext(writer: DiskBlockObjectWriter): Unit = { + def writeNext(writer: PairsWriter): Unit = { writer.write(cur._1._2, cur._2) cur = if (upstream.hasNext) upstream.next() else null } diff --git a/core/src/main/scala/org/apache/spark/util/collection/PairsWriter.scala b/core/src/main/scala/org/apache/spark/util/collection/PairsWriter.scala new file mode 100644 index 000000000000..9d7c209f242e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/PairsWriter.scala @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +private[spark] trait PairsWriter { + + def write(key: Any, value: Any): Unit +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/ShufflePartitionPairsWriter.scala b/core/src/main/scala/org/apache/spark/util/collection/ShufflePartitionPairsWriter.scala new file mode 100644 index 000000000000..6f19a2323efd --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/ShufflePartitionPairsWriter.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import java.io.{Closeable, FilterOutputStream, OutputStream} + +import org.apache.spark.api.shuffle.ShufflePartitionWriter +import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager} +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter +import org.apache.spark.storage.BlockId + +/** + * A key-value writer inspired by {@link DiskBlockObjectWriter} that pushes the bytes to an + * arbitrary partition writer instead of writing to local disk through the block manager. + */ +private[spark] class ShufflePartitionPairsWriter( + partitionWriter: ShufflePartitionWriter, + serializerManager: SerializerManager, + serializerInstance: SerializerInstance, + blockId: BlockId, + writeMetrics: ShuffleWriteMetricsReporter) + extends PairsWriter with Closeable { + + private var isOpen = false + private var partitionStream: OutputStream = _ + private var wrappedStream: OutputStream = _ + private var objOut: SerializationStream = _ + private var numRecordsWritten = 0 + private var curNumBytesWritten = 0L + + override def write(key: Any, value: Any): Unit = { + if (!isOpen) { + open() + isOpen = true + } + objOut.writeKey(key) + objOut.writeValue(value) + writeMetrics.incRecordsWritten(1) + } + + private def open(): Unit = { + // The contract is that the partition writer is expected to close its own streams, but + // the compressor will only flush the stream when it is specifically closed. So we want to + // close objOut to flush the compressed bytes to the partition writer stream, but we don't want + // to close the partition output stream in the process. + partitionStream = new CloseShieldOutputStream(partitionWriter.toStream) + wrappedStream = serializerManager.wrapStream(blockId, partitionStream) + objOut = serializerInstance.serializeStream(wrappedStream) + } + + override def close(): Unit = { + if (isOpen) { + // Closing objOut should propagate close to all inner layers + // We can't close wrappedStream explicitly because closing objOut and closing wrappedStream + // causes problems when closing compressed output streams twice. + objOut.close() + objOut = null + wrappedStream = null + partitionStream = null + partitionWriter.close() + isOpen = false + updateBytesWritten() + } + } + + /** + * Notify the writer that a record worth of bytes has been written with OutputStream#write. + */ + private def recordWritten(): Unit = { + numRecordsWritten += 1 + writeMetrics.incRecordsWritten(1) + + if (numRecordsWritten % 16384 == 0) { + updateBytesWritten() + } + } + + private def updateBytesWritten(): Unit = { + val numBytesWritten = partitionWriter.getNumBytesWritten + val bytesWrittenDiff = numBytesWritten - curNumBytesWritten + writeMetrics.incBytesWritten(bytesWrittenDiff) + curNumBytesWritten = numBytesWritten + } + + private class CloseShieldOutputStream(delegate: OutputStream) + extends FilterOutputStream(delegate) { + + override def close(): Unit = flush() + } +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala index 5232c2bd8d6f..337b0673b403 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala @@ -52,7 +52,7 @@ private[spark] trait WritablePartitionedPairCollection[K, V] { new WritablePartitionedIterator { private[this] var cur = if (it.hasNext) it.next() else null - def writeNext(writer: DiskBlockObjectWriter): Unit = { + def writeNext(writer: PairsWriter): Unit = { writer.write(cur._1._2, cur._2) cur = if (it.hasNext) it.next() else null } @@ -96,7 +96,7 @@ private[spark] object WritablePartitionedPairCollection { * has an associated partition. */ private[spark] trait WritablePartitionedIterator { - def writeNext(writer: DiskBlockObjectWriter): Unit + def writeNext(writer: PairsWriter): Unit def hasNext(): Boolean diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala index 110ff1b03d51..32257b0cc4b5 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala @@ -22,6 +22,7 @@ import org.mockito.Mockito.when import org.apache.spark.{Aggregator, SparkEnv, TaskContext} import org.apache.spark.benchmark.Benchmark import org.apache.spark.shuffle.BaseShuffleHandle +import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport /** * Benchmark to measure performance for aggregate primitives. @@ -76,13 +77,14 @@ object SortShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase { when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager) TaskContext.setTaskContext(taskContext) + val writeSupport = new DefaultShuffleWriteSupport(defaultConf, blockResolver) val shuffleWriter = new SortShuffleWriter[String, String, String]( blockResolver, shuffleHandle, 0, - taskContext - ) + taskContext, + writeSupport) shuffleWriter } From b0ca59950c19a36e969892e955175d11619811e9 Mon Sep 17 00:00:00 2001 From: mccheah Date: Tue, 16 Apr 2019 19:13:31 -0700 Subject: [PATCH 08/17] [SPARK-25299] Make UnsafeShuffleWriter use the new API (#536) Ported from https://github.com/bloomberg/apache-spark-on-k8s/pull/9. Credits to @ifilonenko! --- .../shuffle/sort/UnsafeShuffleWriter.java | 236 +++++++++--------- .../io/DefaultShuffleMapOutputWriter.java | 18 +- .../shuffle/sort/SortShuffleManager.scala | 3 +- .../sort/UnsafeShuffleWriterSuite.java | 43 +++- ...ypassMergeSortShuffleWriterBenchmark.scala | 4 +- .../sort/UnsafeShuffleWriterBenchmark.scala | 5 +- 6 files changed, 162 insertions(+), 147 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 9d05f03613ce..b5ca6c8a50ce 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -20,6 +20,7 @@ import javax.annotation.Nullable; import java.io.*; import java.nio.channels.FileChannel; +import java.nio.channels.WritableByteChannel; import java.util.Iterator; import scala.Option; @@ -31,18 +32,19 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.io.ByteStreams; import com.google.common.io.Closeables; -import com.google.common.io.Files; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.*; import org.apache.spark.annotation.Private; +import org.apache.spark.api.shuffle.ShuffleMapOutputWriter; +import org.apache.spark.api.shuffle.ShufflePartitionWriter; +import org.apache.spark.api.shuffle.ShuffleWriteSupport; import org.apache.spark.internal.config.package$; import org.apache.spark.io.CompressionCodec; import org.apache.spark.io.CompressionCodec$; import org.apache.spark.io.NioBufferedFileInputStream; import org.apache.commons.io.output.CloseShieldOutputStream; -import org.apache.commons.io.output.CountingOutputStream; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.network.util.LimitedInputStream; import org.apache.spark.scheduler.MapStatus; @@ -53,7 +55,6 @@ import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.shuffle.ShuffleWriter; import org.apache.spark.storage.BlockManager; -import org.apache.spark.storage.TimeTrackingOutputStream; import org.apache.spark.unsafe.Platform; import org.apache.spark.util.Utils; @@ -65,7 +66,6 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private static final ClassTag OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object(); @VisibleForTesting - static final int DEFAULT_INITIAL_SORT_BUFFER_SIZE = 4096; static final int DEFAULT_INITIAL_SER_BUFFER_SIZE = 1024 * 1024; private final BlockManager blockManager; @@ -74,6 +74,7 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private final SerializerInstance serializer; private final Partitioner partitioner; private final ShuffleWriteMetricsReporter writeMetrics; + private final ShuffleWriteSupport shuffleWriteSupport; private final int shuffleId; private final int mapId; private final TaskContext taskContext; @@ -81,7 +82,6 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private final boolean transferToEnabled; private final int initialSortBufferSize; private final int inputBufferSizeInBytes; - private final int outputBufferSizeInBytes; @Nullable private MapStatus mapStatus; @Nullable private ShuffleExternalSorter sorter; @@ -103,18 +103,6 @@ private static final class MyByteArrayOutputStream extends ByteArrayOutputStream */ private boolean stopping = false; - private class CloseAndFlushShieldOutputStream extends CloseShieldOutputStream { - - CloseAndFlushShieldOutputStream(OutputStream outputStream) { - super(outputStream); - } - - @Override - public void flush() { - // do nothing - } - } - public UnsafeShuffleWriter( BlockManager blockManager, IndexShuffleBlockResolver shuffleBlockResolver, @@ -123,7 +111,8 @@ public UnsafeShuffleWriter( int mapId, TaskContext taskContext, SparkConf sparkConf, - ShuffleWriteMetricsReporter writeMetrics) throws IOException { + ShuffleWriteMetricsReporter writeMetrics, + ShuffleWriteSupport shuffleWriteSupport) throws IOException { final int numPartitions = handle.dependency().partitioner().numPartitions(); if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) { throw new IllegalArgumentException( @@ -140,6 +129,7 @@ public UnsafeShuffleWriter( this.serializer = dep.serializer().newInstance(); this.partitioner = dep.partitioner(); this.writeMetrics = writeMetrics; + this.shuffleWriteSupport = shuffleWriteSupport; this.taskContext = taskContext; this.sparkConf = sparkConf; this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true); @@ -147,8 +137,6 @@ public UnsafeShuffleWriter( (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE()); this.inputBufferSizeInBytes = (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024; - this.outputBufferSizeInBytes = - (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE()) * 1024; open(); } @@ -230,24 +218,27 @@ void closeAndWriteOutput() throws IOException { serOutputStream = null; final SpillInfo[] spills = sorter.closeAndGetSpills(); sorter = null; + final ShuffleMapOutputWriter mapWriter = shuffleWriteSupport + .createMapOutputWriter(shuffleId, mapId, partitioner.numPartitions()); final long[] partitionLengths; - final File output = shuffleBlockResolver.getDataFile(shuffleId, mapId); - final File tmp = Utils.tempFileWith(output); try { try { - partitionLengths = mergeSpills(spills, tmp); + partitionLengths = mergeSpills(spills, mapWriter); } finally { for (SpillInfo spill : spills) { - if (spill.file.exists() && ! spill.file.delete()) { + if (spill.file.exists() && !spill.file.delete()) { logger.error("Error while deleting spill file {}", spill.file.getPath()); } } } - shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp); - } finally { - if (tmp.exists() && !tmp.delete()) { - logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); + mapWriter.commitAllPartitions(); + } catch (Exception e) { + try { + mapWriter.abort(e); + } catch (Exception innerE) { + logger.error("Failed to abort the Map Output Writer", innerE); } + throw e; } mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); } @@ -281,7 +272,8 @@ void forceSorterToSpill() throws IOException { * * @return the partition lengths in the merged file. */ - private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOException { + private long[] mergeSpills(SpillInfo[] spills, + ShuffleMapOutputWriter mapWriter) throws IOException { final boolean compressionEnabled = (boolean) sparkConf.get(package$.MODULE$.SHUFFLE_COMPRESS()); final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf); final boolean fastMergeEnabled = @@ -289,17 +281,24 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti final boolean fastMergeIsSupported = !compressionEnabled || CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec); final boolean encryptionEnabled = blockManager.serializerManager().encryptionEnabled(); + final int numPartitions = partitioner.numPartitions(); + long[] partitionLengths = new long[numPartitions]; try { if (spills.length == 0) { - new FileOutputStream(outputFile).close(); // Create an empty file - return new long[partitioner.numPartitions()]; - } else if (spills.length == 1) { - // Here, we don't need to perform any metrics updates because the bytes written to this - // output file would have already been counted as shuffle bytes written. - Files.move(spills[0].file, outputFile); - return spills[0].partitionLengths; + // The contract we are working under states that we will open a partition writer for + // each partition, regardless of number of spills + for (int i = 0; i < numPartitions; i++) { + ShufflePartitionWriter writer = null; + try { + writer = mapWriter.getNextPartitionWriter(); + } finally { + if (writer != null) { + writer.close(); + } + } + } + return partitionLengths; } else { - final long[] partitionLengths; // There are multiple spills to merge, so none of these spill files' lengths were counted // towards our shuffle write count or shuffle write time. If we use the slow merge path, // then the final output file's size won't necessarily be equal to the sum of the spill @@ -316,14 +315,14 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti // that doesn't need to interpret the spilled bytes. if (transferToEnabled && !encryptionEnabled) { logger.debug("Using transferTo-based fast merge"); - partitionLengths = mergeSpillsWithTransferTo(spills, outputFile); + partitionLengths = mergeSpillsWithTransferTo(spills, mapWriter); } else { logger.debug("Using fileStream-based fast merge"); - partitionLengths = mergeSpillsWithFileStream(spills, outputFile, null); + partitionLengths = mergeSpillsWithFileStream(spills, mapWriter, null); } } else { logger.debug("Using slow merge"); - partitionLengths = mergeSpillsWithFileStream(spills, outputFile, compressionCodec); + partitionLengths = mergeSpillsWithFileStream(spills, mapWriter, compressionCodec); } // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has // in-memory records, we write out the in-memory records to a file but do not count that @@ -331,13 +330,9 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti // to be counted as shuffle write, but this will lead to double-counting of the final // SpillInfo's bytes. writeMetrics.decBytesWritten(spills[spills.length - 1].file.length()); - writeMetrics.incBytesWritten(outputFile.length()); return partitionLengths; } } catch (IOException e) { - if (outputFile.exists() && !outputFile.delete()) { - logger.error("Unable to delete output file {}", outputFile.getPath()); - } throw e; } } @@ -345,73 +340,79 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti /** * Merges spill files using Java FileStreams. This code path is typically slower than * the NIO-based merge, {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], - * File)}, and it's mostly used in cases where the IO compression codec does not support - * concatenation of compressed data, when encryption is enabled, or when users have - * explicitly disabled use of {@code transferTo} in order to work around kernel bugs. + * ShuffleMapOutputWriter)}, and it's mostly used in cases where the IO compression codec + * does not support concatenation of compressed data, when encryption is enabled, or when + * users have explicitly disabled use of {@code transferTo} in order to work around kernel bugs. * This code path might also be faster in cases where individual partition size in a spill * is small and UnsafeShuffleWriter#mergeSpillsWithTransferTo method performs many small * disk ios which is inefficient. In those case, Using large buffers for input and output * files helps reducing the number of disk ios, making the file merging faster. * * @param spills the spills to merge. - * @param outputFile the file to write the merged data to. + * @param mapWriter the map output writer to use for output. * @param compressionCodec the IO compression codec, or null if shuffle compression is disabled. * @return the partition lengths in the merged file. */ private long[] mergeSpillsWithFileStream( SpillInfo[] spills, - File outputFile, + ShuffleMapOutputWriter mapWriter, @Nullable CompressionCodec compressionCodec) throws IOException { - assert (spills.length >= 2); final int numPartitions = partitioner.numPartitions(); final long[] partitionLengths = new long[numPartitions]; final InputStream[] spillInputStreams = new InputStream[spills.length]; - final OutputStream bos = new BufferedOutputStream( - new FileOutputStream(outputFile), - outputBufferSizeInBytes); - // Use a counting output stream to avoid having to close the underlying file and ask - // the file system for its size after each partition is written. - final CountingOutputStream mergedFileOutputStream = new CountingOutputStream(bos); - boolean threwException = true; try { for (int i = 0; i < spills.length; i++) { spillInputStreams[i] = new NioBufferedFileInputStream( - spills[i].file, - inputBufferSizeInBytes); + spills[i].file, + inputBufferSizeInBytes); } for (int partition = 0; partition < numPartitions; partition++) { - final long initialFileLength = mergedFileOutputStream.getByteCount(); - // Shield the underlying output stream from close() and flush() calls, so that we can close - // the higher level streams to make sure all data is really flushed and internal state is - // cleaned. - OutputStream partitionOutput = new CloseAndFlushShieldOutputStream( - new TimeTrackingOutputStream(writeMetrics, mergedFileOutputStream)); - partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput); - if (compressionCodec != null) { - partitionOutput = compressionCodec.compressedOutputStream(partitionOutput); - } - for (int i = 0; i < spills.length; i++) { - final long partitionLengthInSpill = spills[i].partitionLengths[partition]; - if (partitionLengthInSpill > 0) { - InputStream partitionInputStream = new LimitedInputStream(spillInputStreams[i], - partitionLengthInSpill, false); - try { - partitionInputStream = blockManager.serializerManager().wrapForEncryption( - partitionInputStream); - if (compressionCodec != null) { - partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream); + boolean copyThrewExecption = true; + ShufflePartitionWriter writer = null; + try { + writer = mapWriter.getNextPartitionWriter(); + OutputStream partitionOutput = null; + try { + // Shield the underlying output stream from close() calls, so that we can close the + // higher level streams to make sure all data is really flushed and internal state + // is cleaned + partitionOutput = new CloseShieldOutputStream(writer.toStream()); + partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput); + if (compressionCodec != null) { + partitionOutput = compressionCodec.compressedOutputStream(partitionOutput); + } + for (int i = 0; i < spills.length; i++) { + final long partitionLengthInSpill = spills[i].partitionLengths[partition]; + + if (partitionLengthInSpill > 0) { + InputStream partitionInputStream = null; + try { + partitionInputStream = new LimitedInputStream(spillInputStreams[i], + partitionLengthInSpill, false); + partitionInputStream = blockManager.serializerManager().wrapForEncryption( + partitionInputStream); + if (compressionCodec != null) { + partitionInputStream = compressionCodec.compressedInputStream( + partitionInputStream); + } + ByteStreams.copy(partitionInputStream, partitionOutput); + } finally { + partitionInputStream.close(); + } } - ByteStreams.copy(partitionInputStream, partitionOutput); - } finally { - partitionInputStream.close(); + copyThrewExecption = false; } + } finally { + Closeables.close(partitionOutput, copyThrewExecption); } + } finally { + Closeables.close(writer, copyThrewExecption); } - partitionOutput.flush(); - partitionOutput.close(); - partitionLengths[partition] = (mergedFileOutputStream.getByteCount() - initialFileLength); + long numBytesWritten = writer.getNumBytesWritten(); + partitionLengths[partition] = numBytesWritten; + writeMetrics.incBytesWritten(numBytesWritten); } threwException = false; } finally { @@ -420,7 +421,6 @@ private long[] mergeSpillsWithFileStream( for (InputStream stream : spillInputStreams) { Closeables.close(stream, threwException); } - Closeables.close(mergedFileOutputStream, threwException); } return partitionLengths; } @@ -430,54 +430,49 @@ private long[] mergeSpillsWithFileStream( * This is only safe when the IO compression codec and serializer support concatenation of * serialized streams. * + * @param spills the spills to merge. + * @param mapWriter the map output writer to use for output. * @return the partition lengths in the merged file. */ - private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) throws IOException { - assert (spills.length >= 2); + private long[] mergeSpillsWithTransferTo( + SpillInfo[] spills, + ShuffleMapOutputWriter mapWriter) throws IOException { final int numPartitions = partitioner.numPartitions(); final long[] partitionLengths = new long[numPartitions]; final FileChannel[] spillInputChannels = new FileChannel[spills.length]; final long[] spillInputChannelPositions = new long[spills.length]; - FileChannel mergedFileOutputChannel = null; boolean threwException = true; try { for (int i = 0; i < spills.length; i++) { spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel(); } - // This file needs to opened in append mode in order to work around a Linux kernel bug that - // affects transferTo; see SPARK-3948 for more details. - mergedFileOutputChannel = new FileOutputStream(outputFile, true).getChannel(); - - long bytesWrittenToMergedFile = 0; for (int partition = 0; partition < numPartitions; partition++) { - for (int i = 0; i < spills.length; i++) { - final long partitionLengthInSpill = spills[i].partitionLengths[partition]; - final FileChannel spillInputChannel = spillInputChannels[i]; - final long writeStartTime = System.nanoTime(); - Utils.copyFileStreamNIO( - spillInputChannel, - mergedFileOutputChannel, - spillInputChannelPositions[i], - partitionLengthInSpill); - spillInputChannelPositions[i] += partitionLengthInSpill; - writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); - bytesWrittenToMergedFile += partitionLengthInSpill; - partitionLengths[partition] += partitionLengthInSpill; + boolean copyThrewExecption = true; + ShufflePartitionWriter writer = null; + try { + writer = mapWriter.getNextPartitionWriter(); + WritableByteChannel channel = writer.toChannel(); + for (int i = 0; i < spills.length; i++) { + long partitionLengthInSpill = 0L; + partitionLengthInSpill += spills[i].partitionLengths[partition]; + final FileChannel spillInputChannel = spillInputChannels[i]; + final long writeStartTime = System.nanoTime(); + Utils.copyFileStreamNIO( + spillInputChannel, + channel, + spillInputChannelPositions[i], + partitionLengthInSpill); + copyThrewExecption = false; + spillInputChannelPositions[i] += partitionLengthInSpill; + writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); + } + } finally { + Closeables.close(writer, copyThrewExecption); } - } - // Check the position after transferTo loop to see if it is in the right position and raise an - // exception if it is incorrect. The position will not be increased to the expected length - // after calling transferTo in kernel version 2.6.32. This issue is described at - // https://bugs.openjdk.java.net/browse/JDK-7052359 and SPARK-3948. - if (mergedFileOutputChannel.position() != bytesWrittenToMergedFile) { - throw new IOException( - "Current position " + mergedFileOutputChannel.position() + " does not equal expected " + - "position " + bytesWrittenToMergedFile + " after transferTo. Please check your kernel" + - " version to see if it is 2.6.32, as there is a kernel bug which will lead to " + - "unexpected behavior when using transferTo. You can set spark.file.transferTo=false " + - "to disable this NIO feature." - ); + long numBytes = writer.getNumBytesWritten(); + partitionLengths[partition] = numBytes; + writeMetrics.incBytesWritten(numBytes); } threwException = false; } finally { @@ -487,7 +482,6 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th assert(spillInputChannelPositions[i] == spills[i].file.length()); Closeables.close(spillInputChannels[i], threwException); } - Closeables.close(mergedFileOutputChannel, threwException); } return partitionLengths; } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java index 0f7e5ed66bb7..c84158e1891d 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java @@ -92,7 +92,8 @@ public ShufflePartitionWriter getNextPartitionWriter() throws IOException { @Override public void commitAllPartitions() throws IOException { cleanUp(); - blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, outputTempFile); + File resolvedTmp = outputTempFile != null && outputTempFile.isFile() ? outputTempFile : null; + blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, resolvedTmp); } @Override @@ -111,11 +112,9 @@ private void cleanUp() throws IOException { if (outputBufferedFileStream != null) { outputBufferedFileStream.close(); } - if (outputFileChannel != null) { outputFileChannel.close(); } - if (outputFileStream != null) { outputFileStream.close(); } @@ -191,8 +190,9 @@ public long getNumBytesWritten() { } @Override - public void close() throws IOException { + public void close() { if (stream != null) { + // Closing is a no-op. stream.close(); } partitionLengths[partitionId] = getNumBytesWritten(); @@ -222,18 +222,10 @@ public void write(byte[] buf, int pos, int length) throws IOException { } @Override - public void close() throws IOException { - flush(); + public void close() { isClosed = true; } - @Override - public void flush() throws IOException { - if (!isClosed) { - outputBufferedFileStream.flush(); - } - } - private void verifyNotClosed() { if (isClosed) { throw new IllegalStateException("Attempting to write to a closed block output stream."); diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 42a249564cd0..849050556c56 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -146,7 +146,8 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager mapId, context, env.conf, - metrics) + metrics, + shuffleExecutorComponents.writes()) case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] => new BypassMergeSortShuffleWriter( env.blockManager, diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 9bf707f783d4..012dc5d21bce 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -19,8 +19,10 @@ import java.io.*; import java.nio.ByteBuffer; +import java.nio.file.Files; import java.util.*; +import org.mockito.stubbing.Answer; import scala.Option; import scala.Product2; import scala.Tuple2; @@ -39,6 +41,7 @@ import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; import org.apache.spark.TaskContext; +import org.apache.spark.TaskContext$; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.executor.TaskMetrics; import org.apache.spark.io.CompressionCodec$; @@ -53,6 +56,7 @@ import org.apache.spark.security.CryptoStreamUtils; import org.apache.spark.serializer.*; import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport; import org.apache.spark.storage.*; import org.apache.spark.util.Utils; @@ -65,6 +69,7 @@ public class UnsafeShuffleWriterSuite { + static final int DEFAULT_INITIAL_SORT_BUFFER_SIZE = 4096; static final int NUM_PARTITITONS = 4; TestMemoryManager memoryManager; TaskMemoryManager taskMemoryManager; @@ -85,6 +90,7 @@ public class UnsafeShuffleWriterSuite { @After public void tearDown() { + TaskContext$.MODULE$.unset(); Utils.deleteRecursively(tempDir); final long leakedMemory = taskMemoryManager.cleanUpAllAllocatedMemory(); if (leakedMemory != 0) { @@ -132,14 +138,28 @@ public void setUp() throws IOException { }); when(shuffleBlockResolver.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile); - doAnswer(invocationOnMock -> { + + Answer renameTempAnswer = invocationOnMock -> { partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2]; File tmp = (File) invocationOnMock.getArguments()[3]; - mergedOutputFile.delete(); - tmp.renameTo(mergedOutputFile); + if (!mergedOutputFile.delete()) { + throw new RuntimeException("Failed to delete old merged output file."); + } + if (tmp != null) { + Files.move(tmp.toPath(), mergedOutputFile.toPath()); + } else if (!mergedOutputFile.createNewFile()) { + throw new RuntimeException("Failed to create empty merged output file."); + } return null; - }).when(shuffleBlockResolver) - .writeIndexFileAndCommit(anyInt(), anyInt(), any(long[].class), any(File.class)); + }; + + doAnswer(renameTempAnswer) + .when(shuffleBlockResolver) + .writeIndexFileAndCommit(anyInt(), anyInt(), any(long[].class), any(File.class)); + + doAnswer(renameTempAnswer) + .when(shuffleBlockResolver) + .writeIndexFileAndCommit(anyInt(), anyInt(), any(long[].class), eq(null)); when(diskBlockManager.createTempShuffleBlock()).thenAnswer(invocationOnMock -> { TempShuffleBlockId blockId = new TempShuffleBlockId(UUID.randomUUID()); @@ -151,6 +171,9 @@ public void setUp() throws IOException { when(taskContext.taskMetrics()).thenReturn(taskMetrics); when(shuffleDep.serializer()).thenReturn(serializer); when(shuffleDep.partitioner()).thenReturn(hashPartitioner); + when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager); + + TaskContext$.MODULE$.setTaskContext(taskContext); } private UnsafeShuffleWriter createWriter( @@ -164,7 +187,8 @@ private UnsafeShuffleWriter createWriter( 0, // map id taskContext, conf, - taskContext.taskMetrics().shuffleWriteMetrics() + taskContext.taskMetrics().shuffleWriteMetrics(), + new DefaultShuffleWriteSupport(conf, shuffleBlockResolver) ); } @@ -444,10 +468,10 @@ public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpillRadixOn() thro } private void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exception { - memoryManager.limit(UnsafeShuffleWriter.DEFAULT_INITIAL_SORT_BUFFER_SIZE * 16); + memoryManager.limit(DEFAULT_INITIAL_SORT_BUFFER_SIZE * 16); final UnsafeShuffleWriter writer = createWriter(false); final ArrayList> dataToWrite = new ArrayList<>(); - for (int i = 0; i < UnsafeShuffleWriter.DEFAULT_INITIAL_SORT_BUFFER_SIZE + 1; i++) { + for (int i = 0; i < DEFAULT_INITIAL_SORT_BUFFER_SIZE + 1; i++) { dataToWrite.add(new Tuple2<>(i, i)); } writer.write(dataToWrite.iterator()); @@ -525,7 +549,8 @@ public void testPeakMemoryUsed() throws Exception { 0, // map id taskContext, conf, - taskContext.taskMetrics().shuffleWriteMetrics()); + taskContext.taskMetrics().shuffleWriteMetrics(), + new DefaultShuffleWriteSupport(conf, shuffleBlockResolver)); // Peak memory should be monotonically increasing. More specifically, every time // we allocate a new page it should increase by exactly the size of the page. diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala index 7eb867fc29fd..69fe03e75606 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala @@ -19,7 +19,7 @@ package org.apache.spark.shuffle.sort import org.apache.spark.SparkConf import org.apache.spark.benchmark.Benchmark -import org.apache.spark.shuffle.sort.io.{DefaultShuffleWriteSupport} +import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport /** * Benchmark to measure performance for aggregate primitives. @@ -46,9 +46,9 @@ object BypassMergeSortShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase def getWriter(transferTo: Boolean): BypassMergeSortShuffleWriter[String, String] = { val conf = new SparkConf(loadDefaults = false) - val shuffleWriteSupport = new DefaultShuffleWriteSupport(conf, blockResolver) conf.set("spark.file.transferTo", String.valueOf(transferTo)) conf.set("spark.shuffle.file.buffer", "32k") + val shuffleWriteSupport = new DefaultShuffleWriteSupport(conf, blockResolver) val shuffleWriter = new BypassMergeSortShuffleWriter[String, String]( blockManager, diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala index 15a08111f6d5..20bf3eac95d8 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala @@ -18,6 +18,7 @@ package org.apache.spark.shuffle.sort import org.apache.spark.{SparkConf, TaskContext} import org.apache.spark.benchmark.Benchmark +import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport /** * Benchmark to measure performance for aggregate primitives. @@ -42,6 +43,7 @@ object UnsafeShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase { def getWriter(transferTo: Boolean): UnsafeShuffleWriter[String, String] = { val conf = new SparkConf(loadDefaults = false) conf.set("spark.file.transferTo", String.valueOf(transferTo)) + val shuffleWriteSupport = new DefaultShuffleWriteSupport(conf, blockResolver) TaskContext.setTaskContext(taskContext) new UnsafeShuffleWriter[String, String]( @@ -52,7 +54,8 @@ object UnsafeShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase { 0, taskContext, conf, - taskContext.taskMetrics().shuffleWriteMetrics + taskContext.taskMetrics().shuffleWriteMetrics, + shuffleWriteSupport ) } From b9a2bc9aef788b67889fb9be3f59141ad416fbb1 Mon Sep 17 00:00:00 2001 From: mccheah Date: Fri, 19 Apr 2019 13:17:02 -0700 Subject: [PATCH 09/17] [SPARK-25299] Shuffle locations api (#517) Implements the shuffle locations API as part of SPARK-25299. This adds an additional field to all `MapStatus` objects: a `MapShuffleLocations` that indicates where a task's map output is stored. This module is optional and implementations of the pluggable shuffle writers and readers can ignore it accordingly. This API is designed with the use case in mind of future plugin implementations desiring to have the driver store metadata about where shuffle blocks are stored. There are a few caveats to this design: - We originally wanted to remove the `BlockManagerId` from `MapStatus` entirely and replace it with this object. However, doing this proves to be very difficult, as many places use the block manager ID for other kinds of shuffle data bookkeeping. As a result, we concede to storing the block manager ID redundantly here. However, the overhead should be minimal: because we cache block manager ids and default map shuffle locations, the two fields in `MapStatus` should point to the same object on the heap. Thus we add `O(M)` storage overhead on the driver, where for each map status we're storing an additional pointer to the same on-heap object. We will run benchmarks against the TPC-DS workload to see if there are significant performance repercussions for this implementation. - `KryoSerializer` expects `CompressedMapStatus` and `HighlyCompressedMapStatus` to be serialized via reflection, so originally all fields of these classes needed to be registered with Kryo. However, the `MapShuffleLocations` is now pluggable. We think however that previously Kryo was defaulting to Java serialization anyways, so we now just explicitly tell Kryo to use `ExternalizableSerializer` to deal with these objects. There's a small hack in the serialization protocol that attempts to avoid serializing the same `BlockManagerId` twice in the case that the map shuffle locations is a `DefaultMapShuffleLocations`. --- .../api/shuffle/MapShuffleLocations.java | 39 ++++++ .../spark/api/shuffle/ShuffleLocation.java | 25 ++++ .../api/shuffle/ShuffleMapOutputWriter.java | 3 +- .../sort/BypassMergeSortShuffleWriter.java | 20 ++- .../sort/DefaultMapShuffleLocations.java | 76 +++++++++++ .../shuffle/sort/UnsafeShuffleWriter.java | 10 +- .../io/DefaultShuffleExecutorComponents.java | 2 +- .../io/DefaultShuffleMapOutputWriter.java | 10 +- .../sort/io/DefaultShuffleWriteSupport.java | 8 +- .../org/apache/spark/MapOutputTracker.scala | 29 ++-- .../apache/spark/scheduler/MapStatus.scala | 125 +++++++++++++++--- .../spark/serializer/KryoSerializer.scala | 6 +- .../shuffle/BlockStoreShuffleReader.scala | 12 +- .../shuffle/sort/SortShuffleWriter.scala | 7 +- .../apache/spark/storage/BlockManagerId.scala | 4 +- .../sort/UnsafeShuffleWriterSuite.java | 7 +- .../apache/spark/MapOutputTrackerSuite.scala | 39 +++--- .../scala/org/apache/spark/ShuffleSuite.scala | 6 +- .../spark/scheduler/DAGSchedulerSuite.scala | 81 +++++++----- .../spark/scheduler/MapStatusSuite.scala | 28 ++-- .../scheduler/SchedulerIntegrationSuite.scala | 3 +- .../serializer/KryoSerializerSuite.scala | 7 +- .../BlockStoreShuffleReaderSuite.scala | 8 +- .../BlockStoreShuffleReaderBenchmark.scala | 4 +- ...ypassMergeSortShuffleWriterBenchmark.scala | 4 +- .../BypassMergeSortShuffleWriterSuite.scala | 6 +- .../sort/SortShuffleWriterBenchmark.scala | 6 +- .../sort/UnsafeShuffleWriterBenchmark.scala | 4 +- .../DefaultShuffleMapOutputWriterSuite.scala | 9 +- 29 files changed, 463 insertions(+), 125 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/api/shuffle/MapShuffleLocations.java create mode 100644 core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocation.java create mode 100644 core/src/main/java/org/apache/spark/shuffle/sort/DefaultMapShuffleLocations.java diff --git a/core/src/main/java/org/apache/spark/api/shuffle/MapShuffleLocations.java b/core/src/main/java/org/apache/spark/api/shuffle/MapShuffleLocations.java new file mode 100644 index 000000000000..b0aed4d08d38 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/shuffle/MapShuffleLocations.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.api.shuffle; + +import org.apache.spark.annotation.Experimental; + +import java.io.Serializable; + +/** + * Represents metadata about where shuffle blocks were written in a single map task. + *

+ * This is optionally returned by shuffle writers. The inner shuffle locations may + * be accessed by shuffle readers. Shuffle locations are only necessary when the + * location of shuffle blocks needs to be managed by the driver; shuffle plugins + * may choose to use an external database or other metadata management systems to + * track the locations of shuffle blocks instead. + */ +@Experimental +public interface MapShuffleLocations extends Serializable { + + /** + * Get the location for a given shuffle block written by this map task. + */ + ShuffleLocation getLocationForBlock(int reduceId); +} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocation.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocation.java new file mode 100644 index 000000000000..87eb497098e0 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocation.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.shuffle; + +/** + * Marker interface representing a location of a shuffle block. Implementations of shuffle readers + * and writers are expected to cast this down to an implementation-specific representation. + */ +public interface ShuffleLocation { +} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java index 5119e34803a8..181701175d35 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java @@ -20,6 +20,7 @@ import java.io.IOException; import org.apache.spark.annotation.Experimental; +import org.apache.spark.api.java.Optional; /** * :: Experimental :: @@ -31,7 +32,7 @@ public interface ShuffleMapOutputWriter { ShufflePartitionWriter getNextPartitionWriter() throws IOException; - void commitAllPartitions() throws IOException; + Optional commitAllPartitions() throws IOException; void abort(Throwable error) throws IOException; } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index aef133fe7d46..434286175e41 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -25,6 +25,8 @@ import java.nio.channels.WritableByteChannel; import javax.annotation.Nullable; +import org.apache.spark.api.java.Optional; +import org.apache.spark.api.shuffle.MapShuffleLocations; import scala.None$; import scala.Option; import scala.Product2; @@ -134,8 +136,11 @@ public void write(Iterator> records) throws IOException { try { if (!records.hasNext()) { partitionLengths = new long[numPartitions]; - mapOutputWriter.commitAllPartitions(); - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + Optional blockLocs = mapOutputWriter.commitAllPartitions(); + mapStatus = MapStatus$.MODULE$.apply( + blockManager.shuffleServerId(), + blockLocs.orNull(), + partitionLengths); return; } final SerializerInstance serInstance = serializer.newInstance(); @@ -168,8 +173,11 @@ public void write(Iterator> records) throws IOException { } partitionLengths = writePartitionedData(mapOutputWriter); - mapOutputWriter.commitAllPartitions(); - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + Optional mapLocations = mapOutputWriter.commitAllPartitions(); + mapStatus = MapStatus$.MODULE$.apply( + blockManager.shuffleServerId(), + mapLocations.orNull(), + partitionLengths); } catch (Exception e) { try { mapOutputWriter.abort(e); @@ -178,6 +186,10 @@ public void write(Iterator> records) throws IOException { } throw e; } + mapStatus = MapStatus$.MODULE$.apply( + blockManager.shuffleServerId(), + DefaultMapShuffleLocations.get(blockManager.shuffleServerId()), + partitionLengths); } @VisibleForTesting diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/DefaultMapShuffleLocations.java b/core/src/main/java/org/apache/spark/shuffle/sort/DefaultMapShuffleLocations.java new file mode 100644 index 000000000000..ffd97c0f2660 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/DefaultMapShuffleLocations.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort; + +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.CacheLoader; +import com.google.common.cache.LoadingCache; + +import org.apache.spark.api.shuffle.MapShuffleLocations; +import org.apache.spark.api.shuffle.ShuffleLocation; +import org.apache.spark.storage.BlockManagerId; + +import java.util.Objects; + +public class DefaultMapShuffleLocations implements MapShuffleLocations, ShuffleLocation { + + /** + * We borrow the cache size from the BlockManagerId's cache - around 1MB, which should be + * feasible. + */ + private static final LoadingCache + DEFAULT_SHUFFLE_LOCATIONS_CACHE = + CacheBuilder.newBuilder() + .maximumSize(BlockManagerId.blockManagerIdCacheSize()) + .build(new CacheLoader() { + @Override + public DefaultMapShuffleLocations load(BlockManagerId blockManagerId) { + return new DefaultMapShuffleLocations(blockManagerId); + } + }); + + private final BlockManagerId location; + + public DefaultMapShuffleLocations(BlockManagerId blockManagerId) { + this.location = blockManagerId; + } + + public static DefaultMapShuffleLocations get(BlockManagerId blockManagerId) { + return DEFAULT_SHUFFLE_LOCATIONS_CACHE.getUnchecked(blockManagerId); + } + + @Override + public ShuffleLocation getLocationForBlock(int reduceId) { + return this; + } + + public BlockManagerId getBlockManagerId() { + return location; + } + + @Override + public boolean equals(Object other) { + return other instanceof DefaultMapShuffleLocations + && Objects.equals(((DefaultMapShuffleLocations) other).location, location); + } + + @Override + public int hashCode() { + return Objects.hashCode(location); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index b5ca6c8a50ce..95c4577cb770 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -23,6 +23,8 @@ import java.nio.channels.WritableByteChannel; import java.util.Iterator; +import org.apache.spark.api.java.Optional; +import org.apache.spark.api.shuffle.MapShuffleLocations; import scala.Option; import scala.Product2; import scala.collection.JavaConverters; @@ -221,6 +223,7 @@ void closeAndWriteOutput() throws IOException { final ShuffleMapOutputWriter mapWriter = shuffleWriteSupport .createMapOutputWriter(shuffleId, mapId, partitioner.numPartitions()); final long[] partitionLengths; + Optional mapLocations; try { try { partitionLengths = mergeSpills(spills, mapWriter); @@ -231,7 +234,7 @@ void closeAndWriteOutput() throws IOException { } } } - mapWriter.commitAllPartitions(); + mapLocations = mapWriter.commitAllPartitions(); } catch (Exception e) { try { mapWriter.abort(e); @@ -240,7 +243,10 @@ void closeAndWriteOutput() throws IOException { } throw e; } - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + mapStatus = MapStatus$.MODULE$.apply( + blockManager.shuffleServerId(), + mapLocations.orNull(), + partitionLengths); } @VisibleForTesting diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java index 76e87a674025..f7ec202ef4b9 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java @@ -46,6 +46,6 @@ public ShuffleWriteSupport writes() { throw new IllegalStateException( "Executor components must be initialized before getting writers."); } - return new DefaultShuffleWriteSupport(sparkConf, blockResolver); + return new DefaultShuffleWriteSupport(sparkConf, blockResolver, blockManager.shuffleServerId()); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java index c84158e1891d..7eb0d56776de 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java @@ -24,6 +24,10 @@ import java.io.OutputStream; import java.nio.channels.FileChannel; +import org.apache.spark.api.java.Optional; +import org.apache.spark.api.shuffle.MapShuffleLocations; +import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations; +import org.apache.spark.storage.BlockManagerId; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -49,6 +53,7 @@ public class DefaultShuffleMapOutputWriter implements ShuffleMapOutputWriter { private final int bufferSize; private int currPartitionId = 0; private long currChannelPosition; + private final BlockManagerId shuffleServerId; private final File outputFile; private File outputTempFile; @@ -61,11 +66,13 @@ public DefaultShuffleMapOutputWriter( int shuffleId, int mapId, int numPartitions, + BlockManagerId shuffleServerId, ShuffleWriteMetricsReporter metrics, IndexShuffleBlockResolver blockResolver, SparkConf sparkConf) { this.shuffleId = shuffleId; this.mapId = mapId; + this.shuffleServerId = shuffleServerId; this.metrics = metrics; this.blockResolver = blockResolver; this.bufferSize = @@ -90,10 +97,11 @@ public ShufflePartitionWriter getNextPartitionWriter() throws IOException { } @Override - public void commitAllPartitions() throws IOException { + public Optional commitAllPartitions() throws IOException { cleanUp(); File resolvedTmp = outputTempFile != null && outputTempFile.isFile() ? outputTempFile : null; blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, resolvedTmp); + return Optional.of(DefaultMapShuffleLocations.get(shuffleServerId)); } @Override diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleWriteSupport.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleWriteSupport.java index f8fadd0ecfa6..86f158349568 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleWriteSupport.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleWriteSupport.java @@ -22,17 +22,21 @@ import org.apache.spark.api.shuffle.ShuffleMapOutputWriter; import org.apache.spark.api.shuffle.ShuffleWriteSupport; import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.storage.BlockManagerId; public class DefaultShuffleWriteSupport implements ShuffleWriteSupport { private final SparkConf sparkConf; private final IndexShuffleBlockResolver blockResolver; + private final BlockManagerId shuffleServerId; public DefaultShuffleWriteSupport( SparkConf sparkConf, - IndexShuffleBlockResolver blockResolver) { + IndexShuffleBlockResolver blockResolver, + BlockManagerId shuffleServerId) { this.sparkConf = sparkConf; this.blockResolver = blockResolver; + this.shuffleServerId = shuffleServerId; } @Override @@ -41,7 +45,7 @@ public ShuffleMapOutputWriter createMapOutputWriter( int mapId, int numPartitions) { return new DefaultShuffleMapOutputWriter( - shuffleId, mapId, numPartitions, + shuffleId, mapId, numPartitions, shuffleServerId, TaskContext.get().taskMetrics().shuffleWriteMetrics(), blockResolver, sparkConf); } } diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 1d4b1ef9c9a1..74975019e748 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -28,6 +28,7 @@ import scala.concurrent.duration.Duration import scala.reflect.ClassTag import scala.util.control.NonFatal +import org.apache.spark.api.shuffle.{MapShuffleLocations, ShuffleLocation} import org.apache.spark.broadcast.{Broadcast, BroadcastManager} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ @@ -281,9 +282,9 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } // For testing - def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int) - : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { - getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1) + def getMapSizesByShuffleLocation(shuffleId: Int, reduceId: Int) + : Iterator[(ShuffleLocation, Seq[(BlockId, Long)])] = { + getMapSizesByShuffleLocation(shuffleId, reduceId, reduceId + 1) } /** @@ -295,8 +296,8 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging * and the second item is a sequence of (shuffle block id, shuffle block size) tuples * describing the shuffle blocks that are stored at that block manager. */ - def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] + def getMapSizesByShuffleLocation(shuffleId: Int, startPartition: Int, endPartition: Int) + : Iterator[(ShuffleLocation, Seq[(BlockId, Long)])] /** * Deletes map output status information for the specified shuffle stage. @@ -645,8 +646,8 @@ private[spark] class MapOutputTrackerMaster( // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. // This method is only called in local-mode. - def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { + def getMapSizesByShuffleLocation(shuffleId: Int, startPartition: Int, endPartition: Int) + : Iterator[(ShuffleLocation, Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") shuffleStatuses.get(shuffleId) match { case Some (shuffleStatus) => @@ -682,12 +683,13 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr private val fetching = new HashSet[Int] // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. - override def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { + override def getMapSizesByShuffleLocation(shuffleId: Int, startPartition: Int, endPartition: Int) + : Iterator[(ShuffleLocation, Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") val statuses = getStatuses(shuffleId) try { - MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) + MapOutputTracker.convertMapStatuses( + shuffleId, startPartition, endPartition, statuses) } catch { case e: MetadataFetchFailedException => // We experienced a fetch failure so our mapStatuses cache is outdated; clear it: @@ -871,9 +873,9 @@ private[spark] object MapOutputTracker extends Logging { shuffleId: Int, startPartition: Int, endPartition: Int, - statuses: Array[MapStatus]): Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { + statuses: Array[MapStatus]): Iterator[(ShuffleLocation, Seq[(BlockId, Long)])] = { assert (statuses != null) - val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long)]] + val splitsByAddress = new HashMap[ShuffleLocation, ListBuffer[(BlockId, Long)]] for ((status, mapId) <- statuses.iterator.zipWithIndex) { if (status == null) { val errorMessage = s"Missing an output location for shuffle $shuffleId" @@ -883,7 +885,8 @@ private[spark] object MapOutputTracker extends Logging { for (part <- startPartition until endPartition) { val size = status.getSizeForBlock(part) if (size != 0) { - splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) += + val shuffleLoc = status.mapShuffleLocations.getLocationForBlock(part) + splitsByAddress.getOrElseUpdate(shuffleLoc, ListBuffer()) += ((ShuffleBlockId(shuffleId, mapId, part), size)) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 64f0a060a247..a61f9bd14ef2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -24,7 +24,9 @@ import scala.collection.mutable import org.roaringbitmap.RoaringBitmap import org.apache.spark.SparkEnv +import org.apache.spark.api.shuffle.MapShuffleLocations import org.apache.spark.internal.config +import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.Utils @@ -33,7 +35,16 @@ import org.apache.spark.util.Utils * task ran on as well as the sizes of outputs for each reducer, for passing on to the reduce tasks. */ private[spark] sealed trait MapStatus { - /** Location where this task was run. */ + + /** + * Locations where this task stored shuffle blocks. + * + * May be null if the MapOutputTracker is not tracking the location of shuffle blocks, leaving it + * up to the implementation of shuffle plugins to do so. + */ + def mapShuffleLocations: MapShuffleLocations + + /** Location where the task was run. */ def location: BlockManagerId /** @@ -56,11 +67,31 @@ private[spark] object MapStatus { .map(_.conf.get(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS)) .getOrElse(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS.defaultValue.get) + // A temporary concession to the fact that we only expect implementations of shuffle provided by + // Spark to be storing shuffle locations in the driver, meaning we want to introduce as little + // serialization overhead as possible in such default cases. + // + // If more similar cases arise, consider adding a serialization API for these shuffle locations. + private val DEFAULT_MAP_SHUFFLE_LOCATIONS_ID: Byte = 0 + private val NON_DEFAULT_MAP_SHUFFLE_LOCATIONS_ID: Byte = 1 + + /** + * Visible for testing. + */ def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = { + apply(loc, DefaultMapShuffleLocations.get(loc), uncompressedSizes) + } + + def apply( + loc: BlockManagerId, + mapShuffleLocs: MapShuffleLocations, + uncompressedSizes: Array[Long]): MapStatus = { if (uncompressedSizes.length > minPartitionsToUseHighlyCompressMapStatus) { - HighlyCompressedMapStatus(loc, uncompressedSizes) + HighlyCompressedMapStatus( + loc, mapShuffleLocs, uncompressedSizes) } else { - new CompressedMapStatus(loc, uncompressedSizes) + new CompressedMapStatus( + loc, mapShuffleLocs, uncompressedSizes) } } @@ -91,41 +122,89 @@ private[spark] object MapStatus { math.pow(LOG_BASE, compressedSize & 0xFF).toLong } } -} + def writeLocations( + loc: BlockManagerId, + mapShuffleLocs: MapShuffleLocations, + out: ObjectOutput): Unit = { + if (mapShuffleLocs != null) { + out.writeBoolean(true) + if (mapShuffleLocs.isInstanceOf[DefaultMapShuffleLocations] + && mapShuffleLocs.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId == loc) { + out.writeByte(MapStatus.DEFAULT_MAP_SHUFFLE_LOCATIONS_ID) + } else { + out.writeByte(MapStatus.NON_DEFAULT_MAP_SHUFFLE_LOCATIONS_ID) + out.writeObject(mapShuffleLocs) + } + } else { + out.writeBoolean(false) + } + loc.writeExternal(out) + } + + def readLocations(in: ObjectInput): (BlockManagerId, MapShuffleLocations) = { + if (in.readBoolean()) { + val locId = in.readByte() + if (locId == MapStatus.DEFAULT_MAP_SHUFFLE_LOCATIONS_ID) { + val blockManagerId = BlockManagerId(in) + (blockManagerId, DefaultMapShuffleLocations.get(blockManagerId)) + } else { + val mapShuffleLocations = in.readObject().asInstanceOf[MapShuffleLocations] + val blockManagerId = BlockManagerId(in) + (blockManagerId, mapShuffleLocations) + } + } else { + val blockManagerId = BlockManagerId(in) + (blockManagerId, null) + } + } +} /** * A [[MapStatus]] implementation that tracks the size of each block. Size for each block is * represented using a single byte. * - * @param loc location where the task is being executed. + * @param loc Location were the task is being executed. + * @param mapShuffleLocs locations where the task stored its shuffle blocks - may be null. * @param compressedSizes size of the blocks, indexed by reduce partition id. */ private[spark] class CompressedMapStatus( private[this] var loc: BlockManagerId, + private[this] var mapShuffleLocs: MapShuffleLocations, private[this] var compressedSizes: Array[Byte]) extends MapStatus with Externalizable { - protected def this() = this(null, null.asInstanceOf[Array[Byte]]) // For deserialization only + // For deserialization only + protected def this() = this(null, null, null.asInstanceOf[Array[Byte]]) - def this(loc: BlockManagerId, uncompressedSizes: Array[Long]) { - this(loc, uncompressedSizes.map(MapStatus.compressSize)) + def this( + loc: BlockManagerId, + mapShuffleLocations: MapShuffleLocations, + uncompressedSizes: Array[Long]) { + this( + loc, + mapShuffleLocations, + uncompressedSizes.map(MapStatus.compressSize)) } override def location: BlockManagerId = loc + override def mapShuffleLocations: MapShuffleLocations = mapShuffleLocs + override def getSizeForBlock(reduceId: Int): Long = { MapStatus.decompressSize(compressedSizes(reduceId)) } override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { - loc.writeExternal(out) + MapStatus.writeLocations(loc, mapShuffleLocs, out) out.writeInt(compressedSizes.length) out.write(compressedSizes) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { - loc = BlockManagerId(in) + val (deserializedLoc, deserializedMapShuffleLocs) = MapStatus.readLocations(in) + loc = deserializedLoc + mapShuffleLocs = deserializedMapShuffleLocs val len = in.readInt() compressedSizes = new Array[Byte](len) in.readFully(compressedSizes) @@ -138,6 +217,7 @@ private[spark] class CompressedMapStatus( * plus a bitmap for tracking which blocks are empty. * * @param loc location where the task is being executed + * @param mapShuffleLocs location where the task stored shuffle blocks - may be null * @param numNonEmptyBlocks the number of non-empty blocks * @param emptyBlocks a bitmap tracking which blocks are empty * @param avgSize average size of the non-empty and non-huge blocks @@ -145,6 +225,7 @@ private[spark] class CompressedMapStatus( */ private[spark] class HighlyCompressedMapStatus private ( private[this] var loc: BlockManagerId, + private[this] var mapShuffleLocs: MapShuffleLocations, private[this] var numNonEmptyBlocks: Int, private[this] var emptyBlocks: RoaringBitmap, private[this] var avgSize: Long, @@ -155,10 +236,12 @@ private[spark] class HighlyCompressedMapStatus private ( require(loc == null || avgSize > 0 || hugeBlockSizes.size > 0 || numNonEmptyBlocks == 0, "Average size can only be zero for map stages that produced no output") - protected def this() = this(null, -1, null, -1, null) // For deserialization only + protected def this() = this(null, null, -1, null, -1, null) // For deserialization only override def location: BlockManagerId = loc + override def mapShuffleLocations: MapShuffleLocations = mapShuffleLocs + override def getSizeForBlock(reduceId: Int): Long = { assert(hugeBlockSizes != null) if (emptyBlocks.contains(reduceId)) { @@ -172,7 +255,7 @@ private[spark] class HighlyCompressedMapStatus private ( } override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { - loc.writeExternal(out) + MapStatus.writeLocations(loc, mapShuffleLocs, out) emptyBlocks.writeExternal(out) out.writeLong(avgSize) out.writeInt(hugeBlockSizes.size) @@ -183,7 +266,9 @@ private[spark] class HighlyCompressedMapStatus private ( } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { - loc = BlockManagerId(in) + val (deserializedLoc, deserializedMapShuffleLocs) = MapStatus.readLocations(in) + loc = deserializedLoc + mapShuffleLocs = deserializedMapShuffleLocs emptyBlocks = new RoaringBitmap() emptyBlocks.readExternal(in) avgSize = in.readLong() @@ -199,7 +284,10 @@ private[spark] class HighlyCompressedMapStatus private ( } private[spark] object HighlyCompressedMapStatus { - def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): HighlyCompressedMapStatus = { + def apply( + loc: BlockManagerId, + mapShuffleLocs: MapShuffleLocations, + uncompressedSizes: Array[Long]): HighlyCompressedMapStatus = { // We must keep track of which blocks are empty so that we don't report a zero-sized // block as being non-empty (or vice-versa) when using the average block size. var i = 0 @@ -239,7 +327,12 @@ private[spark] object HighlyCompressedMapStatus { } emptyBlocks.trim() emptyBlocks.runOptimize() - new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize, - hugeBlockSizes) + new HighlyCompressedMapStatus( + loc, + mapShuffleLocs, + numNonEmptyBlocks, + emptyBlocks, + avgSize, + hugeBlockSizes) } } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index eef19973e8d7..fd8d2cd4e1af 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -31,7 +31,7 @@ import com.esotericsoftware.kryo.{Kryo, KryoException, Serializer => KryoClassSe import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} import com.esotericsoftware.kryo.io.{UnsafeInput => KryoUnsafeInput, UnsafeOutput => KryoUnsafeOutput} import com.esotericsoftware.kryo.pool.{KryoCallback, KryoFactory, KryoPool} -import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer} +import com.esotericsoftware.kryo.serializers.{ExternalizableSerializer, JavaSerializer => KryoJavaSerializer} import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator} import org.apache.avro.generic.{GenericData, GenericRecord} import org.roaringbitmap.RoaringBitmap @@ -151,6 +151,8 @@ class KryoSerializer(conf: SparkConf) kryo.register(classOf[SerializableConfiguration], new KryoJavaSerializer()) kryo.register(classOf[SerializableJobConf], new KryoJavaSerializer()) kryo.register(classOf[PythonBroadcast], new KryoJavaSerializer()) + kryo.register(classOf[CompressedMapStatus], new ExternalizableSerializer()) + kryo.register(classOf[HighlyCompressedMapStatus], new ExternalizableSerializer()) kryo.register(classOf[GenericRecord], new GenericAvroSerializer(avroSchemas)) kryo.register(classOf[GenericData.Record], new GenericAvroSerializer(avroSchemas)) @@ -482,8 +484,6 @@ private[serializer] object KryoSerializer { private val toRegister: Seq[Class[_]] = Seq( ByteBuffer.allocate(1).getClass, classOf[StorageLevel], - classOf[CompressedMapStatus], - classOf[HighlyCompressedMapStatus], classOf[CompactBuffer[_]], classOf[BlockManagerId], classOf[Array[Boolean]], diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index c7843710413d..4d559556360c 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -20,7 +20,8 @@ package org.apache.spark.shuffle import org.apache.spark._ import org.apache.spark.internal.{config, Logging} import org.apache.spark.serializer.SerializerManager -import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} +import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations +import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockFetcherIterator} import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter @@ -47,7 +48,14 @@ private[spark] class BlockStoreShuffleReader[K, C]( context, blockManager.shuffleClient, blockManager, - mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition), + mapOutputTracker.getMapSizesByShuffleLocation(handle.shuffleId, startPartition, endPartition) + .map { + case (loc: DefaultMapShuffleLocations, blocks: Seq[(BlockId, Long)]) => + (loc.getBlockManagerId, blocks) + case _ => + throw new UnsupportedOperationException("Not allowed to using non-default map shuffle" + + " locations yet.") + }, serializerManager.wrapStream, // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024, diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 62316f384b64..1fcae684b005 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -67,8 +67,11 @@ private[spark] class SortShuffleWriter[K, V, C]( val mapOutputWriter = writeSupport.createMapOutputWriter( dep.shuffleId, mapId, dep.partitioner.numPartitions) val partitionLengths = sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter) - mapOutputWriter.commitAllPartitions() - mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) + val mapLocations = mapOutputWriter.commitAllPartitions() + mapStatus = MapStatus( + blockManager.shuffleServerId, + mapLocations.orNull(), + partitionLengths) } /** Close this writer, passing along whether the map completed */ diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index d188bdd912e5..97b99e08d9ca 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -132,12 +132,14 @@ private[spark] object BlockManagerId { getCachedBlockManagerId(obj) } + val blockManagerIdCacheSize = 10000 + /** * The max cache size is hardcoded to 10000, since the size of a BlockManagerId * object is about 48B, the total memory cost should be below 1MB which is feasible. */ val blockManagerIdCache = CacheBuilder.newBuilder() - .maximumSize(10000) + .maximumSize(blockManagerIdCacheSize) .build(new CacheLoader[BlockManagerId, BlockManagerId]() { override def load(id: BlockManagerId) = id }) diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 012dc5d21bce..5f0de31bd25e 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -172,6 +172,8 @@ public void setUp() throws IOException { when(shuffleDep.serializer()).thenReturn(serializer); when(shuffleDep.partitioner()).thenReturn(hashPartitioner); when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager); + when(blockManager.shuffleServerId()).thenReturn(BlockManagerId.apply( + "0", "localhost", 9099, Option.empty())); TaskContext$.MODULE$.setTaskContext(taskContext); } @@ -188,8 +190,7 @@ private UnsafeShuffleWriter createWriter( taskContext, conf, taskContext.taskMetrics().shuffleWriteMetrics(), - new DefaultShuffleWriteSupport(conf, shuffleBlockResolver) - ); + new DefaultShuffleWriteSupport(conf, shuffleBlockResolver, blockManager.shuffleServerId())); } private void assertSpillFilesWereCleanedUp() { @@ -550,7 +551,7 @@ public void testPeakMemoryUsed() throws Exception { taskContext, conf, taskContext.taskMetrics().shuffleWriteMetrics(), - new DefaultShuffleWriteSupport(conf, shuffleBlockResolver)); + new DefaultShuffleWriteSupport(conf, shuffleBlockResolver, blockManager.shuffleServerId())); // Peak memory should be monotonically increasing. More specifically, every time // we allocate a new page it should increase by exactly the size of the page. diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index d86975964b55..0a77c4f6d583 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.internal.config.Network.{RPC_ASK_TIMEOUT, RPC_MESSAGE_MA import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEnv} import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus} import org.apache.spark.shuffle.FetchFailedException +import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} class MapOutputTrackerSuite extends SparkFunSuite { @@ -67,10 +68,13 @@ class MapOutputTrackerSuite extends SparkFunSuite { Array(1000L, 10000L))) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), Array(10000L, 1000L))) - val statuses = tracker.getMapSizesByExecutorId(10, 0) + val statuses = tracker.getMapSizesByShuffleLocation(10, 0) assert(statuses.toSet === - Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))), - (BlockManagerId("b", "hostB", 1000), ArrayBuffer((ShuffleBlockId(10, 1, 0), size10000)))) + Seq( + (DefaultMapShuffleLocations.get(BlockManagerId("a", "hostA", 1000)), + ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))), + (DefaultMapShuffleLocations.get(BlockManagerId("b", "hostB", 1000)), + ArrayBuffer((ShuffleBlockId(10, 1, 0), size10000)))) .toSet) assert(0 == tracker.getNumCachedSerializedBroadcast) tracker.stop() @@ -90,11 +94,11 @@ class MapOutputTrackerSuite extends SparkFunSuite { tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), Array(compressedSize10000, compressedSize1000))) assert(tracker.containsShuffle(10)) - assert(tracker.getMapSizesByExecutorId(10, 0).nonEmpty) + assert(tracker.getMapSizesByShuffleLocation(10, 0).nonEmpty) assert(0 == tracker.getNumCachedSerializedBroadcast) tracker.unregisterShuffle(10) assert(!tracker.containsShuffle(10)) - assert(tracker.getMapSizesByExecutorId(10, 0).isEmpty) + assert(tracker.getMapSizesByShuffleLocation(10, 0).isEmpty) tracker.stop() rpcEnv.shutdown() @@ -121,7 +125,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { // The remaining reduce task might try to grab the output despite the shuffle failure; // this should cause it to fail, and the scheduler will ignore the failure due to the // stage already being aborted. - intercept[FetchFailedException] { tracker.getMapSizesByExecutorId(10, 1) } + intercept[FetchFailedException] { tracker.getMapSizesByShuffleLocation(10, 1) } tracker.stop() rpcEnv.shutdown() @@ -143,24 +147,26 @@ class MapOutputTrackerSuite extends SparkFunSuite { masterTracker.registerShuffle(10, 1) slaveTracker.updateEpoch(masterTracker.getEpoch) // This is expected to fail because no outputs have been registered for the shuffle. - intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } + intercept[FetchFailedException] { slaveTracker.getMapSizesByShuffleLocation(10, 0) } val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) masterTracker.registerMapOutput(10, 0, MapStatus( BlockManagerId("a", "hostA", 1000), Array(1000L))) slaveTracker.updateEpoch(masterTracker.getEpoch) - assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq === - Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) + assert(slaveTracker.getMapSizesByShuffleLocation(10, 0).toSeq === + Seq( + (DefaultMapShuffleLocations.get(BlockManagerId("a", "hostA", 1000)), + ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) assert(0 == masterTracker.getNumCachedSerializedBroadcast) val masterTrackerEpochBeforeLossOfMapOutput = masterTracker.getEpoch masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) assert(masterTracker.getEpoch > masterTrackerEpochBeforeLossOfMapOutput) slaveTracker.updateEpoch(masterTracker.getEpoch) - intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } + intercept[FetchFailedException] { slaveTracker.getMapSizesByShuffleLocation(10, 0) } // failure should be cached - intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } + intercept[FetchFailedException] { slaveTracker.getMapSizesByShuffleLocation(10, 0) } assert(0 == masterTracker.getNumCachedSerializedBroadcast) masterTracker.stop() @@ -261,8 +267,11 @@ class MapOutputTrackerSuite extends SparkFunSuite { // being sent. masterTracker.registerShuffle(20, 100) (0 until 100).foreach { i => + val bmId = BlockManagerId("999", "mps", 1000) masterTracker.registerMapOutput(20, i, new CompressedMapStatus( - BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0))) + bmId, + DefaultMapShuffleLocations.get(bmId), + Array.fill[Long](4000000)(0))) } val senderAddress = RpcAddress("localhost", 12345) val rpcCallContext = mock(classOf[RpcCallContext]) @@ -315,11 +324,11 @@ class MapOutputTrackerSuite extends SparkFunSuite { tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), Array(size10000, size0, size1000, size0))) assert(tracker.containsShuffle(10)) - assert(tracker.getMapSizesByExecutorId(10, 0, 4).toSeq === + assert(tracker.getMapSizesByShuffleLocation(10, 0, 4).toSeq === Seq( - (BlockManagerId("a", "hostA", 1000), + (DefaultMapShuffleLocations.get(BlockManagerId("a", "hostA", 1000)), Seq((ShuffleBlockId(10, 0, 1), size1000), (ShuffleBlockId(10, 0, 3), size10000))), - (BlockManagerId("b", "hostB", 1000), + (DefaultMapShuffleLocations.get(BlockManagerId("b", "hostB", 1000)), Seq((ShuffleBlockId(10, 1, 0), size10000), (ShuffleBlockId(10, 1, 2), size1000))) ) ) diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 90c790cefcca..83026c002f1b 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -73,7 +73,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC // All blocks must have non-zero size (0 until NUM_BLOCKS).foreach { id => - val statuses = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(shuffleId, id) + val statuses = SparkEnv.get.mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, id) assert(statuses.forall(_._2.forall(blockIdSizePair => blockIdSizePair._2 > 0))) } } @@ -112,7 +112,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC assert(c.count === 4) val blockSizes = (0 until NUM_BLOCKS).flatMap { id => - val statuses = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(shuffleId, id) + val statuses = SparkEnv.get.mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, id) statuses.flatMap(_._2.map(_._2)) } val nonEmptyBlocks = blockSizes.filter(x => x > 0) @@ -137,7 +137,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC assert(c.count === 4) val blockSizes = (0 until NUM_BLOCKS).flatMap { id => - val statuses = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(shuffleId, id) + val statuses = SparkEnv.get.mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, id) statuses.flatMap(_._2.map(_._2)) } val nonEmptyBlocks = blockSizes.filter(x => x > 0) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index e74f4627db9b..8c0d7baccf85 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -30,12 +30,14 @@ import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.time.SpanSugar._ import org.apache.spark._ +import org.apache.spark.api.shuffle.MapShuffleLocations import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.executor.ExecutorMetrics import org.apache.spark.internal.config import org.apache.spark.rdd.{DeterministicLevel, RDD} import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.shuffle.{FetchFailedException, MetadataFetchFailedException} +import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, CallSite, LongAccumulator, Utils} @@ -701,8 +703,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostB", 1)))) - assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === - HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) + assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet === + HashSet(makeShuffleLocation("hostA"), makeShuffleLocation("hostB"))) complete(taskSets(1), Seq((Success, 42))) assert(results === Map(0 -> 42)) assertDataStructuresEmpty() @@ -728,8 +730,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // have the 2nd attempt pass complete(taskSets(2), Seq((Success, makeMapStatus("hostA", reduceRdd.partitions.length)))) // we can see both result blocks now - assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet === - HashSet("hostA", "hostB")) + assert(mapOutputTracker + .getMapSizesByShuffleLocation(shuffleId, 0) + .map(_._1.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host) + .toSet === HashSet("hostA", "hostB")) complete(taskSets(3), Seq((Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) assertDataStructuresEmpty() @@ -767,11 +771,11 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi runEvent(ExecutorLost("exec-hostA", event)) if (expectFileLoss) { intercept[MetadataFetchFailedException] { - mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0) + mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0) } } else { - assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === - HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) + assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet === + HashSet(makeShuffleLocation("hostA"), makeShuffleLocation("hostB"))) } } } @@ -1064,8 +1068,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi (Success, makeMapStatus("hostA", reduceRdd.partitions.length)), (Success, makeMapStatus("hostB", reduceRdd.partitions.length)))) // The MapOutputTracker should know about both map output locations. - assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet === - HashSet("hostA", "hostB")) + assert(mapOutputTracker + .getMapSizesByShuffleLocation(shuffleId, 0) + .map(_._1.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host) + .toSet === HashSet("hostA", "hostB")) // The first result task fails, with a fetch failure for the output from the first mapper. runEvent(makeCompletionEvent( @@ -1194,10 +1200,14 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi (Success, makeMapStatus("hostA", 2)), (Success, makeMapStatus("hostB", 2)))) // The MapOutputTracker should know about both map output locations. - assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet === - HashSet("hostA", "hostB")) - assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 1).map(_._1.host).toSet === - HashSet("hostA", "hostB")) + assert(mapOutputTracker + .getMapSizesByShuffleLocation(shuffleId, 0) + .map(_._1.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host) + .toSet === HashSet("hostA", "hostB")) + assert(mapOutputTracker + .getMapSizesByShuffleLocation(shuffleId, 1) + .map(_._1.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host) + .toSet === HashSet("hostA", "hostB")) // The first result task fails, with a fetch failure for the output from the first mapper. runEvent(makeCompletionEvent( @@ -1387,8 +1397,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi Success, makeMapStatus("hostA", reduceRdd.partitions.size))) assert(shuffleStage.numAvailableOutputs === 2) - assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === - HashSet(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) + assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet === + HashSet(makeShuffleLocation("hostB"), makeShuffleLocation("hostA"))) // finish the next stage normally, which completes the job complete(taskSets(1), Seq((Success, 42), (Success, 43))) @@ -1542,7 +1552,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi reduceIdx <- reduceIdxs } { // this would throw an exception if the map status hadn't been registered - val statuses = mapOutputTracker.getMapSizesByExecutorId(stage, reduceIdx) + val statuses = mapOutputTracker.getMapSizesByShuffleLocation(stage, reduceIdx) // really we should have already thrown an exception rather than fail either of these // asserts, but just to be extra defensive let's double check the statuses are OK assert(statuses != null) @@ -1594,7 +1604,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // check that we have all the map output for stage 0 (0 until reduceRdd.partitions.length).foreach { reduceIdx => - val statuses = mapOutputTracker.getMapSizesByExecutorId(0, reduceIdx) + val statuses = mapOutputTracker.getMapSizesByShuffleLocation(0, reduceIdx) // really we should have already thrown an exception rather than fail either of these // asserts, but just to be extra defensive let's double check the statuses are OK assert(statuses != null) @@ -1793,8 +1803,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // have hostC complete the resubmitted task complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1)))) - assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === - HashSet(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) + assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet === + HashSet(makeShuffleLocation("hostC"), makeShuffleLocation("hostB"))) // Make sure that the reduce stage was now submitted. assert(taskSets.size === 3) @@ -2056,8 +2066,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi submit(reduceRdd, Array(0)) complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)))) - assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === - HashSet(makeBlockManagerId("hostA"))) + assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet === + HashSet(makeShuffleLocation("hostA"))) // Reducer should run on the same host that map task ran val reduceTaskSet = taskSets(1) @@ -2102,8 +2112,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi submit(reduceRdd, Array(0)) complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)))) - assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === - HashSet(makeBlockManagerId("hostA"))) + assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet === + HashSet(makeShuffleLocation("hostA"))) // Reducer should run where RDD 2 has preferences, even though it also has a shuffle dep val reduceTaskSet = taskSets(1) @@ -2266,8 +2276,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", rdd1.partitions.length)), (Success, makeMapStatus("hostB", rdd1.partitions.length)))) - assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === - HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) + assert(mapOutputTracker.getMapSizesByShuffleLocation(dep1.shuffleId, 0).map(_._1).toSet === + HashSet(makeShuffleLocation("hostA"), makeShuffleLocation("hostB"))) assert(listener1.results.size === 1) // When attempting the second stage, show a fetch failure @@ -2282,8 +2292,9 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assert(taskSets(2).stageId === 0) complete(taskSets(2), Seq( (Success, makeMapStatus("hostC", rdd2.partitions.length)))) - assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === - HashSet(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) + assert(mapOutputTracker.getMapSizesByShuffleLocation(dep1.shuffleId, 0).map(_._1).toSet === + HashSet(makeShuffleLocation("hostC"), makeShuffleLocation("hostB"))) + assert(listener2.results.size === 0) // Second stage listener should still not have a result // Stage 1 should now be running as task set 3; make its first task succeed @@ -2291,8 +2302,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi complete(taskSets(3), Seq( (Success, makeMapStatus("hostB", rdd2.partitions.length)), (Success, makeMapStatus("hostD", rdd2.partitions.length)))) - assert(mapOutputTracker.getMapSizesByExecutorId(dep2.shuffleId, 0).map(_._1).toSet === - HashSet(makeBlockManagerId("hostB"), makeBlockManagerId("hostD"))) + assert(mapOutputTracker.getMapSizesByShuffleLocation(dep2.shuffleId, 0).map(_._1).toSet === + HashSet(makeShuffleLocation("hostB"), makeShuffleLocation("hostD"))) assert(listener2.results.size === 1) // Finally, the reduce job should be running as task set 4; make it see a fetch failure, @@ -2330,8 +2341,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", rdd1.partitions.length)), (Success, makeMapStatus("hostB", rdd1.partitions.length)))) - assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === - HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) + assert(mapOutputTracker.getMapSizesByShuffleLocation(dep1.shuffleId, 0).map(_._1).toSet === + HashSet(makeShuffleLocation("hostA"), makeShuffleLocation("hostB"))) assert(listener1.results.size === 1) // When attempting stage1, trigger a fetch failure. @@ -2356,8 +2367,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assert(taskSets(2).stageId === 0) complete(taskSets(2), Seq( (Success, makeMapStatus("hostC", rdd2.partitions.length)))) - assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === - Set(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) + assert(mapOutputTracker.getMapSizesByShuffleLocation(dep1.shuffleId, 0).map(_._1).toSet === + Set(makeShuffleLocation("hostC"), makeShuffleLocation("hostB"))) // After stage0 is finished, stage1 will be submitted and found there is no missing // partitions in it. Then listener got triggered. @@ -2921,6 +2932,10 @@ object DAGSchedulerSuite { def makeBlockManagerId(host: String): BlockManagerId = BlockManagerId("exec-" + host, host, 12345) + + def makeShuffleLocation(host: String): MapShuffleLocations = { + DefaultMapShuffleLocations.get(makeBlockManagerId(host)) + } } object FailThisAttempt { diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala index c1e7fb9a1db1..3c786c0927bc 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkFunSuite} import org.apache.spark.LocalSparkContext._ import org.apache.spark.internal.config import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} +import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations import org.apache.spark.storage.BlockManagerId class MapStatusSuite extends SparkFunSuite { @@ -61,7 +62,11 @@ class MapStatusSuite extends SparkFunSuite { stddev <- Seq(0.0, 0.01, 0.5, 1.0) ) { val sizes = Array.fill[Long](numSizes)(abs(round(Random.nextGaussian() * stddev)) + mean) - val status = MapStatus(BlockManagerId("a", "b", 10), sizes) + val bmId = BlockManagerId("a", "b", 10) + val status = MapStatus( + bmId, + DefaultMapShuffleLocations.get(bmId), + sizes) val status1 = compressAndDecompressMapStatus(status) for (i <- 0 until numSizes) { if (sizes(i) != 0) { @@ -75,7 +80,7 @@ class MapStatusSuite extends SparkFunSuite { test("large tasks should use " + classOf[HighlyCompressedMapStatus].getName) { val sizes = Array.fill[Long](2001)(150L) - val status = MapStatus(null, sizes) + val status = MapStatus(null, null, sizes) assert(status.isInstanceOf[HighlyCompressedMapStatus]) assert(status.getSizeForBlock(10) === 150L) assert(status.getSizeForBlock(50) === 150L) @@ -86,11 +91,13 @@ class MapStatusSuite extends SparkFunSuite { test("HighlyCompressedMapStatus: estimated size should be the average non-empty block size") { val sizes = Array.tabulate[Long](3000) { i => i.toLong } val avg = sizes.sum / sizes.count(_ != 0) - val loc = BlockManagerId("a", "b", 10) - val status = MapStatus(loc, sizes) + val bmId = BlockManagerId("a", "b", 10) + val loc = DefaultMapShuffleLocations.get(bmId) + val status = MapStatus(bmId, loc, sizes) val status1 = compressAndDecompressMapStatus(status) assert(status1.isInstanceOf[HighlyCompressedMapStatus]) - assert(status1.location == loc) + assert(status1.location == loc.getBlockManagerId) + assert(status1.mapShuffleLocations == loc) for (i <- 0 until 3000) { val estimate = status1.getSizeForBlock(i) if (sizes(i) > 0) { @@ -108,11 +115,13 @@ class MapStatusSuite extends SparkFunSuite { val sizes = (0L to 3000L).toArray val smallBlockSizes = sizes.filter(n => n > 0 && n < threshold) val avg = smallBlockSizes.sum / smallBlockSizes.length - val loc = BlockManagerId("a", "b", 10) - val status = MapStatus(loc, sizes) + val bmId = BlockManagerId("a", "b", 10) + val loc = DefaultMapShuffleLocations.get(bmId) + val status = MapStatus(bmId, loc, sizes) val status1 = compressAndDecompressMapStatus(status) assert(status1.isInstanceOf[HighlyCompressedMapStatus]) - assert(status1.location == loc) + assert(status1.location === bmId) + assert(status1.mapShuffleLocations === loc) for (i <- 0 until threshold) { val estimate = status1.getSizeForBlock(i) if (sizes(i) > 0) { @@ -165,7 +174,8 @@ class MapStatusSuite extends SparkFunSuite { SparkEnv.set(env) // Value of element in sizes is equal to the corresponding index. val sizes = (0L to 2000L).toArray - val status1 = MapStatus(BlockManagerId("exec-0", "host-0", 100), sizes) + val bmId = BlockManagerId("exec-0", "host-0", 100) + val status1 = MapStatus(bmId, DefaultMapShuffleLocations.get(bmId), sizes) val arrayStream = new ByteArrayOutputStream(102400) val objectOutputStream = new ObjectOutputStream(arrayStream) assert(status1.isInstanceOf[HighlyCompressedMapStatus]) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala index aa6db8d0423a..83305a96e679 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala @@ -192,7 +192,8 @@ abstract class SchedulerIntegrationSuite[T <: MockBackend: ClassTag] extends Spa shuffleId <- shuffleIds reduceIdx <- (0 until nParts) } { - val statuses = taskScheduler.mapOutputTracker.getMapSizesByExecutorId(shuffleId, reduceIdx) + val statuses = taskScheduler.mapOutputTracker.getMapSizesByShuffleLocation( + shuffleId, reduceIdx) // really we should have already thrown an exception rather than fail either of these // asserts, but just to be extra defensive let's double check the statuses are OK assert(statuses != null) diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 16eec7e0bea1..c523d0cb9ce8 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -36,8 +36,9 @@ import org.apache.spark.internal.config._ import org.apache.spark.internal.config.Kryo._ import org.apache.spark.scheduler.HighlyCompressedMapStatus import org.apache.spark.serializer.KryoTest._ +import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations import org.apache.spark.storage.BlockManagerId -import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.util.ThreadUtils class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { conf.set(SERIALIZER, "org.apache.spark.serializer.KryoSerializer") @@ -350,8 +351,10 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { val ser = new KryoSerializer(conf).newInstance() val denseBlockSizes = new Array[Long](5000) val sparseBlockSizes = Array[Long](0L, 1L, 0L, 2L) + val bmId = BlockManagerId("exec-1", "host", 1234) Seq(denseBlockSizes, sparseBlockSizes).foreach { blockSizes => - ser.serialize(HighlyCompressedMapStatus(BlockManagerId("exec-1", "host", 1234), blockSizes)) + ser.serialize(HighlyCompressedMapStatus( + bmId, DefaultMapShuffleLocations.get(bmId), blockSizes)) } } diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index 6d2ef17a7a79..b3073addb7cc 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark._ import org.apache.spark.internal.config import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.serializer.{JavaSerializer, SerializerManager} +import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId} /** @@ -102,14 +103,17 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext // Make a mocked MapOutputTracker for the shuffle reader to use to determine what // shuffle data to read. val mapOutputTracker = mock(classOf[MapOutputTracker]) - when(mapOutputTracker.getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1)).thenReturn { + when(mapOutputTracker.getMapSizesByShuffleLocation( + shuffleId, reduceId, reduceId + 1)).thenReturn { // Test a scenario where all data is local, to avoid creating a bunch of additional mocks // for the code to read data over the network. val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId => val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) (shuffleBlockId, byteOutputStream.size().toLong) } - Seq((localBlockManagerId, shuffleBlockIdsAndSizes)).toIterator + Seq( + (DefaultMapShuffleLocations.get(localBlockManagerId), shuffleBlockIdsAndSizes)) + .toIterator } // Create a mocked shuffle handle to pass into HashShuffleReader. diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala index 2690f1a515fc..b39e37c1e384 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala @@ -193,13 +193,13 @@ object BlockStoreShuffleReaderBenchmark extends BenchmarkBase { dataBlockId = remoteBlockManagerId } - when(mapOutputTracker.getMapSizesByExecutorId(0, 0, 1)) + when(mapOutputTracker.getMapSizesByShuffleLocation(0, 0, 1)) .thenReturn { val shuffleBlockIdsAndSizes = (0 until NUM_MAPS).map { mapId => val shuffleBlockId = ShuffleBlockId(0, mapId, 0) (shuffleBlockId, dataFileLength) } - Seq((dataBlockId, shuffleBlockIdsAndSizes)).toIterator + Seq((DefaultMapShuffleLocations.get(dataBlockId), shuffleBlockIdsAndSizes)).toIterator } when(dependency.serializer).thenReturn(serializer) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala index 69fe03e75606..0b3394e88d9f 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala @@ -20,6 +20,7 @@ package org.apache.spark.shuffle.sort import org.apache.spark.SparkConf import org.apache.spark.benchmark.Benchmark import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport +import org.apache.spark.storage.BlockManagerId /** * Benchmark to measure performance for aggregate primitives. @@ -46,9 +47,10 @@ object BypassMergeSortShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase def getWriter(transferTo: Boolean): BypassMergeSortShuffleWriter[String, String] = { val conf = new SparkConf(loadDefaults = false) + val shuffleWriteSupport = new DefaultShuffleWriteSupport( + conf, blockResolver, BlockManagerId("0", "localhost", 7090)) conf.set("spark.file.transferTo", String.valueOf(transferTo)) conf.set("spark.shuffle.file.buffer", "32k") - val shuffleWriteSupport = new DefaultShuffleWriteSupport(conf, blockResolver) val shuffleWriter = new BypassMergeSortShuffleWriter[String, String]( blockManager, diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index 6683858830cd..05d10e9f63d0 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -25,7 +25,7 @@ import scala.collection.mutable.ArrayBuffer import org.mockito.{Mock, MockitoAnnotations} import org.mockito.Answers.RETURNS_SMART_NULLS -import org.mockito.ArgumentMatchers.{any, anyInt, anyString} +import org.mockito.ArgumentMatchers.{any, anyInt} import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer @@ -140,7 +140,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte metricsSystem = null, taskMetrics = taskMetrics)) - writeSupport = new DefaultShuffleWriteSupport(conf, blockResolver) + writeSupport = new DefaultShuffleWriteSupport( + conf, blockResolver, BlockManagerId("0", "localhost", 7090)) } override def afterEach(): Unit = { @@ -203,7 +204,6 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte assert(taskMetrics.memoryBytesSpilled === 0) } - // TODO(ifilonenko): MAKE THIS PASS test("write with some empty partitions with transferTo") { def records: Iterator[(Int, Int)] = Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala index 32257b0cc4b5..b0ff15cb1f79 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala @@ -23,6 +23,7 @@ import org.apache.spark.{Aggregator, SparkEnv, TaskContext} import org.apache.spark.benchmark.Benchmark import org.apache.spark.shuffle.BaseShuffleHandle import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport +import org.apache.spark.storage.BlockManagerId /** * Benchmark to measure performance for aggregate primitives. @@ -77,7 +78,10 @@ object SortShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase { when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager) TaskContext.setTaskContext(taskContext) - val writeSupport = new DefaultShuffleWriteSupport(defaultConf, blockResolver) + val writeSupport = new DefaultShuffleWriteSupport( + defaultConf, + blockResolver, + BlockManagerId("0", "localhost", 9099)) val shuffleWriter = new SortShuffleWriter[String, String, String]( blockResolver, diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala index 20bf3eac95d8..0e659ff7cc5f 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala @@ -19,6 +19,7 @@ package org.apache.spark.shuffle.sort import org.apache.spark.{SparkConf, TaskContext} import org.apache.spark.benchmark.Benchmark import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport +import org.apache.spark.storage.BlockManagerId /** * Benchmark to measure performance for aggregate primitives. @@ -43,7 +44,8 @@ object UnsafeShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase { def getWriter(transferTo: Boolean): UnsafeShuffleWriter[String, String] = { val conf = new SparkConf(loadDefaults = false) conf.set("spark.file.transferTo", String.valueOf(transferTo)) - val shuffleWriteSupport = new DefaultShuffleWriteSupport(conf, blockResolver) + val shuffleWriteSupport = new DefaultShuffleWriteSupport( + conf, blockResolver, BlockManagerId("0", "localhost", 9099)) TaskContext.setTaskContext(taskContext) new UnsafeShuffleWriter[String, String]( diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala index 22d52924a7c7..d704f72015ce 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala @@ -34,6 +34,7 @@ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.network.util.LimitedInputStream import org.apache.spark.shuffle.IndexShuffleBlockResolver +import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.Utils class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAfterEach { @@ -87,7 +88,13 @@ class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAft }).when(blockResolver) .writeIndexFileAndCommit(anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File])) mapOutputWriter = new DefaultShuffleMapOutputWriter( - 0, 0, NUM_PARTITIONS, shuffleWriteMetrics, blockResolver, conf) + 0, + 0, + NUM_PARTITIONS, + BlockManagerId("0", "localhost", 9099), + shuffleWriteMetrics, + blockResolver, + conf) } private def readRecordsFromFile(fromByte: Boolean): Array[Array[Int]] = { From 31dd8a48a3b53b79a31b6c49b1e47f581c921a36 Mon Sep 17 00:00:00 2001 From: mccheah Date: Fri, 19 Apr 2019 15:09:34 -0700 Subject: [PATCH 10/17] [SPARK-25299] Move shuffle writers back to being given specific partition ids (#540) We originally made the shuffle map output writer API behave like an iterator in fetching the "next" partition writer. However, the shuffle writer implementations tend to skip opening empty partitions. If we used an iterator-like API though we would be tied down to opening a partition writer for every single partition, even if some of them are empty. Here, we go back to using specific partition identifiers to give us more freedom to avoid needing to create writers for empty partitions. --- .../api/shuffle/ShuffleMapOutputWriter.java | 2 +- .../sort/BypassMergeSortShuffleWriter.java | 2 +- .../shuffle/sort/UnsafeShuffleWriter.java | 16 ++------- .../io/DefaultShuffleMapOutputWriter.java | 10 ++++-- .../util/collection/ExternalSorter.scala | 36 ++----------------- .../DefaultShuffleMapOutputWriterSuite.scala | 8 ++--- 6 files changed, 17 insertions(+), 57 deletions(-) diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java index 181701175d35..062cf4ff0fba 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java @@ -30,7 +30,7 @@ */ @Experimental public interface ShuffleMapOutputWriter { - ShufflePartitionWriter getNextPartitionWriter() throws IOException; + ShufflePartitionWriter getPartitionWriter(int partitionId) throws IOException; Optional commitAllPartitions() throws IOException; diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 434286175e41..b339738de5ad 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -216,7 +216,7 @@ private long[] writePartitionedData(ShuffleMapOutputWriter mapOutputWriter) thro boolean copyThrewException = true; ShufflePartitionWriter writer = null; try { - writer = mapOutputWriter.getNextPartitionWriter(); + writer = mapOutputWriter.getPartitionWriter(i); if (!file.exists()) { copyThrewException = false; } else { diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 95c4577cb770..e4175f985d91 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -291,18 +291,6 @@ private long[] mergeSpills(SpillInfo[] spills, long[] partitionLengths = new long[numPartitions]; try { if (spills.length == 0) { - // The contract we are working under states that we will open a partition writer for - // each partition, regardless of number of spills - for (int i = 0; i < numPartitions; i++) { - ShufflePartitionWriter writer = null; - try { - writer = mapWriter.getNextPartitionWriter(); - } finally { - if (writer != null) { - writer.close(); - } - } - } return partitionLengths; } else { // There are multiple spills to merge, so none of these spill files' lengths were counted @@ -378,7 +366,7 @@ private long[] mergeSpillsWithFileStream( boolean copyThrewExecption = true; ShufflePartitionWriter writer = null; try { - writer = mapWriter.getNextPartitionWriter(); + writer = mapWriter.getPartitionWriter(partition); OutputStream partitionOutput = null; try { // Shield the underlying output stream from close() calls, so that we can close the @@ -457,7 +445,7 @@ private long[] mergeSpillsWithTransferTo( boolean copyThrewExecption = true; ShufflePartitionWriter writer = null; try { - writer = mapWriter.getNextPartitionWriter(); + writer = mapWriter.getPartitionWriter(partition); WritableByteChannel channel = writer.toChannel(); for (int i = 0; i < spills.length; i++) { long partitionLengthInSpill = 0L; diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java index 7eb0d56776de..926c3b943399 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java @@ -51,7 +51,7 @@ public class DefaultShuffleMapOutputWriter implements ShuffleMapOutputWriter { private final IndexShuffleBlockResolver blockResolver; private final long[] partitionLengths; private final int bufferSize; - private int currPartitionId = 0; + private int lastPartitionId = -1; private long currChannelPosition; private final BlockManagerId shuffleServerId; @@ -84,7 +84,11 @@ public DefaultShuffleMapOutputWriter( } @Override - public ShufflePartitionWriter getNextPartitionWriter() throws IOException { + public ShufflePartitionWriter getPartitionWriter(int partitionId) throws IOException { + if (partitionId <= lastPartitionId) { + throw new IllegalArgumentException("Partitions should be requested in increasing order."); + } + lastPartitionId = partitionId; if (outputTempFile == null) { outputTempFile = Utils.tempFileWith(outputFile); } @@ -93,7 +97,7 @@ public ShufflePartitionWriter getNextPartitionWriter() throws IOException { } else { currChannelPosition = 0L; } - return new DefaultShufflePartitionWriter(currPartitionId++); + return new DefaultShufflePartitionWriter(partitionId); } @Override diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 8ccc1dfc9b3f..df5ce73b9acf 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -721,17 +721,6 @@ private[spark] class ExternalSorter[K, V, C]( lengths } - private def writeEmptyPartition(mapOutputWriter: ShuffleMapOutputWriter): Unit = { - var partitionWriter: ShufflePartitionWriter = null - try { - partitionWriter = mapOutputWriter.getNextPartitionWriter - } finally { - if (partitionWriter != null) { - partitionWriter.close() - } - } - } - /** * Write all the data added into this ExternalSorter into a map output writer that pushes bytes * to some arbitrary backing store. This is called by the SortShuffleWriter. @@ -742,26 +731,16 @@ private[spark] class ExternalSorter[K, V, C]( shuffleId: Int, mapId: Int, mapOutputWriter: ShuffleMapOutputWriter): Array[Long] = { // Track location of each range in the map output val lengths = new Array[Long](numPartitions) - var nextPartitionId = 0 if (spills.isEmpty) { // Case where we only have in-memory data val collection = if (aggregator.isDefined) map else buffer val it = collection.destructiveSortedWritablePartitionedIterator(comparator) while (it.hasNext()) { val partitionId = it.nextPartition() - // The contract for the plugin is that we will ask for a writer for every partition - // even if it's empty. However, the external sorter will return non-contiguous - // partition ids. So this loop "backfills" the empty partitions that form the gaps. - - // The algorithm as a whole is correct because the partition ids are returned by the - // iterator in ascending order. - for (emptyPartition <- nextPartitionId until partitionId) { - writeEmptyPartition(mapOutputWriter) - } var partitionWriter: ShufflePartitionWriter = null var partitionPairsWriter: ShufflePartitionPairsWriter = null try { - partitionWriter = mapOutputWriter.getNextPartitionWriter + partitionWriter = mapOutputWriter.getPartitionWriter(partitionId) val blockId = ShuffleBlockId(shuffleId, mapId, partitionId) partitionPairsWriter = new ShufflePartitionPairsWriter( partitionWriter, @@ -783,7 +762,6 @@ private[spark] class ExternalSorter[K, V, C]( if (partitionWriter != null) { lengths(partitionId) = partitionWriter.getNumBytesWritten } - nextPartitionId = partitionId + 1 } } else { // We must perform merge-sort; get an iterator by partition and write everything directly. @@ -794,14 +772,11 @@ private[spark] class ExternalSorter[K, V, C]( // The algorithm as a whole is correct because the partition ids are returned by the // iterator in ascending order. - for (emptyPartition <- nextPartitionId until id) { - writeEmptyPartition(mapOutputWriter) - } val blockId = ShuffleBlockId(shuffleId, mapId, id) var partitionWriter: ShufflePartitionWriter = null var partitionPairsWriter: ShufflePartitionPairsWriter = null try { - partitionWriter = mapOutputWriter.getNextPartitionWriter + partitionWriter = mapOutputWriter.getPartitionWriter(id) partitionPairsWriter = new ShufflePartitionPairsWriter( partitionWriter, serializerManager, @@ -821,16 +796,9 @@ private[spark] class ExternalSorter[K, V, C]( if (partitionWriter != null) { lengths(id) = partitionWriter.getNumBytesWritten } - nextPartitionId = id + 1 } } - // The iterator may have stopped short of opening a writer for every partition. So fill in the - // remaining empty partitions. - for (emptyPartition <- nextPartitionId until numPartitions) { - writeEmptyPartition(mapOutputWriter) - } - context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled) context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled) context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala index d704f72015ce..420b0d4d2f67 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala @@ -140,7 +140,7 @@ class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAft test("writing to an outputstream") { (0 until NUM_PARTITIONS).foreach{ p => - val writer = mapOutputWriter.getNextPartitionWriter + val writer = mapOutputWriter.getPartitionWriter(p) val stream = writer.toStream() data(p).foreach { i => stream.write(i)} stream.close() @@ -159,7 +159,7 @@ class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAft test("writing to a channel") { (0 until NUM_PARTITIONS).foreach{ p => - val writer = mapOutputWriter.getNextPartitionWriter + val writer = mapOutputWriter.getPartitionWriter(p) val channel = writer.toChannel() val byteBuffer = ByteBuffer.allocate(D_LEN * 4) val intBuffer = byteBuffer.asIntBuffer() @@ -179,7 +179,7 @@ class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAft test("copyStreams with an outputstream") { (0 until NUM_PARTITIONS).foreach{ p => - val writer = mapOutputWriter.getNextPartitionWriter + val writer = mapOutputWriter.getPartitionWriter(p) val stream = writer.toStream() val byteBuffer = ByteBuffer.allocate(D_LEN * 4) val intBuffer = byteBuffer.asIntBuffer() @@ -200,7 +200,7 @@ class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAft test("copyStreamsWithNIO with a channel") { (0 until NUM_PARTITIONS).foreach{ p => - val writer = mapOutputWriter.getNextPartitionWriter + val writer = mapOutputWriter.getPartitionWriter(p) val channel = writer.toChannel() val byteBuffer = ByteBuffer.allocate(D_LEN * 4) val intBuffer = byteBuffer.asIntBuffer() From 1f6f6e676c5fa68f1498f4f9979bbbea56035582 Mon Sep 17 00:00:00 2001 From: mccheah Date: Fri, 19 Apr 2019 15:19:57 -0700 Subject: [PATCH 11/17] [SPARK-25299] Don't set map status twice in bypass merge sort shuffle writer (#541) --- .../spark/shuffle/sort/BypassMergeSortShuffleWriter.java | 4 ---- 1 file changed, 4 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index b339738de5ad..22386c39aca0 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -186,10 +186,6 @@ public void write(Iterator> records) throws IOException { } throw e; } - mapStatus = MapStatus$.MODULE$.apply( - blockManager.shuffleServerId(), - DefaultMapShuffleLocations.get(blockManager.shuffleServerId()), - partitionLengths); } @VisibleForTesting From 09138c084a22eeaea815a6bfb243e55efa85c956 Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Tue, 30 Apr 2019 14:06:54 -0700 Subject: [PATCH 12/17] merge conflicts --- .../spark/api/shuffle/ShuffleBlockInfo.java | 79 +++++++++++++ .../shuffle/ShuffleExecutorComponents.java | 2 + .../spark/api/shuffle/ShuffleLocation.java | 3 +- .../spark/api/shuffle/ShuffleReadSupport.java | 38 ++++++ .../io/DefaultShuffleExecutorComponents.java | 25 +++- .../org/apache/spark/MapOutputTracker.scala | 21 ++-- .../apache/spark/executor/TaskMetrics.scala | 12 +- .../shuffle/BlockStoreShuffleReader.scala | 90 +++++++++----- .../io/DefaultShuffleReadSupport.scala | 111 ++++++++++++++++++ .../shuffle/sort/SortShuffleManager.scala | 6 +- .../storage/ShuffleBlockFetcherIterator.scala | 19 ++- .../apache/spark/MapOutputTrackerSuite.scala | 15 +-- .../scala/org/apache/spark/ShuffleSuite.scala | 2 + .../spark/scheduler/DAGSchedulerSuite.scala | 34 +++--- .../BlockStoreShuffleReaderSuite.scala | 45 ++++--- .../BlockStoreShuffleReaderBenchmark.scala | 13 +- .../ShuffleBlockFetcherIteratorSuite.scala | 41 +++---- .../spark/sql/execution/ShuffledRowRDD.scala | 7 +- 18 files changed, 452 insertions(+), 111 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java create mode 100644 core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java create mode 100644 core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java new file mode 100644 index 000000000000..a312831cb628 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.shuffle; + +import org.apache.spark.api.java.Optional; + +import java.util.Objects; + +/** + * :: Experimental :: + * An object defining the shuffle block and length metadata associated with the block. + * @since 3.0.0 + */ +public class ShuffleBlockInfo { + private final int shuffleId; + private final int mapId; + private final int reduceId; + private final long length; + private final Optional shuffleLocation; + + public ShuffleBlockInfo(int shuffleId, int mapId, int reduceId, long length, + Optional shuffleLocation) { + this.shuffleId = shuffleId; + this.mapId = mapId; + this.reduceId = reduceId; + this.length = length; + this.shuffleLocation = shuffleLocation; + } + + public int getShuffleId() { + return shuffleId; + } + + public int getMapId() { + return mapId; + } + + public int getReduceId() { + return reduceId; + } + + public long getLength() { + return length; + } + + public Optional getShuffleLocation() { + return shuffleLocation; + } + + @Override + public boolean equals(Object other) { + return other instanceof ShuffleBlockInfo + && shuffleId == ((ShuffleBlockInfo) other).shuffleId + && mapId == ((ShuffleBlockInfo) other).mapId + && reduceId == ((ShuffleBlockInfo) other).reduceId + && length == ((ShuffleBlockInfo) other).length + && Objects.equals(shuffleLocation, ((ShuffleBlockInfo) other).shuffleLocation); + } + + @Override + public int hashCode() { + return Objects.hash(shuffleId, mapId, reduceId, length, shuffleLocation); + } +} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java index 4fc20bad9938..8baa3bf6f859 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java @@ -30,4 +30,6 @@ public interface ShuffleExecutorComponents { void initializeExecutor(String appId, String execId); ShuffleWriteSupport writes(); + + ShuffleReadSupport reads(); } diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocation.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocation.java index 87eb497098e0..d06c11b3c01e 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocation.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocation.java @@ -21,5 +21,4 @@ * Marker interface representing a location of a shuffle block. Implementations of shuffle readers * and writers are expected to cast this down to an implementation-specific representation. */ -public interface ShuffleLocation { -} +public interface ShuffleLocation {} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java new file mode 100644 index 000000000000..9cd8fde09064 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.shuffle; + +import org.apache.spark.annotation.Experimental; + +import java.io.IOException; +import java.io.InputStream; + +/** + * :: Experimental :: + * An interface for reading shuffle records. + * @since 3.0.0 + */ +@Experimental +public interface ShuffleReadSupport { + /** + * Returns an underlying {@link Iterable} that will iterate + * through shuffle data, given an iterable for the shuffle blocks to fetch. + */ + Iterable getPartitionReaders(Iterable blockMetadata) + throws IOException; +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java index f7ec202ef4b9..91a5d7f7945e 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java @@ -17,11 +17,15 @@ package org.apache.spark.shuffle.sort.io; +import org.apache.spark.MapOutputTracker; import org.apache.spark.SparkConf; import org.apache.spark.SparkEnv; import org.apache.spark.api.shuffle.ShuffleExecutorComponents; +import org.apache.spark.api.shuffle.ShuffleReadSupport; import org.apache.spark.api.shuffle.ShuffleWriteSupport; +import org.apache.spark.serializer.SerializerManager; import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.shuffle.io.DefaultShuffleReadSupport; import org.apache.spark.storage.BlockManager; public class DefaultShuffleExecutorComponents implements ShuffleExecutorComponents { @@ -29,6 +33,8 @@ public class DefaultShuffleExecutorComponents implements ShuffleExecutorComponen private final SparkConf sparkConf; private BlockManager blockManager; private IndexShuffleBlockResolver blockResolver; + private MapOutputTracker mapOutputTracker; + private SerializerManager serializerManager; public DefaultShuffleExecutorComponents(SparkConf sparkConf) { this.sparkConf = sparkConf; @@ -37,15 +43,30 @@ public DefaultShuffleExecutorComponents(SparkConf sparkConf) { @Override public void initializeExecutor(String appId, String execId) { blockManager = SparkEnv.get().blockManager(); + mapOutputTracker = SparkEnv.get().mapOutputTracker(); + serializerManager = SparkEnv.get().serializerManager(); blockResolver = new IndexShuffleBlockResolver(sparkConf, blockManager); } @Override public ShuffleWriteSupport writes() { + checkInitialized(); + return new DefaultShuffleWriteSupport(sparkConf, blockResolver, blockManager.shuffleServerId()); + } + + @Override + public ShuffleReadSupport reads() { + checkInitialized(); + return new DefaultShuffleReadSupport(blockManager, + mapOutputTracker, + serializerManager, + sparkConf); + } + + private void checkInitialized() { if (blockResolver == null) { throw new IllegalStateException( - "Executor components must be initialized before getting writers."); + "Executor components must be initialized before getting writers."); } - return new DefaultShuffleWriteSupport(sparkConf, blockResolver, blockManager.shuffleServerId()); } } diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 74975019e748..ebddf5ff6f6e 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -283,7 +283,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging // For testing def getMapSizesByShuffleLocation(shuffleId: Int, reduceId: Int) - : Iterator[(ShuffleLocation, Seq[(BlockId, Long)])] = { + : Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = { getMapSizesByShuffleLocation(shuffleId, reduceId, reduceId + 1) } @@ -297,7 +297,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging * describing the shuffle blocks that are stored at that block manager. */ def getMapSizesByShuffleLocation(shuffleId: Int, startPartition: Int, endPartition: Int) - : Iterator[(ShuffleLocation, Seq[(BlockId, Long)])] + : Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] /** * Deletes map output status information for the specified shuffle stage. @@ -647,7 +647,7 @@ private[spark] class MapOutputTrackerMaster( // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. // This method is only called in local-mode. def getMapSizesByShuffleLocation(shuffleId: Int, startPartition: Int, endPartition: Int) - : Iterator[(ShuffleLocation, Seq[(BlockId, Long)])] = { + : Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") shuffleStatuses.get(shuffleId) match { case Some (shuffleStatus) => @@ -684,7 +684,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. override def getMapSizesByShuffleLocation(shuffleId: Int, startPartition: Int, endPartition: Int) - : Iterator[(ShuffleLocation, Seq[(BlockId, Long)])] = { + : Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") val statuses = getStatuses(shuffleId) try { @@ -873,9 +873,9 @@ private[spark] object MapOutputTracker extends Logging { shuffleId: Int, startPartition: Int, endPartition: Int, - statuses: Array[MapStatus]): Iterator[(ShuffleLocation, Seq[(BlockId, Long)])] = { + statuses: Array[MapStatus]): Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = { assert (statuses != null) - val splitsByAddress = new HashMap[ShuffleLocation, ListBuffer[(BlockId, Long)]] + val splitsByAddress = new HashMap[Option[ShuffleLocation], ListBuffer[(BlockId, Long)]] for ((status, mapId) <- statuses.iterator.zipWithIndex) { if (status == null) { val errorMessage = s"Missing an output location for shuffle $shuffleId" @@ -885,9 +885,14 @@ private[spark] object MapOutputTracker extends Logging { for (part <- startPartition until endPartition) { val size = status.getSizeForBlock(part) if (size != 0) { - val shuffleLoc = status.mapShuffleLocations.getLocationForBlock(part) - splitsByAddress.getOrElseUpdate(shuffleLoc, ListBuffer()) += + if (status.mapShuffleLocations == null) { + splitsByAddress.getOrElseUpdate(Option.empty, ListBuffer()) += ((ShuffleBlockId(shuffleId, mapId, part), size)) + } else { + val shuffleLoc = status.mapShuffleLocations.getLocationForBlock(part) + splitsByAddress.getOrElseUpdate(Option.apply(shuffleLoc), ListBuffer()) += + ((ShuffleBlockId(shuffleId, mapId, part), size)) + } } } } diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index ea79c7310349..df30fd5c7f67 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -56,6 +56,8 @@ class TaskMetrics private[spark] () extends Serializable { private val _diskBytesSpilled = new LongAccumulator private val _peakExecutionMemory = new LongAccumulator private val _updatedBlockStatuses = new CollectionAccumulator[(BlockId, BlockStatus)] + private var _decorFunc: TempShuffleReadMetrics => TempShuffleReadMetrics = + Predef.identity[TempShuffleReadMetrics] /** * Time taken on the executor to deserialize this task. @@ -187,11 +189,17 @@ class TaskMetrics private[spark] () extends Serializable { * be lost. */ private[spark] def createTempShuffleReadMetrics(): TempShuffleReadMetrics = synchronized { - val readMetrics = new TempShuffleReadMetrics - tempShuffleReadMetrics += readMetrics + val tempShuffleMetrics = new TempShuffleReadMetrics + val readMetrics = _decorFunc(tempShuffleMetrics) + tempShuffleReadMetrics += tempShuffleMetrics readMetrics } + private[spark] def decorateTempShuffleReadMetrics( + decorFunc: TempShuffleReadMetrics => TempShuffleReadMetrics): Unit = synchronized { + _decorFunc = decorFunc + } + /** * Merge values across all temporary [[ShuffleReadMetrics]] into `_shuffleReadMetrics`. * This is expected to be called on executor heartbeat and at the end of a task. diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 4d559556360c..530c3694ad1e 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -17,11 +17,18 @@ package org.apache.spark.shuffle +import java.io.InputStream + +import scala.collection.JavaConverters._ + import org.apache.spark._ +import org.apache.spark.api.java.Optional +import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReadSupport} import org.apache.spark.internal.{config, Logging} +import org.apache.spark.io.CompressionCodec import org.apache.spark.serializer.SerializerManager -import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations -import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockFetcherIterator} +import org.apache.spark.shuffle.io.DefaultShuffleReadSupport +import org.apache.spark.storage.{ShuffleBlockFetcherIterator, ShuffleBlockId} import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter @@ -35,41 +42,68 @@ private[spark] class BlockStoreShuffleReader[K, C]( endPartition: Int, context: TaskContext, readMetrics: ShuffleReadMetricsReporter, + shuffleReadSupport: ShuffleReadSupport, serializerManager: SerializerManager = SparkEnv.get.serializerManager, - blockManager: BlockManager = SparkEnv.get.blockManager, - mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) + mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker, + sparkConf: SparkConf = SparkEnv.get.conf) extends ShuffleReader[K, C] with Logging { private val dep = handle.dependency + private val compressionCodec = CompressionCodec.createCodec(sparkConf) + + private val compressShuffle = sparkConf.get(config.SHUFFLE_COMPRESS) + /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { - val wrappedStreams = new ShuffleBlockFetcherIterator( - context, - blockManager.shuffleClient, - blockManager, - mapOutputTracker.getMapSizesByShuffleLocation(handle.shuffleId, startPartition, endPartition) - .map { - case (loc: DefaultMapShuffleLocations, blocks: Seq[(BlockId, Long)]) => - (loc.getBlockManagerId, blocks) - case _ => - throw new UnsupportedOperationException("Not allowed to using non-default map shuffle" + - " locations yet.") - }, - serializerManager.wrapStream, - // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility - SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024, - SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT), - SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), - SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), - SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT), - SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY), - readMetrics).toCompletionIterator + val streamsIterator = + shuffleReadSupport.getPartitionReaders(new Iterable[ShuffleBlockInfo] { + override def iterator: Iterator[ShuffleBlockInfo] = { + mapOutputTracker + .getMapSizesByShuffleLocation(handle.shuffleId, startPartition, endPartition) + .flatMap { shuffleLocationInfo => + shuffleLocationInfo._2.map { blockInfo => + val block = blockInfo._1.asInstanceOf[ShuffleBlockId] + new ShuffleBlockInfo( + block.shuffleId, + block.mapId, + block.reduceId, + blockInfo._2, + Optional.ofNullable(shuffleLocationInfo._1.orNull)) + } + } + } + }.asJava).iterator() - val serializerInstance = dep.serializer.newInstance() + val retryingWrappedStreams = new Iterator[InputStream] { + override def hasNext: Boolean = streamsIterator.hasNext - // Create a key/value iterator for each stream - val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) => + override def next(): InputStream = { + var returnStream: InputStream = null + while (streamsIterator.hasNext && returnStream == null) { + if (shuffleReadSupport.isInstanceOf[DefaultShuffleReadSupport]) { + // The default implementation checks for corrupt streams, so it will already have + // decompressed/decrypted the bytes + returnStream = streamsIterator.next() + } else { + val nextStream = streamsIterator.next() + returnStream = if (compressShuffle) { + compressionCodec.compressedInputStream( + serializerManager.wrapForEncryption(nextStream)) + } else { + serializerManager.wrapForEncryption(nextStream) + } + } + } + if (returnStream == null) { + throw new IllegalStateException("Expected shuffle reader iterator to return a stream") + } + returnStream + } + } + + val serializerInstance = dep.serializer.newInstance() + val recordIter = retryingWrappedStreams.flatMap { wrappedStream => // Note: the asKeyValueIterator below wraps a key/value iterator inside of a // NextIterator. The NextIterator makes sure that close() is called on the // underlying InputStream when all records have been read. diff --git a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala new file mode 100644 index 000000000000..9b9b8508e88a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.io + +import java.io.InputStream + +import scala.collection.JavaConverters._ + +import org.apache.spark.{MapOutputTracker, SparkConf, TaskContext} +import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReadSupport} +import org.apache.spark.internal.config +import org.apache.spark.serializer.SerializerManager +import org.apache.spark.shuffle.ShuffleReadMetricsReporter +import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations +import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} + +class DefaultShuffleReadSupport( + blockManager: BlockManager, + mapOutputTracker: MapOutputTracker, + serializerManager: SerializerManager, + conf: SparkConf) extends ShuffleReadSupport { + + private val maxBytesInFlight = conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024 + private val maxReqsInFlight = conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT) + private val maxBlocksInFlightPerAddress = + conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS) + private val maxReqSizeShuffleToMem = conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM) + private val detectCorrupt = conf.get(config.SHUFFLE_DETECT_CORRUPT) + + override def getPartitionReaders(blockMetadata: java.lang.Iterable[ShuffleBlockInfo]): + java.lang.Iterable[InputStream] = { + + val iterableToReturn = if (blockMetadata.asScala.isEmpty) { + Iterable.empty + } else { + val (minReduceId, maxReduceId) = blockMetadata.asScala.map(block => block.getReduceId) + .foldLeft(Int.MaxValue, 0) { + case ((min, max), elem) => (math.min(min, elem), math.max(max, elem)) + } + val shuffleId = blockMetadata.asScala.head.getShuffleId + new ShuffleBlockFetcherIterable( + TaskContext.get(), + blockManager, + serializerManager, + maxBytesInFlight, + maxReqsInFlight, + maxBlocksInFlightPerAddress, + maxReqSizeShuffleToMem, + detectCorrupt, + shuffleMetrics = TaskContext.get().taskMetrics().createTempShuffleReadMetrics(), + minReduceId, + maxReduceId, + shuffleId, + mapOutputTracker + ) + } + iterableToReturn.asJava + } +} + +private class ShuffleBlockFetcherIterable( + context: TaskContext, + blockManager: BlockManager, + serializerManager: SerializerManager, + maxBytesInFlight: Long, + maxReqsInFlight: Int, + maxBlocksInFlightPerAddress: Int, + maxReqSizeShuffleToMem: Long, + detectCorruption: Boolean, + shuffleMetrics: ShuffleReadMetricsReporter, + minReduceId: Int, + maxReduceId: Int, + shuffleId: Int, + mapOutputTracker: MapOutputTracker) extends Iterable[InputStream] { + + override def iterator: Iterator[InputStream] = { + new ShuffleBlockFetcherIterator( + context, + blockManager.shuffleClient, + blockManager, + mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, minReduceId, maxReduceId + 1) + .map { shuffleLocationInfo => + val defaultShuffleLocation = shuffleLocationInfo._1 + .get.asInstanceOf[DefaultMapShuffleLocations] + (defaultShuffleLocation.getBlockManagerId, shuffleLocationInfo._2) + }, + serializerManager.wrapStream, + maxBytesInFlight, + maxReqsInFlight, + maxBlocksInFlightPerAddress, + maxReqSizeShuffleToMem, + detectCorruption, + shuffleMetrics).toCompletionIterator + } + +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 849050556c56..38495ae523d8 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -124,7 +124,11 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { new BlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], - startPartition, endPartition, context, metrics) + startPartition, + endPartition, + context, + metrics, + shuffleExecutorComponents.reads()) } /** Get a writer for a given partition. Called on executors by map tasks. */ diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index c89d5cc971d2..22fc4da97a5b 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -76,7 +76,7 @@ final class ShuffleBlockFetcherIterator( detectCorrupt: Boolean, detectCorruptUseExtraMemory: Boolean, shuffleMetrics: ShuffleReadMetricsReporter) - extends Iterator[(BlockId, InputStream)] with DownloadFileManager with Logging { + extends Iterator[InputStream] with DownloadFileManager with Logging { import ShuffleBlockFetcherIterator._ @@ -399,7 +399,7 @@ final class ShuffleBlockFetcherIterator( * * Throws a FetchFailedException if the next block could not be fetched. */ - override def next(): (BlockId, InputStream) = { + override def next(): InputStream = { if (!hasNext) { throw new NoSuchElementException() } @@ -497,7 +497,6 @@ final class ShuffleBlockFetcherIterator( in.close() } } - case FailureFetchResult(blockId, address, e) => throwFetchFailedException(blockId, address, e) } @@ -510,6 +509,7 @@ final class ShuffleBlockFetcherIterator( throw new NoSuchElementException() } currentResult = result.asInstanceOf[SuccessFetchResult] +<<<<<<< HEAD (currentResult.blockId, new BufferReleasingInputStream( input, @@ -517,10 +517,19 @@ final class ShuffleBlockFetcherIterator( currentResult.blockId, currentResult.address, detectCorrupt && streamCompressedOrEncrypted)) +======= + val blockId = currentResult.blockId.asInstanceOf[ShuffleBlockId] + new BufferReleasingInputStream(input, this) + } + + // for testing only + def getCurrentBlock(): ShuffleBlockId = { + currentResult.blockId.asInstanceOf[ShuffleBlockId] +>>>>>>> b35d23845c... [SPARK-25299] shuffle reader API (#523) } - def toCompletionIterator: Iterator[(BlockId, InputStream)] = { - CompletionIterator[(BlockId, InputStream), this.type](this, + def toCompletionIterator: Iterator[InputStream] = { + CompletionIterator[InputStream, this.type](this, onCompleteCallback.onComplete(context)) } diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 0a77c4f6d583..8fcbc845d1a7 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -71,9 +71,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val statuses = tracker.getMapSizesByShuffleLocation(10, 0) assert(statuses.toSet === Seq( - (DefaultMapShuffleLocations.get(BlockManagerId("a", "hostA", 1000)), + (Some(DefaultMapShuffleLocations.get(BlockManagerId("a", "hostA", 1000))), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))), - (DefaultMapShuffleLocations.get(BlockManagerId("b", "hostB", 1000)), + (Some(DefaultMapShuffleLocations.get(BlockManagerId("b", "hostB", 1000))), ArrayBuffer((ShuffleBlockId(10, 1, 0), size10000)))) .toSet) assert(0 == tracker.getNumCachedSerializedBroadcast) @@ -155,7 +155,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { slaveTracker.updateEpoch(masterTracker.getEpoch) assert(slaveTracker.getMapSizesByShuffleLocation(10, 0).toSeq === Seq( - (DefaultMapShuffleLocations.get(BlockManagerId("a", "hostA", 1000)), + (Some(DefaultMapShuffleLocations.get(BlockManagerId("a", "hostA", 1000))), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) assert(0 == masterTracker.getNumCachedSerializedBroadcast) @@ -324,12 +324,13 @@ class MapOutputTrackerSuite extends SparkFunSuite { tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), Array(size10000, size0, size1000, size0))) assert(tracker.containsShuffle(10)) - assert(tracker.getMapSizesByShuffleLocation(10, 0, 4).toSeq === + assert(tracker.getMapSizesByShuffleLocation(10, 0, 4) + .map(x => (x._1.get, x._2)).toSeq === Seq( - (DefaultMapShuffleLocations.get(BlockManagerId("a", "hostA", 1000)), - Seq((ShuffleBlockId(10, 0, 1), size1000), (ShuffleBlockId(10, 0, 3), size10000))), (DefaultMapShuffleLocations.get(BlockManagerId("b", "hostB", 1000)), - Seq((ShuffleBlockId(10, 1, 0), size10000), (ShuffleBlockId(10, 1, 2), size1000))) + Seq((ShuffleBlockId(10, 1, 0), size10000), (ShuffleBlockId(10, 1, 2), size1000))), + (DefaultMapShuffleLocations.get(BlockManagerId("a", "hostA", 1000)), + Seq((ShuffleBlockId(10, 0, 1), size1000), (ShuffleBlockId(10, 0, 3), size10000))) ) ) diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 83026c002f1b..1d2713151f50 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -409,12 +409,14 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC val taskContext = new TaskContextImpl( 1, 0, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem) + TaskContext.setTaskContext(taskContext) val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1, taskContext, metrics) val readData = reader.read().toIndexedSeq assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq) manager.unregisterShuffle(0) + TaskContext.unset() } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 8c0d7baccf85..9875b675400a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -704,7 +704,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostB", 1)))) assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet === - HashSet(makeShuffleLocation("hostA"), makeShuffleLocation("hostB"))) + HashSet(makeMaybeShuffleLocation("hostA"), makeMaybeShuffleLocation("hostB"))) complete(taskSets(1), Seq((Success, 42))) assert(results === Map(0 -> 42)) assertDataStructuresEmpty() @@ -732,7 +732,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // we can see both result blocks now assert(mapOutputTracker .getMapSizesByShuffleLocation(shuffleId, 0) - .map(_._1.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host) + .map(_._1.get.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host) .toSet === HashSet("hostA", "hostB")) complete(taskSets(3), Seq((Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) @@ -775,7 +775,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi } } else { assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet === - HashSet(makeShuffleLocation("hostA"), makeShuffleLocation("hostB"))) + HashSet(makeMaybeShuffleLocation("hostA"), makeMaybeShuffleLocation("hostB"))) } } } @@ -1070,7 +1070,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // The MapOutputTracker should know about both map output locations. assert(mapOutputTracker .getMapSizesByShuffleLocation(shuffleId, 0) - .map(_._1.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host) + .map(_._1.get.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host) .toSet === HashSet("hostA", "hostB")) // The first result task fails, with a fetch failure for the output from the first mapper. @@ -1202,11 +1202,11 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // The MapOutputTracker should know about both map output locations. assert(mapOutputTracker .getMapSizesByShuffleLocation(shuffleId, 0) - .map(_._1.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host) + .map(_._1.get.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host) .toSet === HashSet("hostA", "hostB")) assert(mapOutputTracker .getMapSizesByShuffleLocation(shuffleId, 1) - .map(_._1.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host) + .map(_._1.get.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host) .toSet === HashSet("hostA", "hostB")) // The first result task fails, with a fetch failure for the output from the first mapper. @@ -1398,7 +1398,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi makeMapStatus("hostA", reduceRdd.partitions.size))) assert(shuffleStage.numAvailableOutputs === 2) assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet === - HashSet(makeShuffleLocation("hostB"), makeShuffleLocation("hostA"))) + HashSet(makeMaybeShuffleLocation("hostB"), makeMaybeShuffleLocation("hostA"))) // finish the next stage normally, which completes the job complete(taskSets(1), Seq((Success, 42), (Success, 43))) @@ -1804,7 +1804,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // have hostC complete the resubmitted task complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1)))) assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet === - HashSet(makeShuffleLocation("hostC"), makeShuffleLocation("hostB"))) + HashSet(makeMaybeShuffleLocation("hostC"), makeMaybeShuffleLocation("hostB"))) // Make sure that the reduce stage was now submitted. assert(taskSets.size === 3) @@ -2067,7 +2067,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)))) assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet === - HashSet(makeShuffleLocation("hostA"))) + HashSet(makeMaybeShuffleLocation("hostA"))) // Reducer should run on the same host that map task ran val reduceTaskSet = taskSets(1) @@ -2113,7 +2113,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)))) assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet === - HashSet(makeShuffleLocation("hostA"))) + HashSet(makeMaybeShuffleLocation("hostA"))) // Reducer should run where RDD 2 has preferences, even though it also has a shuffle dep val reduceTaskSet = taskSets(1) @@ -2277,7 +2277,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi (Success, makeMapStatus("hostA", rdd1.partitions.length)), (Success, makeMapStatus("hostB", rdd1.partitions.length)))) assert(mapOutputTracker.getMapSizesByShuffleLocation(dep1.shuffleId, 0).map(_._1).toSet === - HashSet(makeShuffleLocation("hostA"), makeShuffleLocation("hostB"))) + HashSet(makeMaybeShuffleLocation("hostA"), makeMaybeShuffleLocation("hostB"))) assert(listener1.results.size === 1) // When attempting the second stage, show a fetch failure @@ -2293,7 +2293,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi complete(taskSets(2), Seq( (Success, makeMapStatus("hostC", rdd2.partitions.length)))) assert(mapOutputTracker.getMapSizesByShuffleLocation(dep1.shuffleId, 0).map(_._1).toSet === - HashSet(makeShuffleLocation("hostC"), makeShuffleLocation("hostB"))) + HashSet(makeMaybeShuffleLocation("hostC"), makeMaybeShuffleLocation("hostB"))) assert(listener2.results.size === 0) // Second stage listener should still not have a result @@ -2303,7 +2303,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi (Success, makeMapStatus("hostB", rdd2.partitions.length)), (Success, makeMapStatus("hostD", rdd2.partitions.length)))) assert(mapOutputTracker.getMapSizesByShuffleLocation(dep2.shuffleId, 0).map(_._1).toSet === - HashSet(makeShuffleLocation("hostB"), makeShuffleLocation("hostD"))) + HashSet(makeMaybeShuffleLocation("hostB"), makeMaybeShuffleLocation("hostD"))) assert(listener2.results.size === 1) // Finally, the reduce job should be running as task set 4; make it see a fetch failure, @@ -2342,7 +2342,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi (Success, makeMapStatus("hostA", rdd1.partitions.length)), (Success, makeMapStatus("hostB", rdd1.partitions.length)))) assert(mapOutputTracker.getMapSizesByShuffleLocation(dep1.shuffleId, 0).map(_._1).toSet === - HashSet(makeShuffleLocation("hostA"), makeShuffleLocation("hostB"))) + HashSet(makeMaybeShuffleLocation("hostA"), makeMaybeShuffleLocation("hostB"))) assert(listener1.results.size === 1) // When attempting stage1, trigger a fetch failure. @@ -2368,7 +2368,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi complete(taskSets(2), Seq( (Success, makeMapStatus("hostC", rdd2.partitions.length)))) assert(mapOutputTracker.getMapSizesByShuffleLocation(dep1.shuffleId, 0).map(_._1).toSet === - Set(makeShuffleLocation("hostC"), makeShuffleLocation("hostB"))) + Set(makeMaybeShuffleLocation("hostC"), makeMaybeShuffleLocation("hostB"))) // After stage0 is finished, stage1 will be submitted and found there is no missing // partitions in it. Then listener got triggered. @@ -2936,6 +2936,10 @@ object DAGSchedulerSuite { def makeShuffleLocation(host: String): MapShuffleLocations = { DefaultMapShuffleLocations.get(makeBlockManagerId(host)) } + + def makeMaybeShuffleLocation(host: String): Option[MapShuffleLocations] = { + Some(DefaultMapShuffleLocations.get(makeBlockManagerId(host))) + } } object FailThisAttempt { diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index b3073addb7cc..6468914bf318 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -21,13 +21,19 @@ import java.io.{ByteArrayOutputStream, InputStream} import java.nio.ByteBuffer import org.mockito.Mockito.{mock, when} +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer import org.apache.spark._ +import org.apache.spark.api.shuffle.ShuffleLocation import org.apache.spark.internal.config +import org.apache.spark.io.CompressionCodec import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.serializer.{JavaSerializer, SerializerManager} +import org.apache.spark.shuffle.io.DefaultShuffleReadSupport import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId} +import org.apache.spark.storage.BlockId /** * Wrapper for a managed buffer that keeps track of how many times retain and release are called. @@ -79,11 +85,14 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext // Create a buffer with some randomly generated key-value pairs to use as the shuffle data // from each mappers (all mappers return the same shuffle data). val byteOutputStream = new ByteArrayOutputStream() - val serializationStream = serializer.newInstance().serializeStream(byteOutputStream) + val compressionCodec = CompressionCodec.createCodec(testConf) + val compressedOutputStream = compressionCodec.compressedOutputStream(byteOutputStream) + val serializationStream = serializer.newInstance().serializeStream(compressedOutputStream) (0 until keyValuePairsPerMap).foreach { i => serializationStream.writeKey(i) serializationStream.writeValue(2*i) } + compressedOutputStream.close() // Setup the mocked BlockManager to return RecordingManagedBuffers. val localBlockManagerId = BlockManagerId("test-client", "test-client", 1) @@ -102,19 +111,20 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext // Make a mocked MapOutputTracker for the shuffle reader to use to determine what // shuffle data to read. - val mapOutputTracker = mock(classOf[MapOutputTracker]) - when(mapOutputTracker.getMapSizesByShuffleLocation( - shuffleId, reduceId, reduceId + 1)).thenReturn { - // Test a scenario where all data is local, to avoid creating a bunch of additional mocks - // for the code to read data over the network. - val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId => - val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) - (shuffleBlockId, byteOutputStream.size().toLong) - } - Seq( - (DefaultMapShuffleLocations.get(localBlockManagerId), shuffleBlockIdsAndSizes)) - .toIterator + val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId => + val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) + (shuffleBlockId, byteOutputStream.size().toLong) } + val blocksToRetrieve = Seq( + (Option.apply(DefaultMapShuffleLocations.get(localBlockManagerId)), shuffleBlockIdsAndSizes)) + val mapOutputTracker = mock(classOf[MapOutputTracker]) + when(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, reduceId, reduceId + 1)) + .thenAnswer(new Answer[Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])]] { + def answer(invocationOnMock: InvocationOnMock): + Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = { + blocksToRetrieve.iterator + } + }) // Create a mocked shuffle handle to pass into HashShuffleReader. val shuffleHandle = { @@ -128,19 +138,23 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext val serializerManager = new SerializerManager( serializer, new SparkConf() - .set(config.SHUFFLE_COMPRESS, false) + .set(config.SHUFFLE_COMPRESS, true) .set(config.SHUFFLE_SPILL_COMPRESS, false)) val taskContext = TaskContext.empty() + TaskContext.setTaskContext(taskContext) val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() + + val shuffleReadSupport = + new DefaultShuffleReadSupport(blockManager, mapOutputTracker, serializerManager, testConf) val shuffleReader = new BlockStoreShuffleReader( shuffleHandle, reduceId, reduceId + 1, taskContext, metrics, + shuffleReadSupport, serializerManager, - blockManager, mapOutputTracker) assert(shuffleReader.read().length === keyValuePairsPerMap * numMaps) @@ -151,5 +165,6 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext assert(buffer.callsToRetain === 1) assert(buffer.callsToRelease === 1) } + TaskContext.unset() } } diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala index b39e37c1e384..4dc1251a4ca8 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala @@ -34,10 +34,10 @@ import org.apache.spark.metrics.source.Source import org.apache.spark.network.BlockTransferService import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.netty.{NettyBlockTransferService, SparkTransportConf} -import org.apache.spark.network.util.TransportConf import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.serializer.{KryoSerializer, SerializerManager} import org.apache.spark.shuffle.{BaseShuffleHandle, BlockStoreShuffleReader, FetchFailedException} +import org.apache.spark.shuffle.io.DefaultShuffleReadSupport import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, BlockManagerMaster, ShuffleBlockId} import org.apache.spark.util.{AccumulatorV2, TaskCompletionListener, TaskFailureListener, Utils} @@ -199,21 +199,28 @@ object BlockStoreShuffleReaderBenchmark extends BenchmarkBase { val shuffleBlockId = ShuffleBlockId(0, mapId, 0) (shuffleBlockId, dataFileLength) } - Seq((DefaultMapShuffleLocations.get(dataBlockId), shuffleBlockIdsAndSizes)).toIterator + Seq((Option.apply(DefaultMapShuffleLocations.get(dataBlockId)), shuffleBlockIdsAndSizes)) + .toIterator } when(dependency.serializer).thenReturn(serializer) when(dependency.aggregator).thenReturn(aggregator) when(dependency.keyOrdering).thenReturn(sorter) + val readSupport = new DefaultShuffleReadSupport( + blockManager, + mapOutputTracker, + serializerManager, + defaultConf) + new BlockStoreShuffleReader[String, String]( shuffleHandle, 0, 1, taskContext, taskContext.taskMetrics().createTempShuffleReadMetrics(), + readSupport, serializerManager, - blockManager, mapOutputTracker ) } diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index a1c298ae9446..f0fe829cddd0 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -22,14 +22,13 @@ import java.nio.ByteBuffer import java.util.UUID import java.util.concurrent.Semaphore -import scala.concurrent.ExecutionContext.Implicits.global -import scala.concurrent.Future - import org.mockito.ArgumentMatchers.{any, eq => meq} import org.mockito.Mockito.{mock, times, verify, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatest.PrivateMethodTester +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.Future import org.apache.spark.{SparkFunSuite, TaskContext} import org.apache.spark.network._ @@ -127,7 +126,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT for (i <- 0 until 5) { assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements") - val (blockId, inputStream) = iterator.next() + val inputStream = iterator.next() + val blockId = iterator.getCurrentBlock() // Make sure we release buffers when a wrapped input stream is closed. val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId)) @@ -203,11 +203,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT taskContext.taskMetrics.createTempShuffleReadMetrics()) verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release() - iterator.next()._2.close() // close() first block's input stream + iterator.next().close() // close() first block's input stream verify(blocks(ShuffleBlockId(0, 0, 0)), times(1)).release() // Get the 2nd block but do not exhaust the iterator - val subIter = iterator.next()._2 + val subIter = iterator.next() // Complete the task; then the 2nd block buffer should be exhausted verify(blocks(ShuffleBlockId(0, 1, 0)), times(0)).release() @@ -427,7 +427,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT sem.acquire() // The first block should be returned without an exception - val (id1, _) = iterator.next() + iterator.next() + val id1 = iterator.getCurrentBlock() assert(id1 === ShuffleBlockId(0, 0, 0)) when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) @@ -447,6 +448,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT intercept[FetchFailedException] { iterator.next() } sem.acquire() + intercept[FetchFailedException] { iterator.next() } } @@ -558,16 +560,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT true, true, taskContext.taskMetrics.createTempShuffleReadMetrics()) - val (id, st) = iterator.next() - // Check that the test setup is correct -- make sure we have a concatenated stream. - assert (st.asInstanceOf[BufferReleasingInputStream].delegate.isInstanceOf[SequenceInputStream]) - - val dst = new DataInputStream(st) - for (i <- 1 to 2500) { - assert(i === dst.readInt()) - } - assert(dst.read() === -1) - dst.close() + // Blocks should be returned without exceptions. + iterator.next() + val blockId1 = iterator.getCurrentBlock() + iterator.next() + val blockId2 = iterator.getCurrentBlock() + assert(Set(blockId1, blockId2) === Set(ShuffleBlockId(0, 0, 0), ShuffleBlockId(0, 1, 0))) } test("retry corrupt blocks (disabled)") { @@ -626,11 +624,14 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT sem.acquire() // The first block should be returned without an exception - val (id1, _) = iterator.next() + iterator.next() + val id1 = iterator.getCurrentBlock() assert(id1 === ShuffleBlockId(0, 0, 0)) - val (id2, _) = iterator.next() + iterator.next() + val id2 = iterator.getCurrentBlock() assert(id2 === ShuffleBlockId(0, 1, 0)) - val (id3, _) = iterator.next() + iterator.next() + val id3 = iterator.getCurrentBlock() assert(id3 === ShuffleBlockId(0, 2, 0)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala index 079ff25fcb67..22cfbf506c64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala @@ -156,10 +156,11 @@ class ShuffledRowRDD( override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { val shuffledRowPartition = split.asInstanceOf[ShuffledRowRDDPartition] - val tempMetrics = context.taskMetrics().createTempShuffleReadMetrics() // `SQLShuffleReadMetricsReporter` will update its own metrics for SQL exchange operator, // as well as the `tempMetrics` for basic shuffle metrics. - val sqlMetricsReporter = new SQLShuffleReadMetricsReporter(tempMetrics, metrics) + context.taskMetrics().decorateTempShuffleReadMetrics( + tempMetrics => new SQLShuffleReadMetricsReporter(tempMetrics, metrics)) + val tempMetrics = context.taskMetrics().createTempShuffleReadMetrics() // The range of pre-shuffle partitions that we are fetching at here is // [startPreShufflePartitionIndex, endPreShufflePartitionIndex - 1]. val reader = @@ -168,7 +169,7 @@ class ShuffledRowRDD( shuffledRowPartition.startPreShufflePartitionIndex, shuffledRowPartition.endPreShufflePartitionIndex, context, - sqlMetricsReporter) + tempMetrics) reader.read().asInstanceOf[Iterator[Product2[Int, InternalRow]]].map(_._2) } From 89ecbf008888c9b7403ad60965910f97ceb2783b Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Tue, 30 Apr 2019 18:17:36 -0700 Subject: [PATCH 13/17] [SPARK-25299] fix reader benchmarks (#544) Fix the stubbing of the reader benchmark tests --- .../BlockStoreShuffleReaderBenchmark.scala | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala index 4dc1251a4ca8..4f5bb264170d 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala @@ -24,9 +24,12 @@ import org.mockito.{Mock, MockitoAnnotations} import org.mockito.Answers.RETURNS_SMART_NULLS import org.mockito.ArgumentMatchers.any import org.mockito.Mockito.when +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer import scala.util.Random import org.apache.spark.{Aggregator, MapOutputTracker, ShuffleDependency, SparkConf, SparkEnv, TaskContext} +import org.apache.spark.api.shuffle.ShuffleLocation import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} @@ -194,14 +197,17 @@ object BlockStoreShuffleReaderBenchmark extends BenchmarkBase { } when(mapOutputTracker.getMapSizesByShuffleLocation(0, 0, 1)) - .thenReturn { - val shuffleBlockIdsAndSizes = (0 until NUM_MAPS).map { mapId => - val shuffleBlockId = ShuffleBlockId(0, mapId, 0) - (shuffleBlockId, dataFileLength) + .thenAnswer(new Answer[Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])]] { + def answer(invocationOnMock: InvocationOnMock): + Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = { + val shuffleBlockIdsAndSizes = (0 until NUM_MAPS).map { mapId => + val shuffleBlockId = ShuffleBlockId(0, mapId, 0) + (shuffleBlockId, dataFileLength) + } + Seq((Option.apply(DefaultMapShuffleLocations.get(dataBlockId)), shuffleBlockIdsAndSizes)) + .toIterator } - Seq((Option.apply(DefaultMapShuffleLocations.get(dataBlockId)), shuffleBlockIdsAndSizes)) - .toIterator - } + }) when(dependency.serializer).thenReturn(serializer) when(dependency.aggregator).thenReturn(aggregator) From 9e34cd37fe878b42555abb9d53e614023b06697d Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Tue, 7 May 2019 16:38:49 -0700 Subject: [PATCH 14/17] merge conflicts --- .../spark/api/shuffle/ShuffleDataIO.java | 3 + .../api/shuffle/ShuffleDriverComponents.java | 33 +++++++++ .../shuffle/ShuffleExecutorComponents.java | 4 +- .../shuffle/sort/io/DefaultShuffleDataIO.java | 7 ++ .../io/DefaultShuffleExecutorComponents.java | 4 +- .../DefaultShuffleDriverComponents.java | 54 ++++++++++++++ .../org/apache/spark/ContextCleaner.scala | 8 ++- .../scala/org/apache/spark/SparkContext.scala | 13 +++- .../shuffle/sort/SortShuffleManager.scala | 9 ++- .../spark/InternalAccumulatorSuite.scala | 3 +- .../ShuffleDriverComponentsSuite.scala | 71 +++++++++++++++++++ 11 files changed, 201 insertions(+), 8 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/api/shuffle/ShuffleDriverComponents.java create mode 100644 core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/DefaultShuffleDriverComponents.java create mode 100644 core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDataIO.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDataIO.java index 4cb40f6dd00b..dd7c0ac7320c 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDataIO.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDataIO.java @@ -27,5 +27,8 @@ */ @Experimental public interface ShuffleDataIO { + String SHUFFLE_SPARK_CONF_PREFIX = "spark.shuffle.plugin."; + + ShuffleDriverComponents driver(); ShuffleExecutorComponents executor(); } diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDriverComponents.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDriverComponents.java new file mode 100644 index 000000000000..6a0ec8d44fd4 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDriverComponents.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.shuffle; + +import java.io.IOException; +import java.util.Map; + +public interface ShuffleDriverComponents { + + /** + * @return additional SparkConf values necessary for the executors. + */ + Map initializeApplication(); + + void cleanupApplication() throws IOException; + + void removeShuffleData(int shuffleId, boolean blocking) throws IOException; +} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java index 8baa3bf6f859..a5fa032bf651 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java @@ -19,6 +19,8 @@ import org.apache.spark.annotation.Experimental; +import java.util.Map; + /** * :: Experimental :: * An interface for building shuffle support for Executors @@ -27,7 +29,7 @@ */ @Experimental public interface ShuffleExecutorComponents { - void initializeExecutor(String appId, String execId); + void initializeExecutor(String appId, String execId, Map extraConfigs); ShuffleWriteSupport writes(); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleDataIO.java index 906600c0f15f..7c124c1fe68b 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleDataIO.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleDataIO.java @@ -18,8 +18,10 @@ package org.apache.spark.shuffle.sort.io; import org.apache.spark.SparkConf; +import org.apache.spark.api.shuffle.ShuffleDriverComponents; import org.apache.spark.api.shuffle.ShuffleExecutorComponents; import org.apache.spark.api.shuffle.ShuffleDataIO; +import org.apache.spark.shuffle.sort.lifecycle.DefaultShuffleDriverComponents; public class DefaultShuffleDataIO implements ShuffleDataIO { @@ -33,4 +35,9 @@ public DefaultShuffleDataIO(SparkConf sparkConf) { public ShuffleExecutorComponents executor() { return new DefaultShuffleExecutorComponents(sparkConf); } + + @Override + public ShuffleDriverComponents driver() { + return new DefaultShuffleDriverComponents(); + } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java index 91a5d7f7945e..3b5f9670d64d 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java @@ -28,6 +28,8 @@ import org.apache.spark.shuffle.io.DefaultShuffleReadSupport; import org.apache.spark.storage.BlockManager; +import java.util.Map; + public class DefaultShuffleExecutorComponents implements ShuffleExecutorComponents { private final SparkConf sparkConf; @@ -41,7 +43,7 @@ public DefaultShuffleExecutorComponents(SparkConf sparkConf) { } @Override - public void initializeExecutor(String appId, String execId) { + public void initializeExecutor(String appId, String execId, Map extraConfigs) { blockManager = SparkEnv.get().blockManager(); mapOutputTracker = SparkEnv.get().mapOutputTracker(); serializerManager = SparkEnv.get().serializerManager(); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/DefaultShuffleDriverComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/DefaultShuffleDriverComponents.java new file mode 100644 index 000000000000..a3eddc8ec930 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/DefaultShuffleDriverComponents.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort.lifecycle; + +import com.google.common.collect.ImmutableMap; +import org.apache.spark.SparkEnv; +import org.apache.spark.api.shuffle.ShuffleDriverComponents; +import org.apache.spark.storage.BlockManagerMaster; + +import java.io.IOException; +import java.util.Map; + +public class DefaultShuffleDriverComponents implements ShuffleDriverComponents { + + private BlockManagerMaster blockManagerMaster; + + @Override + public Map initializeApplication() { + blockManagerMaster = SparkEnv.get().blockManager().master(); + return ImmutableMap.of(); + } + + @Override + public void cleanupApplication() { + // do nothing + } + + @Override + public void removeShuffleData(int shuffleId, boolean blocking) throws IOException { + checkInitialized(); + blockManagerMaster.removeShuffle(shuffleId, blocking); + } + + private void checkInitialized() { + if (blockManagerMaster == null) { + throw new IllegalStateException("Driver components must be initialized before using"); + } + } +} diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index 305ec46a364a..fa28e54116d2 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -23,6 +23,7 @@ import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue, Scheduled import scala.collection.JavaConverters._ +import org.apache.spark.api.shuffle.ShuffleDriverComponents import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ @@ -58,7 +59,9 @@ private class CleanupTaskWeakReference( * to be processed when the associated object goes out of scope of the application. Actual * cleanup is performed in a separate daemon thread. */ -private[spark] class ContextCleaner(sc: SparkContext) extends Logging { +private[spark] class ContextCleaner( + sc: SparkContext, + shuffleDriverComponents: ShuffleDriverComponents) extends Logging { /** * A buffer to ensure that `CleanupTaskWeakReference`s are not garbage collected as long as they @@ -222,7 +225,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { try { logDebug("Cleaning shuffle " + shuffleId) mapOutputTrackerMaster.unregisterShuffle(shuffleId) - blockManagerMaster.removeShuffle(shuffleId, blocking) + shuffleDriverComponents.removeShuffleData(shuffleId, blocking) listeners.asScala.foreach(_.shuffleCleaned(shuffleId)) logInfo("Cleaned shuffle " + shuffleId) } catch { @@ -270,7 +273,6 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } } - private def blockManagerMaster = sc.env.blockManager.master private def broadcastManager = sc.env.broadcastManager private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 4abb18d4aaa7..63919fe3d314 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -40,6 +40,7 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHad import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.api.shuffle.{ShuffleDataIO, ShuffleDriverComponents} import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} import org.apache.spark.input.{FixedLengthBinaryInputFormat, PortableDataStream, StreamInputFormat, WholeTextFileInputFormat} @@ -213,6 +214,7 @@ class SparkContext(config: SparkConf) extends Logging { private var _shutdownHookRef: AnyRef = _ private var _statusStore: AppStatusStore = _ private var _heartbeater: Heartbeater = _ + private var _shuffleDriverComponents: ShuffleDriverComponents = _ /* ------------------------------------------------------------------------------------- * | Accessors and public fields. These provide access to the internal state of the | @@ -486,6 +488,14 @@ class SparkContext(config: SparkConf) extends Logging { executorEnvs ++= _conf.getExecutorEnv executorEnvs("SPARK_USER") = sparkUser + val configuredPluginClasses = conf.get(SHUFFLE_IO_PLUGIN_CLASS) + val maybeIO = Utils.loadExtensions( + classOf[ShuffleDataIO], Seq(configuredPluginClasses), conf) + require(maybeIO.size == 1, s"Failed to load plugins of type $configuredPluginClasses") + _shuffleDriverComponents = maybeIO.head.driver() + _shuffleDriverComponents.initializeApplication().asScala.foreach { + case (k, v) => _conf.set(ShuffleDataIO.SHUFFLE_SPARK_CONF_PREFIX + k, v) } + // We need to register "HeartbeatReceiver" before "createTaskScheduler" because Executor will // retrieve "HeartbeatReceiver" in the constructor. (SPARK-6640) _heartbeatReceiver = env.rpcEnv.setupEndpoint( @@ -555,7 +565,7 @@ class SparkContext(config: SparkConf) extends Logging { _cleaner = if (_conf.get(CLEANER_REFERENCE_TRACKING)) { - Some(new ContextCleaner(this)) + Some(new ContextCleaner(this, _shuffleDriverComponents)) } else { None } @@ -1911,6 +1921,7 @@ class SparkContext(config: SparkConf) extends Logging { } _heartbeater = null } + _shuffleDriverComponents.cleanupApplication() if (env != null && _heartbeatReceiver != null) { Utils.tryLogNonFatalError { env.rpcEnv.stop(_heartbeatReceiver) diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 38495ae523d8..1c4ced695133 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -19,6 +19,8 @@ package org.apache.spark.shuffle.sort import java.util.concurrent.ConcurrentHashMap +import scala.collection.JavaConverters._ + import org.apache.spark._ import org.apache.spark.api.shuffle.{ShuffleDataIO, ShuffleExecutorComponents} import org.apache.spark.internal.{config, Logging} @@ -225,7 +227,12 @@ private[spark] object SortShuffleManager extends Logging { classOf[ShuffleDataIO], Seq(configuredPluginClasses), conf) require(maybeIO.size == 1, s"Failed to load plugins of type $configuredPluginClasses") val executorComponents = maybeIO.head.executor() - executorComponents.initializeExecutor(conf.getAppId, SparkEnv.get.executorId) + val extraConfigs = conf.getAllWithPrefix(ShuffleDataIO.SHUFFLE_SPARK_CONF_PREFIX) + .toMap + executorComponents.initializeExecutor( + conf.getAppId, + SparkEnv.get.executorId, + extraConfigs.asJava) executorComponents } } diff --git a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala index 62824a5bec9d..28cbeeda7a88 100644 --- a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala @@ -210,7 +210,8 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext { /** * A special [[ContextCleaner]] that saves the IDs of the accumulators registered for cleanup. */ - private class SaveAccumContextCleaner(sc: SparkContext) extends ContextCleaner(sc) { + private class SaveAccumContextCleaner(sc: SparkContext) extends + ContextCleaner(sc, null) { private val accumsRegistered = new ArrayBuffer[Long] override def registerAccumulatorForCleanup(a: AccumulatorV2[_, _]): Unit = { diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala new file mode 100644 index 000000000000..dbb954945a8b --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import java.util + +import com.google.common.collect.ImmutableMap + +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite} +import org.apache.spark.api.shuffle.{ShuffleDataIO, ShuffleDriverComponents, ShuffleExecutorComponents, ShuffleWriteSupport} +import org.apache.spark.internal.config.SHUFFLE_IO_PLUGIN_CLASS +import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport + +class ShuffleDriverComponentsSuite extends SparkFunSuite with LocalSparkContext { + test(s"test serialization of shuffle initialization conf to executors") { + val testConf = new SparkConf() + .setAppName("testing") + .setMaster("local-cluster[2,1,1024]") + .set(SHUFFLE_IO_PLUGIN_CLASS, "org.apache.spark.shuffle.TestShuffleDataIO") + + sc = new SparkContext(testConf) + + sc.parallelize(Seq((1, "one"), (2, "two"), (3, "three")), 3) + .groupByKey() + .collect() + } +} + +class TestShuffleDriverComponents extends ShuffleDriverComponents { + override def initializeApplication(): util.Map[String, String] = + ImmutableMap.of("test-key", "test-value") + + override def cleanupApplication(): Unit = {} + + override def removeShuffleData(shuffleId: Int, blocking: Boolean): Unit = {} +} + +class TestShuffleDataIO(sparkConf: SparkConf) extends ShuffleDataIO { + override def driver(): ShuffleDriverComponents = new TestShuffleDriverComponents() + + override def executor(): ShuffleExecutorComponents = + new TestShuffleExecutorComponents(sparkConf) +} + +class TestShuffleExecutorComponents(sparkConf: SparkConf) extends ShuffleExecutorComponents { + override def initializeExecutor(appId: String, execId: String, + extraConfigs: util.Map[String, String]): Unit = { + assert(extraConfigs.get("test-key") == "test-value") + } + + override def writes(): ShuffleWriteSupport = { + val blockManager = SparkEnv.get.blockManager + val blockResolver = new IndexShuffleBlockResolver(sparkConf, blockManager) + new DefaultShuffleWriteSupport(sparkConf, blockResolver) + } +} From 10bc73e7827909d6747cacda51cbe3c7a9e66ca2 Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Tue, 7 May 2019 16:38:49 -0700 Subject: [PATCH 15/17] [SPARK-25299] Driver lifecycle api (#533) Introduce driver shuffle lifecycle APIs From 353a2b50c2628d1bc6a41c2c33c8ced83951c864 Mon Sep 17 00:00:00 2001 From: mccheah Date: Thu, 9 May 2019 11:40:55 -0700 Subject: [PATCH 16/17] [SPARK-25299] Fix semantic merge conflict that broke the build. (#546) --- .../spark/shuffle/ShuffleDriverComponentsSuite.scala | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala index dbb954945a8b..e8372c045860 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala @@ -22,8 +22,9 @@ import java.util import com.google.common.collect.ImmutableMap import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite} -import org.apache.spark.api.shuffle.{ShuffleDataIO, ShuffleDriverComponents, ShuffleExecutorComponents, ShuffleWriteSupport} +import org.apache.spark.api.shuffle.{ShuffleDataIO, ShuffleDriverComponents, ShuffleExecutorComponents, ShuffleReadSupport, ShuffleWriteSupport} import org.apache.spark.internal.config.SHUFFLE_IO_PLUGIN_CLASS +import org.apache.spark.shuffle.io.DefaultShuffleReadSupport import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport class ShuffleDriverComponentsSuite extends SparkFunSuite with LocalSparkContext { @@ -66,6 +67,13 @@ class TestShuffleExecutorComponents(sparkConf: SparkConf) extends ShuffleExecuto override def writes(): ShuffleWriteSupport = { val blockManager = SparkEnv.get.blockManager val blockResolver = new IndexShuffleBlockResolver(sparkConf, blockManager) - new DefaultShuffleWriteSupport(sparkConf, blockResolver) + new DefaultShuffleWriteSupport(sparkConf, blockResolver, blockManager.shuffleServerId) + } + + override def reads(): ShuffleReadSupport = { + val blockManager = SparkEnv.get.blockManager + val mapOutputTracker = SparkEnv.get.mapOutputTracker + val serializerManager = SparkEnv.get.serializerManager + new DefaultShuffleReadSupport(blockManager, mapOutputTracker, serializerManager, sparkConf) } } From 9c4f1e3bfd900e8f46cb40bd0d44b215f0af3abb Mon Sep 17 00:00:00 2001 From: mccheah Date: Fri, 24 May 2019 12:15:03 -0700 Subject: [PATCH 17/17] [SPARK-25299] Propose a new NIO transfer API for partition writing. (#535) * Propose a new NIO transfer API for partition writing. This solves the consistency and resource leakage concerns with the first iteration of thie API, where it would not be obvious that the streamable resources created by ShufflePartitionWriter needed to be closed by ShuffleParittionWriter#close as opposed to closing the resources directly. This introduces the following adjustments: - Channel-based writes are separated out to their own module, SupportsTransferTo. This allows the transfer-to APIs to be modified independently, and users that only provide output streams can ignore the NIO APIs entirely. This also allows us to mark the base ShufflePartitionWriter as a stable API eventually while keeping the NIO APIs marked as experimental or developer-api. - We add APIs that explicitly encodes the notion of transferring bytes from one source to another. The partition writer returns an instance of TransferrableWritableByteChannel, which has APIs for accepting a TransferrableReadableByteChannel and can tell the readable byte channel to transfer its bytes out to some destination sink. - The resources returned by ShufflePartitionWriter are always closed. Internally, DefaultMapOutputWriter keeps resources open until commitAllPartitions() is called. * Migrate unsafe shuffle writer to use new byte channel API. * More sane implementation for unsafe * Fix style * Address comments * Fix imports * Fix build * Fix more build problems * Address comments. --- .../api/shuffle/ShufflePartitionWriter.java | 42 ++------ .../api/shuffle/ShuffleWriteSupport.java | 2 +- .../spark/api/shuffle/SupportsTransferTo.java | 53 +++++++++ .../TransferrableWritableByteChannel.java | 54 ++++++++++ .../sort/BypassMergeSortShuffleWriter.java | 79 +++++++------- ...faultTransferrableWritableByteChannel.java | 51 +++++++++ .../shuffle/sort/UnsafeShuffleWriter.java | 90 +++++++--------- .../io/DefaultShuffleMapOutputWriter.java | 102 +++++++++++------- .../shuffle/sort/SortShuffleManager.scala | 2 - .../util/collection/ExternalSorter.scala | 3 - .../ShufflePartitionPairsWriter.scala | 16 +-- .../sort/UnsafeShuffleWriterSuite.java | 6 +- ...ypassMergeSortShuffleWriterBenchmark.scala | 1 - .../BypassMergeSortShuffleWriterSuite.scala | 5 - .../sort/UnsafeShuffleWriterBenchmark.scala | 4 +- .../DefaultShuffleMapOutputWriterSuite.scala | 32 +++--- 16 files changed, 331 insertions(+), 211 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/api/shuffle/SupportsTransferTo.java create mode 100644 core/src/main/java/org/apache/spark/api/shuffle/TransferrableWritableByteChannel.java create mode 100644 core/src/main/java/org/apache/spark/shuffle/sort/DefaultTransferrableWritableByteChannel.java diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/api/shuffle/ShufflePartitionWriter.java index 6a53803e5d11..74c928b0b9c8 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShufflePartitionWriter.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShufflePartitionWriter.java @@ -17,13 +17,10 @@ package org.apache.spark.api.shuffle; -import java.io.Closeable; import java.io.IOException; import java.io.OutputStream; -import java.nio.channels.Channels; -import java.nio.channels.WritableByteChannel; -import org.apache.http.annotation.Experimental; +import org.apache.spark.annotation.Experimental; /** * :: Experimental :: @@ -32,43 +29,16 @@ * @since 3.0.0 */ @Experimental -public interface ShufflePartitionWriter extends Closeable { +public interface ShufflePartitionWriter { /** - * Returns an underlying {@link OutputStream} that can write bytes to the underlying data store. - *

- * Note that this stream itself is not closed by the caller; close the stream in the - * implementation of this interface's {@link #close()}. + * Opens and returns an underlying {@link OutputStream} that can write bytes to the underlying + * data store. */ - OutputStream toStream() throws IOException; + OutputStream openStream() throws IOException; /** - * Returns an underlying {@link WritableByteChannel} that can write bytes to the underlying data - * store. - *

- * Note that this channel itself is not closed by the caller; close the channel in the - * implementation of this interface's {@link #close()}. - */ - default WritableByteChannel toChannel() throws IOException { - return Channels.newChannel(toStream()); - } - - /** - * Get the number of bytes written by this writer's stream returned by {@link #toStream()} or - * the channel returned by {@link #toChannel()}. + * Get the number of bytes written by this writer's stream returned by {@link #openStream()}. */ long getNumBytesWritten(); - - /** - * Close all resources created by this ShufflePartitionWriter, via calls to {@link #toStream()} - * or {@link #toChannel()}. - *

- * This must always close any stream returned by {@link #toStream()}. - *

- * Note that the default version of {@link #toChannel()} returns a {@link WritableByteChannel} - * that does not itself need to be closed up front; only the underlying output stream given by - * {@link #toStream()} must be closed. - */ - @Override - void close() throws IOException; } diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleWriteSupport.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleWriteSupport.java index 6c69d5db9fd0..7e2b6cf4133f 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleWriteSupport.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleWriteSupport.java @@ -19,7 +19,7 @@ import java.io.IOException; -import org.apache.http.annotation.Experimental; +import org.apache.spark.annotation.Experimental; /** * :: Experimental :: diff --git a/core/src/main/java/org/apache/spark/api/shuffle/SupportsTransferTo.java b/core/src/main/java/org/apache/spark/api/shuffle/SupportsTransferTo.java new file mode 100644 index 000000000000..866b61d0bafd --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/shuffle/SupportsTransferTo.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.shuffle; + +import java.io.IOException; + +import org.apache.spark.annotation.Experimental; + +/** + * :: Experimental :: + * Indicates that partition writers can transfer bytes directly from input byte channels to + * output channels that stream data to the underlying shuffle partition storage medium. + *

+ * This API is separated out for advanced users because it only needs to be used for + * specific low-level optimizations. The idea is that the returned channel can transfer bytes + * from the input file channel out to the backing storage system without copying data into + * memory. + *

+ * Most shuffle plugin implementations should use {@link ShufflePartitionWriter} instead. + * + * @since 3.0.0 + */ +@Experimental +public interface SupportsTransferTo extends ShufflePartitionWriter { + + /** + * Opens and returns a {@link TransferrableWritableByteChannel} for transferring bytes from + * input byte channels to the underlying shuffle data store. + */ + TransferrableWritableByteChannel openTransferrableChannel() throws IOException; + + /** + * Returns the number of bytes written either by this writer's output stream opened by + * {@link #openStream()} or the byte channel opened by {@link #openTransferrableChannel()}. + */ + @Override + long getNumBytesWritten(); +} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/TransferrableWritableByteChannel.java b/core/src/main/java/org/apache/spark/api/shuffle/TransferrableWritableByteChannel.java new file mode 100644 index 000000000000..18234d7c4c94 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/shuffle/TransferrableWritableByteChannel.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.shuffle; + +import java.io.Closeable; +import java.io.IOException; + +import java.nio.channels.FileChannel; +import java.nio.channels.WritableByteChannel; +import org.apache.spark.annotation.Experimental; + +/** + * :: Experimental :: + * Represents an output byte channel that can copy bytes from input file channels to some + * arbitrary storage system. + *

+ * This API is provided for advanced users who can transfer bytes from a file channel to + * some output sink without copying data into memory. Most users should not need to use + * this functionality; this is primarily provided for the built-in shuffle storage backends + * that persist shuffle files on local disk. + *

+ * For a simpler alternative, see {@link ShufflePartitionWriter}. + * + * @since 3.0.0 + */ +@Experimental +public interface TransferrableWritableByteChannel extends Closeable { + + /** + * Copy all bytes from the source readable byte channel into this byte channel. + * + * @param source File to transfer bytes from. Do not call anything on this channel other than + * {@link FileChannel#transferTo(long, long, WritableByteChannel)}. + * @param transferStartPosition Start position of the input file to transfer from. + * @param numBytesToTransfer Number of bytes to transfer from the given source. + */ + void transferFrom(FileChannel source, long transferStartPosition, long numBytesToTransfer) + throws IOException; +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 22386c39aca0..128b90429209 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -21,12 +21,10 @@ import java.io.FileInputStream; import java.io.IOException; import java.io.OutputStream; +import java.nio.channels.Channels; import java.nio.channels.FileChannel; -import java.nio.channels.WritableByteChannel; import javax.annotation.Nullable; -import org.apache.spark.api.java.Optional; -import org.apache.spark.api.shuffle.MapShuffleLocations; import scala.None$; import scala.Option; import scala.Product2; @@ -38,19 +36,22 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.spark.Partitioner; +import org.apache.spark.ShuffleDependency; +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.Optional; +import org.apache.spark.api.shuffle.MapShuffleLocations; +import org.apache.spark.api.shuffle.SupportsTransferTo; import org.apache.spark.api.shuffle.ShuffleMapOutputWriter; import org.apache.spark.api.shuffle.ShufflePartitionWriter; import org.apache.spark.api.shuffle.ShuffleWriteSupport; +import org.apache.spark.api.shuffle.TransferrableWritableByteChannel; import org.apache.spark.internal.config.package$; -import org.apache.spark.Partitioner; -import org.apache.spark.ShuffleDependency; -import org.apache.spark.SparkConf; import org.apache.spark.scheduler.MapStatus; import org.apache.spark.scheduler.MapStatus$; import org.apache.spark.serializer.Serializer; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; -import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.shuffle.ShuffleWriter; import org.apache.spark.storage.*; import org.apache.spark.util.Utils; @@ -90,7 +91,6 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { private final int mapId; private final Serializer serializer; private final ShuffleWriteSupport shuffleWriteSupport; - private final IndexShuffleBlockResolver shuffleBlockResolver; /** Array of file writers, one for each partition */ private DiskBlockObjectWriter[] partitionWriters; @@ -107,7 +107,6 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { BypassMergeSortShuffleWriter( BlockManager blockManager, - IndexShuffleBlockResolver shuffleBlockResolver, BypassMergeSortShuffleHandle handle, int mapId, SparkConf conf, @@ -124,7 +123,6 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { this.numPartitions = partitioner.numPartitions(); this.writeMetrics = writeMetrics; this.serializer = dep.serializer(); - this.shuffleBlockResolver = shuffleBlockResolver; this.shuffleWriteSupport = shuffleWriteSupport; } @@ -209,40 +207,43 @@ private long[] writePartitionedData(ShuffleMapOutputWriter mapOutputWriter) thro try { for (int i = 0; i < numPartitions; i++) { final File file = partitionWriterSegments[i].file(); - boolean copyThrewException = true; - ShufflePartitionWriter writer = null; - try { - writer = mapOutputWriter.getPartitionWriter(i); - if (!file.exists()) { - copyThrewException = false; - } else { - if (transferToEnabled) { - WritableByteChannel outputChannel = writer.toChannel(); - FileInputStream in = new FileInputStream(file); - try (FileChannel inputChannel = in.getChannel()) { - Utils.copyFileStreamNIO(inputChannel, outputChannel, 0, inputChannel.size()); - copyThrewException = false; - } finally { - Closeables.close(in, copyThrewException); - } - } else { - OutputStream tempOutputStream = writer.toStream(); - FileInputStream in = new FileInputStream(file); - try { - Utils.copyStream(in, tempOutputStream, false, false); - copyThrewException = false; - } finally { - Closeables.close(in, copyThrewException); + ShufflePartitionWriter writer = mapOutputWriter.getPartitionWriter(i); + if (file.exists()) { + boolean copyThrewException = true; + if (transferToEnabled) { + FileInputStream in = new FileInputStream(file); + TransferrableWritableByteChannel outputChannel = null; + try (FileChannel inputChannel = in.getChannel()) { + if (writer instanceof SupportsTransferTo) { + outputChannel = ((SupportsTransferTo) writer).openTransferrableChannel(); + } else { + // Use default transferrable writable channel anyways in order to have parity with + // UnsafeShuffleWriter. + outputChannel = new DefaultTransferrableWritableByteChannel( + Channels.newChannel(writer.openStream())); } + outputChannel.transferFrom(inputChannel, 0L, inputChannel.size()); + copyThrewException = false; + } finally { + Closeables.close(in, copyThrewException); + Closeables.close(outputChannel, copyThrewException); } - if (!file.delete()) { - logger.error("Unable to delete file for partition {}", i); + } else { + FileInputStream in = new FileInputStream(file); + OutputStream outputStream = null; + try { + outputStream = writer.openStream(); + Utils.copyStream(in, outputStream, false, false); + copyThrewException = false; + } finally { + Closeables.close(in, copyThrewException); + Closeables.close(outputStream, copyThrewException); } } - } finally { - Closeables.close(writer, copyThrewException); + if (!file.delete()) { + logger.error("Unable to delete file for partition {}", i); + } } - lengths[i] = writer.getNumBytesWritten(); } } finally { diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/DefaultTransferrableWritableByteChannel.java b/core/src/main/java/org/apache/spark/shuffle/sort/DefaultTransferrableWritableByteChannel.java new file mode 100644 index 000000000000..64ce851e392d --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/DefaultTransferrableWritableByteChannel.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort; + +import java.io.IOException; +import java.nio.channels.FileChannel; +import java.nio.channels.WritableByteChannel; +import org.apache.spark.api.shuffle.TransferrableWritableByteChannel; +import org.apache.spark.util.Utils; + +/** + * This is used when transferTo is enabled but the shuffle plugin hasn't implemented + * {@link org.apache.spark.api.shuffle.SupportsTransferTo}. + *

+ * This default implementation exists as a convenience to the unsafe shuffle writer and + * the bypass merge sort shuffle writers. + */ +public class DefaultTransferrableWritableByteChannel implements TransferrableWritableByteChannel { + + private final WritableByteChannel delegate; + + public DefaultTransferrableWritableByteChannel(WritableByteChannel delegate) { + this.delegate = delegate; + } + + @Override + public void transferFrom( + FileChannel source, long transferStartPosition, long numBytesToTransfer) { + Utils.copyFileStreamNIO(source, delegate, transferStartPosition, numBytesToTransfer); + } + + @Override + public void close() throws IOException { + delegate.close(); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index e4175f985d91..5dd0821e10f5 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -17,14 +17,12 @@ package org.apache.spark.shuffle.sort; +import java.nio.channels.Channels; import javax.annotation.Nullable; import java.io.*; import java.nio.channels.FileChannel; -import java.nio.channels.WritableByteChannel; import java.util.Iterator; -import org.apache.spark.api.java.Optional; -import org.apache.spark.api.shuffle.MapShuffleLocations; import scala.Option; import scala.Product2; import scala.collection.JavaConverters; @@ -39,14 +37,17 @@ import org.apache.spark.*; import org.apache.spark.annotation.Private; +import org.apache.spark.api.java.Optional; +import org.apache.spark.api.shuffle.MapShuffleLocations; +import org.apache.spark.api.shuffle.TransferrableWritableByteChannel; import org.apache.spark.api.shuffle.ShuffleMapOutputWriter; import org.apache.spark.api.shuffle.ShufflePartitionWriter; import org.apache.spark.api.shuffle.ShuffleWriteSupport; +import org.apache.spark.api.shuffle.SupportsTransferTo; import org.apache.spark.internal.config.package$; import org.apache.spark.io.CompressionCodec; import org.apache.spark.io.CompressionCodec$; import org.apache.spark.io.NioBufferedFileInputStream; -import org.apache.commons.io.output.CloseShieldOutputStream; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.network.util.LimitedInputStream; import org.apache.spark.scheduler.MapStatus; @@ -54,11 +55,9 @@ import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; import org.apache.spark.serializer.SerializationStream; import org.apache.spark.serializer.SerializerInstance; -import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.shuffle.ShuffleWriter; import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.Platform; -import org.apache.spark.util.Utils; @Private public class UnsafeShuffleWriter extends ShuffleWriter { @@ -71,7 +70,6 @@ public class UnsafeShuffleWriter extends ShuffleWriter { static final int DEFAULT_INITIAL_SER_BUFFER_SIZE = 1024 * 1024; private final BlockManager blockManager; - private final IndexShuffleBlockResolver shuffleBlockResolver; private final TaskMemoryManager memoryManager; private final SerializerInstance serializer; private final Partitioner partitioner; @@ -107,7 +105,6 @@ private static final class MyByteArrayOutputStream extends ByteArrayOutputStream public UnsafeShuffleWriter( BlockManager blockManager, - IndexShuffleBlockResolver shuffleBlockResolver, TaskMemoryManager memoryManager, SerializedShuffleHandle handle, int mapId, @@ -123,7 +120,6 @@ public UnsafeShuffleWriter( " reduce partitions"); } this.blockManager = blockManager; - this.shuffleBlockResolver = shuffleBlockResolver; this.memoryManager = memoryManager; this.mapId = mapId; final ShuffleDependency dep = handle.dependency(); @@ -364,45 +360,37 @@ private long[] mergeSpillsWithFileStream( } for (int partition = 0; partition < numPartitions; partition++) { boolean copyThrewExecption = true; - ShufflePartitionWriter writer = null; + ShufflePartitionWriter writer = mapWriter.getPartitionWriter(partition); + OutputStream partitionOutput = null; try { - writer = mapWriter.getPartitionWriter(partition); - OutputStream partitionOutput = null; - try { - // Shield the underlying output stream from close() calls, so that we can close the - // higher level streams to make sure all data is really flushed and internal state - // is cleaned - partitionOutput = new CloseShieldOutputStream(writer.toStream()); - partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput); - if (compressionCodec != null) { - partitionOutput = compressionCodec.compressedOutputStream(partitionOutput); - } - for (int i = 0; i < spills.length; i++) { - final long partitionLengthInSpill = spills[i].partitionLengths[partition]; - - if (partitionLengthInSpill > 0) { - InputStream partitionInputStream = null; - try { - partitionInputStream = new LimitedInputStream(spillInputStreams[i], - partitionLengthInSpill, false); - partitionInputStream = blockManager.serializerManager().wrapForEncryption( + partitionOutput = writer.openStream(); + partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput); + if (compressionCodec != null) { + partitionOutput = compressionCodec.compressedOutputStream(partitionOutput); + } + for (int i = 0; i < spills.length; i++) { + final long partitionLengthInSpill = spills[i].partitionLengths[partition]; + + if (partitionLengthInSpill > 0) { + InputStream partitionInputStream = null; + try { + partitionInputStream = new LimitedInputStream(spillInputStreams[i], + partitionLengthInSpill, false); + partitionInputStream = blockManager.serializerManager().wrapForEncryption( + partitionInputStream); + if (compressionCodec != null) { + partitionInputStream = compressionCodec.compressedInputStream( partitionInputStream); - if (compressionCodec != null) { - partitionInputStream = compressionCodec.compressedInputStream( - partitionInputStream); - } - ByteStreams.copy(partitionInputStream, partitionOutput); - } finally { - partitionInputStream.close(); } + ByteStreams.copy(partitionInputStream, partitionOutput); + } finally { + partitionInputStream.close(); } - copyThrewExecption = false; } - } finally { - Closeables.close(partitionOutput, copyThrewExecption); + copyThrewExecption = false; } } finally { - Closeables.close(writer, copyThrewExecption); + Closeables.close(partitionOutput, copyThrewExecption); } long numBytesWritten = writer.getNumBytesWritten(); partitionLengths[partition] = numBytesWritten; @@ -443,26 +431,26 @@ private long[] mergeSpillsWithTransferTo( } for (int partition = 0; partition < numPartitions; partition++) { boolean copyThrewExecption = true; - ShufflePartitionWriter writer = null; + ShufflePartitionWriter writer = mapWriter.getPartitionWriter(partition); + TransferrableWritableByteChannel partitionChannel = null; try { - writer = mapWriter.getPartitionWriter(partition); - WritableByteChannel channel = writer.toChannel(); + partitionChannel = writer instanceof SupportsTransferTo ? + ((SupportsTransferTo) writer).openTransferrableChannel() + : new DefaultTransferrableWritableByteChannel( + Channels.newChannel(writer.openStream())); for (int i = 0; i < spills.length; i++) { long partitionLengthInSpill = 0L; partitionLengthInSpill += spills[i].partitionLengths[partition]; final FileChannel spillInputChannel = spillInputChannels[i]; final long writeStartTime = System.nanoTime(); - Utils.copyFileStreamNIO( - spillInputChannel, - channel, - spillInputChannelPositions[i], - partitionLengthInSpill); - copyThrewExecption = false; + partitionChannel.transferFrom( + spillInputChannel, spillInputChannelPositions[i], partitionLengthInSpill); spillInputChannelPositions[i] += partitionLengthInSpill; writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); } + copyThrewExecption = false; } finally { - Closeables.close(writer, copyThrewExecption); + Closeables.close(partitionChannel, copyThrewExecption); } long numBytes = writer.getNumBytesWritten(); partitionLengths[partition] = numBytes; diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java index 926c3b943399..e83db4e4bcef 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java @@ -24,18 +24,21 @@ import java.io.OutputStream; import java.nio.channels.FileChannel; -import org.apache.spark.api.java.Optional; -import org.apache.spark.api.shuffle.MapShuffleLocations; -import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations; -import org.apache.spark.storage.BlockManagerId; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.SparkConf; +import org.apache.spark.api.java.Optional; +import org.apache.spark.api.shuffle.MapShuffleLocations; import org.apache.spark.api.shuffle.ShuffleMapOutputWriter; import org.apache.spark.api.shuffle.ShufflePartitionWriter; -import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; +import org.apache.spark.api.shuffle.SupportsTransferTo; +import org.apache.spark.api.shuffle.TransferrableWritableByteChannel; import org.apache.spark.internal.config.package$; +import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations; +import org.apache.spark.shuffle.sort.DefaultTransferrableWritableByteChannel; +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; +import org.apache.spark.storage.BlockManagerId; import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.storage.TimeTrackingOutputStream; import org.apache.spark.util.Utils; @@ -151,70 +154,70 @@ private void initChannel() throws IOException { } } - private class DefaultShufflePartitionWriter implements ShufflePartitionWriter { + private class DefaultShufflePartitionWriter implements SupportsTransferTo { private final int partitionId; - private PartitionWriterStream stream = null; + private PartitionWriterStream partStream = null; + private PartitionWriterChannel partChannel = null; private DefaultShufflePartitionWriter(int partitionId) { this.partitionId = partitionId; } @Override - public OutputStream toStream() throws IOException { - if (outputFileChannel != null) { - throw new IllegalStateException("Requested an output channel for a previous write but" + - " now an output stream has been requested. Should not be using both channels" + - " and streams to write."); + public OutputStream openStream() throws IOException { + if (partStream == null) { + if (outputFileChannel != null) { + throw new IllegalStateException("Requested an output channel for a previous write but" + + " now an output stream has been requested. Should not be using both channels" + + " and streams to write."); + } + initStream(); + partStream = new PartitionWriterStream(partitionId); } - initStream(); - stream = new PartitionWriterStream(); - return stream; + return partStream; } @Override - public FileChannel toChannel() throws IOException { - if (stream != null) { - throw new IllegalStateException("Requested an output stream for a previous write but" + - " now an output channel has been requested. Should not be using both channels" + - " and streams to write."); + public TransferrableWritableByteChannel openTransferrableChannel() throws IOException { + if (partChannel == null) { + if (partStream != null) { + throw new IllegalStateException("Requested an output stream for a previous write but" + + " now an output channel has been requested. Should not be using both channels" + + " and streams to write."); + } + initChannel(); + partChannel = new PartitionWriterChannel(partitionId); } - initChannel(); - return outputFileChannel; + return partChannel; } @Override public long getNumBytesWritten() { - if (outputFileChannel != null && stream == null) { + if (partChannel != null) { try { - long newPosition = outputFileChannel.position(); - return newPosition - currChannelPosition; - } catch (Exception e) { - log.error("The partition which failed is: {}", partitionId, e); - throw new IllegalStateException("Failed to calculate position of file channel", e); + return partChannel.getCount(); + } catch (IOException e) { + throw new RuntimeException(e); } - } else if (stream != null) { - return stream.getCount(); + } else if (partStream != null) { + return partStream.getCount(); } else { // Assume an empty partition if stream and channel are never created return 0; } } - - @Override - public void close() { - if (stream != null) { - // Closing is a no-op. - stream.close(); - } - partitionLengths[partitionId] = getNumBytesWritten(); - } } private class PartitionWriterStream extends OutputStream { + private final int partitionId; private int count = 0; private boolean isClosed = false; + PartitionWriterStream(int partitionId) { + this.partitionId = partitionId; + } + public int getCount() { return count; } @@ -236,6 +239,7 @@ public void write(byte[] buf, int pos, int length) throws IOException { @Override public void close() { isClosed = true; + partitionLengths[partitionId] = count; } private void verifyNotClosed() { @@ -244,4 +248,24 @@ private void verifyNotClosed() { } } } + + private class PartitionWriterChannel extends DefaultTransferrableWritableByteChannel { + + private final int partitionId; + + PartitionWriterChannel(int partitionId) { + super(outputFileChannel); + this.partitionId = partitionId; + } + + public long getCount() throws IOException { + long writtenPosition = outputFileChannel.position(); + return writtenPosition - currChannelPosition; + } + + @Override + public void close() throws IOException { + partitionLengths[partitionId] = getCount(); + } + } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 1c4ced695133..947753f6b40e 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -146,7 +146,6 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] => new UnsafeShuffleWriter( env.blockManager, - shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], context.taskMemoryManager(), unsafeShuffleHandle, mapId, @@ -157,7 +156,6 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] => new BypassMergeSortShuffleWriter( env.blockManager, - shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], bypassMergeSortHandle, mapId, env.conf, diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index df5ce73b9acf..14d34e1c47c8 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -755,9 +755,6 @@ private[spark] class ExternalSorter[K, V, C]( if (partitionPairsWriter != null) { partitionPairsWriter.close() } - if (partitionWriter != null) { - partitionWriter.close() - } } if (partitionWriter != null) { lengths(partitionId) = partitionWriter.getNumBytesWritten diff --git a/core/src/main/scala/org/apache/spark/util/collection/ShufflePartitionPairsWriter.scala b/core/src/main/scala/org/apache/spark/util/collection/ShufflePartitionPairsWriter.scala index 6f19a2323efd..8538a78b377c 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ShufflePartitionPairsWriter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ShufflePartitionPairsWriter.scala @@ -54,25 +54,17 @@ private[spark] class ShufflePartitionPairsWriter( } private def open(): Unit = { - // The contract is that the partition writer is expected to close its own streams, but - // the compressor will only flush the stream when it is specifically closed. So we want to - // close objOut to flush the compressed bytes to the partition writer stream, but we don't want - // to close the partition output stream in the process. - partitionStream = new CloseShieldOutputStream(partitionWriter.toStream) + partitionStream = partitionWriter.openStream wrappedStream = serializerManager.wrapStream(blockId, partitionStream) objOut = serializerInstance.serializeStream(wrappedStream) } override def close(): Unit = { if (isOpen) { - // Closing objOut should propagate close to all inner layers - // We can't close wrappedStream explicitly because closing objOut and closing wrappedStream - // causes problems when closing compressed output streams twice. objOut.close() objOut = null wrappedStream = null partitionStream = null - partitionWriter.close() isOpen = false updateBytesWritten() } @@ -96,10 +88,4 @@ private[spark] class ShufflePartitionPairsWriter( writeMetrics.incBytesWritten(bytesWrittenDiff) curNumBytesWritten = numBytesWritten } - - private class CloseShieldOutputStream(delegate: OutputStream) - extends FilterOutputStream(delegate) { - - override def close(): Unit = flush() - } } diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 5f0de31bd25e..5ea0907277eb 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -183,8 +183,7 @@ private UnsafeShuffleWriter createWriter( conf.set("spark.file.transferTo", String.valueOf(transferToEnabled)); return new UnsafeShuffleWriter<>( blockManager, - shuffleBlockResolver, - taskMemoryManager, + taskMemoryManager, new SerializedShuffleHandle<>(0, 1, shuffleDep), 0, // map id taskContext, @@ -544,8 +543,7 @@ public void testPeakMemoryUsed() throws Exception { final UnsafeShuffleWriter writer = new UnsafeShuffleWriter<>( blockManager, - shuffleBlockResolver, - taskMemoryManager, + taskMemoryManager, new SerializedShuffleHandle<>(0, 1, shuffleDep), 0, // map id taskContext, diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala index 0b3394e88d9f..dbd73f2688df 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala @@ -54,7 +54,6 @@ object BypassMergeSortShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase val shuffleWriter = new BypassMergeSortShuffleWriter[String, String]( blockManager, - blockResolver, shuffleHandle, 0, conf, diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index 05d10e9f63d0..8061e32bc6b6 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -158,7 +158,6 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte test("write empty iterator") { val writer = new BypassMergeSortShuffleWriter[Int, Int]( blockManager, - blockResolver, shuffleHandle, 0, // MapId conf, @@ -184,7 +183,6 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) val writer = new BypassMergeSortShuffleWriter[Int, Int]( blockManager, - blockResolver, shuffleHandle, 0, // MapId transferConf, @@ -209,7 +207,6 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) val writer = new BypassMergeSortShuffleWriter[Int, Int]( blockManager, - blockResolver, shuffleHandle, 0, // MapId conf, @@ -245,7 +242,6 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte val writer = new BypassMergeSortShuffleWriter[Int, Int]( blockManager, - blockResolver, shuffleHandle, 0, // MapId conf, @@ -268,7 +264,6 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte test("cleanup of intermediate files after errors") { val writer = new BypassMergeSortShuffleWriter[Int, Int]( blockManager, - blockResolver, shuffleHandle, 0, // MapId conf, diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala index 0e659ff7cc5f..7066ba8fb44d 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala @@ -50,15 +50,13 @@ object UnsafeShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase { TaskContext.setTaskContext(taskContext) new UnsafeShuffleWriter[String, String]( blockManager, - blockResolver, taskMemoryManager, shuffleHandle, 0, taskContext, conf, taskContext.taskMetrics().shuffleWriteMetrics, - shuffleWriteSupport - ) + shuffleWriteSupport) } def writeBenchmarkWithSmallDataset(): Unit = { diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala index 420b0d4d2f67..1f4ef0f20399 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.shuffle.sort.io import java.io.{ByteArrayInputStream, File, FileInputStream, FileOutputStream} import java.math.BigInteger import java.nio.ByteBuffer +import java.nio.channels.{Channels, WritableByteChannel} import org.mockito.Answers.RETURNS_SMART_NULLS import org.mockito.ArgumentMatchers.{any, anyInt, anyLong} @@ -31,10 +32,12 @@ import org.mockito.stubbing.Answer import org.scalatest.BeforeAndAfterEach import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.api.shuffle.SupportsTransferTo import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.network.util.LimitedInputStream import org.apache.spark.shuffle.IndexShuffleBlockResolver import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.ByteBufferInputStream import org.apache.spark.util.Utils class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAfterEach { @@ -141,14 +144,13 @@ class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAft test("writing to an outputstream") { (0 until NUM_PARTITIONS).foreach{ p => val writer = mapOutputWriter.getPartitionWriter(p) - val stream = writer.toStream() + val stream = writer.openStream() data(p).foreach { i => stream.write(i)} stream.close() intercept[IllegalStateException] { stream.write(p) } assert(writer.getNumBytesWritten() == D_LEN) - writer.close } mapOutputWriter.commitAllPartitions() val partitionLengths = (0 until NUM_PARTITIONS).map { _ => D_LEN.toDouble}.toArray @@ -160,15 +162,23 @@ class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAft test("writing to a channel") { (0 until NUM_PARTITIONS).foreach{ p => val writer = mapOutputWriter.getPartitionWriter(p) - val channel = writer.toChannel() + val channel = writer.asInstanceOf[SupportsTransferTo].openTransferrableChannel() val byteBuffer = ByteBuffer.allocate(D_LEN * 4) val intBuffer = byteBuffer.asIntBuffer() intBuffer.put(data(p)) - assert(channel.isOpen) - channel.write(byteBuffer) + val numBytes = byteBuffer.remaining() + val outputTempFile = File.createTempFile("channelTemp", "", tempDir) + val outputTempFileStream = new FileOutputStream(outputTempFile) + Utils.copyStream( + new ByteBufferInputStream(byteBuffer), + outputTempFileStream, + closeStreams = true) + val tempFileInput = new FileInputStream(outputTempFile) + channel.transferFrom(tempFileInput.getChannel, 0L, numBytes) // Bytes require * 4 + channel.close() + tempFileInput.close() assert(writer.getNumBytesWritten == D_LEN * 4) - writer.close } mapOutputWriter.commitAllPartitions() val partitionLengths = (0 until NUM_PARTITIONS).map { _ => (D_LEN * 4).toDouble}.toArray @@ -180,7 +190,7 @@ class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAft test("copyStreams with an outputstream") { (0 until NUM_PARTITIONS).foreach{ p => val writer = mapOutputWriter.getPartitionWriter(p) - val stream = writer.toStream() + val stream = writer.openStream() val byteBuffer = ByteBuffer.allocate(D_LEN * 4) val intBuffer = byteBuffer.asIntBuffer() intBuffer.put(data(p)) @@ -189,7 +199,6 @@ class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAft in.close() stream.close() assert(writer.getNumBytesWritten == D_LEN * 4) - writer.close } mapOutputWriter.commitAllPartitions() val partitionLengths = (0 until NUM_PARTITIONS).map { _ => (D_LEN * 4).toDouble}.toArray @@ -201,7 +210,7 @@ class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAft test("copyStreamsWithNIO with a channel") { (0 until NUM_PARTITIONS).foreach{ p => val writer = mapOutputWriter.getPartitionWriter(p) - val channel = writer.toChannel() + val channel = writer.asInstanceOf[SupportsTransferTo].openTransferrableChannel() val byteBuffer = ByteBuffer.allocate(D_LEN * 4) val intBuffer = byteBuffer.asIntBuffer() intBuffer.put(data(p)) @@ -209,10 +218,9 @@ class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAft out.write(byteBuffer.array()) out.close() val in = new FileInputStream(tempFile) - Utils.copyFileStreamNIO(in.getChannel, channel, 0, D_LEN * 4) - in.close() + channel.transferFrom(in.getChannel, 0L, byteBuffer.remaining()) + channel.close() assert(writer.getNumBytesWritten == D_LEN * 4) - writer.close } mapOutputWriter.commitAllPartitions() val partitionLengths = (0 until NUM_PARTITIONS).map { _ => (D_LEN * 4).toDouble}.toArray