@@ -9,7 +9,15 @@ use tokio_postgres::{Error, Row, SimpleQueryMessage};
99/// in the transaction. Transactions can be nested, with inner transactions implemented via savepoints.
1010pub struct Transaction < ' a > {
1111 connection : ConnectionRef < ' a > ,
12- transaction : tokio_postgres:: Transaction < ' a > ,
12+ transaction : Option < tokio_postgres:: Transaction < ' a > > ,
13+ }
14+
15+ impl < ' a > Drop for Transaction < ' a > {
16+ fn drop ( & mut self ) {
17+ if let Some ( transaction) = self . transaction . take ( ) {
18+ let _ = self . connection . block_on ( transaction. rollback ( ) ) ;
19+ }
20+ }
1321}
1422
1523impl < ' a > Transaction < ' a > {
@@ -19,31 +27,38 @@ impl<'a> Transaction<'a> {
1927 ) -> Transaction < ' a > {
2028 Transaction {
2129 connection,
22- transaction,
30+ transaction : Some ( transaction ) ,
2331 }
2432 }
2533
2634 /// Consumes the transaction, committing all changes made within it.
2735 pub fn commit ( mut self ) -> Result < ( ) , Error > {
28- self . connection . block_on ( self . transaction . commit ( ) )
36+ self . connection
37+ . block_on ( self . transaction . take ( ) . unwrap ( ) . commit ( ) )
2938 }
3039
3140 /// Rolls the transaction back, discarding all changes made within it.
3241 ///
3342 /// This is equivalent to `Transaction`'s `Drop` implementation, but provides any error encountered to the caller.
3443 pub fn rollback ( mut self ) -> Result < ( ) , Error > {
35- self . connection . block_on ( self . transaction . rollback ( ) )
44+ self . connection
45+ . block_on ( self . transaction . take ( ) . unwrap ( ) . rollback ( ) )
3646 }
3747
3848 /// Like `Client::prepare`.
3949 pub fn prepare ( & mut self , query : & str ) -> Result < Statement , Error > {
40- self . connection . block_on ( self . transaction . prepare ( query) )
50+ self . connection
51+ . block_on ( self . transaction . as_ref ( ) . unwrap ( ) . prepare ( query) )
4152 }
4253
4354 /// Like `Client::prepare_typed`.
4455 pub fn prepare_typed ( & mut self , query : & str , types : & [ Type ] ) -> Result < Statement , Error > {
45- self . connection
46- . block_on ( self . transaction . prepare_typed ( query, types) )
56+ self . connection . block_on (
57+ self . transaction
58+ . as_ref ( )
59+ . unwrap ( )
60+ . prepare_typed ( query, types) ,
61+ )
4762 }
4863
4964 /// Like `Client::execute`.
@@ -52,7 +67,7 @@ impl<'a> Transaction<'a> {
5267 T : ?Sized + ToStatement ,
5368 {
5469 self . connection
55- . block_on ( self . transaction . execute ( query, params) )
70+ . block_on ( self . transaction . as_ref ( ) . unwrap ( ) . execute ( query, params) )
5671 }
5772
5873 /// Like `Client::query`.
@@ -61,7 +76,7 @@ impl<'a> Transaction<'a> {
6176 T : ?Sized + ToStatement ,
6277 {
6378 self . connection
64- . block_on ( self . transaction . query ( query, params) )
79+ . block_on ( self . transaction . as_ref ( ) . unwrap ( ) . query ( query, params) )
6580 }
6681
6782 /// Like `Client::query_one`.
@@ -70,7 +85,7 @@ impl<'a> Transaction<'a> {
7085 T : ?Sized + ToStatement ,
7186 {
7287 self . connection
73- . block_on ( self . transaction . query_one ( query, params) )
88+ . block_on ( self . transaction . as_ref ( ) . unwrap ( ) . query_one ( query, params) )
7489 }
7590
7691 /// Like `Client::query_opt`.
@@ -83,7 +98,7 @@ impl<'a> Transaction<'a> {
8398 T : ?Sized + ToStatement ,
8499 {
85100 self . connection
86- . block_on ( self . transaction . query_opt ( query, params) )
101+ . block_on ( self . transaction . as_ref ( ) . unwrap ( ) . query_opt ( query, params) )
87102 }
88103
89104 /// Like `Client::query_raw`.
@@ -95,7 +110,7 @@ impl<'a> Transaction<'a> {
95110 {
96111 let stream = self
97112 . connection
98- . block_on ( self . transaction . query_raw ( query, params) ) ?;
113+ . block_on ( self . transaction . as_ref ( ) . unwrap ( ) . query_raw ( query, params) ) ?;
99114 Ok ( RowIter :: new ( self . connection . as_ref ( ) , stream) )
100115 }
101116
@@ -114,16 +129,20 @@ impl<'a> Transaction<'a> {
114129 T : ?Sized + ToStatement ,
115130 {
116131 self . connection
117- . block_on ( self . transaction . bind ( query, params) )
132+ . block_on ( self . transaction . as_ref ( ) . unwrap ( ) . bind ( query, params) )
118133 }
119134
120135 /// Continues execution of a portal, returning the next set of rows.
121136 ///
122137 /// Unlike `query`, portals can be incrementally evaluated by limiting the number of rows returned in each call to
123138 /// `query_portal`. If the requested number is negative or 0, all remaining rows will be returned.
124139 pub fn query_portal ( & mut self , portal : & Portal , max_rows : i32 ) -> Result < Vec < Row > , Error > {
125- self . connection
126- . block_on ( self . transaction . query_portal ( portal, max_rows) )
140+ self . connection . block_on (
141+ self . transaction
142+ . as_ref ( )
143+ . unwrap ( )
144+ . query_portal ( portal, max_rows) ,
145+ )
127146 }
128147
129148 /// The maximally flexible version of `query_portal`.
@@ -132,9 +151,12 @@ impl<'a> Transaction<'a> {
132151 portal : & Portal ,
133152 max_rows : i32 ,
134153 ) -> Result < RowIter < ' _ > , Error > {
135- let stream = self
136- . connection
137- . block_on ( self . transaction . query_portal_raw ( portal, max_rows) ) ?;
154+ let stream = self . connection . block_on (
155+ self . transaction
156+ . as_ref ( )
157+ . unwrap ( )
158+ . query_portal_raw ( portal, max_rows) ,
159+ ) ?;
138160 Ok ( RowIter :: new ( self . connection . as_ref ( ) , stream) )
139161 }
140162
@@ -143,7 +165,9 @@ impl<'a> Transaction<'a> {
143165 where
144166 T : ?Sized + ToStatement ,
145167 {
146- let sink = self . connection . block_on ( self . transaction . copy_in ( query) ) ?;
168+ let sink = self
169+ . connection
170+ . block_on ( self . transaction . as_ref ( ) . unwrap ( ) . copy_in ( query) ) ?;
147171 Ok ( CopyInWriter :: new ( self . connection . as_ref ( ) , sink) )
148172 }
149173
@@ -152,44 +176,45 @@ impl<'a> Transaction<'a> {
152176 where
153177 T : ?Sized + ToStatement ,
154178 {
155- let stream = self . connection . block_on ( self . transaction . copy_out ( query) ) ?;
179+ let stream = self
180+ . connection
181+ . block_on ( self . transaction . as_ref ( ) . unwrap ( ) . copy_out ( query) ) ?;
156182 Ok ( CopyOutReader :: new ( self . connection . as_ref ( ) , stream) )
157183 }
158184
159185 /// Like `Client::simple_query`.
160186 pub fn simple_query ( & mut self , query : & str ) -> Result < Vec < SimpleQueryMessage > , Error > {
161187 self . connection
162- . block_on ( self . transaction . simple_query ( query) )
188+ . block_on ( self . transaction . as_ref ( ) . unwrap ( ) . simple_query ( query) )
163189 }
164190
165191 /// Like `Client::batch_execute`.
166192 pub fn batch_execute ( & mut self , query : & str ) -> Result < ( ) , Error > {
167193 self . connection
168- . block_on ( self . transaction . batch_execute ( query) )
194+ . block_on ( self . transaction . as_ref ( ) . unwrap ( ) . batch_execute ( query) )
169195 }
170196
171197 /// Like `Client::cancel_token`.
172198 pub fn cancel_token ( & self ) -> CancelToken {
173- CancelToken :: new ( self . transaction . cancel_token ( ) )
199+ CancelToken :: new ( self . transaction . as_ref ( ) . unwrap ( ) . cancel_token ( ) )
174200 }
175201
176202 /// Like `Client::transaction`, but creates a nested transaction via a savepoint.
177203 pub fn transaction ( & mut self ) -> Result < Transaction < ' _ > , Error > {
178- let transaction = self . connection . block_on ( self . transaction . transaction ( ) ) ?;
179- Ok ( Transaction {
180- connection : self . connection . as_ref ( ) ,
181- transaction,
182- } )
204+ let transaction = self
205+ . connection
206+ . block_on ( self . transaction . as_mut ( ) . unwrap ( ) . transaction ( ) ) ?;
207+ Ok ( Transaction :: new ( self . connection . as_ref ( ) , transaction) )
183208 }
209+
184210 /// Like `Client::transaction`, but creates a nested transaction via a savepoint with the specified name.
185211 pub fn savepoint < I > ( & mut self , name : I ) -> Result < Transaction < ' _ > , Error >
186212 where
187213 I : Into < String > ,
188214 {
189- let transaction = self . connection . block_on ( self . transaction . savepoint ( name) ) ?;
190- Ok ( Transaction {
191- connection : self . connection . as_ref ( ) ,
192- transaction,
193- } )
215+ let transaction = self
216+ . connection
217+ . block_on ( self . transaction . as_mut ( ) . unwrap ( ) . savepoint ( name) ) ?;
218+ Ok ( Transaction :: new ( self . connection . as_ref ( ) , transaction) )
194219 }
195220}
0 commit comments