@@ -102,6 +102,7 @@ pub struct AsyncPgConnection {
102
102
stmt_cache : Arc < Mutex < StmtCache < diesel:: pg:: Pg , Statement > > > ,
103
103
transaction_state : Arc < Mutex < AnsiTransactionManager > > ,
104
104
metadata_cache : Arc < Mutex < PgMetadataCache > > ,
105
+ error_receiver : tokio:: sync:: oneshot:: Receiver < diesel:: result:: Error > ,
105
106
}
106
107
107
108
#[ async_trait:: async_trait]
@@ -124,12 +125,19 @@ impl AsyncConnection for AsyncPgConnection {
124
125
let ( client, connection) = tokio_postgres:: connect ( database_url, tokio_postgres:: NoTls )
125
126
. await
126
127
. map_err ( ErrorHelper ) ?;
127
- tokio:: spawn ( async move {
128
- if let Err ( e) = connection. await {
129
- eprintln ! ( "connection error: {e}" ) ;
128
+ // If there is a connection error, we capture it in this channel and make when
129
+ // the user next calls one of the functions on the connection in this trait, we
130
+ // return the error instead of the inner result.
131
+ let ( sender, receiver) = tokio:: sync:: oneshot:: channel ( ) ;
132
+ tokio:: spawn ( async {
133
+ if let Err ( connection_error) = connection. await {
134
+ let connection_error = diesel:: result:: Error :: from ( ErrorHelper ( connection_error) ) ;
135
+ if let Err ( send_error) = sender. send ( connection_error) {
136
+ eprintln ! ( "Failed to send connection error through channel, connection must have been dropped: {}" , send_error) ;
137
+ }
130
138
}
131
139
} ) ;
132
- Self :: try_from ( client) . await
140
+ Self :: try_from_with_error_receiver ( client, receiver ) . await
133
141
}
134
142
135
143
fn load < ' conn , ' query , T > ( & ' conn mut self , source : T ) -> Self :: LoadFuture < ' conn , ' query >
@@ -138,15 +146,19 @@ impl AsyncConnection for AsyncPgConnection {
138
146
T :: Query : QueryFragment < Self :: Backend > + QueryId + ' query ,
139
147
{
140
148
let query = source. as_query ( ) ;
141
- self . with_prepared_statement ( query, |conn, stmt, binds| async move {
149
+ let f = self . with_prepared_statement ( query, |conn, stmt, binds| async move {
142
150
let res = conn. query_raw ( & stmt, binds) . await . map_err ( ErrorHelper ) ?;
143
151
144
152
Ok ( res
145
153
. map_err ( |e| diesel:: result:: Error :: from ( ErrorHelper ( e) ) )
146
154
. map_ok ( PgRow :: new)
147
155
. boxed ( ) )
148
- } )
149
- . boxed ( )
156
+ } ) ;
157
+
158
+ match self . error_receiver . try_recv ( ) {
159
+ Ok ( e) => Box :: pin ( async move { Err ( e) } ) ,
160
+ Err ( _) => f,
161
+ }
150
162
}
151
163
152
164
fn execute_returning_count < ' conn , ' query , T > (
@@ -156,7 +168,7 @@ impl AsyncConnection for AsyncPgConnection {
156
168
where
157
169
T : QueryFragment < Self :: Backend > + QueryId + ' query ,
158
170
{
159
- self . with_prepared_statement ( source, |conn, stmt, binds| async move {
171
+ let f = self . with_prepared_statement ( source, |conn, stmt, binds| async move {
160
172
let binds = binds
161
173
. iter ( )
162
174
. map ( |b| b as & ( dyn ToSql + Sync ) )
@@ -166,8 +178,12 @@ impl AsyncConnection for AsyncPgConnection {
166
178
. await
167
179
. map_err ( ErrorHelper ) ?;
168
180
Ok ( res as usize )
169
- } )
170
- . boxed ( )
181
+ } ) ;
182
+
183
+ match self . error_receiver . try_recv ( ) {
184
+ Ok ( e) => Box :: pin ( async move { Err ( e) } ) ,
185
+ Err ( _) => f,
186
+ }
171
187
}
172
188
173
189
fn transaction_state ( & mut self ) -> & mut AnsiTransactionManager {
@@ -270,11 +286,24 @@ impl AsyncPgConnection {
270
286
271
287
/// Construct a new `AsyncPgConnection` instance from an existing [`tokio_postgres::Client`]
272
288
pub async fn try_from ( conn : tokio_postgres:: Client ) -> ConnectionResult < Self > {
289
+ // We create a dummy receiver here. If the user is calling this, they have
290
+ // created their own client and connection and are handling any error in
291
+ // the latter themselves.
292
+ Self :: try_from_with_error_receiver ( conn, tokio:: sync:: oneshot:: channel ( ) . 1 ) . await
293
+ }
294
+
295
+ /// Construct a new `AsyncPgConnection` instance from an existing [`tokio_postgres::Client`]
296
+ /// and a [`tokio::sync::oneshot::Receiver`] for receiving an error from the connection.
297
+ async fn try_from_with_error_receiver (
298
+ conn : tokio_postgres:: Client ,
299
+ error_receiver : tokio:: sync:: oneshot:: Receiver < diesel:: result:: Error > ,
300
+ ) -> ConnectionResult < Self > {
273
301
let mut conn = Self {
274
302
conn : Arc :: new ( conn) ,
275
303
stmt_cache : Arc :: new ( Mutex :: new ( StmtCache :: new ( ) ) ) ,
276
304
transaction_state : Arc :: new ( Mutex :: new ( AnsiTransactionManager :: default ( ) ) ) ,
277
305
metadata_cache : Arc :: new ( Mutex :: new ( PgMetadataCache :: new ( ) ) ) ,
306
+ error_receiver,
278
307
} ;
279
308
conn. set_config_options ( )
280
309
. await
@@ -340,7 +369,7 @@ impl AsyncPgConnection {
340
369
let metadata_cache = self . metadata_cache . clone ( ) ;
341
370
let tm = self . transaction_state . clone ( ) ;
342
371
343
- async move {
372
+ let f = async move {
344
373
let sql = sql?;
345
374
let is_safe_to_cache_prepared = is_safe_to_cache_prepared?;
346
375
collect_bind_result?;
@@ -411,8 +440,12 @@ impl AsyncPgConnection {
411
440
let res = callback ( raw_connection, stmt. clone ( ) , binds) . await ;
412
441
let mut tm = tm. lock ( ) . await ;
413
442
update_transaction_manager_status ( res, & mut tm)
443
+ } ;
444
+
445
+ match self . error_receiver . try_recv ( ) {
446
+ Ok ( e) => Box :: pin ( async move { Err ( e) } ) ,
447
+ Err ( _) => f. boxed ( ) ,
414
448
}
415
- . boxed ( )
416
449
}
417
450
}
418
451
0 commit comments