Skip to content

Commit

Permalink
[SPARK-26713][CORE][2.4] Interrupt pipe IO threads in PipedRDD when t…
Browse files Browse the repository at this point in the history
…ask is finished

### What changes were proposed in this pull request?
Manually release stdin writer and stderr reader thread when task is finished. This is the backport of apache#23638 including apache#25049.

### Why are the changes needed?
This is a bug fix. PipedRDD's IO threads may hang even the corresponding task is already finished. Without this fix,  it would leak resource(memory specially).

### Does this PR introduce any user-facing change?
No.

### How was this patch tested?
Add new test

Closes apache#25825 from advancedxy/SPARK-26713_for_2.4.

Authored-by: Xianjin YE <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
  • Loading branch information
advancedxy authored and rshkv committed Mar 15, 2021
1 parent 92f7bb7 commit fcf43a2
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 4 deletions.
33 changes: 29 additions & 4 deletions core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ private[spark] class PipedRDD[T: ClassTag](
val childThreadException = new AtomicReference[Throwable](null)

// Start a thread to print the process's stderr to ours
new Thread(s"stderr reader for $command") {
val stderrReaderThread = new Thread(s"${PipedRDD.STDERR_READER_THREAD_PREFIX} $command") {
override def run(): Unit = {
val err = proc.getErrorStream
try {
Expand All @@ -128,10 +128,11 @@ private[spark] class PipedRDD[T: ClassTag](
err.close()
}
}
}.start()
}
stderrReaderThread.start()

// Start a thread to feed the process input from our parent's iterator
new Thread(s"stdin writer for $command") {
val stdinWriterThread = new Thread(s"${PipedRDD.STDIN_WRITER_THREAD_PREFIX} $command") {
override def run(): Unit = {
TaskContext.setTaskContext(context)
val out = new PrintWriter(new BufferedWriter(
Expand All @@ -156,7 +157,28 @@ private[spark] class PipedRDD[T: ClassTag](
out.close()
}
}
}.start()
}
stdinWriterThread.start()

// interrupts stdin writer and stderr reader threads when the corresponding task is finished.
// Otherwise, these threads could outlive the task's lifetime. For example:
// val pipeRDD = sc.range(1, 100).pipe(Seq("cat"))
// val abnormalRDD = pipeRDD.mapPartitions(_ => Iterator.empty)
// the iterator generated by PipedRDD is never involved. If the parent RDD's iterator takes a
// long time to generate(ShuffledRDD's shuffle operation for example), the stdin writer thread
// may consume significant memory and CPU time even if task is already finished.
context.addTaskCompletionListener[Unit] { _ =>
if (proc.isAlive) {
proc.destroy()
}

if (stdinWriterThread.isAlive) {
stdinWriterThread.interrupt()
}
if (stderrReaderThread.isAlive) {
stderrReaderThread.interrupt()
}
}

// Return an iterator that read lines from the process's stdout
val lines = Source.fromInputStream(proc.getInputStream)(encoding).getLines
Expand Down Expand Up @@ -219,4 +241,7 @@ private object PipedRDD {
}
buf
}

val STDIN_WRITER_THREAD_PREFIX = "stdin writer for"
val STDERR_READER_THREAD_PREFIX = "stderr reader for"
}
24 changes: 24 additions & 0 deletions core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.rdd

import java.io.File

import scala.collection.JavaConverters._
import scala.collection.Map
import scala.io.Codec

Expand Down Expand Up @@ -83,6 +84,29 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext {
}
}

test("stdin writer thread should be exited when task is finished") {
assume(TestUtils.testCommandAvailable("cat"))
val nums = sc.makeRDD(Array(1, 2, 3, 4), 1).map { x =>
val obj = new Object()
obj.synchronized {
obj.wait() // make the thread waits here.
}
x
}

val piped = nums.pipe(Seq("cat"))

val result = piped.mapPartitions(_ => Array.emptyIntArray.iterator)

assert(result.collect().length === 0)

// collect stderr writer threads
val stderrWriterThread = Thread.getAllStackTraces.keySet().asScala
.find { _.getName.startsWith(PipedRDD.STDIN_WRITER_THREAD_PREFIX) }

assert(stderrWriterThread.isEmpty)
}

test("advanced pipe") {
assume(TestUtils.testCommandAvailable("cat"))
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
Expand Down

0 comments on commit fcf43a2

Please sign in to comment.