11use crate :: identity:: Identity ;
22use spacetimedb_lib:: GlobalTxId ;
33use std:: collections:: { HashMap , HashSet } ;
4+ use std:: future:: Future ;
45use std:: sync:: atomic:: { AtomicBool , Ordering } ;
56use std:: sync:: { Arc , Mutex } ;
6- use tokio:: sync:: Notify ;
7+ use tokio:: sync:: { watch , Notify } ;
78
89#[ derive( Debug , Clone , Copy , PartialEq , Eq ) ]
910pub enum GlobalTxRole {
@@ -28,18 +29,21 @@ pub struct GlobalTxSession {
2829 pub role : GlobalTxRole ,
2930 pub coordinator_identity : Identity ,
3031 wounded : AtomicBool ,
32+ wounded_tx : watch:: Sender < bool > ,
3133 state : Mutex < GlobalTxState > ,
3234 prepare_id : Mutex < Option < String > > ,
3335 participants : Mutex < HashMap < Identity , String > > ,
3436}
3537
3638impl GlobalTxSession {
3739 fn new ( tx_id : GlobalTxId , role : GlobalTxRole , coordinator_identity : Identity ) -> Self {
40+ let ( wounded_tx, _) = watch:: channel ( false ) ;
3841 Self {
3942 tx_id,
4043 role,
4144 coordinator_identity,
4245 wounded : AtomicBool :: new ( false ) ,
46+ wounded_tx,
4347 state : Mutex :: new ( GlobalTxState :: Running ) ,
4448 prepare_id : Mutex :: new ( None ) ,
4549 participants : Mutex :: new ( HashMap :: new ( ) ) ,
@@ -51,7 +55,15 @@ impl GlobalTxSession {
5155 }
5256
5357 pub fn wound ( & self ) -> bool {
54- !self . wounded . swap ( true , Ordering :: SeqCst )
58+ let was_fresh = !self . wounded . swap ( true , Ordering :: SeqCst ) ;
59+ if was_fresh {
60+ let _ = self . wounded_tx . send ( true ) ;
61+ }
62+ was_fresh
63+ }
64+
65+ pub fn subscribe_wounded ( & self ) -> watch:: Receiver < bool > {
66+ self . wounded_tx . subscribe ( )
5567 }
5668
5769 pub fn state ( & self ) -> GlobalTxState {
@@ -100,11 +112,39 @@ impl Default for LockState {
100112 }
101113}
102114
103- #[ derive( Debug , Clone , Copy , PartialEq , Eq ) ]
104- pub enum AcquireDisposition {
105- Acquired ,
106- Wound ( GlobalTxId ) ,
107- Wait ,
115+ pub enum AcquireDisposition < ' a > {
116+ Acquired ( GlobalTxLockGuard < ' a > ) ,
117+ Cancelled ,
118+ }
119+
120+ pub struct GlobalTxLockGuard < ' a > {
121+ manager : & ' a GlobalTxManager ,
122+ tx_id : Option < GlobalTxId > ,
123+ }
124+
125+ impl < ' a > GlobalTxLockGuard < ' a > {
126+ fn new ( manager : & ' a GlobalTxManager , tx_id : GlobalTxId ) -> Self {
127+ Self {
128+ manager,
129+ tx_id : Some ( tx_id) ,
130+ }
131+ }
132+
133+ pub fn tx_id ( & self ) -> GlobalTxId {
134+ self . tx_id . expect ( "lock guard must always have a tx_id before drop" )
135+ }
136+
137+ pub fn disarm ( mut self ) {
138+ self . tx_id = None ;
139+ }
140+ }
141+
142+ impl Drop for GlobalTxLockGuard < ' _ > {
143+ fn drop ( & mut self ) {
144+ if let Some ( tx_id) = self . tx_id . take ( ) {
145+ self . manager . release ( & tx_id) ;
146+ }
147+ }
108148}
109149
110150#[ derive( Default ) ]
@@ -174,6 +214,10 @@ impl GlobalTxManager {
174214 self . get_session ( tx_id) . map ( |s| s. is_wounded ( ) ) . unwrap_or ( false )
175215 }
176216
217+ pub fn subscribe_wounded ( & self , tx_id : & GlobalTxId ) -> Option < watch:: Receiver < bool > > {
218+ self . get_session ( tx_id) . map ( |s| s. subscribe_wounded ( ) )
219+ }
220+
177221 pub fn wound ( & self , tx_id : & GlobalTxId ) -> Option < Arc < GlobalTxSession > > {
178222 let session = self . get_session ( tx_id) ?;
179223 let _ = session. wound ( ) ;
@@ -183,30 +227,56 @@ impl GlobalTxManager {
183227 Some ( session)
184228 }
185229
186- pub async fn acquire ( & self , tx_id : GlobalTxId ) -> AcquireDisposition {
230+ pub async fn acquire < F , Fut > ( & self , tx_id : GlobalTxId , mut on_wound : F ) -> AcquireDisposition < ' _ >
231+ where
232+ F : FnMut ( GlobalTxId ) -> Fut ,
233+ Fut : Future < Output = ( ) > ,
234+ {
235+ let mut wounded_rx = match self . subscribe_wounded ( & tx_id) {
236+ Some ( rx) => rx,
237+ None => return AcquireDisposition :: Cancelled ,
238+ } ;
187239 loop {
188- let waiter = {
240+ // self.is_wounded(&tx_id)
241+ if * wounded_rx. borrow ( ) {
242+ self . remove_waiter ( & tx_id) ;
243+ return AcquireDisposition :: Cancelled ;
244+ }
245+
246+ let ( waiter, owner_to_wound) = {
189247 let mut state = self . lock_state . lock ( ) . unwrap ( ) ;
190248 match state. owner {
191249 None => {
192250 state. owner = Some ( tx_id) ;
193251 state. waiting . remove ( & tx_id) ;
194- return AcquireDisposition :: Acquired ;
252+ return AcquireDisposition :: Acquired ( GlobalTxLockGuard :: new ( self , tx_id ) ) ;
195253 }
196254 Some ( owner) if owner == tx_id => {
197255 state. waiting . remove ( & tx_id) ;
198- return AcquireDisposition :: Acquired ;
256+ return AcquireDisposition :: Acquired ( GlobalTxLockGuard :: new ( self , tx_id ) ) ;
199257 }
200258 Some ( owner) => {
201259 state. waiting . insert ( tx_id) ;
202- if tx_id < owner && state. wounded_owners . insert ( owner) {
203- return AcquireDisposition :: Wound ( owner) ;
204- }
205- self . lock_notify . notified ( )
260+ let owner_to_wound = ( tx_id < owner && state. wounded_owners . insert ( owner) ) . then_some ( owner) ;
261+ ( self . lock_notify . notified ( ) , owner_to_wound)
206262 }
207263 }
208264 } ;
209- waiter. await ;
265+
266+ if let Some ( owner) = owner_to_wound {
267+ let _ = self . wound ( & owner) ;
268+ on_wound ( owner) . await ;
269+ }
270+
271+ tokio:: select! {
272+ changed = wounded_rx. changed( ) , if !* wounded_rx. borrow( ) => {
273+ if changed. is_ok( ) && * wounded_rx. borrow( ) {
274+ self . remove_waiter( & tx_id) ;
275+ return AcquireDisposition :: Cancelled ;
276+ }
277+ }
278+ _ = waiter => { }
279+ }
210280 }
211281 }
212282
@@ -219,6 +289,10 @@ impl GlobalTxManager {
219289 }
220290 state. waiting . remove ( tx_id) ;
221291 }
292+
293+ fn remove_waiter ( & self , tx_id : & GlobalTxId ) {
294+ self . lock_state . lock ( ) . unwrap ( ) . waiting . remove ( tx_id) ;
295+ }
222296}
223297
224298#[ cfg( test) ]
@@ -227,8 +301,8 @@ mod tests {
227301 use crate :: identity:: Identity ;
228302 use spacetimedb_lib:: { GlobalTxId , Timestamp } ;
229303 use std:: sync:: Arc ;
230- use tokio:: runtime:: Runtime ;
231304 use std:: time:: Duration ;
305+ use tokio:: runtime:: Runtime ;
232306
233307 fn tx_id ( ts : i64 , db_byte : u8 , nonce : u32 ) -> GlobalTxId {
234308 GlobalTxId :: new (
@@ -240,54 +314,83 @@ mod tests {
240314
241315 #[ test]
242316 fn older_requester_wounds_younger_owner ( ) {
243- let manager = GlobalTxManager :: default ( ) ;
317+ let manager = Arc :: new ( GlobalTxManager :: default ( ) ) ;
244318 let younger = tx_id ( 20 , 2 , 0 ) ;
245319 let older = tx_id ( 10 , 1 , 0 ) ;
246320 manager. ensure_session (
247321 younger,
248322 super :: GlobalTxRole :: Participant ,
249323 younger. creator_db ,
250324 ) ;
325+ manager. ensure_session ( older, super :: GlobalTxRole :: Participant , older. creator_db ) ;
251326
252327 let rt = Runtime :: new ( ) . unwrap ( ) ;
253- assert_eq ! ( rt. block_on( manager. acquire( younger) ) , AcquireDisposition :: Acquired ) ;
254- assert_eq ! ( rt. block_on( manager. acquire( older) ) , AcquireDisposition :: Wound ( younger) ) ;
255- assert ! ( manager. wound( & younger) . is_some( ) ) ;
328+ let younger_guard = match rt. block_on ( manager. acquire ( younger, |_| async { } ) ) {
329+ AcquireDisposition :: Acquired ( guard) => guard,
330+ AcquireDisposition :: Cancelled => panic ! ( "younger tx should acquire immediately" ) ,
331+ } ;
332+
333+ let manager_for_task = manager. clone ( ) ;
334+ let older_task = rt. spawn ( async move {
335+ match manager_for_task. acquire ( older, |_| async { } ) . await {
336+ AcquireDisposition :: Acquired ( _guard) => true ,
337+ AcquireDisposition :: Cancelled => false ,
338+ }
339+ } ) ;
340+ std:: thread:: sleep ( Duration :: from_millis ( 10 ) ) ;
256341 assert ! ( manager. is_wounded( & younger) ) ;
342+ drop ( younger_guard) ;
343+ assert ! ( matches!(
344+ rt. block_on( older_task) . expect( "task should complete" ) ,
345+ true
346+ ) ) ;
257347 }
258348
259349 #[ test]
260350 fn younger_requester_waits_behind_older_owner ( ) {
261351 let manager = GlobalTxManager :: default ( ) ;
262352 let older = tx_id ( 10 , 1 , 0 ) ;
263353 let younger = tx_id ( 20 , 2 , 0 ) ;
354+ manager. ensure_session ( older, super :: GlobalTxRole :: Participant , older. creator_db ) ;
355+ manager. ensure_session ( younger, super :: GlobalTxRole :: Participant , younger. creator_db ) ;
264356 let rt = Runtime :: new ( ) . unwrap ( ) ;
265357
266- assert_eq ! ( rt. block_on( manager. acquire( older) ) , AcquireDisposition :: Acquired ) ;
358+ let older_guard = match rt. block_on ( manager. acquire ( older, |_| async { } ) ) {
359+ AcquireDisposition :: Acquired ( guard) => guard,
360+ AcquireDisposition :: Cancelled => panic ! ( "older tx should acquire immediately" ) ,
361+ } ;
267362 let wait = rt. block_on ( async {
268- tokio:: time:: timeout ( Duration :: from_millis ( 25 ) , manager. acquire ( younger) ) . await
363+ tokio:: time:: timeout ( Duration :: from_millis ( 25 ) , manager. acquire ( younger, |_| async { } ) ) . await
269364 } ) ;
270365 assert ! ( wait. is_err( ) ) ;
366+ drop ( older_guard) ;
271367 }
272368
273369 #[ test]
274370 fn waiter_acquires_after_release ( ) {
275371 let manager = Arc :: new ( GlobalTxManager :: default ( ) ) ;
276372 let owner = tx_id ( 10 , 1 , 0 ) ;
277373 let waiter = tx_id ( 20 , 2 , 0 ) ;
374+ manager. ensure_session ( owner, super :: GlobalTxRole :: Participant , owner. creator_db ) ;
375+ manager. ensure_session ( waiter, super :: GlobalTxRole :: Participant , waiter. creator_db ) ;
278376 let rt = Runtime :: new ( ) . unwrap ( ) ;
279377
280- assert_eq ! ( rt. block_on( manager. acquire( owner) ) , AcquireDisposition :: Acquired ) ;
378+ let owner_guard = match rt. block_on ( manager. acquire ( owner, |_| async { } ) ) {
379+ AcquireDisposition :: Acquired ( guard) => guard,
380+ AcquireDisposition :: Cancelled => panic ! ( "owner should acquire immediately" ) ,
381+ } ;
281382
282383 let manager_for_thread = manager. clone ( ) ;
283384 let handle = std:: thread:: spawn ( move || {
284385 let rt = Runtime :: new ( ) . unwrap ( ) ;
285- assert_eq ! ( rt. block_on( manager_for_thread. acquire( waiter) ) , AcquireDisposition :: Acquired ) ;
286- manager_for_thread. release ( & waiter) ;
386+ match rt. block_on ( manager_for_thread. acquire ( waiter, |_| async { } ) ) {
387+ AcquireDisposition :: Acquired ( _guard) => { }
388+ AcquireDisposition :: Cancelled => panic ! ( "waiter should acquire after release" ) ,
389+ }
287390 } ) ;
288391
289392 std:: thread:: sleep ( Duration :: from_millis ( 25 ) ) ;
290- manager . release ( & owner ) ;
393+ drop ( owner_guard ) ;
291394 handle. join ( ) . unwrap ( ) ;
292395 }
293396
@@ -303,4 +406,61 @@ mod tests {
303406 assert ! ( manager. wound( & tx_id) . is_some( ) ) ;
304407 assert ! ( session. is_wounded( ) ) ;
305408 }
409+
410+ #[ test]
411+ fn wound_subscription_notifies_waiter ( ) {
412+ let manager = GlobalTxManager :: default ( ) ;
413+ let tx_id = tx_id ( 10 , 1 , 0 ) ;
414+ let _session = manager. ensure_session ( tx_id, super :: GlobalTxRole :: Coordinator , tx_id. creator_db ) ;
415+ let mut wounded_rx = manager. subscribe_wounded ( & tx_id) . expect ( "session should exist" ) ;
416+
417+ let rt = Runtime :: new ( ) . unwrap ( ) ;
418+ rt. block_on ( async {
419+ let notifier = async {
420+ if !* wounded_rx. borrow ( ) {
421+ wounded_rx. changed ( ) . await . expect ( "sender should still exist" ) ;
422+ }
423+ * wounded_rx. borrow ( )
424+ } ;
425+
426+ let trigger = async {
427+ tokio:: time:: sleep ( Duration :: from_millis ( 10 ) ) . await ;
428+ manager. wound ( & tx_id) . expect ( "session should still exist" ) ;
429+ } ;
430+
431+ let ( wounded, ( ) ) = tokio:: join!( notifier, trigger) ;
432+ assert ! ( wounded) ;
433+ } ) ;
434+ }
435+
436+ #[ test]
437+ fn wounded_waiter_is_cancelled ( ) {
438+ let manager = Arc :: new ( GlobalTxManager :: default ( ) ) ;
439+ let owner = tx_id ( 10 , 1 , 0 ) ;
440+ let waiter = tx_id ( 20 , 2 , 0 ) ;
441+ manager. ensure_session ( owner, super :: GlobalTxRole :: Participant , owner. creator_db ) ;
442+ manager. ensure_session ( waiter, super :: GlobalTxRole :: Participant , waiter. creator_db ) ;
443+
444+ let rt = Runtime :: new ( ) . unwrap ( ) ;
445+ let owner_guard = match rt. block_on ( manager. acquire ( owner, |_| async { } ) ) {
446+ AcquireDisposition :: Acquired ( guard) => guard,
447+ AcquireDisposition :: Cancelled => panic ! ( "owner should acquire immediately" ) ,
448+ } ;
449+
450+ let manager_for_task = manager. clone ( ) ;
451+ let waiter_task = rt. spawn ( async move {
452+ matches ! (
453+ manager_for_task. acquire( waiter, |_| async { } ) . await ,
454+ AcquireDisposition :: Cancelled
455+ )
456+ } ) ;
457+ std:: thread:: sleep ( Duration :: from_millis ( 10 ) ) ;
458+ manager. wound ( & waiter) . expect ( "waiter session should exist" ) ;
459+ drop ( owner_guard) ;
460+
461+ assert ! ( matches!(
462+ rt. block_on( waiter_task) . expect( "task should complete" ) ,
463+ true
464+ ) ) ;
465+ }
306466}
0 commit comments