Skip to content

Commit e924090

Browse files
committed
CABI: improve and add cooperative thread built-ins
1 parent 1ae9b76 commit e924090

2 files changed

Lines changed: 243 additions & 48 deletions

File tree

design/mvp/canonical-abi/definitions.py

Lines changed: 98 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
class Trap(BaseException): pass
1818
class CoreWebAssemblyException(BaseException): pass
19+
class ThreadExit(BaseException): pass
1920

2021
def trap():
2122
raise Trap()
@@ -304,7 +305,7 @@ class ComponentInstance:
304305
handles: Table[ResourceHandle | Waitable | WaitableSet | ErrorContext]
305306
threads: Table[Thread]
306307
may_leave: bool
307-
may_block: bool
308+
sync_before_return: bool
308309
backpressure: int
309310
exclusive: Optional[Task]
310311
num_waiting_to_enter: int
@@ -316,11 +317,17 @@ def __init__(self, store, parent = None):
316317
self.handles = Table()
317318
self.threads = Table()
318319
self.may_leave = True
319-
self.may_block = True
320+
self.sync_before_return = False
320321
self.backpressure = 0
321322
self.exclusive = None
322323
self.num_waiting_to_enter = 0
323324

325+
def ready_threads(self) -> list[Thread]:
326+
return [t for t in self.threads.array if t and t.waiting() and t.ready()]
327+
328+
def may_block(self):
329+
return not self.sync_before_return or len(self.ready_threads()) > 0
330+
324331
def reflexive_ancestors(self) -> set[ComponentInstance]:
325332
s = set()
326333
inst = self
@@ -497,7 +504,10 @@ def ready(self):
497504
def __init__(self, task, thread_func):
498505
def cont_func(cancelled):
499506
assert(self.running() and not cancelled)
500-
thread_func()
507+
try:
508+
thread_func()
509+
except ThreadExit:
510+
pass
501511
return None
502512
self.cont = cont_new(cont_func)
503513
self.ready_func = None
@@ -507,7 +517,7 @@ def cont_func(cancelled):
507517
self.storage = [0,0]
508518
assert(self.suspended())
509519

510-
def resume_later(self):
520+
def unsuspend(self):
511521
assert(self.suspended())
512522
self.ready_func = lambda: True
513523
self.task.inst.store.waiting.append(self)
@@ -517,18 +527,25 @@ def resume(self, cancelled):
517527
assert(not self.running() and (self.cancellable or not cancelled))
518528
if self.waiting():
519529
assert(cancelled or self.ready())
520-
self.ready_func = None
521-
self.task.inst.store.waiting.remove(self)
530+
self.stop_waiting()
522531
thread = self
523532
while thread is not None:
524533
cont = thread.cont
525534
thread.cont = None
526535
(thread.cont, switch_to) = resume(cont, cancelled, thread)
536+
if switch_to is None and self.task.inst.sync_before_return:
537+
switch_to = random.choice(self.task.inst.ready_threads())
538+
switch_to.stop_waiting()
527539
thread = switch_to
528540
cancelled = Cancelled.FALSE
529541

542+
def stop_waiting(self):
543+
assert(self.waiting())
544+
self.ready_func = None
545+
self.task.inst.store.waiting.remove(self)
546+
530547
def suspend(self, cancellable) -> Cancelled:
531-
assert(self.running() and self.task.inst.may_block)
548+
assert(self.running() and self.task.inst.may_block())
532549
if self.task.deliver_pending_cancel(cancellable):
533550
return Cancelled.TRUE
534551
self.cancellable = cancellable
@@ -537,7 +554,7 @@ def suspend(self, cancellable) -> Cancelled:
537554
return cancelled
538555

539556
def wait_until(self, ready_func, cancellable = False) -> Cancelled:
540-
assert(self.running() and self.task.inst.may_block)
557+
assert(self.running() and self.task.inst.may_block())
541558
if self.task.deliver_pending_cancel(cancellable):
542559
return Cancelled.TRUE
543560
if ready_func() and not DETERMINISTIC_PROFILE and random.randint(0,1):
@@ -548,7 +565,7 @@ def wait_until(self, ready_func, cancellable = False) -> Cancelled:
548565

549566
def yield_until(self, ready_func, cancellable) -> Cancelled:
550567
assert(self.running())
551-
if self.task.inst.may_block:
568+
if self.task.inst.may_block():
552569
return self.wait_until(ready_func, cancellable)
553570
else:
554571
assert(ready_func())
@@ -557,7 +574,7 @@ def yield_until(self, ready_func, cancellable) -> Cancelled:
557574
def yield_(self, cancellable) -> Cancelled:
558575
return self.yield_until(lambda: True, cancellable)
559576

560-
def switch_to(self, cancellable, other: Thread) -> Cancelled:
577+
def suspend_to_suspended(self, cancellable, other: Thread) -> ResumeArg:
561578
assert(self.running() and other.suspended())
562579
if self.task.deliver_pending_cancel(cancellable):
563580
return Cancelled.TRUE
@@ -566,11 +583,31 @@ def switch_to(self, cancellable, other: Thread) -> Cancelled:
566583
assert(self.running() and (cancellable or not cancelled))
567584
return cancelled
568585

