From c8a752f043d1a4b8bc0a43d7545d4edd2afc9dae Mon Sep 17 00:00:00 2001 From: bwqr Date: Fri, 23 May 2025 16:19:25 +0300 Subject: [PATCH 1/6] move SyncConnectionWrapper struct into `implementation` module Similar to `BlockOn` trait defined inside `src/async_connection_wrapper.rs`, a `SpawnBlocking` trait will be introduced. Before introducing the trait, move the defined structs into `implementation` module. --- src/sync_connection_wrapper/mod.rs | 659 +++++++++++++++-------------- 1 file changed, 332 insertions(+), 327 deletions(-) diff --git a/src/sync_connection_wrapper/mod.rs b/src/sync_connection_wrapper/mod.rs index 9f28e5b..af0aadb 100644 --- a/src/sync_connection_wrapper/mod.rs +++ b/src/sync_connection_wrapper/mod.rs @@ -7,375 +7,380 @@ //! * using a sync Connection implementation in async context //! * using the same code base for async crates needing multiple backends -use crate::{AsyncConnection, SimpleAsyncConnection, TransactionManager}; -use diesel::backend::{Backend, DieselReserveSpecialization}; -use diesel::connection::{CacheSize, Instrumentation}; -use diesel::connection::{ - Connection, LoadConnection, TransactionManagerStatus, WithMetadataLookup, -}; -use diesel::query_builder::{ - AsQuery, CollectedQuery, MoveableBindCollector, QueryBuilder, QueryFragment, QueryId, -}; -use diesel::row::IntoOwnedRow; -use diesel::{ConnectionResult, QueryResult}; -use futures_util::future::BoxFuture; -use futures_util::stream::BoxStream; -use futures_util::{FutureExt, StreamExt, TryFutureExt}; -use std::marker::PhantomData; -use std::sync::{Arc, Mutex}; -use tokio::task::JoinError; - #[cfg(feature = "sqlite")] mod sqlite; -fn from_tokio_join_error(join_error: JoinError) -> diesel::result::Error { - diesel::result::Error::DatabaseError( - diesel::result::DatabaseErrorKind::UnableToSendCommand, - Box::new(join_error.to_string()), - ) -} +pub use self::implementation::SyncConnectionWrapper; +pub use self::implementation::SyncTransactionManagerWrapper; -/// A wrapper of a [`diesel::connection::Connection`] usable in async context. -/// -/// It implements AsyncConnection if [`diesel::connection::Connection`] fullfils requirements: -/// * it's a [`diesel::connection::LoadConnection`] -/// * its [`diesel::connection::Connection::Backend`] has a [`diesel::query_builder::BindCollector`] implementing [`diesel::query_builder::MoveableBindCollector`] -/// * its [`diesel::connection::LoadConnection::Row`] implements [`diesel::row::IntoOwnedRow`] -/// -/// Internally this wrapper type will use `spawn_blocking` on tokio -/// to execute the request on the inner connection. This implies a -/// dependency on tokio and that the runtime is running. -/// -/// Note that only SQLite is supported at the moment. -/// -/// # Examples -/// -/// ```rust -/// # include!("../doctest_setup.rs"); -/// use diesel_async::RunQueryDsl; -/// use schema::users; -/// -/// async fn some_async_fn() { -/// # let database_url = database_url(); -/// use diesel_async::AsyncConnection; -/// use diesel::sqlite::SqliteConnection; -/// let mut conn = -/// SyncConnectionWrapper::::establish(&database_url).await.unwrap(); -/// # create_tables(&mut conn).await; -/// -/// let all_users = users::table.load::<(i32, String)>(&mut conn).await.unwrap(); -/// # assert_eq!(all_users.len(), 2); -/// } -/// -/// # #[cfg(feature = "sqlite")] -/// # #[tokio::main] -/// # async fn main() { -/// # some_async_fn().await; -/// # } -/// ``` -pub struct SyncConnectionWrapper { - inner: Arc>, -} +mod implementation { + use crate::{AsyncConnection, SimpleAsyncConnection, TransactionManager}; + use diesel::backend::{Backend, DieselReserveSpecialization}; + use diesel::connection::{CacheSize, Instrumentation}; + use diesel::connection::{ + Connection, LoadConnection, TransactionManagerStatus, WithMetadataLookup, + }; + use diesel::query_builder::{ + AsQuery, CollectedQuery, MoveableBindCollector, QueryBuilder, QueryFragment, QueryId, + }; + use diesel::row::IntoOwnedRow; + use diesel::{ConnectionResult, QueryResult}; + use futures_util::future::BoxFuture; + use futures_util::stream::BoxStream; + use futures_util::{FutureExt, StreamExt, TryFutureExt}; + use std::marker::PhantomData; + use std::sync::{Arc, Mutex}; + use tokio::task::JoinError; -impl SimpleAsyncConnection for SyncConnectionWrapper -where - C: diesel::connection::Connection + 'static, -{ - async fn batch_execute(&mut self, query: &str) -> QueryResult<()> { - let query = query.to_string(); - self.spawn_blocking(move |inner| inner.batch_execute(query.as_str())) - .await + fn from_tokio_join_error(join_error: JoinError) -> diesel::result::Error { + diesel::result::Error::DatabaseError( + diesel::result::DatabaseErrorKind::UnableToSendCommand, + Box::new(join_error.to_string()), + ) } -} -impl AsyncConnection for SyncConnectionWrapper -where - // Backend bounds - ::Backend: std::default::Default + DieselReserveSpecialization, - ::QueryBuilder: std::default::Default, - // Connection bounds - C: Connection + LoadConnection + WithMetadataLookup + 'static, - ::TransactionManager: Send, - // BindCollector bounds - MD: Send + 'static, - for<'a> ::BindCollector<'a>: - MoveableBindCollector + std::default::Default, - // Row bounds - O: 'static + Send + for<'conn> diesel::row::Row<'conn, C::Backend>, - for<'conn, 'query> ::Row<'conn, 'query>: - IntoOwnedRow<'conn, ::Backend, OwnedRow = O>, -{ - type LoadFuture<'conn, 'query> = BoxFuture<'query, QueryResult>>; - type ExecuteFuture<'conn, 'query> = BoxFuture<'query, QueryResult>; - type Stream<'conn, 'query> = BoxStream<'static, QueryResult>>; - type Row<'conn, 'query> = O; - type Backend = ::Backend; - type TransactionManager = SyncTransactionManagerWrapper<::TransactionManager>; - - async fn establish(database_url: &str) -> ConnectionResult { - let database_url = database_url.to_string(); - tokio::task::spawn_blocking(move || C::establish(&database_url)) - .await - .unwrap_or_else(|e| Err(diesel::ConnectionError::BadConnection(e.to_string()))) - .map(|c| SyncConnectionWrapper::new(c)) + /// A wrapper of a [`diesel::connection::Connection`] usable in async context. + /// + /// It implements AsyncConnection if [`diesel::connection::Connection`] fullfils requirements: + /// * it's a [`diesel::connection::LoadConnection`] + /// * its [`diesel::connection::Connection::Backend`] has a [`diesel::query_builder::BindCollector`] implementing [`diesel::query_builder::MoveableBindCollector`] + /// * its [`diesel::connection::LoadConnection::Row`] implements [`diesel::row::IntoOwnedRow`] + /// + /// Internally this wrapper type will use `spawn_blocking` on tokio + /// to execute the request on the inner connection. This implies a + /// dependency on tokio and that the runtime is running. + /// + /// Note that only SQLite is supported at the moment. + /// + /// # Examples + /// + /// ```rust + /// # include!("../doctest_setup.rs"); + /// use diesel_async::RunQueryDsl; + /// use schema::users; + /// + /// async fn some_async_fn() { + /// # let database_url = database_url(); + /// use diesel_async::AsyncConnection; + /// use diesel::sqlite::SqliteConnection; + /// let mut conn = + /// SyncConnectionWrapper::::establish(&database_url).await.unwrap(); + /// # create_tables(&mut conn).await; + /// + /// let all_users = users::table.load::<(i32, String)>(&mut conn).await.unwrap(); + /// # assert_eq!(all_users.len(), 2); + /// } + /// + /// # #[cfg(feature = "sqlite")] + /// # #[tokio::main] + /// # async fn main() { + /// # some_async_fn().await; + /// # } + /// ``` + pub struct SyncConnectionWrapper { + inner: Arc>, } - fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query> + impl SimpleAsyncConnection for SyncConnectionWrapper where - T: AsQuery + 'query, - T::Query: QueryFragment + QueryId + 'query, + C: diesel::connection::Connection + 'static, { - self.execute_with_prepared_query(source.as_query(), |conn, query| { - use diesel::row::IntoOwnedRow; - let mut cache = <<::Row<'_, '_> as IntoOwnedRow< - ::Backend, - >>::Cache as Default>::default(); - let cursor = conn.load(&query)?; - - let size_hint = cursor.size_hint(); - let mut out = Vec::with_capacity(size_hint.1.unwrap_or(size_hint.0)); - // we use an explicit loop here to easily propagate possible errors - // as early as possible - for row in cursor { - out.push(Ok(IntoOwnedRow::into_owned(row?, &mut cache))); - } - - Ok(out) - }) - .map_ok(|rows| futures_util::stream::iter(rows).boxed()) - .boxed() + async fn batch_execute(&mut self, query: &str) -> QueryResult<()> { + let query = query.to_string(); + self.spawn_blocking(move |inner| inner.batch_execute(query.as_str())) + .await + } } - fn execute_returning_count<'query, T>(&mut self, source: T) -> Self::ExecuteFuture<'_, 'query> + impl AsyncConnection for SyncConnectionWrapper where - T: QueryFragment + QueryId, + // Backend bounds + ::Backend: std::default::Default + DieselReserveSpecialization, + ::QueryBuilder: std::default::Default, + // Connection bounds + C: Connection + LoadConnection + WithMetadataLookup + 'static, + ::TransactionManager: Send, + // BindCollector bounds + MD: Send + 'static, + for<'a> ::BindCollector<'a>: + MoveableBindCollector + std::default::Default, + // Row bounds + O: 'static + Send + for<'conn> diesel::row::Row<'conn, C::Backend>, + for<'conn, 'query> ::Row<'conn, 'query>: + IntoOwnedRow<'conn, ::Backend, OwnedRow = O>, { - self.execute_with_prepared_query(source, |conn, query| conn.execute_returning_count(&query)) - } - - fn transaction_state( - &mut self, - ) -> &mut >::TransactionStateData { - self.exclusive_connection().transaction_state() - } + type LoadFuture<'conn, 'query> = BoxFuture<'query, QueryResult>>; + type ExecuteFuture<'conn, 'query> = BoxFuture<'query, QueryResult>; + type Stream<'conn, 'query> = BoxStream<'static, QueryResult>>; + type Row<'conn, 'query> = O; + type Backend = ::Backend; + type TransactionManager = SyncTransactionManagerWrapper<::TransactionManager>; - fn instrumentation(&mut self) -> &mut dyn Instrumentation { - // there should be no other pending future when this is called - // that means there is only one instance of this arc and - // we can simply access the inner data - if let Some(inner) = Arc::get_mut(&mut self.inner) { - inner - .get_mut() - .unwrap_or_else(|p| p.into_inner()) - .instrumentation() - } else { - panic!("Cannot access shared instrumentation") + async fn establish(database_url: &str) -> ConnectionResult { + let database_url = database_url.to_string(); + tokio::task::spawn_blocking(move || C::establish(&database_url)) + .await + .unwrap_or_else(|e| Err(diesel::ConnectionError::BadConnection(e.to_string()))) + .map(|c| SyncConnectionWrapper::new(c)) } - } - fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) { - // there should be no other pending future when this is called - // that means there is only one instance of this arc and - // we can simply access the inner data - if let Some(inner) = Arc::get_mut(&mut self.inner) { - inner - .get_mut() - .unwrap_or_else(|p| p.into_inner()) - .set_instrumentation(instrumentation) - } else { - panic!("Cannot access shared instrumentation") - } - } + fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query> + where + T: AsQuery + 'query, + T::Query: QueryFragment + QueryId + 'query, + { + self.execute_with_prepared_query(source.as_query(), |conn, query| { + use diesel::row::IntoOwnedRow; + let mut cache = <<::Row<'_, '_> as IntoOwnedRow< + ::Backend, + >>::Cache as Default>::default(); + let cursor = conn.load(&query)?; - fn set_prepared_statement_cache_size(&mut self, size: CacheSize) { - // there should be no other pending future when this is called - // that means there is only one instance of this arc and - // we can simply access the inner data - if let Some(inner) = Arc::get_mut(&mut self.inner) { - inner - .get_mut() - .unwrap_or_else(|p| p.into_inner()) - .set_prepared_statement_cache_size(size) - } else { - panic!("Cannot access shared cache") + let size_hint = cursor.size_hint(); + let mut out = Vec::with_capacity(size_hint.1.unwrap_or(size_hint.0)); + // we use an explicit loop here to easily propagate possible errors + // as early as possible + for row in cursor { + out.push(Ok(IntoOwnedRow::into_owned(row?, &mut cache))); + } + + Ok(out) + }) + .map_ok(|rows| futures_util::stream::iter(rows).boxed()) + .boxed() } - } -} -/// A wrapper of a diesel transaction manager usable in async context. -pub struct SyncTransactionManagerWrapper(PhantomData); + fn execute_returning_count<'query, T>(&mut self, source: T) -> Self::ExecuteFuture<'_, 'query> + where + T: QueryFragment + QueryId, + { + self.execute_with_prepared_query(source, |conn, query| conn.execute_returning_count(&query)) + } -impl TransactionManager> for SyncTransactionManagerWrapper -where - SyncConnectionWrapper: AsyncConnection, - C: Connection + 'static, - T: diesel::connection::TransactionManager + Send, -{ - type TransactionStateData = T::TransactionStateData; + fn transaction_state( + &mut self, + ) -> &mut >::TransactionStateData { + self.exclusive_connection().transaction_state() + } - async fn begin_transaction(conn: &mut SyncConnectionWrapper) -> QueryResult<()> { - conn.spawn_blocking(move |inner| T::begin_transaction(inner)) - .await - } + fn instrumentation(&mut self) -> &mut dyn Instrumentation { + // there should be no other pending future when this is called + // that means there is only one instance of this arc and + // we can simply access the inner data + if let Some(inner) = Arc::get_mut(&mut self.inner) { + inner + .get_mut() + .unwrap_or_else(|p| p.into_inner()) + .instrumentation() + } else { + panic!("Cannot access shared instrumentation") + } + } - async fn commit_transaction(conn: &mut SyncConnectionWrapper) -> QueryResult<()> { - conn.spawn_blocking(move |inner| T::commit_transaction(inner)) - .await - } + fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) { + // there should be no other pending future when this is called + // that means there is only one instance of this arc and + // we can simply access the inner data + if let Some(inner) = Arc::get_mut(&mut self.inner) { + inner + .get_mut() + .unwrap_or_else(|p| p.into_inner()) + .set_instrumentation(instrumentation) + } else { + panic!("Cannot access shared instrumentation") + } + } - async fn rollback_transaction(conn: &mut SyncConnectionWrapper) -> QueryResult<()> { - conn.spawn_blocking(move |inner| T::rollback_transaction(inner)) - .await + fn set_prepared_statement_cache_size(&mut self, size: CacheSize) { + // there should be no other pending future when this is called + // that means there is only one instance of this arc and + // we can simply access the inner data + if let Some(inner) = Arc::get_mut(&mut self.inner) { + inner + .get_mut() + .unwrap_or_else(|p| p.into_inner()) + .set_prepared_statement_cache_size(size) + } else { + panic!("Cannot access shared cache") + } + } } - fn transaction_manager_status_mut( - conn: &mut SyncConnectionWrapper, - ) -> &mut TransactionManagerStatus { - T::transaction_manager_status_mut(conn.exclusive_connection()) - } -} + /// A wrapper of a diesel transaction manager usable in async context. + pub struct SyncTransactionManagerWrapper(PhantomData); -impl SyncConnectionWrapper { - /// Builds a wrapper with this underlying sync connection - pub fn new(connection: C) -> Self + impl TransactionManager> for SyncTransactionManagerWrapper where - C: Connection, + SyncConnectionWrapper: AsyncConnection, + C: Connection + 'static, + T: diesel::connection::TransactionManager + Send, { - SyncConnectionWrapper { - inner: Arc::new(Mutex::new(connection)), + type TransactionStateData = T::TransactionStateData; + + async fn begin_transaction(conn: &mut SyncConnectionWrapper) -> QueryResult<()> { + conn.spawn_blocking(move |inner| T::begin_transaction(inner)) + .await } - } - /// Run a operation directly with the inner connection - /// - /// This function is usful to register custom functions - /// and collection for Sqlite for example - /// - /// # Example - /// - /// ```rust - /// # include!("../doctest_setup.rs"); - /// # #[tokio::main] - /// # async fn main() { - /// # run_test().await.unwrap(); - /// # } - /// # - /// # async fn run_test() -> QueryResult<()> { - /// # let mut conn = establish_connection().await; - /// conn.spawn_blocking(|conn| { - /// // sqlite.rs sqlite NOCASE only works for ASCII characters, - /// // this collation allows handling UTF-8 (barring locale differences) - /// conn.register_collation("RUSTNOCASE", |rhs, lhs| { - /// rhs.to_lowercase().cmp(&lhs.to_lowercase()) - /// }) - /// }).await - /// - /// # } - /// ``` - pub fn spawn_blocking<'a, R>( - &mut self, - task: impl FnOnce(&mut C) -> QueryResult + Send + 'static, - ) -> BoxFuture<'a, QueryResult> - where - C: Connection + 'static, - R: Send + 'static, - { - let inner = self.inner.clone(); - tokio::task::spawn_blocking(move || { - let mut inner = inner.lock().unwrap_or_else(|poison| { - // try to be resilient by providing the guard - inner.clear_poison(); - poison.into_inner() - }); - task(&mut inner) - }) - .unwrap_or_else(|err| QueryResult::Err(from_tokio_join_error(err))) - .boxed() + async fn commit_transaction(conn: &mut SyncConnectionWrapper) -> QueryResult<()> { + conn.spawn_blocking(move |inner| T::commit_transaction(inner)) + .await + } + + async fn rollback_transaction(conn: &mut SyncConnectionWrapper) -> QueryResult<()> { + conn.spawn_blocking(move |inner| T::rollback_transaction(inner)) + .await + } + + fn transaction_manager_status_mut( + conn: &mut SyncConnectionWrapper, + ) -> &mut TransactionManagerStatus { + T::transaction_manager_status_mut(conn.exclusive_connection()) + } } - fn execute_with_prepared_query<'a, MD, Q, R>( - &mut self, - query: Q, - callback: impl FnOnce(&mut C, &CollectedQuery) -> QueryResult + Send + 'static, - ) -> BoxFuture<'a, QueryResult> - where - // Backend bounds - ::Backend: std::default::Default + DieselReserveSpecialization, - ::QueryBuilder: std::default::Default, - // Connection bounds - C: Connection + LoadConnection + WithMetadataLookup + 'static, - ::TransactionManager: Send, - // BindCollector bounds - MD: Send + 'static, - for<'b> ::BindCollector<'b>: - MoveableBindCollector + std::default::Default, - // Arguments/Return bounds - Q: QueryFragment + QueryId, - R: Send + 'static, - { - let backend = C::Backend::default(); + impl SyncConnectionWrapper { + /// Builds a wrapper with this underlying sync connection + pub fn new(connection: C) -> Self + where + C: Connection, + { + SyncConnectionWrapper { + inner: Arc::new(Mutex::new(connection)), + } + } + + /// Run a operation directly with the inner connection + /// + /// This function is usful to register custom functions + /// and collection for Sqlite for example + /// + /// # Example + /// + /// ```rust + /// # include!("../doctest_setup.rs"); + /// # #[tokio::main] + /// # async fn main() { + /// # run_test().await.unwrap(); + /// # } + /// # + /// # async fn run_test() -> QueryResult<()> { + /// # let mut conn = establish_connection().await; + /// conn.spawn_blocking(|conn| { + /// // sqlite.rs sqlite NOCASE only works for ASCII characters, + /// // this collation allows handling UTF-8 (barring locale differences) + /// conn.register_collation("RUSTNOCASE", |rhs, lhs| { + /// rhs.to_lowercase().cmp(&lhs.to_lowercase()) + /// }) + /// }).await + /// + /// # } + /// ``` + pub fn spawn_blocking<'a, R>( + &mut self, + task: impl FnOnce(&mut C) -> QueryResult + Send + 'static, + ) -> BoxFuture<'a, QueryResult> + where + C: Connection + 'static, + R: Send + 'static, + { + let inner = self.inner.clone(); + tokio::task::spawn_blocking(move || { + let mut inner = inner.lock().unwrap_or_else(|poison| { + // try to be resilient by providing the guard + inner.clear_poison(); + poison.into_inner() + }); + task(&mut inner) + }) + .unwrap_or_else(|err| QueryResult::Err(from_tokio_join_error(err))) + .boxed() + } - let (collect_bind_result, collector_data) = { - let exclusive = self.inner.clone(); - let mut inner = exclusive.lock().unwrap_or_else(|poison| { - // try to be resilient by providing the guard - exclusive.clear_poison(); - poison.into_inner() - }); - let mut bind_collector = - <::BindCollector<'_> as Default>::default(); - let metadata_lookup = inner.metadata_lookup(); - let result = query.collect_binds(&mut bind_collector, metadata_lookup, &backend); - let collector_data = bind_collector.moveable(); + fn execute_with_prepared_query<'a, MD, Q, R>( + &mut self, + query: Q, + callback: impl FnOnce(&mut C, &CollectedQuery) -> QueryResult + Send + 'static, + ) -> BoxFuture<'a, QueryResult> + where + // Backend bounds + ::Backend: std::default::Default + DieselReserveSpecialization, + ::QueryBuilder: std::default::Default, + // Connection bounds + C: Connection + LoadConnection + WithMetadataLookup + 'static, + ::TransactionManager: Send, + // BindCollector bounds + MD: Send + 'static, + for<'b> ::BindCollector<'b>: + MoveableBindCollector + std::default::Default, + // Arguments/Return bounds + Q: QueryFragment + QueryId, + R: Send + 'static, + { + let backend = C::Backend::default(); - (result, collector_data) - }; + let (collect_bind_result, collector_data) = { + let exclusive = self.inner.clone(); + let mut inner = exclusive.lock().unwrap_or_else(|poison| { + // try to be resilient by providing the guard + exclusive.clear_poison(); + poison.into_inner() + }); + let mut bind_collector = + <::BindCollector<'_> as Default>::default(); + let metadata_lookup = inner.metadata_lookup(); + let result = query.collect_binds(&mut bind_collector, metadata_lookup, &backend); + let collector_data = bind_collector.moveable(); - let mut query_builder = <::QueryBuilder as Default>::default(); - let sql = query - .to_sql(&mut query_builder, &backend) - .map(|_| query_builder.finish()); - let is_safe_to_cache_prepared = query.is_safe_to_cache_prepared(&backend); + (result, collector_data) + }; + + let mut query_builder = <::QueryBuilder as Default>::default(); + let sql = query + .to_sql(&mut query_builder, &backend) + .map(|_| query_builder.finish()); + let is_safe_to_cache_prepared = query.is_safe_to_cache_prepared(&backend); + + self.spawn_blocking(|inner| { + collect_bind_result?; + let query = CollectedQuery::new(sql?, is_safe_to_cache_prepared?, collector_data); + callback(inner, &query) + }) + } - self.spawn_blocking(|inner| { - collect_bind_result?; - let query = CollectedQuery::new(sql?, is_safe_to_cache_prepared?, collector_data); - callback(inner, &query) - }) + /// Gets an exclusive access to the underlying diesel Connection + /// + /// It panics in case of shared access. + /// This is typically used only used during transaction. + pub(self) fn exclusive_connection(&mut self) -> &mut C + where + C: Connection, + { + // there should be no other pending future when this is called + // that means there is only one instance of this Arc and + // we can simply access the inner data + if let Some(conn_mutex) = Arc::get_mut(&mut self.inner) { + conn_mutex + .get_mut() + .expect("Mutex is poisoned, a thread must have panicked holding it.") + } else { + panic!("Cannot access shared transaction state") + } + } } - /// Gets an exclusive access to the underlying diesel Connection - /// - /// It panics in case of shared access. - /// This is typically used only used during transaction. - pub(self) fn exclusive_connection(&mut self) -> &mut C + #[cfg(any( + feature = "deadpool", + feature = "bb8", + feature = "mobc", + feature = "r2d2" + ))] + impl crate::pooled_connection::PoolableConnection for SyncConnectionWrapper where - C: Connection, + Self: AsyncConnection, { - // there should be no other pending future when this is called - // that means there is only one instance of this Arc and - // we can simply access the inner data - if let Some(conn_mutex) = Arc::get_mut(&mut self.inner) { - conn_mutex - .get_mut() - .expect("Mutex is poisoned, a thread must have panicked holding it.") - } else { - panic!("Cannot access shared transaction state") + fn is_broken(&mut self) -> bool { + Self::TransactionManager::is_broken_transaction_manager(self) } } } - -#[cfg(any( - feature = "deadpool", - feature = "bb8", - feature = "mobc", - feature = "r2d2" -))] -impl crate::pooled_connection::PoolableConnection for SyncConnectionWrapper -where - Self: AsyncConnection, -{ - fn is_broken(&mut self) -> bool { - Self::TransactionManager::is_broken_transaction_manager(self) - } -} From 48a41a1a717a5727418288766f6d65021523ae6f Mon Sep 17 00:00:00 2001 From: bwqr Date: Fri, 23 May 2025 19:14:58 +0300 Subject: [PATCH 2/6] define `SpawnBlocking` trait to customize runtime used for spawning blocking tasks. Previously, `SyncConnectionWrapper` was using tokio as spawning and running blocking tasks. This had prevented using Sqlite backend on wasm32-unknown-unknown target since futures generally run on top of JavaScript promises with the help of wasm_bindgen_futures crate. It is now possible for users to provide their own runtime to spawn blocking tasks inside the `SyncConnectionWrapper`. --- src/sync_connection_wrapper/mod.rs | 127 +++++++++++++++++++++++++---- 1 file changed, 109 insertions(+), 18 deletions(-) diff --git a/src/sync_connection_wrapper/mod.rs b/src/sync_connection_wrapper/mod.rs index af0aadb..1462196 100644 --- a/src/sync_connection_wrapper/mod.rs +++ b/src/sync_connection_wrapper/mod.rs @@ -6,11 +6,38 @@ //! //! * using a sync Connection implementation in async context //! * using the same code base for async crates needing multiple backends +use std::error::Error; +use futures_util::future::BoxFuture; #[cfg(feature = "sqlite")] mod sqlite; +/// This is a helper trait that allows to customize the +/// spawning blocking tasks as part of the +/// [`SyncConnectionWrapper`] type. By default a +/// tokio runtime and its spawn_blocking function is used. +pub trait SpawnBlocking { + /// This function should allow to execute a + /// given blocking task without blocking the caller + /// to get the result + fn spawn_blocking<'a, R>( + &mut self, + task: impl FnOnce() -> R + Send + 'static, + ) -> BoxFuture<'a, Result>> + where + R: Send + 'static; + + /// This function should be used to construct + /// a new runtime instance + fn get_runtime() -> Self; +} + +#[cfg(feature = "tokio")] +pub type SyncConnectionWrapper = self::implementation::SyncConnectionWrapper; + +#[cfg(not(feature = "tokio"))] pub use self::implementation::SyncConnectionWrapper; + pub use self::implementation::SyncTransactionManagerWrapper; mod implementation { @@ -25,17 +52,17 @@ mod implementation { }; use diesel::row::IntoOwnedRow; use diesel::{ConnectionResult, QueryResult}; - use futures_util::future::BoxFuture; use futures_util::stream::BoxStream; use futures_util::{FutureExt, StreamExt, TryFutureExt}; use std::marker::PhantomData; use std::sync::{Arc, Mutex}; - use tokio::task::JoinError; - fn from_tokio_join_error(join_error: JoinError) -> diesel::result::Error { + use super::*; + + fn from_spawn_blocking_error(error: Box) -> diesel::result::Error { diesel::result::Error::DatabaseError( diesel::result::DatabaseErrorKind::UnableToSendCommand, - Box::new(join_error.to_string()), + Box::new(error.to_string()), ) } @@ -77,13 +104,15 @@ mod implementation { /// # some_async_fn().await; /// # } /// ``` - pub struct SyncConnectionWrapper { + pub struct SyncConnectionWrapper { inner: Arc>, + runtime: S, } - impl SimpleAsyncConnection for SyncConnectionWrapper + impl SimpleAsyncConnection for SyncConnectionWrapper where C: diesel::connection::Connection + 'static, + S: SpawnBlocking + Send, { async fn batch_execute(&mut self, query: &str) -> QueryResult<()> { let query = query.to_string(); @@ -92,7 +121,7 @@ mod implementation { } } - impl AsyncConnection for SyncConnectionWrapper + impl AsyncConnection for SyncConnectionWrapper where // Backend bounds ::Backend: std::default::Default + DieselReserveSpecialization, @@ -108,6 +137,8 @@ mod implementation { O: 'static + Send + for<'conn> diesel::row::Row<'conn, C::Backend>, for<'conn, 'query> ::Row<'conn, 'query>: IntoOwnedRow<'conn, ::Backend, OwnedRow = O>, + // SpawnBlocking bounds + S: SpawnBlocking + Send, { type LoadFuture<'conn, 'query> = BoxFuture<'query, QueryResult>>; type ExecuteFuture<'conn, 'query> = BoxFuture<'query, QueryResult>; @@ -118,10 +149,12 @@ mod implementation { async fn establish(database_url: &str) -> ConnectionResult { let database_url = database_url.to_string(); - tokio::task::spawn_blocking(move || C::establish(&database_url)) + let mut runtime = S::get_runtime(); + + runtime.spawn_blocking(move || C::establish(&database_url)) .await .unwrap_or_else(|e| Err(diesel::ConnectionError::BadConnection(e.to_string()))) - .map(|c| SyncConnectionWrapper::new(c)) + .map(move |c| SyncConnectionWrapper::with_runtime(c, runtime)) } fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query> @@ -209,44 +242,60 @@ mod implementation { /// A wrapper of a diesel transaction manager usable in async context. pub struct SyncTransactionManagerWrapper(PhantomData); - impl TransactionManager> for SyncTransactionManagerWrapper + impl TransactionManager> for SyncTransactionManagerWrapper where - SyncConnectionWrapper: AsyncConnection, + SyncConnectionWrapper: AsyncConnection, C: Connection + 'static, + S: SpawnBlocking, T: diesel::connection::TransactionManager + Send, { type TransactionStateData = T::TransactionStateData; - async fn begin_transaction(conn: &mut SyncConnectionWrapper) -> QueryResult<()> { + async fn begin_transaction(conn: &mut SyncConnectionWrapper) -> QueryResult<()> { conn.spawn_blocking(move |inner| T::begin_transaction(inner)) .await } - async fn commit_transaction(conn: &mut SyncConnectionWrapper) -> QueryResult<()> { + async fn commit_transaction(conn: &mut SyncConnectionWrapper) -> QueryResult<()> { conn.spawn_blocking(move |inner| T::commit_transaction(inner)) .await } - async fn rollback_transaction(conn: &mut SyncConnectionWrapper) -> QueryResult<()> { + async fn rollback_transaction(conn: &mut SyncConnectionWrapper) -> QueryResult<()> { conn.spawn_blocking(move |inner| T::rollback_transaction(inner)) .await } fn transaction_manager_status_mut( - conn: &mut SyncConnectionWrapper, + conn: &mut SyncConnectionWrapper, ) -> &mut TransactionManagerStatus { T::transaction_manager_status_mut(conn.exclusive_connection()) } } - impl SyncConnectionWrapper { + impl SyncConnectionWrapper { /// Builds a wrapper with this underlying sync connection pub fn new(connection: C) -> Self where C: Connection, + S: SpawnBlocking, + { + SyncConnectionWrapper { + inner: Arc::new(Mutex::new(connection)), + runtime: S::get_runtime(), + } + } + + /// Builds a wrapper with this underlying sync connection + /// and runtime for spawning blocking tasks + pub fn with_runtime(connection: C, runtime: S) -> Self + where + C: Connection, + S: SpawnBlocking, { SyncConnectionWrapper { inner: Arc::new(Mutex::new(connection)), + runtime, } } @@ -283,9 +332,10 @@ mod implementation { where C: Connection + 'static, R: Send + 'static, + S: SpawnBlocking, { let inner = self.inner.clone(); - tokio::task::spawn_blocking(move || { + self.runtime.spawn_blocking(move || { let mut inner = inner.lock().unwrap_or_else(|poison| { // try to be resilient by providing the guard inner.clear_poison(); @@ -293,7 +343,7 @@ mod implementation { }); task(&mut inner) }) - .unwrap_or_else(|err| QueryResult::Err(from_tokio_join_error(err))) + .unwrap_or_else(|err| QueryResult::Err(from_spawn_blocking_error(err))) .boxed() } @@ -316,6 +366,8 @@ mod implementation { // Arguments/Return bounds Q: QueryFragment + QueryId, R: Send + 'static, + // SpawnBlocking bounds + S: SpawnBlocking, { let backend = C::Backend::default(); @@ -383,4 +435,43 @@ mod implementation { Self::TransactionManager::is_broken_transaction_manager(self) } } + + #[cfg(feature = "tokio")] + pub enum Tokio { + Handle(tokio::runtime::Handle), + Runtime(tokio::runtime::Runtime) + } + + #[cfg(feature = "tokio")] + impl SpawnBlocking for Tokio { + fn spawn_blocking<'a, R>( + &mut self, + task: impl FnOnce() -> R + Send + 'static, + ) -> BoxFuture<'a, Result>> + where + R: Send + 'static, + { + let fut = match self { + Tokio::Handle(handle) => handle.spawn_blocking(task), + Tokio::Runtime(runtime) => runtime.spawn_blocking(task) + }; + + fut + .map_err(|err| Box::from(err)) + .boxed() + } + + fn get_runtime() -> Self { + if let Ok(handle) = tokio::runtime::Handle::try_current() { + Tokio::Handle(handle) + } else { + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_io() + .build() + .unwrap(); + + Tokio::Runtime(runtime) + } + } + } } From 4206ed92b666ef1a4a590125e18efbca9de2f2aa Mon Sep 17 00:00:00 2001 From: bwqr Date: Fri, 23 May 2025 19:24:53 +0300 Subject: [PATCH 3/6] move documentation of `SyncConnectionWrapper` to where it is made public --- src/sync_connection_wrapper/mod.rs | 85 +++++++++++++++++------------- 1 file changed, 47 insertions(+), 38 deletions(-) diff --git a/src/sync_connection_wrapper/mod.rs b/src/sync_connection_wrapper/mod.rs index 1462196..545465e 100644 --- a/src/sync_connection_wrapper/mod.rs +++ b/src/sync_connection_wrapper/mod.rs @@ -32,9 +32,56 @@ pub trait SpawnBlocking { fn get_runtime() -> Self; } +/// A wrapper of a [`diesel::connection::Connection`] usable in async context. +/// +/// It implements AsyncConnection if [`diesel::connection::Connection`] fullfils requirements: +/// * it's a [`diesel::connection::LoadConnection`] +/// * its [`diesel::connection::Connection::Backend`] has a [`diesel::query_builder::BindCollector`] implementing [`diesel::query_builder::MoveableBindCollector`] +/// * its [`diesel::connection::LoadConnection::Row`] implements [`diesel::row::IntoOwnedRow`] +/// +/// Internally this wrapper type will use `spawn_blocking` on tokio +/// to execute the request on the inner connection. This implies a +/// dependency on tokio and that the runtime is running. +/// +/// Note that only SQLite is supported at the moment. +/// +/// # Examples +/// +/// ```rust +/// # include!("../doctest_setup.rs"); +/// use diesel_async::RunQueryDsl; +/// use schema::users; +/// +/// async fn some_async_fn() { +/// # let database_url = database_url(); +/// use diesel_async::AsyncConnection; +/// use diesel::sqlite::SqliteConnection; +/// let mut conn = +/// SyncConnectionWrapper::::establish(&database_url).await.unwrap(); +/// # create_tables(&mut conn).await; +/// +/// let all_users = users::table.load::<(i32, String)>(&mut conn).await.unwrap(); +/// # assert_eq!(all_users.len(), 2); +/// } +/// +/// # #[cfg(feature = "sqlite")] +/// # #[tokio::main] +/// # async fn main() { +/// # some_async_fn().await; +/// # } +/// ``` #[cfg(feature = "tokio")] pub type SyncConnectionWrapper = self::implementation::SyncConnectionWrapper; +/// A wrapper of a [`diesel::connection::Connection`] usable in async context. +/// +/// It implements AsyncConnection if [`diesel::connection::Connection`] fullfils requirements: +/// * it's a [`diesel::connection::LoadConnection`] +/// * its [`diesel::connection::Connection::Backend`] has a [`diesel::query_builder::BindCollector`] implementing [`diesel::query_builder::MoveableBindCollector`] +/// * its [`diesel::connection::LoadConnection::Row`] implements [`diesel::row::IntoOwnedRow`] +/// +/// Internally this wrapper type will use `spawn_blocking` on given type implementing [`SpawnBlocking`] trait +/// to execute the request on the inner connection. #[cfg(not(feature = "tokio"))] pub use self::implementation::SyncConnectionWrapper; @@ -66,44 +113,6 @@ mod implementation { ) } - /// A wrapper of a [`diesel::connection::Connection`] usable in async context. - /// - /// It implements AsyncConnection if [`diesel::connection::Connection`] fullfils requirements: - /// * it's a [`diesel::connection::LoadConnection`] - /// * its [`diesel::connection::Connection::Backend`] has a [`diesel::query_builder::BindCollector`] implementing [`diesel::query_builder::MoveableBindCollector`] - /// * its [`diesel::connection::LoadConnection::Row`] implements [`diesel::row::IntoOwnedRow`] - /// - /// Internally this wrapper type will use `spawn_blocking` on tokio - /// to execute the request on the inner connection. This implies a - /// dependency on tokio and that the runtime is running. - /// - /// Note that only SQLite is supported at the moment. - /// - /// # Examples - /// - /// ```rust - /// # include!("../doctest_setup.rs"); - /// use diesel_async::RunQueryDsl; - /// use schema::users; - /// - /// async fn some_async_fn() { - /// # let database_url = database_url(); - /// use diesel_async::AsyncConnection; - /// use diesel::sqlite::SqliteConnection; - /// let mut conn = - /// SyncConnectionWrapper::::establish(&database_url).await.unwrap(); - /// # create_tables(&mut conn).await; - /// - /// let all_users = users::table.load::<(i32, String)>(&mut conn).await.unwrap(); - /// # assert_eq!(all_users.len(), 2); - /// } - /// - /// # #[cfg(feature = "sqlite")] - /// # #[tokio::main] - /// # async fn main() { - /// # some_async_fn().await; - /// # } - /// ``` pub struct SyncConnectionWrapper { inner: Arc>, runtime: S, From 1260dc2110590ada8b8d519112101a70136cb43e Mon Sep 17 00:00:00 2001 From: bwqr Date: Fri, 23 May 2025 19:34:25 +0300 Subject: [PATCH 4/6] add missing generic argument for `SyncConnectionWrapper struct while implementing PoolableConnection --- src/sync_connection_wrapper/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sync_connection_wrapper/mod.rs b/src/sync_connection_wrapper/mod.rs index 545465e..beb44d5 100644 --- a/src/sync_connection_wrapper/mod.rs +++ b/src/sync_connection_wrapper/mod.rs @@ -436,7 +436,7 @@ mod implementation { feature = "mobc", feature = "r2d2" ))] - impl crate::pooled_connection::PoolableConnection for SyncConnectionWrapper + impl crate::pooled_connection::PoolableConnection for SyncConnectionWrapper where Self: AsyncConnection, { From 5074ca57cf17aa2b7c56c68ed58c865357ee23c2 Mon Sep 17 00:00:00 2001 From: bwqr Date: Fri, 23 May 2025 20:19:04 +0300 Subject: [PATCH 5/6] do not enable io on default tokio runtime for `SyncConnectionWrapper` --- src/sync_connection_wrapper/mod.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/sync_connection_wrapper/mod.rs b/src/sync_connection_wrapper/mod.rs index beb44d5..7361286 100644 --- a/src/sync_connection_wrapper/mod.rs +++ b/src/sync_connection_wrapper/mod.rs @@ -475,7 +475,6 @@ mod implementation { Tokio::Handle(handle) } else { let runtime = tokio::runtime::Builder::new_current_thread() - .enable_io() .build() .unwrap(); From 6074be6f33b75fa28893ebb9c522990e984dd717 Mon Sep 17 00:00:00 2001 From: bwqr Date: Fri, 23 May 2025 20:19:42 +0300 Subject: [PATCH 6/6] run rustfmt on `src/sync_connection_wrapper/mod.rs` --- src/sync_connection_wrapper/mod.rs | 59 ++++++++++++++++++------------ 1 file changed, 35 insertions(+), 24 deletions(-) diff --git a/src/sync_connection_wrapper/mod.rs b/src/sync_connection_wrapper/mod.rs index 7361286..75a2122 100644 --- a/src/sync_connection_wrapper/mod.rs +++ b/src/sync_connection_wrapper/mod.rs @@ -6,8 +6,8 @@ //! //! * using a sync Connection implementation in async context //! * using the same code base for async crates needing multiple backends -use std::error::Error; use futures_util::future::BoxFuture; +use std::error::Error; #[cfg(feature = "sqlite")] mod sqlite; @@ -71,7 +71,8 @@ pub trait SpawnBlocking { /// # } /// ``` #[cfg(feature = "tokio")] -pub type SyncConnectionWrapper = self::implementation::SyncConnectionWrapper; +pub type SyncConnectionWrapper = + self::implementation::SyncConnectionWrapper; /// A wrapper of a [`diesel::connection::Connection`] usable in async context. /// @@ -106,7 +107,9 @@ mod implementation { use super::*; - fn from_spawn_blocking_error(error: Box) -> diesel::result::Error { + fn from_spawn_blocking_error( + error: Box, + ) -> diesel::result::Error { diesel::result::Error::DatabaseError( diesel::result::DatabaseErrorKind::UnableToSendCommand, Box::new(error.to_string()), @@ -149,18 +152,21 @@ mod implementation { // SpawnBlocking bounds S: SpawnBlocking + Send, { - type LoadFuture<'conn, 'query> = BoxFuture<'query, QueryResult>>; + type LoadFuture<'conn, 'query> = + BoxFuture<'query, QueryResult>>; type ExecuteFuture<'conn, 'query> = BoxFuture<'query, QueryResult>; type Stream<'conn, 'query> = BoxStream<'static, QueryResult>>; type Row<'conn, 'query> = O; type Backend = ::Backend; - type TransactionManager = SyncTransactionManagerWrapper<::TransactionManager>; + type TransactionManager = + SyncTransactionManagerWrapper<::TransactionManager>; async fn establish(database_url: &str) -> ConnectionResult { let database_url = database_url.to_string(); let mut runtime = S::get_runtime(); - runtime.spawn_blocking(move || C::establish(&database_url)) + runtime + .spawn_blocking(move || C::establish(&database_url)) .await .unwrap_or_else(|e| Err(diesel::ConnectionError::BadConnection(e.to_string()))) .map(move |c| SyncConnectionWrapper::with_runtime(c, runtime)) @@ -192,16 +198,22 @@ mod implementation { .boxed() } - fn execute_returning_count<'query, T>(&mut self, source: T) -> Self::ExecuteFuture<'_, 'query> + fn execute_returning_count<'query, T>( + &mut self, + source: T, + ) -> Self::ExecuteFuture<'_, 'query> where T: QueryFragment + QueryId, { - self.execute_with_prepared_query(source, |conn, query| conn.execute_returning_count(&query)) + self.execute_with_prepared_query(source, |conn, query| { + conn.execute_returning_count(&query) + }) } fn transaction_state( &mut self, - ) -> &mut >::TransactionStateData { + ) -> &mut >::TransactionStateData + { self.exclusive_connection().transaction_state() } @@ -344,16 +356,17 @@ mod implementation { S: SpawnBlocking, { let inner = self.inner.clone(); - self.runtime.spawn_blocking(move || { - let mut inner = inner.lock().unwrap_or_else(|poison| { - // try to be resilient by providing the guard - inner.clear_poison(); - poison.into_inner() - }); - task(&mut inner) - }) - .unwrap_or_else(|err| QueryResult::Err(from_spawn_blocking_error(err))) - .boxed() + self.runtime + .spawn_blocking(move || { + let mut inner = inner.lock().unwrap_or_else(|poison| { + // try to be resilient by providing the guard + inner.clear_poison(); + poison.into_inner() + }); + task(&mut inner) + }) + .unwrap_or_else(|err| QueryResult::Err(from_spawn_blocking_error(err))) + .boxed() } fn execute_with_prepared_query<'a, MD, Q, R>( @@ -448,7 +461,7 @@ mod implementation { #[cfg(feature = "tokio")] pub enum Tokio { Handle(tokio::runtime::Handle), - Runtime(tokio::runtime::Runtime) + Runtime(tokio::runtime::Runtime), } #[cfg(feature = "tokio")] @@ -462,12 +475,10 @@ mod implementation { { let fut = match self { Tokio::Handle(handle) => handle.spawn_blocking(task), - Tokio::Runtime(runtime) => runtime.spawn_blocking(task) + Tokio::Runtime(runtime) => runtime.spawn_blocking(task), }; - fut - .map_err(|err| Box::from(err)) - .boxed() + fut.map_err(|err| Box::from(err)).boxed() } fn get_runtime() -> Self {