From 516893f1f73830798e978ffc055dd78dd6049c7e Mon Sep 17 00:00:00 2001 From: Valentin Tolmer Date: Tue, 27 Sep 2022 05:03:23 +0200 Subject: [PATCH] server: Fix query building of chained ands/ors Fixes #303 --- server/src/domain/sql_backend_handler.rs | 114 +++++++++++++---------- 1 file changed, 64 insertions(+), 50 deletions(-) diff --git a/server/src/domain/sql_backend_handler.rs b/server/src/domain/sql_backend_handler.rs index 566a13d..843315c 100644 --- a/server/src/domain/sql_backend_handler.rs +++ b/server/src/domain/sql_backend_handler.rs @@ -2,7 +2,7 @@ use super::{error::*, handler::*, sql_tables::*}; use crate::infra::configuration::Configuration; use async_trait::async_trait; use futures_util::StreamExt; -use sea_query::{Alias, Cond, Expr, Iden, Order, Query, SimpleExpr}; +use sea_query::{Alias, Cond, Expr, Iden, Order, Query}; use sea_query_binder::SqlxBinder; use sqlx::{query_as_with, query_with, FromRow, Row}; use std::collections::HashSet; @@ -23,90 +23,89 @@ impl SqlBackendHandler { struct RequiresGroup(bool); // Returns the condition for the SQL query, and whether it requires joining with the groups table. -fn get_user_filter_expr(filter: UserRequestFilter) -> (RequiresGroup, SimpleExpr) { +fn get_user_filter_expr(filter: UserRequestFilter) -> (RequiresGroup, Cond) { + use sea_query::IntoCondition; use UserRequestFilter::*; - fn get_repeated_filter( - fs: Vec, - field: &dyn Fn(SimpleExpr, SimpleExpr) -> SimpleExpr, - ) -> (RequiresGroup, SimpleExpr) { + fn get_repeated_filter(fs: Vec, condition: Cond) -> (RequiresGroup, Cond) { let mut requires_group = false; - let mut it = fs.into_iter(); - let first_expr = match it.next() { - None => return (RequiresGroup(false), Expr::value(true)), - Some(f) => { - let (group, filter) = get_user_filter_expr(f); - requires_group |= group.0; - filter - } - }; - let filter = it.fold(first_expr, |e, f| { + let filter = fs.into_iter().fold(condition, |c, f| { let (group, filters) = get_user_filter_expr(f); requires_group |= group.0; - field(e, filters) + c.add(filters) }); (RequiresGroup(requires_group), filter) } match filter { - And(fs) => get_repeated_filter(fs, &SimpleExpr::and), - Or(fs) => get_repeated_filter(fs, &SimpleExpr::or), + And(fs) => get_repeated_filter(fs, Cond::all()), + Or(fs) => get_repeated_filter(fs, Cond::any()), Not(f) => { let (requires_group, filters) = get_user_filter_expr(*f); - (requires_group, Expr::not(Expr::expr(filters))) + (requires_group, filters.not()) } UserId(user_id) => ( RequiresGroup(false), - Expr::col((Users::Table, Users::UserId)).eq(user_id), + Expr::col((Users::Table, Users::UserId)) + .eq(user_id) + .into_condition(), ), Equality(s1, s2) => ( RequiresGroup(false), if s1 == Users::DisplayName.to_string() { - Expr::col((Users::Table, Users::DisplayName)).eq(s2) + Expr::col((Users::Table, Users::DisplayName)) + .eq(s2) + .into_condition() } else if s1 == Users::UserId.to_string() { panic!("User id should be wrapped") } else { - Expr::expr(Expr::cust(&s1)).eq(s2) + Expr::expr(Expr::cust(&s1)).eq(s2).into_condition() }, ), MemberOf(group) => ( RequiresGroup(true), - Expr::col((Groups::Table, Groups::DisplayName)).eq(group), + Expr::col((Groups::Table, Groups::DisplayName)) + .eq(group) + .into_condition(), ), MemberOfId(group_id) => ( RequiresGroup(true), - Expr::col((Groups::Table, Groups::GroupId)).eq(group_id), + Expr::col((Groups::Table, Groups::GroupId)) + .eq(group_id) + .into_condition(), ), } } // Returns the condition for the SQL query, and whether it requires joining with the groups table. -fn get_group_filter_expr(filter: GroupRequestFilter) -> SimpleExpr { +fn get_group_filter_expr(filter: GroupRequestFilter) -> Cond { + use sea_query::IntoCondition; use GroupRequestFilter::*; - fn get_repeated_filter( - fs: Vec, - field: &dyn Fn(SimpleExpr, SimpleExpr) -> SimpleExpr, - ) -> SimpleExpr { - let mut it = fs.into_iter(); - let first_expr = match it.next() { - None => return Expr::value(true), - Some(f) => get_group_filter_expr(f), - }; - it.fold(first_expr, |e, f| field(e, get_group_filter_expr(f))) - } match filter { - And(fs) => get_repeated_filter(fs, &SimpleExpr::and), - Or(fs) => get_repeated_filter(fs, &SimpleExpr::or), - Not(f) => Expr::not(Expr::expr(get_group_filter_expr(*f))), - DisplayName(name) => Expr::col((Groups::Table, Groups::DisplayName)).eq(name), - GroupId(id) => Expr::col((Groups::Table, Groups::GroupId)).eq(id.0), - Uuid(uuid) => Expr::col((Groups::Table, Groups::Uuid)).eq(uuid.to_string()), + And(fs) => fs + .into_iter() + .fold(Cond::all(), |c, f| c.add(get_group_filter_expr(f))), + Or(fs) => fs + .into_iter() + .fold(Cond::any(), |c, f| c.add(get_group_filter_expr(f))), + Not(f) => get_group_filter_expr(*f).not(), + DisplayName(name) => Expr::col((Groups::Table, Groups::DisplayName)) + .eq(name) + .into_condition(), + GroupId(id) => Expr::col((Groups::Table, Groups::GroupId)) + .eq(id.0) + .into_condition(), + Uuid(uuid) => Expr::col((Groups::Table, Groups::Uuid)) + .eq(uuid.to_string()) + .into_condition(), // WHERE (group_id in (SELECT group_id FROM memberships WHERE user_id = user)) - Member(user) => Expr::col((Memberships::Table, Memberships::GroupId)).in_subquery( - Query::select() - .column(Memberships::GroupId) - .from(Memberships::Table) - .cond_where(Expr::col(Memberships::UserId).eq(user)) - .take(), - ), + Member(user) => Expr::col((Memberships::Table, Memberships::GroupId)) + .in_subquery( + Query::select() + .column(Memberships::GroupId) + .from(Memberships::Table) + .cond_where(Expr::col(Memberships::UserId).eq(user)) + .take(), + ) + .into_condition(), } } @@ -697,6 +696,21 @@ mod tests { .await; assert_eq!(users, vec!["bob", "john"]); } + { + let users = get_user_names( + &handler, + Some(UserRequestFilter::And(vec![ + UserRequestFilter::Or(vec![]), + UserRequestFilter::Or(vec![ + UserRequestFilter::UserId(UserId::new("bob")), + UserRequestFilter::UserId(UserId::new("John")), + UserRequestFilter::UserId(UserId::new("random")), + ]), + ])), + ) + .await; + assert_eq!(users, vec!["bob", "john"]); + } { let users = get_user_names( &handler,