Skip to content

Commit 2c80db1

Browse files
committed
Add stack-safe async loop support with trampoline pattern
- Add AsyncTrampoline class to prevent stack overflow in loops by converting callback recursion into iterative execution - Add thenRunWhileLoop method to AsyncRunnable to support while-loop semantics where condition is checked before body execution - Integrate trampoline into AsyncCallbackLoop by making LoopingCallback implement Runnable to avoid per-iteration lambda allocation JAVA-6120
1 parent 7a271ca commit 2c80db1

5 files changed

Lines changed: 324 additions & 5 deletions

File tree

driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,35 @@ default AsyncRunnable thenRunRetryingWhile(
243243
});
244244
}
245245

246+
/**
247+
* This method is equivalent to a while loop, where the condition is checked before each iteration.
248+
* If the condition returns {@code false} on the first check, the body is never executed.
249+
*
250+
* @param loopBodyRunnable the asynchronous task to be executed in each iteration of the loop
251+
* @param whileCheck a condition to check before each iteration; the loop continues as long as this condition returns true
252+
* @return the composition of this and the looping branch
253+
* @see AsyncCallbackLoop
254+
*/
255+
default AsyncRunnable thenRunWhileLoop(final BooleanSupplier whileCheck, final AsyncRunnable loopBodyRunnable) {
256+
return thenRun(finalCallback -> {
257+
LoopState loopState = new LoopState();
258+
new AsyncCallbackLoop(loopState, iterationCallback -> {
259+
260+
if (loopState.breakAndCompleteIf(() -> !whileCheck.getAsBoolean(), iterationCallback)) {
261+
return;
262+
}
263+
loopBodyRunnable.finish((result, t) -> {
264+
if (t != null) {
265+
iterationCallback.completeExceptionally(t);
266+
return;
267+
}
268+
iterationCallback.complete(iterationCallback);
269+
});
270+
271+
}).run(finalCallback);
272+
});
273+
}
274+
246275
/**
247276
* This method is equivalent to a do-while loop, where the loop body is executed first and
248277
* then the condition is checked to determine whether the loop should continue.
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
/*
2+
* Copyright 2008-present MongoDB, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.mongodb.internal.async;
18+
19+
import com.mongodb.annotations.NotThreadSafe;
20+
import com.mongodb.assertions.Assertions;
21+
import com.mongodb.lang.Nullable;
22+
23+
/**
24+
* A trampoline that converts recursive callback invocations into an iterative loop,
25+
* preventing stack overflow in async loops.
26+
*
27+
* <p>When async loop iterations complete synchronously on the same thread, callback
28+
* recursion occurs: each iteration's {@code callback.onResult()} immediately triggers
29+
* the next iteration, causing unbounded stack growth. For example, a 1000-iteration
30+
* loop would create > 1000 stack frames and cause {@code StackOverflowError}.</p>
31+
*
32+
* <p>The trampoline intercepts this recursion: instead of executing the next iteration
33+
* immediately (which would deepen the stack), it enqueues the work and returns, allowing
34+
* the stack to unwind. A flat loop at the top then processes enqueued work iteratively,
35+
* maintaining constant stack depth regardless of iteration count.</p>
36+
*
37+
* <p>Since async chains are sequential, at most one task is pending at any time.
38+
* The trampoline uses a single slot rather than a queue.</p>
39+
*
40+
* The first call on a thread becomes the "trampoline owner" and runs the drain loop.
41+
* Subsequent (re-entrant) calls on the same thread enqueue their work and return immediately.</p>
42+
*
43+
* <p>This class is not part of the public API and may be removed or changed at any time</p>
44+
*/
45+
@NotThreadSafe
46+
public final class AsyncTrampoline {
47+
48+
private static final ThreadLocal<Bounce> TRAMPOLINE = new ThreadLocal<>();
49+
50+
private AsyncTrampoline() {
51+
}
52+
53+
/**
54+
* Execute work through the trampoline. If no trampoline is active, become the owner
55+
* and drain all enqueued work. If a trampoline is already active, enqueue and return.
56+
*/
57+
public static void run(final Runnable work) {
58+
Bounce bounce = TRAMPOLINE.get();
59+
if (bounce != null) {
60+
// Re-entrant, enqueue and return
61+
bounce.enqueue(work);
62+
} else {
63+
// Become the trampoline owner.
64+
bounce = new Bounce();
65+
TRAMPOLINE.set(bounce);
66+
try {
67+
work.run();
68+
// drain any re-entrant work iteratively
69+
while (bounce.hasWork()) {
70+
bounce.runNext();
71+
}
72+
} finally {
73+
TRAMPOLINE.remove();
74+
}
75+
}
76+
}
77+
78+
/**
79+
* A single-slot container for deferred work.
80+
* At most one task is pending at any time in a sequential async chain.
81+
*/
82+
@NotThreadSafe
83+
private static final class Bounce {
84+
@Nullable
85+
private Runnable work;
86+
87+
void enqueue(final Runnable task) {
88+
if (this.work != null) {
89+
throw new AssertionError("Trampoline slot already occupied. "
90+
+ "It could happen if there are multiple concurrent operations in a sequential async chain.");
91+
}
92+
this.work = task;
93+
}
94+
95+
boolean hasWork() {
96+
return work != null;
97+
}
98+
99+
void runNext() {
100+
Runnable task = this.work;
101+
this.work = null;
102+
Assertions.assertNotNull(task);
103+
task.run();
104+
}
105+
}
106+
}

