Skip to content

Commit

Permalink
Revert "[SPARK-26713][CORE] Interrupt pipe IO threads in PipedRDD whe…
Browse files Browse the repository at this point in the history
…n task is finished"

This reverts commit 1280bfd.
  • Loading branch information
rshkv committed Mar 15, 2021
1 parent 352bdb7 commit 92f7bb7
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 126 deletions.
33 changes: 4 additions & 29 deletions core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ private[spark] class PipedRDD[T: ClassTag](
val childThreadException = new AtomicReference[Throwable](null)

// Start a thread to print the process's stderr to ours
val stderrReaderThread = new Thread(s"${PipedRDD.STDERR_READER_THREAD_PREFIX} $command") {
new Thread(s"stderr reader for $command") {
override def run(): Unit = {
val err = proc.getErrorStream
try {
Expand All @@ -128,11 +128,10 @@ private[spark] class PipedRDD[T: ClassTag](
err.close()
}
}
}
stderrReaderThread.start()
}.start()

// Start a thread to feed the process input from our parent's iterator
val stdinWriterThread = new Thread(s"${PipedRDD.STDIN_WRITER_THREAD_PREFIX} $command") {
new Thread(s"stdin writer for $command") {
override def run(): Unit = {
TaskContext.setTaskContext(context)
val out = new PrintWriter(new BufferedWriter(
Expand All @@ -157,28 +156,7 @@ private[spark] class PipedRDD[T: ClassTag](
out.close()
}
}
}
stdinWriterThread.start()

// interrupts stdin writer and stderr reader threads when the corresponding task is finished.
// Otherwise, these threads could outlive the task's lifetime. For example:
// val pipeRDD = sc.range(1, 100).pipe(Seq("cat"))
// val abnormalRDD = pipeRDD.mapPartitions(_ => Iterator.empty)
// the iterator generated by PipedRDD is never involved. If the parent RDD's iterator takes a
// long time to generate(ShuffledRDD's shuffle operation for example), the stdin writer thread
// may consume significant memory and CPU time even if task is already finished.
context.addTaskCompletionListener[Unit] { _ =>
if (proc.isAlive) {
proc.destroy()
}

if (stdinWriterThread.isAlive) {
stdinWriterThread.interrupt()
}
if (stderrReaderThread.isAlive) {
stderrReaderThread.interrupt()
}
}
}.start()

// Return an iterator that read lines from the process's stdout
val lines = Source.fromInputStream(proc.getInputStream)(encoding).getLines
Expand Down Expand Up @@ -241,7 +219,4 @@ private object PipedRDD {
}
buf
}

val STDIN_WRITER_THREAD_PREFIX = "stdin writer for"
val STDERR_READER_THREAD_PREFIX = "stderr reader for"
}
Original file line number Diff line number Diff line change
Expand Up @@ -141,14 +141,7 @@ final class ShuffleBlockFetcherIterator(

/**
* Whether the iterator is still active. If isZombie is true, the callback interface will no
* longer place fetched blocks into [[results]] and the iterator is marked as fully consumed.
*
* When the iterator is inactive, [[hasNext]] and [[next]] calls will honor that as there are
* cases the iterator is still being consumed. For example, ShuffledRDD + PipedRDD if the
* subprocess command is failed. The task will be marked as failed, then the iterator will be
* cleaned up at task completion, the [[next]] call (called in the stdin writer thread of
* PipedRDD if not exited yet) may hang at [[results.take]]. The defensive check in [[hasNext]]
* and [[next]] reduces the possibility of such race conditions.
* longer place fetched blocks into [[results]].
*/
@GuardedBy("this")
private[this] var isZombie = false
Expand Down Expand Up @@ -387,7 +380,7 @@ final class ShuffleBlockFetcherIterator(
logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))
}

override def hasNext: Boolean = !isZombie && (numBlocksProcessed < numBlocksToFetch)
override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch

/**
* Fetches the next (BlockId, InputStream). If a task fails, the ManagedBuffers
Expand All @@ -399,7 +392,7 @@ final class ShuffleBlockFetcherIterator(
*/
override def next(): (BlockId, InputStream) = {
if (!hasNext) {
throw new NoSuchElementException()
throw new NoSuchElementException
}

numBlocksProcessed += 1
Expand All @@ -410,7 +403,7 @@ final class ShuffleBlockFetcherIterator(
// then fetch it one more time if it's corrupt, throw FailureFetchResult if the second fetch
// is also corrupt, so the previous stage could be retried.
// For local shuffle block, throw FailureFetchResult for the first IOException.
while (!isZombie && result == null) {
while (result == null) {
val startFetchWait = System.currentTimeMillis()
result = results.take()
val stopFetchWait = System.currentTimeMillis()
Expand Down Expand Up @@ -504,9 +497,6 @@ final class ShuffleBlockFetcherIterator(
fetchUpToMaxBytes()
}

if (result == null) { // the iterator is already closed/cleaned up.
throw new NoSuchElementException()
}
currentResult = result.asInstanceOf[SuccessFetchResult]
(currentResult.blockId, new BufferReleasingInputStream(input, this))
}
Expand Down
24 changes: 0 additions & 24 deletions core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.rdd

import java.io.File

import scala.collection.JavaConverters._
import scala.collection.Map
import scala.io.Codec

Expand Down Expand Up @@ -84,29 +83,6 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext {
}
}

test("stdin writer thread should be exited when task is finished") {
assume(TestUtils.testCommandAvailable("cat"))
val nums = sc.makeRDD(Array(1, 2, 3, 4), 1).map { x =>
val obj = new Object()
obj.synchronized {
obj.wait() // make the thread waits here.
}
x
}

val piped = nums.pipe(Seq("cat"))

val result = piped.mapPartitions(_ => Array.emptyIntArray.iterator)

assert(result.collect().length === 0)

// collect stderr writer threads
val stderrWriterThread = Thread.getAllStackTraces.keySet().asScala
.find { _.getName.startsWith(PipedRDD.STDIN_WRITER_THREAD_PREFIX) }

assert(stderrWriterThread.isEmpty)
}

test("advanced pipe") {
assume(TestUtils.testCommandAvailable("cat"))
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,65 +217,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
verify(blocks(ShuffleBlockId(0, 2, 0)), times(0)).release()
}

test("iterator is all consumed if task completes early") {
val blockManager = mock(classOf[BlockManager])
val localBmId = BlockManagerId("test-client", "test-client", 1)
doReturn(localBmId).when(blockManager).blockManagerId

// Make sure remote blocks would return
val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
val blocks = Map[BlockId, ManagedBuffer](
ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(),
ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(),
ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer())

// Semaphore to coordinate event sequence in two different threads.
val sem = new Semaphore(0)

val transfer = mock(classOf[BlockTransferService])
when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
.thenAnswer(new Answer[Unit] {
override def answer(invocation: InvocationOnMock): Unit = {
val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
Future {
// Return the first two blocks, and wait till task completion before returning the last
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 1, 0).toString, blocks(ShuffleBlockId(0, 1, 0)))
sem.acquire()
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 2, 0).toString, blocks(ShuffleBlockId(0, 2, 0)))
}
}
})

val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
(remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator

val taskContext = TaskContext.empty()
val iterator = new ShuffleBlockFetcherIterator(
taskContext,
transfer,
blockManager,
blocksByAddress,
(_, in) => in,
48 * 1024 * 1024,
Int.MaxValue,
Int.MaxValue,
Int.MaxValue,
true,
taskContext.taskMetrics.createTempShuffleReadMetrics())


assert(iterator.hasNext)
iterator.next()

taskContext.markTaskCompleted(None)
sem.release()
assert(iterator.hasNext === false)
}

test("fail all blocks if any of the remote request fails") {
val blockManager = mock(classOf[BlockManager])
val localBmId = BlockManagerId("test-client", "test-client", 1)
Expand Down

0 comments on commit 92f7bb7

Please sign in to comment.