Skip to content

Expose underlying connection errors to user of AsyncPgConnection #121

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ diesel_migrations = "2.1.0"
[features]
default = []
mysql = ["diesel/mysql_backend", "mysql_async", "mysql_common", "futures-channel", "tokio"]
postgres = ["diesel/postgres_backend", "tokio-postgres", "tokio", "tokio/rt"]
postgres = ["diesel/postgres_backend", "tokio-postgres", "tokio", "tokio/macros", "tokio/rt"]
async-connection-wrapper = []
r2d2 = ["diesel/r2d2"]

Expand Down
44 changes: 39 additions & 5 deletions src/pg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use futures_util::future::BoxFuture;
use futures_util::stream::{BoxStream, TryStreamExt};
use futures_util::{Future, FutureExt, StreamExt};
use std::borrow::Cow;
use std::ops::DerefMut;
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio_postgres::types::ToSql;
Expand Down Expand Up @@ -102,6 +103,7 @@ pub struct AsyncPgConnection {
stmt_cache: Arc<Mutex<StmtCache<diesel::pg::Pg, Statement>>>,
transaction_state: Arc<Mutex<AnsiTransactionManager>>,
metadata_cache: Arc<Mutex<PgMetadataCache>>,
error_receiver: Arc<Mutex<tokio::sync::oneshot::Receiver<diesel::result::Error>>>,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if using a mutex here is a good idea, because that one would be locked while executing the query (compared to while starting to prepare as for the other mutexes). That would basically restrict the ability to do pipe lining at all. I need to think about a better solution, probably using a mpmc channel or something like that.

Copy link
Author

@banool banool Oct 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps the approach is this:

  1. Make a new async closure.
  2. In that closure, take the lock.
  3. Manually poll the future. If the future is pending we return future not ready for the closure.

I'm not 100% sure if this works but the overall idea is to make sure we only take the lock to check if the channel has an error in it and if it doesn't, let go of the lock and move on instead of blocking (so tokio::select can instead drive the other branch).

I notice that some other tokio channels maybe have an alternative form of what we need, this poll_recv function: https://docs.rs/tokio/latest/tokio/sync/mpsc/struct.Receiver.html#method.poll_recv.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've finally found the time to have another look at this. #132 is what I've settled on.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neato, looks good to me! I'll close this one out.

}

#[async_trait::async_trait]
Expand All @@ -124,12 +126,19 @@ impl AsyncConnection for AsyncPgConnection {
let (client, connection) = tokio_postgres::connect(database_url, tokio_postgres::NoTls)
.await
.map_err(ErrorHelper)?;
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("connection error: {e}");
// If there is a connection error, we capture it in this channel and make when
// the user next calls one of the functions on the connection in this trait, we
// return the error instead of the inner result.
let (sender, receiver) = tokio::sync::oneshot::channel();
tokio::spawn(async {
if let Err(connection_error) = connection.await {
let connection_error = diesel::result::Error::from(ErrorHelper(connection_error));
if let Err(send_error) = sender.send(connection_error) {
eprintln!("Failed to send connection error through channel, connection must have been dropped: {}", send_error);
}
}
});
Self::try_from(client).await
Self::try_from_with_error_receiver(client, receiver).await
}

fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query>
Expand Down Expand Up @@ -270,11 +279,24 @@ impl AsyncPgConnection {

/// Construct a new `AsyncPgConnection` instance from an existing [`tokio_postgres::Client`]
pub async fn try_from(conn: tokio_postgres::Client) -> ConnectionResult<Self> {
// We create a dummy receiver here. If the user is calling this, they have
// created their own client and connection and are handling any error in
// the latter themselves.
Self::try_from_with_error_receiver(conn, tokio::sync::oneshot::channel().1).await
}

/// Construct a new `AsyncPgConnection` instance from an existing [`tokio_postgres::Client`]
/// and a [`tokio::sync::oneshot::Receiver`] for receiving an error from the connection.
async fn try_from_with_error_receiver(
conn: tokio_postgres::Client,
error_receiver: tokio::sync::oneshot::Receiver<diesel::result::Error>,
) -> ConnectionResult<Self> {
let mut conn = Self {
conn: Arc::new(conn),
stmt_cache: Arc::new(Mutex::new(StmtCache::new())),
transaction_state: Arc::new(Mutex::new(AnsiTransactionManager::default())),
metadata_cache: Arc::new(Mutex::new(PgMetadataCache::new())),
error_receiver: Arc::new(Mutex::new(error_receiver)),
};
conn.set_config_options()
.await
Expand Down Expand Up @@ -340,7 +362,7 @@ impl AsyncPgConnection {
let metadata_cache = self.metadata_cache.clone();
let tm = self.transaction_state.clone();

async move {
let f = async move {
let sql = sql?;
let is_safe_to_cache_prepared = is_safe_to_cache_prepared?;
collect_bind_result?;
Expand Down Expand Up @@ -411,6 +433,18 @@ impl AsyncPgConnection {
let res = callback(raw_connection, stmt.clone(), binds).await;
let mut tm = tm.lock().await;
update_transaction_manager_status(res, &mut tm)
};

let er = self.error_receiver.clone();
async move {
let mut error_receiver = er.lock().await;
// While the future (f) is running, at any await point tokio::select will
// check if there is an error in the channel from the connection. If there
// is, we will return that instead and f will get aborted.
tokio::select! {
error = error_receiver.deref_mut() => Err(error.unwrap()),
res = f => res,
}
}
.boxed()
}
Expand Down