Skip to content

Commit 4db309d

Browse files
cloud-fanHyukjinKwon
authored andcommitted
[SPARK-56330][CORE][FOLLOWUP] Use dedicated loop for TaskInterruptListeners to avoid blocking completion listeners
### What changes were proposed in this pull request? Followup to #55151. Changed `invokeTaskInterruptListeners` in `TaskContextImpl` to use a dedicated loop instead of routing through the shared `invokeListeners()` method. Also removed the now-unnecessary `markTaskFailedOnListenerError` parameter from `invokeListeners` since it is always `true` (only completion and failure listeners use this method now). ### Why are the changes needed? The current implementation routes `TaskInterruptListener`s through `invokeListeners()`, which uses `listenerInvocationThread` to serialize all listener invocations (completion, failure, and interrupt). This creates a problem when `markInterrupted()` is called on the kill thread while completion listeners need to run on the task thread: 1. The kill thread enters `invokeListeners()` for interrupt callbacks and acquires `listenerInvocationThread`. 2. While the interrupt listener runs, the task thread completes and calls `markTaskCompleted()` → `invokeTaskCompletionListeners()` → `invokeListeners()` for completion callbacks. 3. `invokeListeners()` sees `listenerInvocationThread` is held by the kill thread and returns immediately — **silently skipping all completion listeners**. This causes resource leaks because cleanup logic registered via `addTaskCompletionListener` (e.g., closing file handles, releasing caches, freeing native memory) never executes. The fix uses a dedicated loop for interrupt listeners with independent serialization (synchronized access to the callback stack, but no `listenerInvocationThread` gate), so interrupt listeners and completion/failure listeners can run on different threads without blocking each other. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added a regression test `"SPARK-56330: task completion during interrupt listener execution"` that: 1. Registers a completion listener and a blocking interrupt listener. 2. Interrupts the task from a separate thread and waits for the interrupt listener to start. 3. Marks the task completed on the main thread while the interrupt listener is still running. 4. Verifies the completion listener was called (it would be silently skipped with the shared `invokeListeners()` approach). ### Was this patch authored or co-authored using generative AI tooling? Generated-by: Claude Code Closes #55292 from cloud-fan/SPARK-56330-followup. Authored-by: Wenchen Fan <wenchen@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
1 parent 638ca41 commit 4db309d

2 files changed

Lines changed: 84 additions & 45 deletions

File tree

core/src/main/scala/org/apache/spark/TaskContextImpl.scala

