Skip to content

Refactor paginating queries to a trait #751

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 7, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions src/keyword.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use pg::rows::Row;

use {Model, Crate};
use db::RequestTransaction;
use pagination::Paginate;
use schema::*;
use util::{RequestUtils, CargoResult};

Expand Down Expand Up @@ -138,28 +139,22 @@ impl Model for Keyword {

/// Handles the `GET /keywords` route.
pub fn index(req: &mut Request) -> CargoResult<Response> {
use diesel::expression::dsl::sql;
use diesel::types::BigInt;
use schema::keywords;

let conn = req.db_conn()?;
let (offset, limit) = req.pagination(10, 100)?;
let query = req.query();
let sort = query.get("sort").map(|s| &s[..]).unwrap_or("alpha");

let mut query = keywords::table
.select((keywords::all_columns, sql::<BigInt>("COUNT(*) OVER ()")))
.limit(limit)
.offset(offset)
.into_boxed();
let mut query = keywords::table.into_boxed();

if sort == "crates" {
query = query.order(keywords::crates_cnt.desc());
} else {
query = query.order(keywords::keyword.asc());
}

let data = query.load::<(Keyword, i64)>(&*conn)?;
let data = query.paginate(limit, offset).load::<(Keyword, i64)>(&*conn)?;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a little unfortunate that selecting the total number of results is now gone from this location, but we have to still remember to include the i64 type in the load type here. Not sure what to do differently though, maybe if we made a paginated_load that only took the types this code knows about and inserts the i64?

Not a huge deal, this PR is still a huge improvement for factoring out duplication and raw SQL everywhere, just a thought for a possible next step and/or if pagination becomes a more general diesel thing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I doubt I'd add it to Diesel, but this could definitely be extracted to a crate. I have considered adding a method that lets you specify the type you're going to deserialize to outside of load. (The primary motivation is that Version::belonging_to(&crate).load::<Version> bugs me) If that were present, we could definitely have this be .paginate::<Keyword> and have the paginate method call that with (T, i64). I haven't explored how much work it'd be to add that method though

let total = data.get(0).map(|&(_, t)| t).unwrap_or(0);
let kws = data.into_iter()
.map(|(k, _)| k.encodable())
Expand Down
29 changes: 11 additions & 18 deletions src/krate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use download::{VersionDownload, EncodableVersionDownload};
use git;
use keyword::{EncodableKeyword, CrateKeyword};
use owner::{EncodableOwner, Owner, Rights, OwnerKind, Team, rights, CrateOwner};
use pagination::Paginate;
use schema::*;
use upload;
use user::RequestUser;
Expand Down Expand Up @@ -781,22 +782,16 @@ impl Model for Crate {
/// Handles the `GET /crates` route.
#[allow(trivial_casts)]
pub fn index(req: &mut Request) -> CargoResult<Response> {
use diesel::expression::dsl::sql;
use diesel::types::{BigInt, Bool};
use diesel::expression::AsExpression;
use diesel::types::Bool;

let conn = req.db_conn()?;
let (offset, limit) = req.pagination(10, 100)?;
let params = req.query();
let sort = params.get("sort").map(|s| &**s).unwrap_or("alpha");

let mut query = crates::table
.select((
ALL_COLUMNS,
sql::<BigInt>("COUNT(*) OVER ()"),
sql::<Bool>("false"),
))
.limit(limit)
.offset(offset)
.select((ALL_COLUMNS, AsExpression::<Bool>::as_expression(false)))
.into_boxed();

if sort == "downloads" {
Expand All @@ -813,11 +808,7 @@ pub fn index(req: &mut Request) -> CargoResult<Response> {
),
));

query = query.select((
ALL_COLUMNS,
sql::<BigInt>("COUNT(*) OVER()"),
crates::name.eq(q_string),
));
query = query.select((ALL_COLUMNS, crates::name.eq(q_string)));
let perfect_match = crates::name.eq(q_string).desc();
if sort == "downloads" {
query = query.order((perfect_match, crates::downloads.desc()));
Expand Down Expand Up @@ -887,12 +878,14 @@ pub fn index(req: &mut Request) -> CargoResult<Response> {
));
}

let data = query.load::<(Crate, i64, bool)>(&*conn)?;
let total = data.get(0).map(|&(_, t, _)| t).unwrap_or(0);
let data = query.paginate(limit, offset).load::<((Crate, bool), i64)>(
&*conn,
)?;
let total = data.first().map(|&(_, t)| t).unwrap_or(0);
let crates = data.iter()
.map(|&(ref c, _, _)| c.clone())
.map(|&((ref c, _), _)| c.clone())
.collect::<Vec<_>>();
let perfect_matches = data.into_iter().map(|(_, _, b)| b).collect::<Vec<_>>();
let perfect_matches = data.into_iter().map(|((_, b), _)| b).collect::<Vec<_>>();

let versions = Version::belonging_to(&crates)
.load::<Version>(&*conn)?
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ pub mod user;
pub mod util;
pub mod version;

mod pagination;

#[derive(PartialEq, Eq, Clone, Copy, Debug)]
pub enum Env {
Development,
Expand Down
43 changes: 43 additions & 0 deletions src/pagination.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
use diesel::prelude::*;
use diesel::query_builder::*;
use diesel::types::BigInt;
use diesel::pg::Pg;

pub struct Paginated<T> {
query: T,
limit: i64,
offset: i64,
}

pub trait Paginate: AsQuery + Sized {
fn paginate(self, limit: i64, offset: i64) -> Paginated<Self::Query> {
Paginated {
query: self.as_query(),
limit,
offset,
}
}
}

impl<T: AsQuery> Paginate for T {}

impl<T: Query> Query for Paginated<T> {
type SqlType = (T::SqlType, BigInt);
}

impl<T> QueryFragment<Pg> for Paginated<T>
where
T: QueryFragment<Pg>,
{
fn walk_ast(&self, mut out: AstPass<Pg>) -> QueryResult<()> {
out.push_sql("SELECT *, COUNT(*) OVER () FROM (");
self.query.walk_ast(out.reborrow())?;
out.push_sql(") t LIMIT ");
out.push_bind_param::<BigInt, _>(&self.limit)?;
out.push_sql(" OFFSET ");
out.push_bind_param::<BigInt, _>(&self.offset)?;
Ok(())
}
}

impl_query_id!(Paginated<T>);
19 changes: 7 additions & 12 deletions src/user/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use std::borrow::Cow;
use app::RequestApp;
use db::RequestTransaction;
use krate::Follow;
use pagination::Paginate;
use schema::*;
use util::errors::NotFound;
use util::{RequestUtils, CargoResult, internal, ChainError, human};
Expand Down Expand Up @@ -377,8 +378,7 @@ pub fn show_team(req: &mut Request) -> CargoResult<Response> {

/// Handles the `GET /me/updates` route.
pub fn updates(req: &mut Request) -> CargoResult<Response> {
use diesel::expression::dsl::{any, sql};
use diesel::types::BigInt;
use diesel::expression::dsl::any;

let user = req.user()?;
let (offset, limit) = req.pagination(10, 100)?;
Expand All @@ -389,21 +389,16 @@ pub fn updates(req: &mut Request) -> CargoResult<Response> {
.inner_join(crates::table)
.filter(crates::id.eq(any(followed_crates)))
.order(versions::created_at.desc())
.limit(limit)
.offset(offset)
.select((
versions::all_columns,
crates::name,
sql::<BigInt>("COUNT(*) OVER ()"),
))
.load::<(Version, String, i64)>(&*conn)?;
.select((versions::all_columns, crates::name))
.paginate(limit, offset)
.load::<((Version, String), i64)>(&*conn)?;

let more = data.get(0)
.map(|&(_, _, count)| count > offset + limit)
.map(|&(_, count)| count > offset + limit)
.unwrap_or(false);

let versions = data.into_iter()
.map(|(version, crate_name, _)| version.encodable(&crate_name))
.map(|((version, crate_name), _)| version.encodable(&crate_name))
.collect();

#[derive(RustcEncodable)]
Expand Down