Skip to content

Commit

Permalink
[SPARK-51099][PYTHON] Add logs when the Python worker looks stuck
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Adds logs when the Python worker looks stuck.

- Spark conf: `spark.python.worker.idleTimeoutSeconds` (default `0` that means no timeout)
    The time (in seconds) Spark will wait for activity (e.g., data transfer or communication) from a Python worker before considering it potentially idle or unresponsive. When the timeout is triggered, Spark will log the network-related status for debugging purposes.  However, the Python worker will remain active and continue waiting for communication. The default is `0` that means no timeout.
- SQL conf: `spark.sql.execution.pyspark.udf.idleTimeoutSeconds`
    The same as `spark.python.worker.idleTimeoutSeconds`, but this is a runtime conf for Python UDFs. Falls back to `spark.python.worker.idleTimeoutSeconds`.

For example:

```py
import time
from pyspark.sql import functions as sf

spark.conf.set('spark.sql.execution.pyspark.udf.idleTimeoutSeconds', '1s')

sf.udf
def f(x):
    time.sleep(2)
    return str(x)

spark.range(1).select(f("id")).show()
```

will show a warning message:

```
... WARN PythonUDFWithNamedArgumentsRunner: Idle timeout reached for Python worker (timeout: 1 seconds). No data received from the worker process: handle.map(_.isAlive) = Some(true), channel.isConnected = true, channel.isBlocking = false, selector.isOpen = true, selectionKey.isValid = true, selectionKey.interestOps = 1, hasInputs = false
```

### Why are the changes needed?

For the monitoring of the Python worker.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Manually checked the logs.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #49818 from ueshin/issues/SPARK-51099/pythonrunner_logging.

Authored-by: Takuya Ueshin <[email protected]>
Signed-off-by: Takuya Ueshin <[email protected]>
  • Loading branch information