569-
def yield_to(self, cancellable, other: Thread) -> Cancelled:
586+
def yield_to_suspended(self, cancellable, other: Thread) -> ResumeArg:
570587
assert(self.running() and other.suspended())
571588
self.ready_func = lambda: True
572589
self.task.inst.store.waiting.append(self)
573-
return self.switch_to(cancellable, other)
590+
return self.suspend_to_suspended(cancellable, other)
591+
592+
def suspend_then_promote(self, cancellable, other: Thread) -> ResumeArg:
593+
assert(self.running())
594+
if other.waiting() and other.ready():
595+
other.stop_waiting()
596+
return self.suspend_to_suspended(cancellable, other)
597+
else:
598+
return self.suspend(cancellable)
599+
600+
def yield_then_promote(self, cancellable, other: Thread) -> ResumeArg:
601+
assert(self.running())
602+
if other.waiting() and other.ready():
603+
other.stop_waiting()
604+
return self.yield_to_suspended(cancellable, other)
605+
else:
606+
return self.yield_(cancellable)
607+
608+
def exit(self):
609+
assert(self.running() and self.task.inst.may_block())
610+
raise ThreadExit()
574611

575612
#### Waitable State
576613

@@ -711,8 +748,8 @@ def has_backpressure():
711748
assert(self.inst.exclusive is None)
712749
self.inst.exclusive = self
713750
else:
714-
assert(self.inst.may_block)
715-
self.inst.may_block = False
751+
assert(not self.inst.sync_before_return)
752+
self.inst.sync_before_return = True
716753
self.register_thread(thread)
717754
return True
718755

@@ -763,8 +800,8 @@ def return_(self, result):
763800
trap_if(self.state == Task.State.RESOLVED)
764801
trap_if(self.num_borrows > 0)
765802
if not self.ft.async_:
766-
assert(not self.inst.may_block)
767-
self.inst.may_block = True
803+
assert(self.inst.sync_before_return)
804+
self.inst.sync_before_return = False
768805
assert(result is not None)
769806
self.on_resolve(result)
770807
self.state = Task.State.RESOLVED
@@ -2106,7 +2143,7 @@ def thread_func():
21062143
else:
21072144
event = (EventCode.NONE, 0, 0)
21082145
case CallbackCode.WAIT:
2109-
trap_if(not inst.may_block)
2146+
trap_if(not inst.may_block())
21102147
wset = inst.handles.get(si)
21112148
trap_if(not isinstance(wset, WaitableSet))
21122149
event = wset.wait_until(lambda: not inst.exclusive, cancellable = True)
@@ -2150,7 +2187,7 @@ def call_and_trap_on_throw(callee, args):
21502187
def canon_lower(opts, ft, callee: FuncInst, flat_args):
21512188
thread = current_thread()
21522189
trap_if(not thread.task.inst.may_leave)
2153-
trap_if(not thread.task.inst.may_block and ft.async_ and not opts.async_)
2190+
trap_if(not thread.task.inst.may_block() and ft.async_ and not opts.async_)
21542191

21552192
subtask = Subtask()
21562193
cx = LiftLowerContext(opts, thread.task.inst, subtask)
@@ -2338,7 +2375,7 @@ def canon_waitable_set_new():
23382375
def canon_waitable_set_wait(cancellable, mem, si, ptr):
23392376
inst = current_thread().task.inst
23402377
trap_if(not inst.may_leave)
2341-
trap_if(not inst.may_block)
2378+
trap_if(not inst.may_block())
23422379
wset = inst.handles.get(si)
23432380
trap_if(not isinstance(wset, WaitableSet))
23442381
event = wset.wait(cancellable)
@@ -2393,7 +2430,7 @@ def canon_waitable_join(wi, si):
23932430
def canon_subtask_cancel(async_, i):
23942431
thread = current_thread()
23952432
trap_if(not thread.task.inst.may_leave)
2396-
trap_if(not thread.task.inst.may_block and not async_)
2433+
trap_if(not thread.task.inst.may_block() and not async_)
23972434
subtask = thread.task.inst.handles.get(i)
23982435
trap_if(not isinstance(subtask, Subtask))
23992436
trap_if(subtask.resolve_delivered())
@@ -2454,7 +2491,7 @@ def canon_stream_write(stream_t, opts, i, ptr, n):
24542491
def stream_copy(EndT, BufferT, event_code, stream_t, opts, i, ptr, n):
24552492
thread = current_thread()
24562493
trap_if(not thread.task.inst.may_leave)
2457-
trap_if(not thread.task.inst.may_block and not opts.async_)
2494+
trap_if(not thread.task.inst.may_block() and not opts.async_)
24582495

