Skip to content

Commit 64a6d8f

Browse files
Async terminal connections with tokio
1 parent 7066517 commit 64a6d8f

4 files changed

Lines changed: 268 additions & 150 deletions

File tree

tools/tpcc-runner/src/client.rs

Lines changed: 109 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use std::time::Duration;
77
use crate::config::ConnectionConfig;
88
use crate::module_bindings::*;
99
use spacetimedb_sdk::{DbContext, Identity, Table as _};
10+
use tokio::sync::oneshot;
1011

1112
pub struct ModuleClient {
1213
conn: DbConnection,
@@ -60,6 +61,55 @@ impl ModuleClient {
6061
})
6162
}
6263

64+
pub async fn connect_async(config: &ConnectionConfig, database_identity: Identity) -> Result<Self> {
65+
let (ready_tx, ready_rx) = oneshot::channel();
66+
let ready_tx = Arc::new(Mutex::new(Some(ready_tx)));
67+
let success_tx = Arc::clone(&ready_tx);
68+
let error_tx = Arc::clone(&ready_tx);
69+
let disconnect_error = Arc::new(Mutex::new(None));
70+
let disconnect_error_callback = Arc::clone(&disconnect_error);
71+
let mut builder = DbConnection::builder()
72+
.with_uri(config.uri.clone())
73+
.with_database_name(database_identity.to_string())
74+
.with_confirmed_reads(config.confirmed_reads)
75+
.on_connect(move |_, _, _| {
76+
if let Some(tx) = success_tx.lock().expect("ready mutex poisoned").take() {
77+
let _ = tx.send(Ok::<(), anyhow::Error>(()));
78+
}
79+
})
80+
.on_connect_error(move |_, error| {
81+
if let Some(tx) = error_tx.lock().expect("ready mutex poisoned").take() {
82+
let _ = tx.send(Err(anyhow!("connection failed: {error}")));
83+
}
84+
})
85+
.on_disconnect(move |_, error| {
86+
let message = match error {
87+
Some(error) => format!("connection closed: {error}"),
88+
None => "connection closed".to_string(),
89+
};
90+
*disconnect_error_callback.lock().expect("disconnect mutex poisoned") = Some(message);
91+
});
92+
93+
if let Some(token) = &config.token {
94+
builder = builder.with_token(Some(token.clone()));
95+
}
96+
97+
let conn = builder.build().context("failed to build database connection")?;
98+
let thread = conn.run_threaded();
99+
tokio::time::timeout(Duration::from_secs(config.timeout_secs), ready_rx)
100+
.await
101+
.context("timed out waiting for connection")?
102+
.map_err(|_| anyhow!("connection readiness callback dropped"))??;
103+
104+
Ok(Self {
105+
conn,
106+
thread: Some(thread),
107+
timeout: Duration::from_secs(config.timeout_secs),
108+
disconnect_error,
109+
load_state_subscription: None,
110+
})
111+
}
112+
63113
pub fn subscribe_load_state(&mut self) -> Result<()> {
64114
if self.load_state_subscription.is_some() {
65115
return Ok(());
@@ -340,28 +390,27 @@ impl ModuleClient {
340390
}
341391
}
342392

343-
pub fn new_order(
393+
pub async fn new_order_async(
344394
&self,
345395
w_id: u32,
346396
d_id: u8,
347397
c_id: u32,
348398
order_lines: Vec<NewOrderLineInput>,
349399
) -> Result<Result<NewOrderResult, String>> {
350-
let (tx, rx) = sync_channel(1);
400+
let (tx, rx) = oneshot::channel();
351401
self.conn
352402
.reducers
353403
.new_order_then(w_id, d_id, c_id, order_lines, move |_, res| {
354404
log::debug!("Got response from `new_order`: {res:?}");
355405
let _ = tx.send(res);
356406
})?;
357-
match rx.recv_timeout(self.timeout) {
358-
Ok(Ok(value)) => Ok(value),
359-
Ok(Err(err)) => Err(anyhow!("new_order internal error: {}", err)),
360-
Err(_) => bail!("timed out waiting for new_order"),
407+
match self.await_callback("new_order", rx).await? {
408+
Ok(value) => Ok(value),
409+
Err(err) => Err(anyhow!("new_order internal error: {}", err)),
361410
}
362411
}
363412

364-
pub fn payment(
413+
pub async fn payment_async(
365414
&self,
366415
w_id: u32,
367416
d_id: u8,
@@ -370,7 +419,7 @@ impl ModuleClient {
370419
customer: CustomerSelector,
371420
payment_amount_cents: i64,
372421
) -> Result<Result<PaymentResult, String>> {
373-
let (tx, rx) = sync_channel(1);
422+
let (tx, rx) = oneshot::channel();
374423
self.conn.reducers.payment_then(
375424
w_id,
376425
d_id,
@@ -383,49 +432,51 @@ impl ModuleClient {
383432
let _ = tx.send(res);
384433
},
385434
)?;
386-
match rx.recv_timeout(self.timeout) {
387-
Ok(Ok(value)) => Ok(value),
388-
Ok(Err(err)) => Err(anyhow!("payment internal error: {}", err)),
389-
Err(_) => bail!("timed out waiting for payment"),
435+
match self.await_callback("payment", rx).await? {
436+
Ok(value) => Ok(value),
437+
Err(err) => Err(anyhow!("payment internal error: {}", err)),
390438
}
391439
}
392440

393-
pub fn order_status(
441+
pub async fn order_status_async(
394442
&self,
395443
w_id: u32,
396444
d_id: u8,
397445
customer: CustomerSelector,
398446
) -> Result<Result<OrderStatusResult, String>> {
399-
let (tx, rx) = sync_channel(1);
447+
let (tx, rx) = oneshot::channel();
400448
self.conn
401449
.reducers
402450
.order_status_then(w_id, d_id, customer, move |_, res| {
403451
log::debug!("Got response from `order_status`: {res:?}");
404452
let _ = tx.send(res);
405453
})?;
406-
match rx.recv_timeout(self.timeout) {
407-
Ok(Ok(value)) => Ok(value),
408-
Ok(Err(err)) => Err(anyhow!("order_status internal error: {}", err)),
409-
Err(_) => bail!("timed out waiting for order_status"),
454+
match self.await_callback("order_status", rx).await? {
455+
Ok(value) => Ok(value),
456+
Err(err) => Err(anyhow!("order_status internal error: {}", err)),
410457
}
411458
}
412459

413-
pub fn stock_level(&self, w_id: u32, d_id: u8, threshold: i32) -> Result<Result<StockLevelResult, String>> {
414-
let (tx, rx) = sync_channel(1);
460+
pub async fn stock_level_async(
461+
&self,
462+
w_id: u32,
463+
d_id: u8,
464+
threshold: i32,
465+
) -> Result<Result<StockLevelResult, String>> {
466+
let (tx, rx) = oneshot::channel();
415467
self.conn
416468
.reducers
417469
.stock_level_then(w_id, d_id, threshold, move |_, res| {
418470
log::debug!("Got response from `stock_level`: {res:?}");
419471
let _ = tx.send(res);
420472
})?;
421-
match rx.recv_timeout(self.timeout) {
422-
Ok(Ok(value)) => Ok(value),
423-
Ok(Err(err)) => Err(anyhow!("stock_level internal error: {}", err)),
424-
Err(_) => bail!("timed out waiting for stock_level"),
473+
match self.await_callback("stock_level", rx).await? {
474+
Ok(value) => Ok(value),
475+
Err(err) => Err(anyhow!("stock_level internal error: {}", err)),
425476
}
426477
}
427478

428-
pub fn queue_delivery(
479+
pub async fn queue_delivery_async(
429480
&self,
430481
run_id: String,
431482
driver_id: String,
@@ -434,7 +485,7 @@ impl ModuleClient {
434485
w_id: u32,
435486
carrier_id: u8,
436487
) -> Result<Result<DeliveryQueueAck, String>> {
437-
let (tx, rx) = sync_channel(1);
488+
let (tx, rx) = oneshot::channel();
438489
self.conn.reducers.queue_delivery_then(
439490
run_id,
440491
driver_id,
@@ -447,50 +498,62 @@ impl ModuleClient {
447498
let _ = tx.send(res);
448499
},
449500
)?;
450-
match rx.recv_timeout(self.timeout) {
451-
Ok(Ok(value)) => Ok(value),
452-
Ok(Err(err)) => Err(anyhow!("queue_delivery internal error: {}", err)),
453-
Err(_) => bail!("timed out waiting for queue_delivery"),
501+
match self.await_callback("queue_delivery", rx).await? {
502+
Ok(value) => Ok(value),
503+
Err(err) => Err(anyhow!("queue_delivery internal error: {}", err)),
454504
}
455505
}
456506

457-
pub fn delivery_progress(&self, run_id: String) -> Result<Result<DeliveryProgress, String>> {
458-
let (tx, rx) = sync_channel(1);
507+
pub async fn delivery_progress_async(&self, run_id: String) -> Result<Result<DeliveryProgress, String>> {
508+
let (tx, rx) = oneshot::channel();
459509
self.conn.reducers.delivery_progress_then(run_id, move |_, res| {
460510
log::debug!("Got response from `delivery_progress`: {res:?}");
461511
let _ = tx.send(res);
462512
})?;
463-
match rx.recv_timeout(self.timeout) {
464-
Ok(Ok(value)) => Ok(value),
465-
Ok(Err(err)) => Err(anyhow!("delivery_progress internal error: {}", err)),
466-
Err(_) => bail!("timed out waiting for delivery_progress"),
513+
match self.await_callback("delivery_progress", rx).await? {
514+
Ok(value) => Ok(value),
515+
Err(err) => Err(anyhow!("delivery_progress internal error: {}", err)),
467516
}
468517
}
469518

470-
pub fn fetch_delivery_completions(
519+
pub fn shutdown(mut self) {
520+
let _ = self.conn.disconnect();
521+
if let Some(thread) = self.thread.take() {
522+
let _ = thread.join();
523+
}
524+
}
525+
526+
pub async fn fetch_delivery_completions_async(
471527
&self,
472528
run_id: String,
473529
after_completion_id: u64,
474530
limit: u32,
475531
) -> Result<Result<Vec<DeliveryCompletionView>, String>> {
476-
let (tx, rx) = sync_channel(1);
532+
let (tx, rx) = oneshot::channel();
477533
self.conn
478534
.reducers
479535
.fetch_delivery_completions_then(run_id, after_completion_id, limit, move |_, res| {
480536
log::debug!("Got response from `fetch_delivery_completions`: {res:?}");
481537
let _ = tx.send(res);
482538
})?;
483-
match rx.recv_timeout(self.timeout) {
484-
Ok(Ok(value)) => Ok(value),
485-
Ok(Err(err)) => Err(anyhow!("fetch_delivery_completions internal error: {}", err)),
486-
Err(_) => bail!("timed out waiting for fetch_delivery_completions"),
539+
match self.await_callback("fetch_delivery_completions", rx).await? {
540+
Ok(value) => Ok(value),
541+
Err(err) => Err(anyhow!("fetch_delivery_completions internal error: {}", err)),
487542
}
488543
}
489544

490-
pub fn shutdown(mut self) {
491-
let _ = self.conn.disconnect();
492-
if let Some(thread) = self.thread.take() {
493-
let _ = thread.join();
545+
pub async fn shutdown_async(self) {
546+
let _ = tokio::task::spawn_blocking(move || self.shutdown()).await;
547+
}
548+
549+
async fn await_callback<T>(&self, operation: &str, rx: oneshot::Receiver<T>) -> Result<T> {
550+
match tokio::time::timeout(self.timeout, rx).await {
551+
Ok(Ok(value)) => Ok(value),
552+
Ok(Err(_)) => Err(anyhow!("{operation} callback dropped")),
553+
Err(_) => {
554+
self.ensure_connected()?;
555+
bail!("timed out waiting for {operation}")
556+
}
494557
}
495558
}
496559
}

0 commit comments

Comments
 (0)