Skip to content

Commit 4438e23

Browse files
committed
Expose underlying connection errors to user of AsyncPgConnection
1 parent 1e18b37 commit 4438e23

File tree

1 file changed

+45
-12
lines changed

1 file changed

+45
-12
lines changed

src/pg/mod.rs

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ pub struct AsyncPgConnection {
102102
stmt_cache: Arc<Mutex<StmtCache<diesel::pg::Pg, Statement>>>,
103103
transaction_state: Arc<Mutex<AnsiTransactionManager>>,
104104
metadata_cache: Arc<Mutex<PgMetadataCache>>,
105+
error_receiver: tokio::sync::oneshot::Receiver<diesel::result::Error>,
105106
}
106107

107108
#[async_trait::async_trait]
@@ -124,12 +125,19 @@ impl AsyncConnection for AsyncPgConnection {
124125
let (client, connection) = tokio_postgres::connect(database_url, tokio_postgres::NoTls)
125126
.await
126127
.map_err(ErrorHelper)?;
127-
tokio::spawn(async move {
128-
if let Err(e) = connection.await {
129-
eprintln!("connection error: {e}");
128+
// If there is a connection error, we capture it in this channel and make when
129+
// the user next calls one of the functions on the connection in this trait, we
130+
// return the error instead of the inner result.
131+
let (sender, receiver) = tokio::sync::oneshot::channel();
132+
tokio::spawn(async {
133+
if let Err(connection_error) = connection.await {
134+
let connection_error = diesel::result::Error::from(ErrorHelper(connection_error));
135+
if let Err(send_error) = sender.send(connection_error) {
136+
eprintln!("Failed to send connection error through channel, connection must have been dropped: {}", send_error);
137+
}
130138
}
131139
});
132-
Self::try_from(client).await
140+
Self::try_from_with_error_receiver(client, receiver).await
133141
}
134142

135143
fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query>
@@ -138,15 +146,19 @@ impl AsyncConnection for AsyncPgConnection {
138146
T::Query: QueryFragment<Self::Backend> + QueryId + 'query,
139147
{
140148
let query = source.as_query();
141-
self.with_prepared_statement(query, |conn, stmt, binds| async move {
149+
let f = self.with_prepared_statement(query, |conn, stmt, binds| async move {
142150
let res = conn.query_raw(&stmt, binds).await.map_err(ErrorHelper)?;
143151

144152
Ok(res
145153
.map_err(|e| diesel::result::Error::from(ErrorHelper(e)))
146154
.map_ok(PgRow::new)
147155
.boxed())
148-
})
149-
.boxed()
156+
});
157+
158+
match self.error_receiver.try_recv() {
159+
Ok(e) => Box::pin(async move { Err(e) }),
160+
Err(_) => f,
161+
}
150162
}
151163

152164
fn execute_returning_count<'conn, 'query, T>(
@@ -156,7 +168,7 @@ impl AsyncConnection for AsyncPgConnection {
156168
where
157169
T: QueryFragment<Self::Backend> + QueryId + 'query,
158170
{
159-
self.with_prepared_statement(source, |conn, stmt, binds| async move {
171+
let f = self.with_prepared_statement(source, |conn, stmt, binds| async move {
160172
let binds = binds
161173
.iter()
162174
.map(|b| b as &(dyn ToSql + Sync))
@@ -166,8 +178,12 @@ impl AsyncConnection for AsyncPgConnection {
166178
.await
167179
.map_err(ErrorHelper)?;
168180
Ok(res as usize)
169-
})
170-
.boxed()
181+
});
182+
183+
match self.error_receiver.try_recv() {
184+
Ok(e) => Box::pin(async move { Err(e) }),
185+
Err(_) => f,
186+
}
171187
}
172188

173189
fn transaction_state(&mut self) -> &mut AnsiTransactionManager {
@@ -270,11 +286,24 @@ impl AsyncPgConnection {
270286

271287
/// Construct a new `AsyncPgConnection` instance from an existing [`tokio_postgres::Client`]
272288
pub async fn try_from(conn: tokio_postgres::Client) -> ConnectionResult<Self> {
289+
// We create a dummy receiver here. If the user is calling this, they have
290+
// created their own client and connection and are handling any error in
291+
// the latter themselves.
292+
Self::try_from_with_error_receiver(conn, tokio::sync::oneshot::channel().1).await
293+
}
294+
295+
/// Construct a new `AsyncPgConnection` instance from an existing [`tokio_postgres::Client`]
296+
/// and a [`tokio::sync::oneshot::Receiver`] for receiving an error from the connection.
297+
async fn try_from_with_error_receiver(
298+
conn: tokio_postgres::Client,
299+
error_receiver: tokio::sync::oneshot::Receiver<diesel::result::Error>,
300+
) -> ConnectionResult<Self> {
273301
let mut conn = Self {
274302
conn: Arc::new(conn),
275303
stmt_cache: Arc::new(Mutex::new(StmtCache::new())),
276304
transaction_state: Arc::new(Mutex::new(AnsiTransactionManager::default())),
277305
metadata_cache: Arc::new(Mutex::new(PgMetadataCache::new())),
306+
error_receiver,
278307
};
279308
conn.set_config_options()
280309
.await
@@ -340,7 +369,7 @@ impl AsyncPgConnection {
340369
let metadata_cache = self.metadata_cache.clone();
341370
let tm = self.transaction_state.clone();
342371

343-
async move {
372+
let f = async move {
344373
let sql = sql?;
345374
let is_safe_to_cache_prepared = is_safe_to_cache_prepared?;
346375
collect_bind_result?;
@@ -411,8 +440,12 @@ impl AsyncPgConnection {
411440
let res = callback(raw_connection, stmt.clone(), binds).await;
412441
let mut tm = tm.lock().await;
413442
update_transaction_manager_status(res, &mut tm)
443+
};
444+
445+
match self.error_receiver.try_recv() {
446+
Ok(e) => Box::pin(async move { Err(e) }),
447+
Err(_) => f.boxed(),
414448
}
415-
.boxed()
416449
}
417450
}
418451

0 commit comments

Comments
 (0)