24592496
e = thread.task.inst.handles.get(i)
24602497
trap_if(not isinstance(e, EndT))
@@ -2509,7 +2546,7 @@ def canon_future_write(future_t, opts, i, ptr):
25092546
def future_copy(EndT, BufferT, event_code, future_t, opts, i, ptr):
25102547
thread = current_thread()
25112548
trap_if(not thread.task.inst.may_leave)
2512-
trap_if(not thread.task.inst.may_block and not opts.async_)
2549+
trap_if(not thread.task.inst.may_block() and not opts.async_)
25132550

25142551
e = thread.task.inst.handles.get(i)
25152552
trap_if(not isinstance(e, EndT))
@@ -2562,7 +2599,7 @@ def canon_future_cancel_write(future_t, async_, i):
25622599
def cancel_copy(EndT, event_code, stream_or_future_t, async_, i):
25632600
thread = current_thread()
25642601
trap_if(not thread.task.inst.may_leave)
2565-
trap_if(not thread.task.inst.may_block and not async_)
2602+
trap_if(not thread.task.inst.may_block() and not async_)
25662603
e = thread.task.inst.handles.get(i)
25672604
trap_if(not isinstance(e, EndT))
25682605
trap_if(e.shared.t != stream_or_future_t.t)
@@ -2630,22 +2667,22 @@ def thread_func():
26302667
task.register_thread(new_thread)
26312668
return [new_thread.index]
26322669

2633-
### 🧵 `canon thread.resume-later`
2670+
### 🧵 `canon thread.unsuspend`
26342671

2635-
def canon_thread_resume_later(i):
2672+
def canon_thread_unsuspend(i):
26362673
thread = current_thread()
26372674
trap_if(not thread.task.inst.may_leave)
26382675
other_thread = thread.task.inst.threads.get(i)
26392676
trap_if(not other_thread.suspended())
2640-
other_thread.resume_later()
2677+
other_thread.unsuspend()
26412678
return []
26422679

26432680
### 🧵 `canon thread.suspend`
26442681

26452682
def canon_thread_suspend(cancellable):
26462683
thread = current_thread()
26472684
trap_if(not thread.task.inst.may_leave)
2648-
trap_if(not thread.task.inst.may_block)
2685+
trap_if(not thread.task.inst.may_block())
26492686
cancelled = thread.suspend(cancellable)
26502687
return [cancelled]
26512688

@@ -2657,26 +2694,54 @@ def canon_thread_yield(cancellable):
26572694
cancelled = thread.yield_(cancellable)
26582695
return [cancelled]
26592696

2660-
### 🧵 `canon thread.switch-to`
2697+
### 🧵 `canon thread.suspend-to-suspended`
26612698

2662-
def canon_thread_switch_to(cancellable, i):
2699+
def canon_thread_suspend_to_suspended(cancellable, i):
26632700
thread = current_thread()
26642701
trap_if(not thread.task.inst.may_leave)
26652702
other_thread = thread.task.inst.threads.get(i)
26662703
trap_if(not other_thread.suspended())
2667-
cancelled = thread.switch_to(cancellable, other_thread)
2704+
cancelled = thread.suspend_to_suspended(cancellable, other_thread)
26682705
return [cancelled]
26692706

2670-
### 🧵 `canon thread.yield-to`
2707+
### 🧵 `canon thread.yield-to-suspended`
26712708

2672-
def canon_thread_yield_to(cancellable, i):
2709+
def canon_thread_yield_to_suspended(cancellable, i):
26732710
thread = current_thread()
26742711
trap_if(not thread.task.inst.may_leave)
26752712
other_thread = thread.task.inst.threads.get(i)
26762713
trap_if(not other_thread.suspended())
2677-
cancelled = thread.yield_to(cancellable, other_thread)
2714+
cancelled = thread.yield_to_suspended(cancellable, other_thread)
2715+
return [cancelled]
2716+
2717+
### 🧵 `canon thread.suspend-then-promote`
2718+
2719+
def canon_thread_suspend_then_promote(cancellable, i):
2720+
thread = current_thread()
2721+
trap_if(not thread.task.inst.may_leave)
2722+
trap_if(not thread.task.inst.may_block())
2723+
other_thread = thread.task.inst.threads.get(i)
2724+
cancelled = thread.suspend_then_promote(cancellable, other_thread)
26782725
return [cancelled]
26792726

2727+
### 🧵 `canon thread.yield-then-promote`
2728+
2729+
def canon_thread_yield_then_promote(cancellable, i):
2730+
thread = current_thread()
2731+
trap_if(not thread.task.inst.may_leave)
2732+
other_thread = thread.task.inst.threads.get(i)
2733+
cancelled = thread.yield_then_promote(cancellable, other_thread)
2734+
return [cancelled]
2735+
2736+
### 🧵 `canon thread.exit`
2737+
2738+
def canon_thread_exit():
2739+
thread = current_thread()
2740+
trap_if(not thread.task.inst.may_leave)
2741+
trap_if(not thread.task.inst.may_block())
2742+
thread.exit()
2743+
assert(False)
2744+
26802745
### 📝 `canon error-context.new`
26812746

26822747
@dataclass

0 commit comments

Comments
 (0)