Skip to content

Commit 7df4044

Browse files
committed
Share db connections among terminals
1 parent 8b89b81 commit 7df4044

1 file changed

Lines changed: 77 additions & 95 deletions

File tree

tools/tpcc-runner/src/driver.rs

Lines changed: 77 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use anyhow::{anyhow, bail, Context, Result};
22
use rand::{rngs::StdRng, Rng, SeedableRng};
3+
use std::collections::BTreeMap;
34
use std::fs;
45
use std::path::PathBuf;
56
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
@@ -10,6 +11,7 @@ use tokio::task::JoinSet;
1011
use crate::client::{expect_ok, ModuleClient};
1112
use crate::config::{default_run_id, DriverConfig};
1213
use crate::metrics_module_bindings::register_completed_order;
14+
use crate::metrics_module_bindings::DbConnection as MetricsDbConnection;
1315
use crate::metrics_module_client::connect_metrics_module_async;
1416
use crate::module_bindings::*;
1517
use crate::protocol::{
@@ -23,10 +25,11 @@ use crate::tpcc::*;
2325

2426
struct TerminalRuntime {
2527
config: DriverConfig,
28+
client: Arc<ModuleClient>,
2629
metrics: SharedMetrics,
30+
metrics_client: Arc<MetricsDbConnection>,
2731
abort: Arc<AtomicBool>,
2832
start_logged: Arc<AtomicBool>,
29-
startup_progress: Arc<StartupProgress>,
3033
request_ids: Arc<AtomicU64>,
3134
schedule: RunSchedule,
3235
run_constants: RunConstants,
@@ -35,24 +38,6 @@ struct TerminalRuntime {
3538
seed: u64,
3639
}
3740

38-
struct StartupProgress {
39-
primary_connect_started: AtomicU64,
40-
primary_connect_ready: AtomicU64,
41-
metrics_connect_started: AtomicU64,
42-
metrics_connect_ready: AtomicU64,
43-
}
44-
45-
impl StartupProgress {
46-
fn new() -> Self {
47-
Self {
48-
primary_connect_started: AtomicU64::new(0),
49-
primary_connect_ready: AtomicU64::new(0),
50-
metrics_connect_started: AtomicU64::new(0),
51-
metrics_connect_ready: AtomicU64::new(0),
52-
}
53-
}
54-
}
55-
5641
struct TransactionContext<'a> {
5742
client: &'a ModuleClient,
5843
config: &'a DriverConfig,
@@ -83,10 +68,10 @@ pub async fn run(config: DriverConfig) -> Result<()> {
8368

8469
let abort = Arc::new(AtomicBool::new(false));
8570
let start_logged = Arc::new(AtomicBool::new(false));
86-
let startup_progress = Arc::new(StartupProgress::new());
8771
let request_ids = Arc::new(AtomicU64::new(1));
8872
let mut tasks = JoinSet::new();
89-
let terminal_count = u64::from(config.terminals());
73+
let metrics_client = Arc::new(connect_metrics_module_async(&config.connection).await?);
74+
let shared_database_clients = connect_shared_database_clients(&config, &topology, &used_database_numbers).await?;
9075

9176
log::info!(
9277
"driver {} ready for run {}: warehouses {}..={} terminals={} warmup_start_ms={} measure_start_ms={} measure_end_ms={}",
@@ -100,57 +85,21 @@ pub async fn run(config: DriverConfig) -> Result<()> {
10085
schedule.measure_end_ms
10186
);
10287
log::info!(
103-
"driver {} launching {} terminal task(s); this run will attempt {} primary database connections and {} metrics connections",
88+
"driver {} shared metrics connection ready; launching {} terminal task(s) across {} shared database connection(s)",
10489
config.driver_id,
105-
terminal_count,
106-
terminal_count,
107-
terminal_count
90+
config.terminals(),
91+
shared_database_clients.len()
10892
);
10993

110-
{
111-
let reporter_driver_id = config.driver_id.clone();
112-
let reporter_run_id = run_id.clone();
113-
let reporter_start_logged = start_logged.clone();
114-
let reporter_abort = abort.clone();
115-
let reporter_progress = startup_progress.clone();
116-
tokio::spawn(async move {
117-
loop {
118-
if reporter_abort.load(Ordering::Relaxed) || reporter_start_logged.load(Ordering::Relaxed) {
119-
break;
120-
}
121-
122-
tokio::time::sleep(Duration::from_secs(1)).await;
123-
if reporter_abort.load(Ordering::Relaxed) || reporter_start_logged.load(Ordering::Relaxed) {
124-
break;
125-
}
126-
127-
let primary_started = reporter_progress.primary_connect_started.load(Ordering::Relaxed);
128-
let primary_ready = reporter_progress.primary_connect_ready.load(Ordering::Relaxed);
129-
let metrics_started = reporter_progress.metrics_connect_started.load(Ordering::Relaxed);
130-
let metrics_ready = reporter_progress.metrics_connect_ready.load(Ordering::Relaxed);
131-
log::info!(
132-
"driver {} startup progress for run {}: primary_connect_started={}/{} primary_connected={}/{} metrics_connect_started={}/{} metrics_connected={}/{}",
133-
reporter_driver_id,
134-
reporter_run_id,
135-
primary_started,
136-
terminal_count,
137-
primary_ready,
138-
terminal_count,
139-
metrics_started,
140-
terminal_count,
141-
metrics_ready,
142-
terminal_count
143-
);
144-
145-
if metrics_ready >= terminal_count {
146-
break;
147-
}
148-
}
149-
});
150-
}
151-
15294
for warehouse_id in config.warehouse_start..=config.warehouse_end() {
95+
let database_number = topology.database_number_for_warehouse(warehouse_id)?;
15396
let database_identity = topology.identity_for_warehouse(warehouse_id)?;
97+
let client = shared_database_clients.get(&database_number).cloned().ok_or_else(|| {
98+
anyhow!(
99+
"missing shared database client for {}",
100+
topology.database_name(database_number)
101+
)
102+
})?;
154103
for district_id in 1..=DISTRICTS_PER_WAREHOUSE {
155104
let assignment = TerminalAssignment {
156105
terminal_id: terminal_id(warehouse_id, district_id),
@@ -167,10 +116,11 @@ pub async fn run(config: DriverConfig) -> Result<()> {
167116
let terminal_request_ids = request_ids.clone();
168117
let runtime = TerminalRuntime {
169118
config: terminal_config,
119+
client: client.clone(),
170120
metrics: terminal_metrics,
121+
metrics_client: metrics_client.clone(),
171122
abort: terminal_abort,
172123
start_logged: terminal_start_logged,
173-
startup_progress: startup_progress.clone(),
174124
request_ids: terminal_request_ids,
175125
schedule: terminal_schedule,
176126
run_constants: terminal_constants,
@@ -214,10 +164,12 @@ pub async fn run(config: DriverConfig) -> Result<()> {
214164
}
215165
}
216166
if let Some(err) = first_error {
167+
shutdown_shared_database_clients(shared_database_clients).await;
217168
return Err(err);
218169
}
219170

220-
harvest_delivery_completions(&config, &schedule, &metrics, &topology, &used_database_numbers).await?;
171+
harvest_delivery_completions(&config, &schedule, &metrics, &shared_database_clients).await?;
172+
shutdown_shared_database_clients(shared_database_clients).await;
221173

222174
let summary = metrics.finalize(DriverSummaryMeta {
223175
run_id: run_id.clone(),
@@ -245,23 +197,18 @@ pub async fn run(config: DriverConfig) -> Result<()> {
245197
async fn run_terminal(runtime: TerminalRuntime) -> Result<()> {
246198
let TerminalRuntime {
247199
config,
200+
client,
248201
metrics,
202+
metrics_client,
249203
abort,
250204
start_logged,
251-
startup_progress,
252205
request_ids,
253206
schedule,
254207
run_constants,
255208
assignment,
256209
database_identity,
257210
seed,
258211
} = runtime;
259-
startup_progress.primary_connect_started.fetch_add(1, Ordering::Relaxed);
260-
let client = ModuleClient::connect_async(&config.connection, database_identity).await?;
261-
startup_progress.primary_connect_ready.fetch_add(1, Ordering::Relaxed);
262-
startup_progress.metrics_connect_started.fetch_add(1, Ordering::Relaxed);
263-
let metrics_client = connect_metrics_module_async(&config.connection).await?;
264-
startup_progress.metrics_connect_ready.fetch_add(1, Ordering::Relaxed);
265212
log::info!(
266213
"driver {} terminal {} connected to {} for warehouse {} district {}",
267214
config.driver_id,
@@ -289,7 +236,7 @@ async fn run_terminal(runtime: TerminalRuntime) -> Result<()> {
289236
let kind = choose_transaction(&mut rng);
290237
let started_ms = crate::summary::now_millis();
291238
let context = TransactionContext {
292-
client: &client,
239+
client: client.as_ref(),
293240
config: &config,
294241
run_id: &schedule.run_id,
295242
driver_id: &config.driver_id,
@@ -313,7 +260,6 @@ async fn run_terminal(runtime: TerminalRuntime) -> Result<()> {
313260
}
314261
Err(err) => {
315262
abort.store(true, Ordering::Relaxed);
316-
client.shutdown_async().await;
317263
return Err(err);
318264
}
319265
}
@@ -323,8 +269,6 @@ async fn run_terminal(runtime: TerminalRuntime) -> Result<()> {
323269
tokio::time::sleep(delay).await;
324270
}
325271
}
326-
327-
client.shutdown_async().await;
328272
Ok(())
329273
}
330274

@@ -718,26 +662,13 @@ async fn harvest_delivery_completions(
718662
config: &DriverConfig,
719663
schedule: &RunSchedule,
720664
metrics: &SharedMetrics,
721-
topology: &DatabaseTopology,
722-
used_database_numbers: &[u32],
665+
shared_database_clients: &BTreeMap<u32, Arc<ModuleClient>>,
723666
) -> Result<()> {
724667
let expected = metrics.delivery_queued();
725668
if expected == 0 {
726669
return Ok(());
727670
}
728-
let mut harvest_clients = Vec::with_capacity(used_database_numbers.len());
729-
for database_number in used_database_numbers {
730-
let database_identity = topology.identity_for_database_number(*database_number)?;
731-
let client = ModuleClient::connect_async(&config.connection, database_identity)
732-
.await
733-
.with_context(|| {
734-
format!(
735-
"failed to connect delivery harvester to {}",
736-
topology.database_name(*database_number)
737-
)
738-
})?;
739-
harvest_clients.push((*database_number, client));
740-
}
671+
let harvest_clients: Vec<_> = shared_database_clients.iter().collect();
741672

742673
let mut pending_jobs = 0u64;
743674
let mut completed_jobs = 0u64;
@@ -849,3 +780,54 @@ async fn sleep_until_ms_async(target_ms: u64) {
849780
tokio::time::sleep(Duration::from_millis(target_ms - now_ms)).await;
850781
}
851782
}
783+
784+
async fn connect_shared_database_clients(
785+
config: &DriverConfig,
786+
topology: &DatabaseTopology,
787+
used_database_numbers: &[u32],
788+
) -> Result<BTreeMap<u32, Arc<ModuleClient>>> {
789+
let mut connect_tasks = JoinSet::new();
790+
for database_number in used_database_numbers {
791+
let database_number = *database_number;
792+
let database_identity = topology.identity_for_database_number(database_number)?;
793+
let database_name = topology.database_name(database_number);
794+
let connection = config.connection.clone();
795+
connect_tasks.spawn(async move {
796+
let client = ModuleClient::connect_async(&connection, database_identity)
797+
.await
798+
.with_context(|| format!("failed to connect shared client to {database_name}"))?;
799+
Ok::<_, anyhow::Error>((database_number, database_name, Arc::new(client)))
800+
});
801+
}
802+
803+
let mut shared_clients = BTreeMap::new();
804+
while let Some(result) = connect_tasks.join_next().await {
805+
match result {
806+
Ok(Ok((database_number, database_name, client))) => {
807+
log::info!(
808+
"driver {} shared database client connected to {}",
809+
config.driver_id,
810+
database_name
811+
);
812+
shared_clients.insert(database_number, client);
813+
}
814+
Ok(Err(err)) => {
815+
connect_tasks.abort_all();
816+
return Err(err);
817+
}
818+
Err(err) => {
819+
connect_tasks.abort_all();
820+
return Err(anyhow!("shared database connection task failed: {}", err));
821+
}
822+
}
823+
}
824+
Ok(shared_clients)
825+
}
826+
827+
async fn shutdown_shared_database_clients(shared_database_clients: BTreeMap<u32, Arc<ModuleClient>>) {
828+
for (_, client) in shared_database_clients {
829+
if let Some(client) = Arc::into_inner(client) {
830+
client.shutdown_async().await;
831+
}
832+
}
833+
}

0 commit comments

Comments
 (0)