Skip to content

Commit b9bc153

Browse files
committed
WIP fixing lock issues
1 parent e8e861a commit b9bc153

3 files changed

Lines changed: 210 additions & 49 deletions

File tree

crates/core/src/host/global_tx.rs

Lines changed: 187 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
use crate::identity::Identity;
22
use spacetimedb_lib::GlobalTxId;
33
use std::collections::{HashMap, HashSet};
4+
use std::future::Future;
45
use std::sync::atomic::{AtomicBool, Ordering};
56
use std::sync::{Arc, Mutex};
6-
use tokio::sync::Notify;
7+
use tokio::sync::{watch, Notify};
78

89
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
910
pub enum GlobalTxRole {
@@ -28,18 +29,21 @@ pub struct GlobalTxSession {
2829
pub role: GlobalTxRole,
2930
pub coordinator_identity: Identity,
3031
wounded: AtomicBool,
32+
wounded_tx: watch::Sender<bool>,
3133
state: Mutex<GlobalTxState>,
3234
prepare_id: Mutex<Option<String>>,
3335
participants: Mutex<HashMap<Identity, String>>,
3436
}
3537

3638
impl GlobalTxSession {
3739
fn new(tx_id: GlobalTxId, role: GlobalTxRole, coordinator_identity: Identity) -> Self {
40+
let (wounded_tx, _) = watch::channel(false);
3841
Self {
3942
tx_id,
4043
role,
4144
coordinator_identity,
4245
wounded: AtomicBool::new(false),
46+
wounded_tx,
4347
state: Mutex::new(GlobalTxState::Running),
4448
prepare_id: Mutex::new(None),
4549
participants: Mutex::new(HashMap::new()),
@@ -51,7 +55,15 @@ impl GlobalTxSession {
5155
}
5256

5357
pub fn wound(&self) -> bool {
54-
!self.wounded.swap(true, Ordering::SeqCst)
58+
let was_fresh = !self.wounded.swap(true, Ordering::SeqCst);
59+
if was_fresh {
60+
let _ = self.wounded_tx.send(true);
61+
}
62+
was_fresh
63+
}
64+
65+
pub fn subscribe_wounded(&self) -> watch::Receiver<bool> {
66+
self.wounded_tx.subscribe()
5567
}
5668

5769
pub fn state(&self) -> GlobalTxState {
@@ -100,11 +112,39 @@ impl Default for LockState {
100112
}
101113
}
102114

103-
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
104-
pub enum AcquireDisposition {
105-
Acquired,
106-
Wound(GlobalTxId),
107-
Wait,
115+
pub enum AcquireDisposition<'a> {
116+
Acquired(GlobalTxLockGuard<'a>),
117+
Cancelled,
118+
}
119+
120+
pub struct GlobalTxLockGuard<'a> {
121+
manager: &'a GlobalTxManager,
122+
tx_id: Option<GlobalTxId>,
123+
}
124+
125+
impl<'a> GlobalTxLockGuard<'a> {
126+
fn new(manager: &'a GlobalTxManager, tx_id: GlobalTxId) -> Self {
127+
Self {
128+
manager,
129+
tx_id: Some(tx_id),
130+
}
131+
}
132+
133+
pub fn tx_id(&self) -> GlobalTxId {
134+
self.tx_id.expect("lock guard must always have a tx_id before drop")
135+
}
136+
137+
pub fn disarm(mut self) {
138+
self.tx_id = None;
139+
}
140+
}
141+
142+
impl Drop for GlobalTxLockGuard<'_> {
143+
fn drop(&mut self) {
144+
if let Some(tx_id) = self.tx_id.take() {
145+
self.manager.release(&tx_id);
146+
}
147+
}
108148
}
109149

110150
#[derive(Default)]
@@ -174,6 +214,10 @@ impl GlobalTxManager {
174214
self.get_session(tx_id).map(|s| s.is_wounded()).unwrap_or(false)
175215
}
176216

217+
pub fn subscribe_wounded(&self, tx_id: &GlobalTxId) -> Option<watch::Receiver<bool>> {
218+
self.get_session(tx_id).map(|s| s.subscribe_wounded())
219+
}
220+
177221
pub fn wound(&self, tx_id: &GlobalTxId) -> Option<Arc<GlobalTxSession>> {
178222
let session = self.get_session(tx_id)?;
179223
let _ = session.wound();
@@ -183,30 +227,56 @@ impl GlobalTxManager {
183227
Some(session)
184228
}
185229

186-
pub async fn acquire(&self, tx_id: GlobalTxId) -> AcquireDisposition {
230+
pub async fn acquire<F, Fut>(&self, tx_id: GlobalTxId, mut on_wound: F) -> AcquireDisposition<'_>
231+
where
232+
F: FnMut(GlobalTxId) -> Fut,
233+
Fut: Future<Output = ()>,
234+
{
235+
let mut wounded_rx = match self.subscribe_wounded(&tx_id) {
236+
Some(rx) => rx,
237+
None => return AcquireDisposition::Cancelled,
238+
};
187239
loop {
188-
let waiter = {
240+
// self.is_wounded(&tx_id)
241+
if *wounded_rx.borrow() {
242+
self.remove_waiter(&tx_id);
243+
return AcquireDisposition::Cancelled;
244+
}
245+
246+
let (waiter, owner_to_wound) = {
189247
let mut state = self.lock_state.lock().unwrap();
190248
match state.owner {
191249
None => {
192250
state.owner = Some(tx_id);
193251
state.waiting.remove(&tx_id);
194-
return AcquireDisposition::Acquired;
252+
return AcquireDisposition::Acquired(GlobalTxLockGuard::new(self, tx_id));
195253
}
196254
Some(owner) if owner == tx_id => {
197255
state.waiting.remove(&tx_id);
198-
return AcquireDisposition::Acquired;
256+
return AcquireDisposition::Acquired(GlobalTxLockGuard::new(self, tx_id));
199257
}
200258
Some(owner) => {
201259
state.waiting.insert(tx_id);
202-
if tx_id < owner && state.wounded_owners.insert(owner) {
203-
return AcquireDisposition::Wound(owner);
204-
}
205-
self.lock_notify.notified()
260+
let owner_to_wound = (tx_id < owner && state.wounded_owners.insert(owner)).then_some(owner);
261+
(self.lock_notify.notified(), owner_to_wound)
206262
}
207263
}
208264
};
209-
waiter.await;
265+
266+
if let Some(owner) = owner_to_wound {
267+
let _ = self.wound(&owner);
268+
on_wound(owner).await;
269+
}
270+
271+
tokio::select! {
272+
changed = wounded_rx.changed(), if !*wounded_rx.borrow() => {
273+
if changed.is_ok() && *wounded_rx.borrow() {
274+
self.remove_waiter(&tx_id);
275+
return AcquireDisposition::Cancelled;
276+
}
277+
}
278+
_ = waiter => {}
279+
}
210280
}
211281
}
212282

@@ -219,6 +289,10 @@ impl GlobalTxManager {
219289
}
220290
state.waiting.remove(tx_id);
221291
}
292+
293+
fn remove_waiter(&self, tx_id: &GlobalTxId) {
294+
self.lock_state.lock().unwrap().waiting.remove(tx_id);
295+
}
222296
}
223297

224298
#[cfg(test)]
@@ -227,8 +301,8 @@ mod tests {
227301
use crate::identity::Identity;
228302
use spacetimedb_lib::{GlobalTxId, Timestamp};
229303
use std::sync::Arc;
230-
use tokio::runtime::Runtime;
231304
use std::time::Duration;
305+
use tokio::runtime::Runtime;
232306

233307
fn tx_id(ts: i64, db_byte: u8, nonce: u32) -> GlobalTxId {
234308
GlobalTxId::new(
@@ -240,54 +314,83 @@ mod tests {
240314

241315
#[test]
242316
fn older_requester_wounds_younger_owner() {
243-
let manager = GlobalTxManager::default();
317+
let manager = Arc::new(GlobalTxManager::default());
244318
let younger = tx_id(20, 2, 0);
245319
let older = tx_id(10, 1, 0);
246320
manager.ensure_session(
247321
younger,
248322
super::GlobalTxRole::Participant,
249323
younger.creator_db,
250324
);
325+
manager.ensure_session(older, super::GlobalTxRole::Participant, older.creator_db);
251326

252327
let rt = Runtime::new().unwrap();
253-
assert_eq!(rt.block_on(manager.acquire(younger)), AcquireDisposition::Acquired);
254-
assert_eq!(rt.block_on(manager.acquire(older)), AcquireDisposition::Wound(younger));
255-
assert!(manager.wound(&younger).is_some());
328+
let younger_guard = match rt.block_on(manager.acquire(younger, |_| async {})) {
329+
AcquireDisposition::Acquired(guard) => guard,
330+
AcquireDisposition::Cancelled => panic!("younger tx should acquire immediately"),
331+
};
332+
333+
let manager_for_task = manager.clone();
334+
let older_task = rt.spawn(async move {
335+
match manager_for_task.acquire(older, |_| async {}).await {
336+
AcquireDisposition::Acquired(_guard) => true,
337+
AcquireDisposition::Cancelled => false,
338+
}
339+
});
340+
std::thread::sleep(Duration::from_millis(10));
256341
assert!(manager.is_wounded(&younger));
342+
drop(younger_guard);
343+
assert!(matches!(
344+
rt.block_on(older_task).expect("task should complete"),
345+
true
346+
));
257347
}
258348

259349
#[test]
260350
fn younger_requester_waits_behind_older_owner() {
261351
let manager = GlobalTxManager::default();
262352
let older = tx_id(10, 1, 0);
263353
let younger = tx_id(20, 2, 0);
354+
manager.ensure_session(older, super::GlobalTxRole::Participant, older.creator_db);
355+
manager.ensure_session(younger, super::GlobalTxRole::Participant, younger.creator_db);
264356
let rt = Runtime::new().unwrap();
265357

266-
assert_eq!(rt.block_on(manager.acquire(older)), AcquireDisposition::Acquired);
358+
let older_guard = match rt.block_on(manager.acquire(older, |_| async {})) {
359+
AcquireDisposition::Acquired(guard) => guard,
360+
AcquireDisposition::Cancelled => panic!("older tx should acquire immediately"),
361+
};
267362
let wait = rt.block_on(async {
268-
tokio::time::timeout(Duration::from_millis(25), manager.acquire(younger)).await
363+
tokio::time::timeout(Duration::from_millis(25), manager.acquire(younger, |_| async {})).await
269364
});
270365
assert!(wait.is_err());
366+
drop(older_guard);
271367
}
272368

273369
#[test]
274370
fn waiter_acquires_after_release() {
275371
let manager = Arc::new(GlobalTxManager::default());
276372
let owner = tx_id(10, 1, 0);
277373
let waiter = tx_id(20, 2, 0);
374+
manager.ensure_session(owner, super::GlobalTxRole::Participant, owner.creator_db);
375+
manager.ensure_session(waiter, super::GlobalTxRole::Participant, waiter.creator_db);
278376
let rt = Runtime::new().unwrap();
279377

280-
assert_eq!(rt.block_on(manager.acquire(owner)), AcquireDisposition::Acquired);
378+
let owner_guard = match rt.block_on(manager.acquire(owner, |_| async {})) {
379+
AcquireDisposition::Acquired(guard) => guard,
380+
AcquireDisposition::Cancelled => panic!("owner should acquire immediately"),
381+
};
281382

282383
let manager_for_thread = manager.clone();
283384
let handle = std::thread::spawn(move || {
284385
let rt = Runtime::new().unwrap();
285-
assert_eq!(rt.block_on(manager_for_thread.acquire(waiter)), AcquireDisposition::Acquired);
286-
manager_for_thread.release(&waiter);
386+
match rt.block_on(manager_for_thread.acquire(waiter, |_| async {})) {
387+
AcquireDisposition::Acquired(_guard) => {}
388+
AcquireDisposition::Cancelled => panic!("waiter should acquire after release"),
389+
}
287390
});
288391

289392
std::thread::sleep(Duration::from_millis(25));
290-
manager.release(&owner);
393+
drop(owner_guard);
291394
handle.join().unwrap();
292395
}
293396

@@ -303,4 +406,61 @@ mod tests {
303406
assert!(manager.wound(&tx_id).is_some());
304407
assert!(session.is_wounded());
305408
}
409+
410+
#[test]
411+
fn wound_subscription_notifies_waiter() {
412+
let manager = GlobalTxManager::default();
413+
let tx_id = tx_id(10, 1, 0);
414+
let _session = manager.ensure_session(tx_id, super::GlobalTxRole::Coordinator, tx_id.creator_db);
415+
let mut wounded_rx = manager.subscribe_wounded(&tx_id).expect("session should exist");
416+
417+
let rt = Runtime::new().unwrap();
418+
rt.block_on(async {
419+
let notifier = async {
420+
if !*wounded_rx.borrow() {
421+
wounded_rx.changed().await.expect("sender should still exist");
422+
}
423+
*wounded_rx.borrow()
424+
};
425+
426+
let trigger = async {
427+
tokio::time::sleep(Duration::from_millis(10)).await;
428+
manager.wound(&tx_id).expect("session should still exist");
429+
};
430+
431+
let (wounded, ()) = tokio::join!(notifier, trigger);
432+
assert!(wounded);
433+
});
434+
}
435+
436+
#[test]
437+
fn wounded_waiter_is_cancelled() {
438+
let manager = Arc::new(GlobalTxManager::default());
439+
let owner = tx_id(10, 1, 0);
440+
let waiter = tx_id(20, 2, 0);
441+
manager.ensure_session(owner, super::GlobalTxRole::Participant, owner.creator_db);
442+
manager.ensure_session(waiter, super::GlobalTxRole::Participant, waiter.creator_db);
443+
444+
let rt = Runtime::new().unwrap();
445+
let owner_guard = match rt.block_on(manager.acquire(owner, |_| async {})) {
446+
AcquireDisposition::Acquired(guard) => guard,
447+
AcquireDisposition::Cancelled => panic!("owner should acquire immediately"),
448+
};
449+
450+
let manager_for_task = manager.clone();
451+
let waiter_task = rt.spawn(async move {
452+
matches!(
453+
manager_for_task.acquire(waiter, |_| async {}).await,
454+
AcquireDisposition::Cancelled
455+
)
456+
});
457+
std::thread::sleep(Duration::from_millis(10));
458+
manager.wound(&waiter).expect("waiter session should exist");
459+
drop(owner_guard);
460+
461+
assert!(matches!(
462+
rt.block_on(waiter_task).expect("task should complete"),
463+
true
464+
));
465+
}
306466
}

crates/core/src/host/instance_env.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1159,6 +1159,7 @@ impl InstanceEnv {
11591159
if let Some(tx_id) = tx_id {
11601160
req = req.header(TX_ID_HEADER, tx_id.to_string());
11611161
}
1162+
// TODO: This needs to select on subscribe_wounded as well, so we can stop waiting for the response when wounded.
11621163
let result = async {
11631164
let response = req.send().await.map_err(|e| NodesError::HttpError(e.to_string()))?;
11641165
let status = response.status().as_u16();

0 commit comments

Comments
 (0)