driver-core/src/main/com/mongodb/internal/async/function/AsyncCallbackLoop.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
package com.mongodb.internal.async.function;
1717

1818
import com.mongodb.annotations.NotThreadSafe;
19+
import com.mongodb.internal.async.AsyncTrampoline;
1920
import com.mongodb.internal.async.SingleResultCallback;
2021
import com.mongodb.lang.Nullable;
2122

@@ -58,15 +59,21 @@ public void run(final SingleResultCallback<Void> callback) {
5859

5960
/**
6061
* This callback is allowed to be completed more than once.
62+
* Also implements {@linkplain Runnable} to avoid lambda allocation per iteration when using trampoline.
6163
*/
6264
@NotThreadSafe
63-
private class LoopingCallback implements SingleResultCallback<Void> {
65+
private class LoopingCallback implements SingleResultCallback<Void>, Runnable {
6466
private final SingleResultCallback<Void> wrapped;
6567

6668
LoopingCallback(final SingleResultCallback<Void> callback) {
6769
wrapped = callback;
6870
}
6971

72+
@Override
73+
public void run() {
74+
body.run(this);
75+
}
76+
7077
@Override
7178
public void onResult(@Nullable final Void result, @Nullable final Throwable t) {
7279
if (t != null) {
@@ -80,7 +87,7 @@ public void onResult(@Nullable final Void result, @Nullable final Throwable t) {
8087
return;
8188
}
8289
if (continueLooping) {
83-
body.run(this);
90+
AsyncTrampoline.run(this);
8491
} else {
8592
wrapped.onResult(result, null);
8693
}

driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsAbstractTest.java

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
import static com.mongodb.assertions.Assertions.assertNotNull;
2828
import static com.mongodb.internal.async.AsyncRunnable.beginAsync;
29+
import static org.junit.jupiter.api.Assertions.assertEquals;
2930

3031
abstract class AsyncFunctionsAbstractTest extends AsyncFunctionsTestBase {
3132
private static final TimeoutContext TIMEOUT_CONTEXT = new TimeoutContext(new TimeoutSettings(0, 0, 0, 0L, 0));
@@ -723,6 +724,120 @@ void testTryCatchTestAndRethrow() {
723724
});
724725
}
725726

727+
@Test
728+
void testWhile() {
729+
// last iteration: 3 < 3 = 1
730+
// 1(plainTest exception) + 1(plainTest false) + 1(sync exception) + 1(sync success) * 1(transition to next iteration) = 4
731+
// 1(plainTest exception) + 1(plainTest false) + 1(sync exception) + 1(sync success) * 4(transition to next iteration) = 7
732+
// 1(plainTest exception) + 1(plainTest false) + 1(sync exception) + 1(sync success) * 7(transition to next iteration) = 10
733+
assertBehavesSameVariations(10,
734+
() -> {
735+
int counter = 0;
736+
while (counter < 3 && plainTest(counter)) {
737+
counter++;
738+
sync(counter);
739+
}
740+
},
741+
(callback) -> {
742+
MutableValue<Integer> counter = new MutableValue<>(0);
743+
beginAsync().thenRunWhileLoop(() -> counter.get() < 3 && plainTest(counter.get()), c2 -> {
744+
counter.set(counter.get() + 1);
745+
async(counter.get(), c2);
746+
}).finish(callback);
747+
});
748+
}
749+
750+
@Test
751+
void testWhileWithThenRun() {
752+
// while: last iteration: 3 < 3 = 1
753+
// 1(plainTest exception) + 1(plainTest false) + 1(sync exception) + 1(sync success) * 1(transition to next iteration) = 4
754+
// 1(plainTest exception) + 1(plainTest false) + 1(sync exception) + 1(sync success) * 4(transition to next iteration) = 7
755+
// 1(plainTest exception) + 1(plainTest false) + 1(sync exception) + 1(sync success) * 7(transition to next iteration) = 10
756+
// trailing sync: 1(exception) + 1(success) = 2
757+
// 6(while exception) + 4(while success) * 2(trailing sync) = 14
758+
assertBehavesSameVariations(14,
759+
() -> {
760+
int counter = 0;
761+
while (counter < 3 && plainTest(counter)) {
762+
counter++;
763+
sync(counter);
764+
}
765+
sync(counter + 1);
766+
},
767+
(callback) -> {
768+
MutableValue<Integer> counter = new MutableValue<>(0);
769+
beginAsync().thenRun(c -> {
770+
beginAsync().thenRunWhileLoop(() -> counter.get() < 3 && plainTest(counter.get()), c2 -> {
771+
counter.set(counter.get() + 1);
772+
async(counter.get(), c2);
773+
}).finish(c);
774+
}).thenRun(c -> {
775+
async(counter.get() + 1, c);
776+
}).finish(callback);
777+
});
778+
}
779+
780+
@Test
781+
void testNestedWhileLoops() {
782+
// inner while: 4 success + 6 exception = 10
783+
// last inner iteration: 3 < 3 = 1
784+
// 1(outer plainTest exception) + 1(outer plainTest false) + (inner while) * 1(transition to next iteration) = 12
785+
// 1(outer plainTest exception) + 1(outer plainTest false) + (inner while) * 12(transition to next iteration) = 56
786+
// 1(outer plainTest exception) + 1(outer plainTest false) + (inner while) * 56(transition to next iteration) = 232
787+
assertBehavesSameVariations(232,
788+
() -> {
789+
int outer = 0;
790+
while (outer < 3 && plainTest(outer)) {
791+
int inner = 0;
792+
while (inner < 3 && plainTest(inner)) {
793+
sync(outer + inner);
794+
inner++;
795+
}
796+
outer++;
797+
}
798+
},
799+
(callback) -> {
800+
MutableValue<Integer> outer = new MutableValue<>(0);
801+
beginAsync().thenRunWhileLoop(() -> outer.get() < 3 && plainTest(outer.get()), c -> {
802+
MutableValue<Integer> inner = new MutableValue<>(0);
803+
beginAsync().thenRunWhileLoop(
804+
() -> inner.get() < 3 && plainTest(inner.get()),
805+
c2 -> {
806+
beginAsync().thenRun(c3 -> {
807+
async(outer.get() + inner.get(), c3);
808+
}).thenRun(c3 -> {
809+
inner.set(inner.get() + 1);
810+
c3.complete(c3);
811+
}).finish(c2);
812+
}
813+
).thenRun(c2 -> {
814+
outer.set(outer.get() + 1);
815+
c2.complete(c2);
816+
}).finish(c);
817+
}).finish(callback);
818+
});
819+
}
820+
821+
@Test
822+
void testWhileLoopStackConstant() {
823+
int depthWith100 = maxStackDepthForIterations(100);
824+
int depthWith10000 = maxStackDepthForIterations(10_000);
825+
assertEquals(depthWith100, depthWith10000, "Stack depth should be constant regardless of iteration count (trampoline)");
826+
}
827+
828+
private int maxStackDepthForIterations(final int iterations) {
829+
MutableValue<Integer> counter = new MutableValue<>(0);
830+
MutableValue<Integer> maxDepth = new MutableValue<>(0);
831+
beginAsync().thenRunWhileLoop(() -> counter.get() < iterations, c -> {
832+
maxDepth.set(Math.max(maxDepth.get(), Thread.currentThread().getStackTrace().length));
833+
counter.set(counter.get() + 1);
834+
c.complete(c);
835+
}).finish((v, t) -> {});
836+
837+
assertEquals(iterations, counter.get());
838+
return maxDepth.get();
839+
}
840+
726841
@Test
727842
void testRetryLoop() {
728843
assertBehavesSameVariations(InvocationTracker.DEPTH_LIMIT * 2 + 1,
@@ -768,6 +883,65 @@ void testDoWhileLoop() {
768883
});
769884
}
770885

886+
@Test
887+
void testNestedDoWhileLoops() {
888+
// inner do-while: 3 success + 5 exception = 8
889+
// last outer iteration: 3 < 3 = 1
890+
// 5(inner exception) + 3(inner success) * 1(transition to next iteration) = 8
891+
// 5(inner exception) + 3(inner success) * 1(outer plainTest exception) + 1(outer plainTest false) + 8(transition to next iteration) = 35
892+
// 5(inner exception) + 3(inner success) * 1(outer plainTest exception) + 1(outer plainTest false) + 35(transition to next iteration) = 116
893+
assertBehavesSameVariations(116,
894+
() -> {
895+
int outer = 0;
896+
do {
897+
int inner = 0;
898+
do {
899+
sync(outer + inner);
900+
inner++;
901+
} while (inner < 3 && plainTest(inner));
902+
outer++;
903+
} while (outer < 3 && plainTest(outer));
904+
},
905+
(callback) -> {
906+
MutableValue<Integer> outer = new MutableValue<>(0);
907+
beginAsync().thenRunDoWhileLoop(c -> {
908+
MutableValue<Integer> inner = new MutableValue<>(0);
909+
beginAsync().thenRunDoWhileLoop(c2 -> {
910+
beginAsync().thenRun(c3 -> {
911+
async(outer.get() + inner.get(), c3);
912+
}).thenRun(c3 -> {
913+
inner.set(inner.get() + 1);
914+
c3.complete(c3);
915+
}).finish(c2);
916+
}, () -> inner.get() < 3 && plainTest(inner.get())
917+
).thenRun(c2 -> {
918+
outer.set(outer.get() + 1);
919+
c2.complete(c2);
920+
}).finish(c);
921+
}, () -> outer.get() < 3 && plainTest(outer.get())).finish(callback);
922+
});
923+
}
924+
925+
@Test
926+
void testDoWhileLoopStackConstant() {
927+
int depthWith100 = maxDoWhileStackDepthForIterations(100);
928+
int depthWith10000 = maxDoWhileStackDepthForIterations(10_000);
929+
assertEquals(depthWith100, depthWith10000,
930+
"Stack depth should be constant regardless of iteration count");
931+
}
932+
933+
private int maxDoWhileStackDepthForIterations(final int iterations) {
934+
MutableValue<Integer> counter = new MutableValue<>(0);
935+
MutableValue<Integer> maxDepth = new MutableValue<>(0);
936+
beginAsync().thenRunDoWhileLoop(c -> {
937+
maxDepth.set(Math.max(maxDepth.get(), Thread.currentThread().getStackTrace().length));
938+
counter.set(counter.get() + 1);
939+
c.complete(c);
940+
}, () -> counter.get() < iterations).finish((v, t) -> {});
941+
assertEquals(iterations, counter.get());
942+
return maxDepth.get();
943+
}
944+
771945
@Test
772946
void testFinallyWithPlainInsideTry() {
773947
// (in try: normal flow + exception + exception) * (in finally: normal + exception) = 6

0 commit comments

Comments
 (0)