Lines changed: 49 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ private[spark] class TaskContextImpl(
7777
@transient private val onInterruptCallbacks = new Stack[TaskInterruptListener]
7878

7979
/**
80-
* The thread currently executing task completion, failure, or interrupt listeners, if any.
80+
* The thread currently executing task completion or failure listeners, if any.
8181
*
8282
* `invokeListeners()` uses this to ensure listeners are called sequentially.
8383
*/
@@ -131,12 +131,11 @@ private[spark] class TaskContextImpl(
131131

132132
override def addTaskInterruptListener(listener: TaskInterruptListener): this.type = {
133133
synchronized {
134-
// If there is already a thread invoking listeners, adding the new listener to
135-
// `onInterruptCallbacks` will cause that thread to execute the new listener, and the call to
136-
// `invokeTaskInterruptListeners()` below will be a no-op.
134+
// If another thread is already running `invokeTaskInterruptListeners`, adding the new
135+
// listener to `onInterruptCallbacks` will cause that thread to execute it (the loop pops
136+
// listeners under the TaskContext lock).
137137
//
138-
// If there is no such thread, the call to `invokeTaskInterruptListeners()` below will execute
139-
// all listeners, including the new listener.
138+
// Otherwise, `invokeTaskInterruptListeners()` below will execute all listeners.
140139
onInterruptCallbacks.push(listener)
141140
reasonIfKilled
142141
}.foreach { reason =>
@@ -172,47 +171,58 @@ private[spark] class TaskContextImpl(
172171
private def invokeTaskCompletionListeners(error: Option[Throwable]): Unit = {
173172
// It is safe to access the reference to `onCompleteCallbacks` without holding the TaskContext
174173
// lock. `invokeListeners()` acquires the lock before accessing the contents.
175-
invokeListeners(
176-
onCompleteCallbacks,
177-
"TaskCompletionListener",
178-
error,
179-
markTaskFailedOnListenerError = true) {
174+
invokeListeners(onCompleteCallbacks, "TaskCompletionListener", error) {
180175
_.onTaskCompletion(this)
181176
}
182177
}
183178

184179
private def invokeTaskFailureListeners(error: Throwable): Unit = {
185180
// It is safe to access the reference to `onFailureCallbacks` without holding the TaskContext
186181
// lock. `invokeListeners()` acquires the lock before accessing the contents.
187-
invokeListeners(
188-
onFailureCallbacks,
189-
"TaskFailureListener",
190-
Option(error),
191-
markTaskFailedOnListenerError = true) {
182+
invokeListeners(onFailureCallbacks, "TaskFailureListener", Option(error)) {
192183
_.onTaskFailure(this, error)
193184
}
194185
}
195186

196187
private def invokeTaskInterruptListeners(reason: String, error: Throwable): Unit = {
197-
// It is safe to access the reference to `onInterruptCallbacks` without holding the TaskContext
198-
// lock. `invokeListeners()` acquires the lock before accessing the contents.
199-
// Do not call `markTaskFailed` per listener here: the first failure would win in
200-
// `failureCauseOpt` and mask the aggregate `TaskCompletionListenerException` that
201-
// `markInterrupted` records after swallowing the thrown exception (SPARK-56330).
202-
invokeListeners(
203-
onInterruptCallbacks,
204-
"TaskInterruptListener",
205-
Option(error),
206-
markTaskFailedOnListenerError = false) {
207-
_.onTaskInterrupted(this, reason)
188+
// Do not use `invokeListeners()`. That method uses `listenerInvocationThread` to serialize
189+
// all listener invocations, which would prevent task completion or failure listeners from
190+
// running if the task completes or fails while executing an interrupt listener (causing
191+
// resource leaks such as task-managed resources not being freed).
192+
//
193+
// Instead, directly execute the task interrupt listeners with independent serialization.
194+
// Exceptions are collected and thrown as a TaskCompletionListenerException so that
195+
// `markInterrupted` can catch and record the aggregate failure (SPARK-56330).
196+
def getNextListener(): Option[TaskInterruptListener] = synchronized {
197+
if (onInterruptCallbacks.empty()) {
198+
None
199+
} else {
200+
Some(onInterruptCallbacks.pop())
201+
}
202+
}
203+
val listenerExceptions = new ArrayBuffer[Throwable](2)
204+
var listenerOption: Option[TaskInterruptListener] = None
205+
while ({listenerOption = getNextListener(); listenerOption.nonEmpty}) {
206+
try {
207+
listenerOption.get.onTaskInterrupted(this, reason)
208+
} catch {
209+
case e: Throwable =>
210+
listenerExceptions += e
211+
logError(log"Error in TaskInterruptListener", e)
212+
}
213+
}
214+
if (listenerExceptions.nonEmpty) {
215+
val exception = new TaskCompletionListenerException(
216+
listenerExceptions.map(_.getMessage).toSeq, Option(error))
217+
listenerExceptions.foreach(exception.addSuppressed)
218+
throw exception
208219
}
209220
}
210221

211222
private def invokeListeners[T](
212223
listeners: Stack[T],
213224
name: String,
214-
error: Option[Throwable],
215-
markTaskFailedOnListenerError: Boolean)(
225+
error: Option[Throwable])(
216226
callback: T => Unit): Unit = {
217227
// This method is subject to two constraints:
218228
//
@@ -255,8 +265,7 @@ private[spark] class TaskContextImpl(
255265
callback(listener)
256266
} catch {
257267
case e: Throwable =>
258-
// A listener failed. For completion/failure listeners, temporarily clear
259-
// listenerInvocationThread and markTaskFailed so nested TaskContext calls can run.
268+
// A listener failed. Temporarily clear the listenerInvocationThread and markTaskFailed.
260269
//
261270
// One of the following cases applies (#3 being the interesting one):
262271
//
@@ -290,20 +299,15 @@ private[spark] class TaskContextImpl(
290299
// failed, and now another completion listener has failed. Then our call to
291300
// [[markTaskFailed]] here will have no effect and we simply resume running the
292301
// remaining completion handlers.
293-
//
294-
// Task interrupt listeners skip per-listener [[markTaskFailed]]; see
295-
// [[invokeTaskInterruptListeners]].
296-
if (markTaskFailedOnListenerError) {
297-
try {
298-
listenerInvocationThread = None
299-
markTaskFailed(e)
300-
} catch {
301-
case t: Throwable => e.addSuppressed(t)
302-
} finally {
303-
synchronized {
304-
if (listenerInvocationThread.isEmpty) {
305-
listenerInvocationThread = Some(Thread.currentThread())
306-
}
302+
try {
303+
listenerInvocationThread = None
304+
markTaskFailed(e)
305+
} catch {
306+
case t: Throwable => e.addSuppressed(t)
307+
} finally {
308+
synchronized {
309+
if (listenerInvocationThread.isEmpty) {
310+
listenerInvocationThread = Some(Thread.currentThread())
307311
}
308312
}
309313
}

core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,41 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
199199
assert(invocations == 2)
200200
}
201201

202+
test("SPARK-56330: task completion during interrupt listener execution") {
203+
val context = TaskContext.empty()
204+
val completionListener = mock(classOf[TaskCompletionListener])
205+
context.addTaskCompletionListener(completionListener)
206+
207+
// Add a task interrupt listener that blocks until released.
208+
val interruptListenerStarted = new Semaphore(0)
209+
val interruptListenerRelease = new Semaphore(0)
210+
context.addTaskInterruptListener(new TaskInterruptListener {
211+
override def onTaskInterrupted(context: TaskContext, reason: String): Unit = {
212+
interruptListenerStarted.release()
213+
interruptListenerRelease.acquire()
214+
}
215+
})
216+
217+
// Interrupt the task from a separate thread and wait until the interrupt listener starts.
218+
val interruptThread = new Thread(() => context.markInterrupted("test interrupt"))
219+
interruptThread.start()
220+
interruptListenerStarted.acquire()
221+
222+
// While the interrupt listener is running on the interrupt thread, mark the task completed
223+
// on this thread. With the dedicated interrupt listener loop, this must NOT be blocked.
224+
context.markTaskCompleted(None)
225+
226+
// The completion listener should have been called even though the interrupt listener is still
227+
// running. If `invokeListeners()` were shared between interrupt and completion listeners,
228+
// the completion listener would be silently skipped because `listenerInvocationThread` would
229+
// be held by the interrupt thread.
230+
verify(completionListener, times(1)).onTaskCompletion(any())
231+
232+
// Release the interrupt listener and join the interrupt thread.
233+
interruptListenerRelease.release()
234+
interruptThread.join()
235+
}
236+
202237
test("FailureListener throws after task body fails") {
203238
val context = TaskContext.empty()
204239
val listenerCalls = ArrayBuffer.empty[String]

0 commit comments

Comments
 (0)