11use anyhow:: { anyhow, bail, Context , Result } ;
22use rand:: { rngs:: StdRng , Rng , SeedableRng } ;
3+ use std:: collections:: BTreeMap ;
34use std:: fs;
45use std:: path:: PathBuf ;
56use std:: sync:: atomic:: { AtomicBool , AtomicU64 , Ordering } ;
@@ -10,6 +11,7 @@ use tokio::task::JoinSet;
1011use crate :: client:: { expect_ok, ModuleClient } ;
1112use crate :: config:: { default_run_id, DriverConfig } ;
1213use crate :: metrics_module_bindings:: register_completed_order;
14+ use crate :: metrics_module_bindings:: DbConnection as MetricsDbConnection ;
1315use crate :: metrics_module_client:: connect_metrics_module_async;
1416use crate :: module_bindings:: * ;
1517use crate :: protocol:: {
@@ -23,10 +25,11 @@ use crate::tpcc::*;
2325
2426struct TerminalRuntime {
2527 config : DriverConfig ,
28+ client : Arc < ModuleClient > ,
2629 metrics : SharedMetrics ,
30+ metrics_client : Arc < MetricsDbConnection > ,
2731 abort : Arc < AtomicBool > ,
2832 start_logged : Arc < AtomicBool > ,
29- startup_progress : Arc < StartupProgress > ,
3033 request_ids : Arc < AtomicU64 > ,
3134 schedule : RunSchedule ,
3235 run_constants : RunConstants ,
@@ -35,24 +38,6 @@ struct TerminalRuntime {
3538 seed : u64 ,
3639}
3740
38- struct StartupProgress {
39- primary_connect_started : AtomicU64 ,
40- primary_connect_ready : AtomicU64 ,
41- metrics_connect_started : AtomicU64 ,
42- metrics_connect_ready : AtomicU64 ,
43- }
44-
45- impl StartupProgress {
46- fn new ( ) -> Self {
47- Self {
48- primary_connect_started : AtomicU64 :: new ( 0 ) ,
49- primary_connect_ready : AtomicU64 :: new ( 0 ) ,
50- metrics_connect_started : AtomicU64 :: new ( 0 ) ,
51- metrics_connect_ready : AtomicU64 :: new ( 0 ) ,
52- }
53- }
54- }
55-
5641struct TransactionContext < ' a > {
5742 client : & ' a ModuleClient ,
5843 config : & ' a DriverConfig ,
@@ -83,10 +68,10 @@ pub async fn run(config: DriverConfig) -> Result<()> {
8368
8469 let abort = Arc :: new ( AtomicBool :: new ( false ) ) ;
8570 let start_logged = Arc :: new ( AtomicBool :: new ( false ) ) ;
86- let startup_progress = Arc :: new ( StartupProgress :: new ( ) ) ;
8771 let request_ids = Arc :: new ( AtomicU64 :: new ( 1 ) ) ;
8872 let mut tasks = JoinSet :: new ( ) ;
89- let terminal_count = u64:: from ( config. terminals ( ) ) ;
73+ let metrics_client = Arc :: new ( connect_metrics_module_async ( & config. connection ) . await ?) ;
74+ let shared_database_clients = connect_shared_database_clients ( & config, & topology, & used_database_numbers) . await ?;
9075
9176 log:: info!(
9277 "driver {} ready for run {}: warehouses {}..={} terminals={} warmup_start_ms={} measure_start_ms={} measure_end_ms={}" ,
@@ -100,57 +85,21 @@ pub async fn run(config: DriverConfig) -> Result<()> {
10085 schedule. measure_end_ms
10186 ) ;
10287 log:: info!(
103- "driver {} launching {} terminal task(s); this run will attempt {} primary database connections and {} metrics connections " ,
88+ "driver {} shared metrics connection ready; launching {} terminal task(s) across {} shared database connection(s) " ,
10489 config. driver_id,
105- terminal_count,
106- terminal_count,
107- terminal_count
90+ config. terminals( ) ,
91+ shared_database_clients. len( )
10892 ) ;
10993
110- {
111- let reporter_driver_id = config. driver_id . clone ( ) ;
112- let reporter_run_id = run_id. clone ( ) ;
113- let reporter_start_logged = start_logged. clone ( ) ;
114- let reporter_abort = abort. clone ( ) ;
115- let reporter_progress = startup_progress. clone ( ) ;
116- tokio:: spawn ( async move {
117- loop {
118- if reporter_abort. load ( Ordering :: Relaxed ) || reporter_start_logged. load ( Ordering :: Relaxed ) {
119- break ;
120- }
121-
122- tokio:: time:: sleep ( Duration :: from_secs ( 1 ) ) . await ;
123- if reporter_abort. load ( Ordering :: Relaxed ) || reporter_start_logged. load ( Ordering :: Relaxed ) {
124- break ;
125- }
126-
127- let primary_started = reporter_progress. primary_connect_started . load ( Ordering :: Relaxed ) ;
128- let primary_ready = reporter_progress. primary_connect_ready . load ( Ordering :: Relaxed ) ;
129- let metrics_started = reporter_progress. metrics_connect_started . load ( Ordering :: Relaxed ) ;
130- let metrics_ready = reporter_progress. metrics_connect_ready . load ( Ordering :: Relaxed ) ;
131- log:: info!(
132- "driver {} startup progress for run {}: primary_connect_started={}/{} primary_connected={}/{} metrics_connect_started={}/{} metrics_connected={}/{}" ,
133- reporter_driver_id,
134- reporter_run_id,
135- primary_started,
136- terminal_count,
137- primary_ready,
138- terminal_count,
139- metrics_started,
140- terminal_count,
141- metrics_ready,
142- terminal_count
143- ) ;
144-
145- if metrics_ready >= terminal_count {
146- break ;
147- }
148- }
149- } ) ;
150- }
151-
15294 for warehouse_id in config. warehouse_start ..=config. warehouse_end ( ) {
95+ let database_number = topology. database_number_for_warehouse ( warehouse_id) ?;
15396 let database_identity = topology. identity_for_warehouse ( warehouse_id) ?;
97+ let client = shared_database_clients. get ( & database_number) . cloned ( ) . ok_or_else ( || {
98+ anyhow ! (
99+ "missing shared database client for {}" ,
100+ topology. database_name( database_number)
101+ )
102+ } ) ?;
154103 for district_id in 1 ..=DISTRICTS_PER_WAREHOUSE {
155104 let assignment = TerminalAssignment {
156105 terminal_id : terminal_id ( warehouse_id, district_id) ,
@@ -167,10 +116,11 @@ pub async fn run(config: DriverConfig) -> Result<()> {
167116 let terminal_request_ids = request_ids. clone ( ) ;
168117 let runtime = TerminalRuntime {
169118 config : terminal_config,
119+ client : client. clone ( ) ,
170120 metrics : terminal_metrics,
121+ metrics_client : metrics_client. clone ( ) ,
171122 abort : terminal_abort,
172123 start_logged : terminal_start_logged,
173- startup_progress : startup_progress. clone ( ) ,
174124 request_ids : terminal_request_ids,
175125 schedule : terminal_schedule,
176126 run_constants : terminal_constants,
@@ -214,10 +164,12 @@ pub async fn run(config: DriverConfig) -> Result<()> {
214164 }
215165 }
216166 if let Some ( err) = first_error {
167+ shutdown_shared_database_clients ( shared_database_clients) . await ;
217168 return Err ( err) ;
218169 }
219170
220- harvest_delivery_completions ( & config, & schedule, & metrics, & topology, & used_database_numbers) . await ?;
171+ harvest_delivery_completions ( & config, & schedule, & metrics, & shared_database_clients) . await ?;
172+ shutdown_shared_database_clients ( shared_database_clients) . await ;
221173
222174 let summary = metrics. finalize ( DriverSummaryMeta {
223175 run_id : run_id. clone ( ) ,
@@ -245,23 +197,18 @@ pub async fn run(config: DriverConfig) -> Result<()> {
245197async fn run_terminal ( runtime : TerminalRuntime ) -> Result < ( ) > {
246198 let TerminalRuntime {
247199 config,
200+ client,
248201 metrics,
202+ metrics_client,
249203 abort,
250204 start_logged,
251- startup_progress,
252205 request_ids,
253206 schedule,
254207 run_constants,
255208 assignment,
256209 database_identity,
257210 seed,
258211 } = runtime;
259- startup_progress. primary_connect_started . fetch_add ( 1 , Ordering :: Relaxed ) ;
260- let client = ModuleClient :: connect_async ( & config. connection , database_identity) . await ?;
261- startup_progress. primary_connect_ready . fetch_add ( 1 , Ordering :: Relaxed ) ;
262- startup_progress. metrics_connect_started . fetch_add ( 1 , Ordering :: Relaxed ) ;
263- let metrics_client = connect_metrics_module_async ( & config. connection ) . await ?;
264- startup_progress. metrics_connect_ready . fetch_add ( 1 , Ordering :: Relaxed ) ;
265212 log:: info!(
266213 "driver {} terminal {} connected to {} for warehouse {} district {}" ,
267214 config. driver_id,
@@ -289,7 +236,7 @@ async fn run_terminal(runtime: TerminalRuntime) -> Result<()> {
289236 let kind = choose_transaction ( & mut rng) ;
290237 let started_ms = crate :: summary:: now_millis ( ) ;
291238 let context = TransactionContext {
292- client : & client,
239+ client : client. as_ref ( ) ,
293240 config : & config,
294241 run_id : & schedule. run_id ,
295242 driver_id : & config. driver_id ,
@@ -313,7 +260,6 @@ async fn run_terminal(runtime: TerminalRuntime) -> Result<()> {
313260 }
314261 Err ( err) => {
315262 abort. store ( true , Ordering :: Relaxed ) ;
316- client. shutdown_async ( ) . await ;
317263 return Err ( err) ;
318264 }
319265 }
@@ -323,8 +269,6 @@ async fn run_terminal(runtime: TerminalRuntime) -> Result<()> {
323269 tokio:: time:: sleep ( delay) . await ;
324270 }
325271 }
326-
327- client. shutdown_async ( ) . await ;
328272 Ok ( ( ) )
329273}
330274
@@ -718,26 +662,13 @@ async fn harvest_delivery_completions(
718662 config : & DriverConfig ,
719663 schedule : & RunSchedule ,
720664 metrics : & SharedMetrics ,
721- topology : & DatabaseTopology ,
722- used_database_numbers : & [ u32 ] ,
665+ shared_database_clients : & BTreeMap < u32 , Arc < ModuleClient > > ,
723666) -> Result < ( ) > {
724667 let expected = metrics. delivery_queued ( ) ;
725668 if expected == 0 {
726669 return Ok ( ( ) ) ;
727670 }
728- let mut harvest_clients = Vec :: with_capacity ( used_database_numbers. len ( ) ) ;
729- for database_number in used_database_numbers {
730- let database_identity = topology. identity_for_database_number ( * database_number) ?;
731- let client = ModuleClient :: connect_async ( & config. connection , database_identity)
732- . await
733- . with_context ( || {
734- format ! (
735- "failed to connect delivery harvester to {}" ,
736- topology. database_name( * database_number)
737- )
738- } ) ?;
739- harvest_clients. push ( ( * database_number, client) ) ;
740- }
671+ let harvest_clients: Vec < _ > = shared_database_clients. iter ( ) . collect ( ) ;
741672
742673 let mut pending_jobs = 0u64 ;
743674 let mut completed_jobs = 0u64 ;
@@ -849,3 +780,54 @@ async fn sleep_until_ms_async(target_ms: u64) {
849780 tokio:: time:: sleep ( Duration :: from_millis ( target_ms - now_ms) ) . await ;
850781 }
851782}
783+
784+ async fn connect_shared_database_clients (
785+ config : & DriverConfig ,
786+ topology : & DatabaseTopology ,
787+ used_database_numbers : & [ u32 ] ,
788+ ) -> Result < BTreeMap < u32 , Arc < ModuleClient > > > {
789+ let mut connect_tasks = JoinSet :: new ( ) ;
790+ for database_number in used_database_numbers {
791+ let database_number = * database_number;
792+ let database_identity = topology. identity_for_database_number ( database_number) ?;
793+ let database_name = topology. database_name ( database_number) ;
794+ let connection = config. connection . clone ( ) ;
795+ connect_tasks. spawn ( async move {
796+ let client = ModuleClient :: connect_async ( & connection, database_identity)
797+ . await
798+ . with_context ( || format ! ( "failed to connect shared client to {database_name}" ) ) ?;
799+ Ok :: < _ , anyhow:: Error > ( ( database_number, database_name, Arc :: new ( client) ) )
800+ } ) ;
801+ }
802+
803+ let mut shared_clients = BTreeMap :: new ( ) ;
804+ while let Some ( result) = connect_tasks. join_next ( ) . await {
805+ match result {
806+ Ok ( Ok ( ( database_number, database_name, client) ) ) => {
807+ log:: info!(
808+ "driver {} shared database client connected to {}" ,
809+ config. driver_id,
810+ database_name
811+ ) ;
812+ shared_clients. insert ( database_number, client) ;
813+ }
814+ Ok ( Err ( err) ) => {
815+ connect_tasks. abort_all ( ) ;
816+ return Err ( err) ;
817+ }
818+ Err ( err) => {
819+ connect_tasks. abort_all ( ) ;
820+ return Err ( anyhow ! ( "shared database connection task failed: {}" , err) ) ;
821+ }
822+ }
823+ }
824+ Ok ( shared_clients)
825+ }
826+
827+ async fn shutdown_shared_database_clients ( shared_database_clients : BTreeMap < u32 , Arc < ModuleClient > > ) {
828+ for ( _, client) in shared_database_clients {
829+ if let Some ( client) = Arc :: into_inner ( client) {
830+ client. shutdown_async ( ) . await ;
831+ }
832+ }
833+ }
0 commit comments