ueshin committed Feb 7, 2025
1 parent 5a925c6 commit 8d18df3
Show file tree
Hide file tree
Showing 13 changed files with 163 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -635,8 +635,16 @@ private[spark] object LogKeys {
case object PYTHON_EXEC extends LogKey
case object PYTHON_PACKAGES extends LogKey
case object PYTHON_VERSION extends LogKey
case object PYTHON_WORKER_CHANNEL_IS_BLOCKING_MODE extends LogKey
case object PYTHON_WORKER_CHANNEL_IS_CONNECTED extends LogKey
case object PYTHON_WORKER_HAS_INPUTS extends LogKey
case object PYTHON_WORKER_IDLE_TIMEOUT extends LogKey
case object PYTHON_WORKER_IS_ALIVE extends LogKey
case object PYTHON_WORKER_MODULE extends LogKey
case object PYTHON_WORKER_RESPONSE extends LogKey
case object PYTHON_WORKER_SELECTION_KEY_INTERESTS extends LogKey
case object PYTHON_WORKER_SELECTION_KEY_IS_VALID extends LogKey
case object PYTHON_WORKER_SELECTOR_IS_OPEN extends LogKey
case object QUANTILES extends LogKey
case object QUERY_CACHE_VALUE extends LogKey
case object QUERY_HINT extends LogKey
Expand Down
6 changes: 3 additions & 3 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class SparkEnv (
workerModule: String,
daemonModule: String,
envVars: Map[String, String],
useDaemon: Boolean): (PythonWorker, Option[Int]) = {
useDaemon: Boolean): (PythonWorker, Option[ProcessHandle]) = {
synchronized {
val key = PythonWorkersKey(pythonExec, workerModule, daemonModule, envVars)
val workerFactory = pythonWorkers.getOrElseUpdate(key, new PythonWorkerFactory(
Expand All @@ -163,7 +163,7 @@ class SparkEnv (
pythonExec: String,
workerModule: String,
envVars: Map[String, String],
useDaemon: Boolean): (PythonWorker, Option[Int]) = {
useDaemon: Boolean): (PythonWorker, Option[ProcessHandle]) = {
createPythonWorker(
pythonExec, workerModule, PythonWorkerFactory.defaultDaemonModule, envVars, useDaemon)
}
Expand All @@ -172,7 +172,7 @@ class SparkEnv (
pythonExec: String,
workerModule: String,
daemonModule: String,
envVars: Map[String, String]): (PythonWorker, Option[Int]) = {
envVars: Map[String, String]): (PythonWorker, Option[ProcessHandle]) = {
val useDaemon = conf.get(Python.PYTHON_USE_DAEMON)
createPythonWorker(
pythonExec, workerModule, daemonModule, envVars, useDaemon)
Expand Down
69 changes: 57 additions & 12 deletions core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,16 @@ import java.nio.ByteBuffer
import java.nio.channels.SelectionKey
import java.nio.charset.StandardCharsets.UTF_8
import java.nio.file.{Files => JavaFiles, Path}
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
import java.util.concurrent.atomic.AtomicBoolean

import scala.jdk.CollectionConverters._
import scala.util.{Success, Try}
import scala.util.control.NonFatal

import org.apache.spark._
import org.apache.spark.api.python.PythonFunction.PythonAccumulator
import org.apache.spark.internal.{Logging, LogKeys, MDC}
import org.apache.spark.internal.{Logging, LogKeys, MDC, MessageWithContext}
import org.apache.spark.internal.LogKeys.TASK_NAME
import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES}
import org.apache.spark.internal.config.Python._
Expand Down Expand Up @@ -90,13 +91,44 @@ private[spark] object PythonEvalType {
}
}

private[spark] object BasePythonRunner {
private[spark] object BasePythonRunner extends Logging {

private[spark] lazy val faultHandlerLogDir = Utils.createTempDir(namePrefix = "faulthandler")

private[spark] def faultHandlerLogPath(pid: Int): Path = {
new File(faultHandlerLogDir, pid.toString).toPath
}

private[spark] def pythonWorkerStatusMessageWithContext(
handle: Option[ProcessHandle],
worker: PythonWorker,
hasInputs: Boolean): MessageWithContext = {
log"handle.map(_.isAlive) = " +
log"${MDC(LogKeys.PYTHON_WORKER_IS_ALIVE, handle.map(_.isAlive))}, " +
log"channel.isConnected = " +
log"${MDC(LogKeys.PYTHON_WORKER_CHANNEL_IS_CONNECTED, worker.channel.isConnected)}, " +
log"channel.isBlocking = " +
log"${
MDC(LogKeys.PYTHON_WORKER_CHANNEL_IS_BLOCKING_MODE,
worker.channel.isBlocking)
}, " +
(if (!worker.channel.isBlocking) {
log"selector.isOpen = " +
log"${MDC(LogKeys.PYTHON_WORKER_SELECTOR_IS_OPEN, worker.selector.isOpen)}, " +
log"selectionKey.isValid = " +
log"${
MDC(LogKeys.PYTHON_WORKER_SELECTION_KEY_IS_VALID,
worker.selectionKey.isValid)
}, " +
(Try(worker.selectionKey.interestOps()) match {
case Success(ops) =>
log"selectionKey.interestOps = " +
log"${MDC(LogKeys.PYTHON_WORKER_SELECTION_KEY_INTERESTS, ops)}, "
case _ => log""
})
} else log"") +
log"hasInputs = ${MDC(LogKeys.PYTHON_WORKER_HAS_INPUTS, hasInputs)}"
}
}

/**
Expand All @@ -112,6 +144,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
protected val jobArtifactUUID: Option[String],
protected val metrics: Map[String, AccumulatorV2[Long, Long]])
extends Logging {
import BasePythonRunner._

require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs")

Expand All @@ -122,6 +155,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
private val reuseWorker = conf.get(PYTHON_WORKER_REUSE)
protected val faultHandlerEnabled: Boolean = conf.get(PYTHON_WORKER_FAULTHANLDER_ENABLED)
protected val idleTimeoutSeconds: Long = conf.get(PYTHON_WORKER_IDLE_TIMEOUT_SECONDS)
protected val hideTraceback: Boolean = false
protected val simplifiedTraceback: Boolean = false

Expand Down Expand Up @@ -216,14 +250,14 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString)
envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString)
if (faultHandlerEnabled) {
envVars.put("PYTHON_FAULTHANDLER_DIR", BasePythonRunner.faultHandlerLogDir.toString)
envVars.put("PYTHON_FAULTHANDLER_DIR", faultHandlerLogDir.toString)
}
// allow the user to set the batch size for the BatchedSerializer on UDFs
envVars.put("PYTHON_UDF_BATCH_SIZE", batchSizeForPythonUDF.toString)

envVars.put("SPARK_JOB_ARTIFACT_UUID", jobArtifactUUID.getOrElse("default"))

val (worker: PythonWorker, pid: Option[Int]) = env.createPythonWorker(
val (worker: PythonWorker, handle: Option[ProcessHandle]) = env.createPythonWorker(
pythonExec, workerModule, daemonModule, envVars.asScala.toMap)
// Whether is the worker released into idle pool or closed. When any codes try to release or
// close a worker, they should use `releasedOrClosed.compareAndSet` to flip the state to make
Expand Down Expand Up @@ -256,10 +290,11 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
}

// Return an iterator that read lines from the process's stdout
val dataIn = new DataInputStream(
new BufferedInputStream(new ReaderInputStream(worker, writer), bufferSize))
val dataIn = new DataInputStream(new BufferedInputStream(
new ReaderInputStream(worker, writer, handle, idleTimeoutSeconds),
bufferSize))
val stdoutIterator = newReaderIterator(
dataIn, writer, startTime, env, worker, pid, releasedOrClosed, context)
dataIn, writer, startTime, env, worker, handle.map(_.pid.toInt), releasedOrClosed, context)
new InterruptibleIterator(context, stdoutIterator)
}

Expand Down Expand Up @@ -572,8 +607,8 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
throw writer.exception.get

case e: IOException if faultHandlerEnabled && pid.isDefined &&
JavaFiles.exists(BasePythonRunner.faultHandlerLogPath(pid.get)) =>
val path = BasePythonRunner.faultHandlerLogPath(pid.get)
JavaFiles.exists(faultHandlerLogPath(pid.get)) =>
val path = faultHandlerLogPath(pid.get)
val error = String.join("\n", JavaFiles.readAllLines(path)) + "\n"
JavaFiles.deleteIfExists(path)
throw new SparkException(s"Python worker exited unexpectedly (crashed): $error", e)
Expand Down Expand Up @@ -640,7 +675,11 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
}
}

class ReaderInputStream(worker: PythonWorker, writer: Writer) extends InputStream {
class ReaderInputStream(
worker: PythonWorker,
writer: Writer,
handle: Option[ProcessHandle],
idleTimeoutSeconds: Long) extends InputStream {
private[this] var writerIfbhThreadLocalValue: Object = null
private[this] val temp = new Array[Byte](1)
private[this] val bufferStream = new DirectByteBufferOutputStream()
Expand Down Expand Up @@ -699,7 +738,13 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
val buf = ByteBuffer.wrap(b, off, len)
var n = 0
while (n == 0) {
worker.selector.select()
val selected = worker.selector.select(TimeUnit.SECONDS.toMillis(idleTimeoutSeconds))
if (selected == 0) {
logWarning(log"Idle timeout reached for Python worker (timeout: " +
log"${MDC(LogKeys.PYTHON_WORKER_IDLE_TIMEOUT, idleTimeoutSeconds)} seconds). " +
log"No data received from the worker process: " +
pythonWorkerStatusMessageWithContext(handle, worker, hasInput || buffer.hasRemaining))
}
if (worker.selectionKey.isReadable) {
n = worker.channel.read(buf)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import javax.annotation.concurrent.GuardedBy

import scala.collection.mutable
import scala.jdk.CollectionConverters._
import scala.jdk.OptionConverters._

import org.apache.spark._
import org.apache.spark.errors.SparkCoreErrors
Expand All @@ -35,11 +36,36 @@ import org.apache.spark.internal.LogKeys._
import org.apache.spark.security.SocketAuthHelper
import org.apache.spark.util.{RedirectThread, Utils}

case class PythonWorker(channel: SocketChannel, selector: Selector, selectionKey: SelectionKey) {
def stop(): Unit = {
Option(selectionKey).foreach(_.cancel())
selector.close()
channel.close()
case class PythonWorker(channel: SocketChannel) {

private[this] var selectorOpt: Option[Selector] = None
private[this] var selectionKeyOpt: Option[SelectionKey] = None

def selector: Selector = selectorOpt.orNull
def selectionKey: SelectionKey = selectionKeyOpt.orNull

private def closeSelector(): Unit = {
selectionKeyOpt.foreach(_.cancel())
selectorOpt.foreach(_.close())
}

def refresh(): this.type = synchronized {
closeSelector()
if (channel.isBlocking) {
selectorOpt = None
selectionKeyOpt = None
} else {
val selector = Selector.open()
selectorOpt = Some(selector)
selectionKeyOpt =
Some(channel.register(selector, SelectionKey.OP_READ | SelectionKey.OP_WRITE))
}
this
}

def stop(): Unit = synchronized {
closeSelector()
Option(channel).foreach(_.close())
}
}

Expand Down Expand Up @@ -93,7 +119,7 @@ private[spark] class PythonWorkerFactory(
envVars.getOrElse("PYTHONPATH", ""),
sys.env.getOrElse("PYTHONPATH", ""))

def create(): (PythonWorker, Option[Int]) = {
def create(): (PythonWorker, Option[ProcessHandle]) = {
if (useDaemon) {
self.synchronized {
// Pull from idle workers until we one that is alive, otherwise create a new one.
Expand All @@ -102,8 +128,7 @@ private[spark] class PythonWorkerFactory(
val workerHandle = daemonWorkers(worker)
if (workerHandle.isAlive()) {
try {
worker.selectionKey.interestOps(SelectionKey.OP_READ | SelectionKey.OP_WRITE)
return (worker, Some(workerHandle.pid().toInt))
return (worker.refresh(), Some(workerHandle))
} catch {
case c: CancelledKeyException => /* pass */
}
Expand All @@ -124,9 +149,9 @@ private[spark] class PythonWorkerFactory(
* processes itself to avoid the high cost of forking from Java. This currently only works
* on UNIX-based systems.
*/
private def createThroughDaemon(): (PythonWorker, Option[Int]) = {
private def createThroughDaemon(): (PythonWorker, Option[ProcessHandle]) = {

def createWorker(): (PythonWorker, Option[Int]) = {
def createWorker(): (PythonWorker, Option[ProcessHandle]) = {
val socketChannel = SocketChannel.open(new InetSocketAddress(daemonHost, daemonPort))
// These calls are blocking.
val pid = new DataInputStream(Channels.newInputStream(socketChannel)).readInt()
Expand All @@ -138,12 +163,9 @@ private[spark] class PythonWorkerFactory(
)
authHelper.authToServer(socketChannel.socket())
socketChannel.configureBlocking(false)
val selector = Selector.open()
val selectionKey = socketChannel.register(selector,
SelectionKey.OP_READ | SelectionKey.OP_WRITE)
val worker = PythonWorker(socketChannel, selector, selectionKey)
val worker = PythonWorker(socketChannel)
daemonWorkers.put(worker, processHandle)
(worker, Some(pid))
(worker.refresh(), Some(processHandle))
}

self.synchronized {
Expand All @@ -167,7 +189,8 @@ private[spark] class PythonWorkerFactory(
/**
* Launch a worker by executing worker.py (by default) directly and telling it to connect to us.
*/
private[spark] def createSimpleWorker(blockingMode: Boolean): (PythonWorker, Option[Int]) = {
private[spark] def createSimpleWorker(
blockingMode: Boolean): (PythonWorker, Option[ProcessHandle]) = {
var serverSocketChannel: ServerSocketChannel = null
try {
serverSocketChannel = ServerSocketChannel.open()
Expand Down Expand Up @@ -219,17 +242,11 @@ private[spark] class PythonWorkerFactory(
if (!blockingMode) {
socketChannel.configureBlocking(false)
}
val selector = Selector.open()
val selectionKey = if (blockingMode) {
null
} else {
socketChannel.register(selector, SelectionKey.OP_READ | SelectionKey.OP_WRITE)
}
val worker = PythonWorker(socketChannel, selector, selectionKey)
val worker = PythonWorker(socketChannel)
self.synchronized {
simpleWorkers.put(worker, workerProcess)
}
return (worker, Some(pid))
(worker.refresh(), ProcessHandle.of(pid).toScala)
} catch {
case e: Exception =>
throw new SparkException("Python worker failed to connect back.", e)
Expand All @@ -239,7 +256,6 @@ private[spark] class PythonWorkerFactory(
serverSocketChannel.close()
}
}
null
}

private def startDaemon(): Unit = {
Expand Down
14 changes: 14 additions & 0 deletions core/src/main/scala/org/apache/spark/internal/config/Python.scala
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,18 @@ private[spark] object Python {
.version("3.2.0")
.booleanConf
.createWithDefault(false)

private val PYTHON_WORKER_IDLE_TIMEOUT_SECONDS_KEY = "spark.python.worker.idleTimeoutSeconds"

val PYTHON_WORKER_IDLE_TIMEOUT_SECONDS = ConfigBuilder(PYTHON_WORKER_IDLE_TIMEOUT_SECONDS_KEY)
.doc("The time (in seconds) Spark will wait for activity " +
"(e.g., data transfer or communication) from a Python worker before considering it " +
"potentially idle or unresponsive. When the timeout is triggered, " +
"Spark will log the network-related status for debugging purposes. " +
"However, the Python worker will remain active and continue waiting for communication. " +
"The default is `0` that means no timeout.")
.version("4.0.0")
.timeConf(TimeUnit.SECONDS)
.checkValue(_ >= 0, "The idle timeout should be 0 or positive.")
.createWithDefault(0)
}
Original file line number Diff line number Diff line change
Expand Up @@ -3359,6 +3359,14 @@ object SQLConf {
.version("4.0.0")
.fallbackConf(Python.PYTHON_WORKER_FAULTHANLDER_ENABLED)

val PYTHON_UDF_WORKER_IDLE_TIMEOUT_SECONDS =
buildConf("spark.sql.execution.pyspark.udf.idleTimeoutSeconds")
.doc(
s"Same as ${Python.PYTHON_WORKER_IDLE_TIMEOUT_SECONDS.key} for Python execution with " +
"DataFrame and SQL. It can change during runtime.")
.version("4.0.0")
.fallbackConf(Python.PYTHON_WORKER_IDLE_TIMEOUT_SECONDS)

val PYSPARK_PLOT_MAX_ROWS =
buildConf("spark.sql.pyspark.plotting.max_rows")
.doc("The visual limit on plots. If set to 1000 for top-n-based plots (pie, bar, barh), " +
Expand Down Expand Up @@ -6271,6 +6279,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf {

def pythonUDFWorkerFaulthandlerEnabled: Boolean = getConf(PYTHON_UDF_WORKER_FAULTHANLDER_ENABLED)

def pythonUDFWorkerIdleTimeoutSeconds: Long = getConf(PYTHON_UDF_WORKER_IDLE_TIMEOUT_SECONDS)

def pythonUDFArrowConcurrencyLevel: Option[Int] = getConf(PYTHON_UDF_ARROW_CONCURRENCY_LEVEL)

def pysparkPlotMaxRows: Int = getConf(PYSPARK_PLOT_MAX_ROWS)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class ApplyInPandasWithStatePythonRunner(
funcs.head._1.funcs.head.pythonExec)

override val faultHandlerEnabled: Boolean = SQLConf.get.pythonUDFWorkerFaulthandlerEnabled
override val idleTimeoutSeconds: Long = SQLConf.get.pythonUDFWorkerIdleTimeoutSeconds

private val sqlConf = SQLConf.get

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ abstract class BaseArrowPythonRunner(
funcs.head._1.funcs.head.pythonExec)

override val faultHandlerEnabled: Boolean = SQLConf.get.pythonUDFWorkerFaulthandlerEnabled
override val idleTimeoutSeconds: Long = SQLConf.get.pythonUDFWorkerIdleTimeoutSeconds

override val errorOnDuplicatedFieldNames: Boolean = true

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class ArrowPythonUDTFRunner(
funcs.head.funcs.head.pythonExec)

override val faultHandlerEnabled: Boolean = SQLConf.get.pythonUDFWorkerFaulthandlerEnabled
override val idleTimeoutSeconds: Long = SQLConf.get.pythonUDFWorkerIdleTimeoutSeconds

override val errorOnDuplicatedFieldNames: Boolean = true

Expand Down
Loading

0 comments on commit 8d18df3

Please sign in to comment.