diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala index f1daf62ad4d1f..02b28b72fb0e7 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala @@ -113,7 +113,7 @@ private[spark] class PipedRDD[T: ClassTag]( val childThreadException = new AtomicReference[Throwable](null) // Start a thread to print the process's stderr to ours - val stderrReaderThread = new Thread(s"${PipedRDD.STDERR_READER_THREAD_PREFIX} $command") { + new Thread(s"stderr reader for $command") { override def run(): Unit = { val err = proc.getErrorStream try { @@ -128,11 +128,10 @@ private[spark] class PipedRDD[T: ClassTag]( err.close() } } - } - stderrReaderThread.start() + }.start() // Start a thread to feed the process input from our parent's iterator - val stdinWriterThread = new Thread(s"${PipedRDD.STDIN_WRITER_THREAD_PREFIX} $command") { + new Thread(s"stdin writer for $command") { override def run(): Unit = { TaskContext.setTaskContext(context) val out = new PrintWriter(new BufferedWriter( @@ -157,28 +156,7 @@ private[spark] class PipedRDD[T: ClassTag]( out.close() } } - } - stdinWriterThread.start() - - // interrupts stdin writer and stderr reader threads when the corresponding task is finished. - // Otherwise, these threads could outlive the task's lifetime. For example: - // val pipeRDD = sc.range(1, 100).pipe(Seq("cat")) - // val abnormalRDD = pipeRDD.mapPartitions(_ => Iterator.empty) - // the iterator generated by PipedRDD is never involved. If the parent RDD's iterator takes a - // long time to generate(ShuffledRDD's shuffle operation for example), the stdin writer thread - // may consume significant memory and CPU time even if task is already finished. - context.addTaskCompletionListener[Unit] { _ => - if (proc.isAlive) { - proc.destroy() - } - - if (stdinWriterThread.isAlive) { - stdinWriterThread.interrupt() - } - if (stderrReaderThread.isAlive) { - stderrReaderThread.interrupt() - } - } + }.start() // Return an iterator that read lines from the process's stdout val lines = Source.fromInputStream(proc.getInputStream)(encoding).getLines @@ -241,7 +219,4 @@ private object PipedRDD { } buf } - - val STDIN_WRITER_THREAD_PREFIX = "stdin writer for" - val STDERR_READER_THREAD_PREFIX = "stderr reader for" } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 3966980a11ed0..1e5f3f7719977 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -141,14 +141,7 @@ final class ShuffleBlockFetcherIterator( /** * Whether the iterator is still active. If isZombie is true, the callback interface will no - * longer place fetched blocks into [[results]] and the iterator is marked as fully consumed. - * - * When the iterator is inactive, [[hasNext]] and [[next]] calls will honor that as there are - * cases the iterator is still being consumed. For example, ShuffledRDD + PipedRDD if the - * subprocess command is failed. The task will be marked as failed, then the iterator will be - * cleaned up at task completion, the [[next]] call (called in the stdin writer thread of - * PipedRDD if not exited yet) may hang at [[results.take]]. The defensive check in [[hasNext]] - * and [[next]] reduces the possibility of such race conditions. + * longer place fetched blocks into [[results]]. */ @GuardedBy("this") private[this] var isZombie = false @@ -387,7 +380,7 @@ final class ShuffleBlockFetcherIterator( logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime)) } - override def hasNext: Boolean = !isZombie && (numBlocksProcessed < numBlocksToFetch) + override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch /** * Fetches the next (BlockId, InputStream). If a task fails, the ManagedBuffers @@ -399,7 +392,7 @@ final class ShuffleBlockFetcherIterator( */ override def next(): (BlockId, InputStream) = { if (!hasNext) { - throw new NoSuchElementException() + throw new NoSuchElementException } numBlocksProcessed += 1 @@ -410,7 +403,7 @@ final class ShuffleBlockFetcherIterator( // then fetch it one more time if it's corrupt, throw FailureFetchResult if the second fetch // is also corrupt, so the previous stage could be retried. // For local shuffle block, throw FailureFetchResult for the first IOException. - while (!isZombie && result == null) { + while (result == null) { val startFetchWait = System.currentTimeMillis() result = results.take() val stopFetchWait = System.currentTimeMillis() @@ -504,9 +497,6 @@ final class ShuffleBlockFetcherIterator( fetchUpToMaxBytes() } - if (result == null) { // the iterator is already closed/cleaned up. - throw new NoSuchElementException() - } currentResult = result.asInstanceOf[SuccessFetchResult] (currentResult.blockId, new BufferReleasingInputStream(input, this)) } diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala index 69739a2e58481..1a0eb250e7cdc 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.rdd import java.io.File -import scala.collection.JavaConverters._ import scala.collection.Map import scala.io.Codec @@ -84,29 +83,6 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext { } } - test("stdin writer thread should be exited when task is finished") { - assume(TestUtils.testCommandAvailable("cat")) - val nums = sc.makeRDD(Array(1, 2, 3, 4), 1).map { x => - val obj = new Object() - obj.synchronized { - obj.wait() // make the thread waits here. - } - x - } - - val piped = nums.pipe(Seq("cat")) - - val result = piped.mapPartitions(_ => Array.emptyIntArray.iterator) - - assert(result.collect().length === 0) - - // collect stderr writer threads - val stderrWriterThread = Thread.getAllStackTraces.keySet().asScala - .find { _.getName.startsWith(PipedRDD.STDIN_WRITER_THREAD_PREFIX) } - - assert(stderrWriterThread.isEmpty) - } - test("advanced pipe") { assume(TestUtils.testCommandAvailable("cat")) val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 98fe9663b6211..6b83243fe496c 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -217,65 +217,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT verify(blocks(ShuffleBlockId(0, 2, 0)), times(0)).release() } - test("iterator is all consumed if task completes early") { - val blockManager = mock(classOf[BlockManager]) - val localBmId = BlockManagerId("test-client", "test-client", 1) - doReturn(localBmId).when(blockManager).blockManagerId - - // Make sure remote blocks would return - val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) - val blocks = Map[BlockId, ManagedBuffer]( - ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), - ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(), - ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()) - - // Semaphore to coordinate event sequence in two different threads. - val sem = new Semaphore(0) - - val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) - .thenAnswer(new Answer[Unit] { - override def answer(invocation: InvocationOnMock): Unit = { - val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] - Future { - // Return the first two blocks, and wait till task completion before returning the last - listener.onBlockFetchSuccess( - ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) - listener.onBlockFetchSuccess( - ShuffleBlockId(0, 1, 0).toString, blocks(ShuffleBlockId(0, 1, 0))) - sem.acquire() - listener.onBlockFetchSuccess( - ShuffleBlockId(0, 2, 0).toString, blocks(ShuffleBlockId(0, 2, 0))) - } - } - }) - - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator - - val taskContext = TaskContext.empty() - val iterator = new ShuffleBlockFetcherIterator( - taskContext, - transfer, - blockManager, - blocksByAddress, - (_, in) => in, - 48 * 1024 * 1024, - Int.MaxValue, - Int.MaxValue, - Int.MaxValue, - true, - taskContext.taskMetrics.createTempShuffleReadMetrics()) - - - assert(iterator.hasNext) - iterator.next() - - taskContext.markTaskCompleted(None) - sem.release() - assert(iterator.hasNext === false) - } - test("fail all blocks if any of the remote request fails") { val blockManager = mock(classOf[BlockManager]) val localBmId = BlockManagerId("test-client", "test-client", 1)