diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala index 02b28b72fb0e7..f1daf62ad4d1f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala @@ -113,7 +113,7 @@ private[spark] class PipedRDD[T: ClassTag]( val childThreadException = new AtomicReference[Throwable](null) // Start a thread to print the process's stderr to ours - new Thread(s"stderr reader for $command") { + val stderrReaderThread = new Thread(s"${PipedRDD.STDERR_READER_THREAD_PREFIX} $command") { override def run(): Unit = { val err = proc.getErrorStream try { @@ -128,10 +128,11 @@ private[spark] class PipedRDD[T: ClassTag]( err.close() } } - }.start() + } + stderrReaderThread.start() // Start a thread to feed the process input from our parent's iterator - new Thread(s"stdin writer for $command") { + val stdinWriterThread = new Thread(s"${PipedRDD.STDIN_WRITER_THREAD_PREFIX} $command") { override def run(): Unit = { TaskContext.setTaskContext(context) val out = new PrintWriter(new BufferedWriter( @@ -156,7 +157,28 @@ private[spark] class PipedRDD[T: ClassTag]( out.close() } } - }.start() + } + stdinWriterThread.start() + + // interrupts stdin writer and stderr reader threads when the corresponding task is finished. + // Otherwise, these threads could outlive the task's lifetime. For example: + // val pipeRDD = sc.range(1, 100).pipe(Seq("cat")) + // val abnormalRDD = pipeRDD.mapPartitions(_ => Iterator.empty) + // the iterator generated by PipedRDD is never involved. If the parent RDD's iterator takes a + // long time to generate(ShuffledRDD's shuffle operation for example), the stdin writer thread + // may consume significant memory and CPU time even if task is already finished. + context.addTaskCompletionListener[Unit] { _ => + if (proc.isAlive) { + proc.destroy() + } + + if (stdinWriterThread.isAlive) { + stdinWriterThread.interrupt() + } + if (stderrReaderThread.isAlive) { + stderrReaderThread.interrupt() + } + } // Return an iterator that read lines from the process's stdout val lines = Source.fromInputStream(proc.getInputStream)(encoding).getLines @@ -219,4 +241,7 @@ private object PipedRDD { } buf } + + val STDIN_WRITER_THREAD_PREFIX = "stdin writer for" + val STDERR_READER_THREAD_PREFIX = "stderr reader for" } diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala index 1a0eb250e7cdc..69739a2e58481 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.rdd import java.io.File +import scala.collection.JavaConverters._ import scala.collection.Map import scala.io.Codec @@ -83,6 +84,29 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext { } } + test("stdin writer thread should be exited when task is finished") { + assume(TestUtils.testCommandAvailable("cat")) + val nums = sc.makeRDD(Array(1, 2, 3, 4), 1).map { x => + val obj = new Object() + obj.synchronized { + obj.wait() // make the thread waits here. + } + x + } + + val piped = nums.pipe(Seq("cat")) + + val result = piped.mapPartitions(_ => Array.emptyIntArray.iterator) + + assert(result.collect().length === 0) + + // collect stderr writer threads + val stderrWriterThread = Thread.getAllStackTraces.keySet().asScala + .find { _.getName.startsWith(PipedRDD.STDIN_WRITER_THREAD_PREFIX) } + + assert(stderrWriterThread.isEmpty) + } + test("advanced pipe") { assume(TestUtils.testCommandAvailable("cat")) val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)