diff --git a/server/src/domain/handler.rs b/server/src/domain/handler.rs index 4f7b902..d7a5b68 100644 --- a/server/src/domain/handler.rs +++ b/server/src/domain/handler.rs @@ -272,26 +272,33 @@ pub struct UserAndGroups { } #[async_trait] -pub trait BackendHandler: Clone + Send { +pub trait GroupBackendHandler { + async fn list_groups(&self, filters: Option) -> Result>; + async fn get_group_details(&self, group_id: GroupId) -> Result; + async fn update_group(&self, request: UpdateGroupRequest) -> Result<()>; + async fn create_group(&self, group_name: &str) -> Result; + async fn delete_group(&self, group_id: GroupId) -> Result<()>; +} + +#[async_trait] +pub trait UserBackendHandler { async fn list_users( &self, filters: Option, get_groups: bool, ) -> Result>; - async fn list_groups(&self, filters: Option) -> Result>; async fn get_user_details(&self, user_id: &UserId) -> Result; - async fn get_group_details(&self, group_id: GroupId) -> Result; async fn create_user(&self, request: CreateUserRequest) -> Result<()>; async fn update_user(&self, request: UpdateUserRequest) -> Result<()>; - async fn update_group(&self, request: UpdateGroupRequest) -> Result<()>; async fn delete_user(&self, user_id: &UserId) -> Result<()>; - async fn create_group(&self, group_name: &str) -> Result; - async fn delete_group(&self, group_id: GroupId) -> Result<()>; async fn add_user_to_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()>; async fn remove_user_from_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()>; async fn get_user_groups(&self, user_id: &UserId) -> Result>; } +#[async_trait] +pub trait BackendHandler: Clone + Send + GroupBackendHandler + UserBackendHandler {} + #[cfg(test)] mockall::mock! { pub TestBackendHandler{} @@ -299,22 +306,27 @@ mockall::mock! { fn clone(&self) -> Self; } #[async_trait] - impl BackendHandler for TestBackendHandler { - async fn list_users(&self, filters: Option, get_groups: bool) -> Result>; + impl GroupBackendHandler for TestBackendHandler { async fn list_groups(&self, filters: Option) -> Result>; - async fn get_user_details(&self, user_id: &UserId) -> Result; async fn get_group_details(&self, group_id: GroupId) -> Result; - async fn create_user(&self, request: CreateUserRequest) -> Result<()>; - async fn update_user(&self, request: UpdateUserRequest) -> Result<()>; async fn update_group(&self, request: UpdateGroupRequest) -> Result<()>; - async fn delete_user(&self, user_id: &UserId) -> Result<()>; async fn create_group(&self, group_name: &str) -> Result; async fn delete_group(&self, group_id: GroupId) -> Result<()>; + } + #[async_trait] + impl UserBackendHandler for TestBackendHandler { + async fn list_users(&self, filters: Option, get_groups: bool) -> Result>; + async fn get_user_details(&self, user_id: &UserId) -> Result; + async fn create_user(&self, request: CreateUserRequest) -> Result<()>; + async fn update_user(&self, request: UpdateUserRequest) -> Result<()>; + async fn delete_user(&self, user_id: &UserId) -> Result<()>; async fn get_user_groups(&self, user_id: &UserId) -> Result>; async fn add_user_to_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()>; async fn remove_user_from_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()>; } #[async_trait] + impl BackendHandler for TestBackendHandler {} + #[async_trait] impl LoginHandler for TestBackendHandler { async fn bind(&self, request: BindRequest) -> Result<()>; } diff --git a/server/src/domain/mod.rs b/server/src/domain/mod.rs index b533c9f..b139c59 100644 --- a/server/src/domain/mod.rs +++ b/server/src/domain/mod.rs @@ -3,6 +3,8 @@ pub mod handler; pub mod ldap; pub mod opaque_handler; pub mod sql_backend_handler; +pub mod sql_group_backend_handler; pub mod sql_migrations; pub mod sql_opaque_handler; pub mod sql_tables; +pub mod sql_user_backend_handler; diff --git a/server/src/domain/sql_backend_handler.rs b/server/src/domain/sql_backend_handler.rs index fde2b13..0b8478d 100644 --- a/server/src/domain/sql_backend_handler.rs +++ b/server/src/domain/sql_backend_handler.rs @@ -1,14 +1,8 @@ -use super::{error::*, handler::*, sql_tables::*}; +use super::{handler::*, sql_tables::*}; use crate::infra::configuration::Configuration; use async_trait::async_trait; -use futures_util::StreamExt; -use sea_query::{Alias, Cond, Expr, Iden, Order, Query}; -use sea_query_binder::SqlxBinder; -use sqlx::{query_as_with, query_with, FromRow, Row}; -use std::collections::HashSet; -use tracing::{debug, instrument}; -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct SqlBackendHandler { pub(crate) config: Configuration, pub(crate) sql_pool: Pool, @@ -20,536 +14,31 @@ impl SqlBackendHandler { } } -struct RequiresGroup(bool); - -// Returns the condition for the SQL query, and whether it requires joining with the groups table. -fn get_user_filter_expr(filter: UserRequestFilter) -> (RequiresGroup, Cond) { - use sea_query::IntoCondition; - use UserRequestFilter::*; - fn get_repeated_filter(fs: Vec, condition: Cond) -> (RequiresGroup, Cond) { - let mut requires_group = false; - let filter = fs.into_iter().fold(condition, |c, f| { - let (group, filters) = get_user_filter_expr(f); - requires_group |= group.0; - c.add(filters) - }); - (RequiresGroup(requires_group), filter) - } - match filter { - And(fs) => get_repeated_filter(fs, Cond::all()), - Or(fs) => get_repeated_filter(fs, Cond::any()), - Not(f) => { - let (requires_group, filters) = get_user_filter_expr(*f); - (requires_group, filters.not()) - } - UserId(user_id) => ( - RequiresGroup(false), - Expr::col((Users::Table, Users::UserId)) - .eq(user_id) - .into_condition(), - ), - Equality(s1, s2) => ( - RequiresGroup(false), - if s1 == Users::DisplayName.to_string() { - Expr::col((Users::Table, Users::DisplayName)) - .eq(s2) - .into_condition() - } else if s1 == Users::UserId.to_string() { - panic!("User id should be wrapped") - } else { - Expr::expr(Expr::cust(&s1)).eq(s2).into_condition() - }, - ), - MemberOf(group) => ( - RequiresGroup(true), - Expr::col((Groups::Table, Groups::DisplayName)) - .eq(group) - .into_condition(), - ), - MemberOfId(group_id) => ( - RequiresGroup(true), - Expr::col((Groups::Table, Groups::GroupId)) - .eq(group_id) - .into_condition(), - ), - } -} - -// Returns the condition for the SQL query, and whether it requires joining with the groups table. -fn get_group_filter_expr(filter: GroupRequestFilter) -> Cond { - use sea_query::IntoCondition; - use GroupRequestFilter::*; - match filter { - And(fs) => fs - .into_iter() - .fold(Cond::all(), |c, f| c.add(get_group_filter_expr(f))), - Or(fs) => fs - .into_iter() - .fold(Cond::any(), |c, f| c.add(get_group_filter_expr(f))), - Not(f) => get_group_filter_expr(*f).not(), - DisplayName(name) => Expr::col((Groups::Table, Groups::DisplayName)) - .eq(name) - .into_condition(), - GroupId(id) => Expr::col((Groups::Table, Groups::GroupId)) - .eq(id.0) - .into_condition(), - Uuid(uuid) => Expr::col((Groups::Table, Groups::Uuid)) - .eq(uuid.to_string()) - .into_condition(), - // WHERE (group_id in (SELECT group_id FROM memberships WHERE user_id = user)) - Member(user) => Expr::col((Memberships::Table, Memberships::GroupId)) - .in_subquery( - Query::select() - .column(Memberships::GroupId) - .from(Memberships::Table) - .cond_where(Expr::col(Memberships::UserId).eq(user)) - .take(), - ) - .into_condition(), - } -} - #[async_trait] -impl BackendHandler for SqlBackendHandler { - #[instrument(skip_all, level = "debug", ret, err)] - async fn list_users( - &self, - filters: Option, - get_groups: bool, - ) -> Result> { - debug!(?filters, get_groups); - let (query, values) = { - let mut query_builder = Query::select() - .column((Users::Table, Users::UserId)) - .column(Users::Email) - .column((Users::Table, Users::DisplayName)) - .column(Users::FirstName) - .column(Users::LastName) - .column(Users::Avatar) - .column((Users::Table, Users::CreationDate)) - .column((Users::Table, Users::Uuid)) - .from(Users::Table) - .order_by((Users::Table, Users::UserId), Order::Asc) - .to_owned(); - let add_join_group_tables = |builder: &mut sea_query::SelectStatement| { - builder - .left_join( - Memberships::Table, - Expr::tbl(Users::Table, Users::UserId) - .equals(Memberships::Table, Memberships::UserId), - ) - .left_join( - Groups::Table, - Expr::tbl(Memberships::Table, Memberships::GroupId) - .equals(Groups::Table, Groups::GroupId), - ); - }; - if get_groups { - add_join_group_tables(&mut query_builder); - query_builder - .column((Groups::Table, Groups::GroupId)) - .expr_as( - Expr::col((Groups::Table, Groups::DisplayName)), - Alias::new("group_display_name"), - ) - .expr_as( - Expr::col((Groups::Table, Groups::CreationDate)), - sea_query::Alias::new("group_creation_date"), - ) - .expr_as( - Expr::col((Groups::Table, Groups::Uuid)), - sea_query::Alias::new("group_uuid"), - ) - .order_by(Alias::new("group_display_name"), Order::Asc); - } - if let Some(filter) = filters { - if filter == UserRequestFilter::Not(Box::new(UserRequestFilter::And(Vec::new()))) { - return Ok(Vec::new()); - } - if filter != UserRequestFilter::And(Vec::new()) - && filter != UserRequestFilter::Or(Vec::new()) - { - let (RequiresGroup(requires_group), condition) = get_user_filter_expr(filter); - query_builder.cond_where(condition); - if requires_group && !get_groups { - add_join_group_tables(&mut query_builder); - } - } - } - - query_builder.build_sqlx(DbQueryBuilder {}) - }; - - debug!(%query); - - // For group_by. - use itertools::Itertools; - let mut users = Vec::new(); - // The rows are returned sorted by user_id. We group them by - // this key which gives us one element (`rows`) per group. - for (_, rows) in &query_with(&query, values) - .fetch_all(&self.sql_pool) - .await? - .into_iter() - .group_by(|row| row.get::(&*Users::UserId.to_string())) - { - let mut rows = rows.peekable(); - users.push(UserAndGroups { - user: User::from_row(rows.peek().unwrap()).unwrap(), - groups: if get_groups { - Some( - rows.filter_map(|row| { - let display_name = row.get::("group_display_name"); - if display_name.is_empty() { - None - } else { - Some(GroupDetails { - group_id: row.get::(&*Groups::GroupId.to_string()), - display_name, - creation_date: row.get::, _>( - "group_creation_date", - ), - uuid: row.get::("group_uuid"), - }) - } - }) - .collect(), - ) - } else { - None - }, - }); - } - Ok(users) - } - - #[instrument(skip_all, level = "debug", ret, err)] - async fn list_groups(&self, filters: Option) -> Result> { - debug!(?filters); - let (query, values) = { - let mut query_builder = Query::select() - .column((Groups::Table, Groups::GroupId)) - .column(Groups::DisplayName) - .column(Groups::CreationDate) - .column(Groups::Uuid) - .column(Memberships::UserId) - .from(Groups::Table) - .left_join( - Memberships::Table, - Expr::tbl(Groups::Table, Groups::GroupId) - .equals(Memberships::Table, Memberships::GroupId), - ) - .order_by(Groups::DisplayName, Order::Asc) - .order_by(Memberships::UserId, Order::Asc) - .to_owned(); - - if let Some(filter) = filters { - if filter == GroupRequestFilter::Not(Box::new(GroupRequestFilter::And(Vec::new()))) - { - return Ok(Vec::new()); - } - if filter != GroupRequestFilter::And(Vec::new()) - && filter != GroupRequestFilter::Or(Vec::new()) - { - query_builder.cond_where(get_group_filter_expr(filter)); - } - } - - query_builder.build_sqlx(DbQueryBuilder {}) - }; - debug!(%query); - - // For group_by. - use itertools::Itertools; - let mut groups = Vec::new(); - // The rows are returned sorted by display_name, equivalent to group_id. We group them by - // this key which gives us one element (`rows`) per group. - for (group_details, rows) in &query_with(&query, values) - .fetch_all(&self.sql_pool) - .await? - .into_iter() - .group_by(|row| GroupDetails::from_row(row).unwrap()) - { - groups.push(Group { - id: group_details.group_id, - display_name: group_details.display_name, - creation_date: group_details.creation_date, - uuid: group_details.uuid, - users: rows - .map(|row| row.get::(&*Memberships::UserId.to_string())) - // If a group has no users, an empty string is returned because of the left - // join. - .filter(|s| !s.as_str().is_empty()) - .collect(), - }); - } - Ok(groups) - } - - #[instrument(skip_all, level = "debug", ret)] - async fn get_user_details(&self, user_id: &UserId) -> Result { - debug!(?user_id); - let (query, values) = Query::select() - .column(Users::UserId) - .column(Users::Email) - .column(Users::DisplayName) - .column(Users::FirstName) - .column(Users::LastName) - .column(Users::Avatar) - .column(Users::CreationDate) - .column(Users::Uuid) - .from(Users::Table) - .cond_where(Expr::col(Users::UserId).eq(user_id)) - .build_sqlx(DbQueryBuilder {}); - debug!(%query); - - Ok(query_as_with::<_, User, _>(query.as_str(), values) - .fetch_one(&self.sql_pool) - .await?) - } - - #[instrument(skip_all, level = "debug", ret, err)] - async fn get_group_details(&self, group_id: GroupId) -> Result { - debug!(?group_id); - let (query, values) = Query::select() - .column(Groups::GroupId) - .column(Groups::DisplayName) - .column(Groups::CreationDate) - .column(Groups::Uuid) - .from(Groups::Table) - .cond_where(Expr::col(Groups::GroupId).eq(group_id)) - .build_sqlx(DbQueryBuilder {}); - debug!(%query); - - Ok(query_as_with::<_, GroupDetails, _>(&query, values) - .fetch_one(&self.sql_pool) - .await?) - } - - #[instrument(skip_all, level = "debug", ret, err)] - async fn get_user_groups(&self, user_id: &UserId) -> Result> { - debug!(?user_id); - let (query, values) = Query::select() - .column((Groups::Table, Groups::GroupId)) - .column(Groups::DisplayName) - .column(Groups::CreationDate) - .column(Groups::Uuid) - .from(Groups::Table) - .inner_join( - Memberships::Table, - Expr::tbl(Groups::Table, Groups::GroupId) - .equals(Memberships::Table, Memberships::GroupId), - ) - .cond_where(Expr::col(Memberships::UserId).eq(user_id)) - .build_sqlx(DbQueryBuilder {}); - debug!(%query); - - query_as_with::<_, GroupDetails, _>(&query, values) - .fetch(&self.sql_pool) - // Collect the vector of rows, each potentially an error. - .collect::>>() - .await - .into_iter() - // Transform it into a single result (the first error if any), and group the group_ids - // into a HashSet. - .collect::>>() - // Map the sqlx::Error into a DomainError. - .map_err(DomainError::DatabaseError) - } - - #[instrument(skip_all, level = "debug", err)] - async fn create_user(&self, request: CreateUserRequest) -> Result<()> { - debug!(user_id = ?request.user_id); - let columns = vec![ - Users::UserId, - Users::Email, - Users::DisplayName, - Users::FirstName, - Users::LastName, - Users::Avatar, - Users::CreationDate, - Users::Uuid, - ]; - let now = chrono::Utc::now(); - let uuid = Uuid::from_name_and_date(request.user_id.as_str(), &now); - let values = vec![ - request.user_id.into(), - request.email.into(), - request.display_name.unwrap_or_default().into(), - request.first_name.unwrap_or_default().into(), - request.last_name.unwrap_or_default().into(), - request.avatar.unwrap_or_default().into(), - now.naive_utc().into(), - uuid.into(), - ]; - let (query, values) = Query::insert() - .into_table(Users::Table) - .columns(columns) - .values_panic(values) - .build_sqlx(DbQueryBuilder {}); - debug!(%query); - query_with(query.as_str(), values) - .execute(&self.sql_pool) - .await?; - Ok(()) - } - - #[instrument(skip_all, level = "debug", err)] - async fn update_user(&self, request: UpdateUserRequest) -> Result<()> { - debug!(user_id = ?request.user_id); - let mut values = Vec::new(); - if let Some(email) = request.email { - values.push((Users::Email, email.into())); - } - if let Some(display_name) = request.display_name { - values.push((Users::DisplayName, display_name.into())); - } - if let Some(first_name) = request.first_name { - values.push((Users::FirstName, first_name.into())); - } - if let Some(last_name) = request.last_name { - values.push((Users::LastName, last_name.into())); - } - if let Some(avatar) = request.avatar { - values.push((Users::Avatar, avatar.into())); - } - if values.is_empty() { - return Ok(()); - } - let (query, values) = Query::update() - .table(Users::Table) - .values(values) - .cond_where(Expr::col(Users::UserId).eq(request.user_id)) - .build_sqlx(DbQueryBuilder {}); - debug!(%query); - query_with(query.as_str(), values) - .execute(&self.sql_pool) - .await?; - Ok(()) - } - - #[instrument(skip_all, level = "debug", err)] - async fn update_group(&self, request: UpdateGroupRequest) -> Result<()> { - debug!(?request.group_id); - let mut values = Vec::new(); - if let Some(display_name) = request.display_name { - values.push((Groups::DisplayName, display_name.into())); - } - if values.is_empty() { - return Ok(()); - } - let (query, values) = Query::update() - .table(Groups::Table) - .values(values) - .cond_where(Expr::col(Groups::GroupId).eq(request.group_id)) - .build_sqlx(DbQueryBuilder {}); - debug!(%query); - query_with(query.as_str(), values) - .execute(&self.sql_pool) - .await?; - Ok(()) - } - - #[instrument(skip_all, level = "debug", err)] - async fn delete_user(&self, user_id: &UserId) -> Result<()> { - debug!(?user_id); - let (query, values) = Query::delete() - .from_table(Users::Table) - .cond_where(Expr::col(Users::UserId).eq(user_id)) - .build_sqlx(DbQueryBuilder {}); - debug!(%query); - query_with(query.as_str(), values) - .execute(&self.sql_pool) - .await?; - Ok(()) - } - - #[instrument(skip_all, level = "debug", ret, err)] - async fn create_group(&self, group_name: &str) -> Result { - debug!(?group_name); - crate::domain::sql_tables::create_group(group_name, &self.sql_pool).await?; - let (query, values) = Query::select() - .column(Groups::GroupId) - .from(Groups::Table) - .cond_where(Expr::col(Groups::DisplayName).eq(group_name)) - .build_sqlx(DbQueryBuilder {}); - debug!(%query); - let row = query_with(query.as_str(), values) - .fetch_one(&self.sql_pool) - .await?; - Ok(GroupId(row.get::(&*Groups::GroupId.to_string()))) - } - - #[instrument(skip_all, level = "debug", err)] - async fn delete_group(&self, group_id: GroupId) -> Result<()> { - debug!(?group_id); - let (query, values) = Query::delete() - .from_table(Groups::Table) - .cond_where(Expr::col(Groups::GroupId).eq(group_id)) - .build_sqlx(DbQueryBuilder {}); - debug!(%query); - query_with(query.as_str(), values) - .execute(&self.sql_pool) - .await?; - Ok(()) - } - - #[instrument(skip_all, level = "debug", err)] - async fn add_user_to_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()> { - debug!(?user_id, ?group_id); - let (query, values) = Query::insert() - .into_table(Memberships::Table) - .columns(vec![Memberships::UserId, Memberships::GroupId]) - .values_panic(vec![user_id.into(), group_id.into()]) - .build_sqlx(DbQueryBuilder {}); - debug!(%query); - query_with(query.as_str(), values) - .execute(&self.sql_pool) - .await?; - Ok(()) - } - - #[instrument(skip_all, level = "debug", err)] - async fn remove_user_from_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()> { - debug!(?user_id, ?group_id); - let (query, values) = Query::delete() - .from_table(Memberships::Table) - .cond_where( - Cond::all() - .add(Expr::col(Memberships::GroupId).eq(group_id)) - .add(Expr::col(Memberships::UserId).eq(user_id)), - ) - .build_sqlx(DbQueryBuilder {}); - debug!(%query); - query_with(query.as_str(), values) - .execute(&self.sql_pool) - .await?; - Ok(()) - } -} +impl BackendHandler for SqlBackendHandler {} #[cfg(test)] -mod tests { +pub mod tests { use super::*; use crate::domain::sql_tables::init_table; use crate::infra::configuration::ConfigurationBuilder; use lldap_auth::{opaque, registration}; - fn get_default_config() -> Configuration { + pub fn get_default_config() -> Configuration { ConfigurationBuilder::for_tests() } - async fn get_in_memory_db() -> Pool { + pub async fn get_in_memory_db() -> Pool { PoolOptions::new().connect("sqlite::memory:").await.unwrap() } - async fn get_initialized_db() -> Pool { + pub async fn get_initialized_db() -> Pool { let sql_pool = get_in_memory_db().await; init_table(&sql_pool).await.unwrap(); sql_pool } - async fn insert_user(handler: &SqlBackendHandler, name: &str, pass: &str) { + pub async fn insert_user(handler: &SqlBackendHandler, name: &str, pass: &str) { use crate::domain::opaque_handler::OpaqueHandler; insert_user_no_password(handler, name).await; let mut rng = rand::rngs::OsRng; @@ -577,29 +66,32 @@ mod tests { .unwrap(); } - async fn insert_user_no_password(handler: &SqlBackendHandler, name: &str) { + pub async fn insert_user_no_password(handler: &SqlBackendHandler, name: &str) { handler .create_user(CreateUserRequest { user_id: UserId::new(name), email: "bob@bob.bob".to_string(), + display_name: Some("display ".to_string() + name), + first_name: Some("first ".to_string() + name), + last_name: Some("last ".to_string() + name), ..Default::default() }) .await .unwrap(); } - async fn insert_group(handler: &SqlBackendHandler, name: &str) -> GroupId { + pub async fn insert_group(handler: &SqlBackendHandler, name: &str) -> GroupId { handler.create_group(name).await.unwrap() } - async fn insert_membership(handler: &SqlBackendHandler, group_id: GroupId, user_id: &str) { + pub async fn insert_membership(handler: &SqlBackendHandler, group_id: GroupId, user_id: &str) { handler .add_user_to_group(&UserId::new(user_id), group_id) .await .unwrap(); } - async fn get_user_names( + pub async fn get_user_names( handler: &SqlBackendHandler, filters: Option, ) -> Vec { @@ -612,299 +104,30 @@ mod tests { .collect::>() } - #[tokio::test] - async fn test_bind_user() { - let sql_pool = get_initialized_db().await; - let config = get_default_config(); - let handler = SqlBackendHandler::new(config, sql_pool.clone()); - insert_user(&handler, "bob", "bob00").await; - - handler - .bind(BindRequest { - name: UserId::new("bob"), - password: "bob00".to_string(), - }) - .await - .unwrap(); - handler - .bind(BindRequest { - name: UserId::new("andrew"), - password: "bob00".to_string(), - }) - .await - .unwrap_err(); - handler - .bind(BindRequest { - name: UserId::new("bob"), - password: "wrong_password".to_string(), - }) - .await - .unwrap_err(); + pub struct TestFixture { + pub handler: SqlBackendHandler, + pub groups: Vec, } - #[tokio::test] - async fn test_user_no_password() { - let sql_pool = get_initialized_db().await; - let config = get_default_config(); - let handler = SqlBackendHandler::new(config, sql_pool.clone()); - insert_user_no_password(&handler, "bob").await; - - handler - .bind(BindRequest { - name: UserId::new("bob"), - password: "bob00".to_string(), - }) - .await - .unwrap_err(); - } - - #[tokio::test] - async fn test_list_users() { - let sql_pool = get_initialized_db().await; - let config = get_default_config(); - let handler = SqlBackendHandler::new(config, sql_pool); - insert_user(&handler, "bob", "bob00").await; - insert_user(&handler, "patrick", "pass").await; - insert_user(&handler, "John", "Pa33w0rd!").await; - insert_user(&handler, "NoGroup", "Pa33w0rd!").await; - let group_1 = insert_group(&handler, "Best Group").await; - let group_2 = insert_group(&handler, "Worst Group").await; - insert_membership(&handler, group_1, "bob").await; - insert_membership(&handler, group_1, "patrick").await; - insert_membership(&handler, group_2, "patrick").await; - insert_membership(&handler, group_2, "John").await; - { - let users = get_user_names(&handler, None).await; - assert_eq!(users, vec!["bob", "john", "nogroup", "patrick"]); + impl TestFixture { + pub async fn new() -> Self { + let sql_pool = get_initialized_db().await; + let config = get_default_config(); + let handler = SqlBackendHandler::new(config, sql_pool); + insert_user_no_password(&handler, "bob").await; + insert_user_no_password(&handler, "patrick").await; + insert_user_no_password(&handler, "John").await; + insert_user_no_password(&handler, "NoGroup").await; + let mut groups = vec![]; + groups.push(insert_group(&handler, "Best Group").await); + groups.push(insert_group(&handler, "Worst Group").await); + groups.push(insert_group(&handler, "Empty Group").await); + insert_membership(&handler, groups[0], "bob").await; + insert_membership(&handler, groups[0], "patrick").await; + insert_membership(&handler, groups[1], "patrick").await; + insert_membership(&handler, groups[1], "John").await; + Self { handler, groups } } - { - let users = get_user_names( - &handler, - Some(UserRequestFilter::UserId(UserId::new("bob"))), - ) - .await; - assert_eq!(users, vec!["bob"]); - } - { - let users = get_user_names( - &handler, - Some(UserRequestFilter::Or(vec![ - UserRequestFilter::UserId(UserId::new("bob")), - UserRequestFilter::UserId(UserId::new("John")), - ])), - ) - .await; - assert_eq!(users, vec!["bob", "john"]); - } - { - let users = get_user_names( - &handler, - Some(UserRequestFilter::And(vec![ - UserRequestFilter::Or(vec![]), - UserRequestFilter::Or(vec![ - UserRequestFilter::UserId(UserId::new("bob")), - UserRequestFilter::UserId(UserId::new("John")), - UserRequestFilter::UserId(UserId::new("random")), - ]), - ])), - ) - .await; - assert_eq!(users, vec!["bob", "john"]); - } - { - let users = get_user_names( - &handler, - Some(UserRequestFilter::Not(Box::new(UserRequestFilter::UserId( - UserId::new("bob"), - )))), - ) - .await; - assert_eq!(users, vec!["john", "nogroup", "patrick"]); - } - { - let users = handler - .list_users(None, true) - .await - .unwrap() - .into_iter() - .map(|u| { - ( - u.user.user_id.to_string(), - u.user.display_name.to_string(), - u.groups - .unwrap() - .into_iter() - .map(|g| g.group_id) - .collect::>(), - ) - }) - .collect::>(); - assert_eq!( - users, - vec![ - ("bob".to_string(), String::new(), vec![group_1]), - ("john".to_string(), String::new(), vec![group_2]), - ("nogroup".to_string(), String::new(), vec![]), - ("patrick".to_string(), String::new(), vec![group_1, group_2]), - ] - ); - } - { - let users = handler - .list_users(None, true) - .await - .unwrap() - .into_iter() - .map(|u| { - ( - u.user.creation_date, - u.groups - .unwrap() - .into_iter() - .map(|g| g.creation_date) - .collect::>(), - ) - }) - .collect::>(); - for (user_date, groups) in users { - for group_date in groups { - assert_ne!(user_date, group_date); - } - } - } - } - - #[tokio::test] - async fn test_list_groups() { - let sql_pool = get_initialized_db().await; - let config = get_default_config(); - let handler = SqlBackendHandler::new(config, sql_pool.clone()); - insert_user(&handler, "bob", "bob00").await; - insert_user(&handler, "patrick", "pass").await; - insert_user(&handler, "John", "Pa33w0rd!").await; - let group_1 = insert_group(&handler, "Best Group").await; - let group_2 = insert_group(&handler, "Worst Group").await; - let group_3 = insert_group(&handler, "Empty Group").await; - insert_membership(&handler, group_1, "bob").await; - insert_membership(&handler, group_1, "patrick").await; - insert_membership(&handler, group_2, "patrick").await; - insert_membership(&handler, group_2, "John").await; - let get_group_ids = |filter| async { - handler - .list_groups(filter) - .await - .unwrap() - .into_iter() - .map(|g| g.id) - .collect::>() - }; - assert_eq!(get_group_ids(None).await, vec![group_1, group_3, group_2]); - assert_eq!( - get_group_ids(Some(GroupRequestFilter::Or(vec![ - GroupRequestFilter::DisplayName("Empty Group".to_string()), - GroupRequestFilter::Member(UserId::new("bob")), - ]))) - .await, - vec![group_1, group_3] - ); - assert_eq!( - get_group_ids(Some(GroupRequestFilter::And(vec![ - GroupRequestFilter::Not(Box::new(GroupRequestFilter::DisplayName( - "value".to_string() - ))), - GroupRequestFilter::GroupId(group_1), - ]))) - .await, - vec![group_1] - ); - } - - #[tokio::test] - async fn test_get_user_details() { - let sql_pool = get_initialized_db().await; - let config = get_default_config(); - let handler = SqlBackendHandler::new(config, sql_pool); - insert_user(&handler, "bob", "bob00").await; - { - let user = handler.get_user_details(&UserId::new("bob")).await.unwrap(); - assert_eq!(user.user_id.as_str(), "bob"); - } - { - handler - .get_user_details(&UserId::new("John")) - .await - .unwrap_err(); - } - } - - #[tokio::test] - async fn test_user_lowercase() { - let sql_pool = get_initialized_db().await; - let config = get_default_config(); - let handler = SqlBackendHandler::new(config, sql_pool); - insert_user(&handler, "Bob", "bob00").await; - { - let user = handler.get_user_details(&UserId::new("bOb")).await.unwrap(); - assert_eq!(user.user_id.as_str(), "bob"); - } - { - handler - .get_user_details(&UserId::new("John")) - .await - .unwrap_err(); - } - } - - #[tokio::test] - async fn test_get_user_groups() { - let sql_pool = get_initialized_db().await; - let config = get_default_config(); - let handler = SqlBackendHandler::new(config, sql_pool.clone()); - insert_user(&handler, "bob", "bob00").await; - insert_user(&handler, "patrick", "pass").await; - insert_user(&handler, "John", "Pa33w0rd!").await; - let group_1 = insert_group(&handler, "Group1").await; - let group_2 = insert_group(&handler, "Group2").await; - insert_membership(&handler, group_1, "bob").await; - insert_membership(&handler, group_1, "patrick").await; - insert_membership(&handler, group_2, "patrick").await; - let get_group_ids = |user: &'static str| async { - let mut groups = handler - .get_user_groups(&UserId::new(user)) - .await - .unwrap() - .into_iter() - .map(|g| g.group_id) - .collect::>(); - groups.sort_by(|g1, g2| g1.0.cmp(&g2.0)); - groups - }; - assert_eq!(get_group_ids("bob").await, vec![group_1]); - assert_eq!(get_group_ids("patrick").await, vec![group_1, group_2]); - assert_eq!(get_group_ids("John").await, vec![]); - } - - #[tokio::test] - async fn test_delete_user() { - let sql_pool = get_initialized_db().await; - let config = get_default_config(); - let handler = SqlBackendHandler::new(config, sql_pool.clone()); - - insert_user(&handler, "val", "s3np4i").await; - insert_user(&handler, "Hector", "Be$t").await; - insert_user(&handler, "Jennz", "boupBoup").await; - - // Remove a user - handler.delete_user(&UserId::new("Jennz")).await.unwrap(); - - assert_eq!(get_user_names(&handler, None).await, vec!["hector", "val"]); - - // Insert new user and remove two - insert_user(&handler, "NewBoi", "Joni").await; - handler.delete_user(&UserId::new("Hector")).await.unwrap(); - handler.delete_user(&UserId::new("NewBoi")).await.unwrap(); - - assert_eq!(get_user_names(&handler, None).await, vec!["val"]); } #[tokio::test] @@ -913,7 +136,7 @@ mod tests { let config = get_default_config(); let handler = SqlBackendHandler::new(config, sql_pool); let user_name = UserId::new(r#"bob"e"i'o;aĆ¼"#); - insert_user(&handler, user_name.as_str(), "bob00").await; + insert_user_no_password(&handler, user_name.as_str()).await; { let users = handler .list_users(None, false) diff --git a/server/src/domain/sql_group_backend_handler.rs b/server/src/domain/sql_group_backend_handler.rs new file mode 100644 index 0000000..cec1b90 --- /dev/null +++ b/server/src/domain/sql_group_backend_handler.rs @@ -0,0 +1,301 @@ +use super::{ + error::Result, + handler::{ + Group, GroupBackendHandler, GroupDetails, GroupId, GroupRequestFilter, UpdateGroupRequest, + UserId, + }, + sql_backend_handler::SqlBackendHandler, + sql_tables::{DbQueryBuilder, Groups, Memberships}, +}; +use async_trait::async_trait; +use sea_query::{Cond, Expr, Iden, Order, Query, SimpleExpr}; +use sea_query_binder::SqlxBinder; +use sqlx::{query_as_with, query_with, FromRow, Row}; +use tracing::{debug, instrument}; + +// Returns the condition for the SQL query, and whether it requires joining with the groups table. +fn get_group_filter_expr(filter: GroupRequestFilter) -> Cond { + use sea_query::IntoCondition; + use GroupRequestFilter::*; + match filter { + And(fs) => { + if fs.is_empty() { + SimpleExpr::Value(true.into()).into_condition() + } else { + fs.into_iter() + .fold(Cond::all(), |c, f| c.add(get_group_filter_expr(f))) + } + } + Or(fs) => { + if fs.is_empty() { + SimpleExpr::Value(false.into()).into_condition() + } else { + fs.into_iter() + .fold(Cond::any(), |c, f| c.add(get_group_filter_expr(f))) + } + } + Not(f) => get_group_filter_expr(*f).not(), + DisplayName(name) => Expr::col((Groups::Table, Groups::DisplayName)) + .eq(name) + .into_condition(), + GroupId(id) => Expr::col((Groups::Table, Groups::GroupId)) + .eq(id.0) + .into_condition(), + Uuid(uuid) => Expr::col((Groups::Table, Groups::Uuid)) + .eq(uuid.to_string()) + .into_condition(), + // WHERE (group_id in (SELECT group_id FROM memberships WHERE user_id = user)) + Member(user) => Expr::col((Memberships::Table, Memberships::GroupId)) + .in_subquery( + Query::select() + .column(Memberships::GroupId) + .from(Memberships::Table) + .cond_where(Expr::col(Memberships::UserId).eq(user)) + .take(), + ) + .into_condition(), + } +} + +#[async_trait] +impl GroupBackendHandler for SqlBackendHandler { + #[instrument(skip_all, level = "debug", ret, err)] + async fn list_groups(&self, filters: Option) -> Result> { + debug!(?filters); + let (query, values) = { + let mut query_builder = Query::select() + .column((Groups::Table, Groups::GroupId)) + .column(Groups::DisplayName) + .column(Groups::CreationDate) + .column(Groups::Uuid) + .column(Memberships::UserId) + .from(Groups::Table) + .left_join( + Memberships::Table, + Expr::tbl(Groups::Table, Groups::GroupId) + .equals(Memberships::Table, Memberships::GroupId), + ) + .order_by(Groups::DisplayName, Order::Asc) + .order_by(Memberships::UserId, Order::Asc) + .to_owned(); + + if let Some(filter) = filters { + query_builder.cond_where(get_group_filter_expr(filter)); + } + + query_builder.build_sqlx(DbQueryBuilder {}) + }; + debug!(%query); + + // For group_by. + use itertools::Itertools; + let mut groups = Vec::new(); + // The rows are returned sorted by display_name, equivalent to group_id. We group them by + // this key which gives us one element (`rows`) per group. + for (group_details, rows) in &query_with(&query, values) + .fetch_all(&self.sql_pool) + .await? + .into_iter() + .group_by(|row| GroupDetails::from_row(row).unwrap()) + { + groups.push(Group { + id: group_details.group_id, + display_name: group_details.display_name, + creation_date: group_details.creation_date, + uuid: group_details.uuid, + users: rows + .map(|row| row.get::(&*Memberships::UserId.to_string())) + // If a group has no users, an empty string is returned because of the left + // join. + .filter(|s| !s.as_str().is_empty()) + .collect(), + }); + } + Ok(groups) + } + + #[instrument(skip_all, level = "debug", ret, err)] + async fn get_group_details(&self, group_id: GroupId) -> Result { + debug!(?group_id); + let (query, values) = Query::select() + .column(Groups::GroupId) + .column(Groups::DisplayName) + .column(Groups::CreationDate) + .column(Groups::Uuid) + .from(Groups::Table) + .cond_where(Expr::col(Groups::GroupId).eq(group_id)) + .build_sqlx(DbQueryBuilder {}); + debug!(%query); + + Ok(query_as_with::<_, GroupDetails, _>(&query, values) + .fetch_one(&self.sql_pool) + .await?) + } + + #[instrument(skip_all, level = "debug", err)] + async fn update_group(&self, request: UpdateGroupRequest) -> Result<()> { + debug!(?request.group_id); + let mut values = Vec::new(); + if let Some(display_name) = request.display_name { + values.push((Groups::DisplayName, display_name.into())); + } + if values.is_empty() { + return Ok(()); + } + let (query, values) = Query::update() + .table(Groups::Table) + .values(values) + .cond_where(Expr::col(Groups::GroupId).eq(request.group_id)) + .build_sqlx(DbQueryBuilder {}); + debug!(%query); + query_with(query.as_str(), values) + .execute(&self.sql_pool) + .await?; + Ok(()) + } + + #[instrument(skip_all, level = "debug", ret, err)] + async fn create_group(&self, group_name: &str) -> Result { + debug!(?group_name); + crate::domain::sql_tables::create_group(group_name, &self.sql_pool).await?; + let (query, values) = Query::select() + .column(Groups::GroupId) + .from(Groups::Table) + .cond_where(Expr::col(Groups::DisplayName).eq(group_name)) + .build_sqlx(DbQueryBuilder {}); + debug!(%query); + let row = query_with(query.as_str(), values) + .fetch_one(&self.sql_pool) + .await?; + Ok(GroupId(row.get::(&*Groups::GroupId.to_string()))) + } + + #[instrument(skip_all, level = "debug", err)] + async fn delete_group(&self, group_id: GroupId) -> Result<()> { + debug!(?group_id); + let (query, values) = Query::delete() + .from_table(Groups::Table) + .cond_where(Expr::col(Groups::GroupId).eq(group_id)) + .build_sqlx(DbQueryBuilder {}); + debug!(%query); + query_with(query.as_str(), values) + .execute(&self.sql_pool) + .await?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::domain::sql_backend_handler::tests::*; + + async fn get_group_ids( + handler: &SqlBackendHandler, + filters: Option, + ) -> Vec { + handler + .list_groups(filters) + .await + .unwrap() + .into_iter() + .map(|g| g.id) + .collect::>() + } + + #[tokio::test] + async fn test_list_groups_no_filter() { + let fixture = TestFixture::new().await; + assert_eq!( + get_group_ids(&fixture.handler, None).await, + vec![fixture.groups[0], fixture.groups[2], fixture.groups[1]] + ); + } + + #[tokio::test] + async fn test_list_groups_simple_filter() { + let fixture = TestFixture::new().await; + assert_eq!( + get_group_ids( + &fixture.handler, + Some(GroupRequestFilter::Or(vec![ + GroupRequestFilter::DisplayName("Empty Group".to_string()), + GroupRequestFilter::Member(UserId::new("bob")), + ])) + ) + .await, + vec![fixture.groups[0], fixture.groups[2]] + ); + } + + #[tokio::test] + async fn test_list_groups_negation() { + let fixture = TestFixture::new().await; + assert_eq!( + get_group_ids( + &fixture.handler, + Some(GroupRequestFilter::And(vec![ + GroupRequestFilter::Not(Box::new(GroupRequestFilter::DisplayName( + "value".to_string() + ))), + GroupRequestFilter::GroupId(fixture.groups[0]), + ])) + ) + .await, + vec![fixture.groups[0]] + ); + } + + #[tokio::test] + async fn test_get_group_details() { + let fixture = TestFixture::new().await; + let details = fixture + .handler + .get_group_details(fixture.groups[0]) + .await + .unwrap(); + assert_eq!(details.group_id, fixture.groups[0]); + assert_eq!(details.display_name, "Best Group"); + assert_eq!( + get_group_ids( + &fixture.handler, + Some(GroupRequestFilter::Uuid(details.uuid)) + ) + .await, + vec![fixture.groups[0]] + ); + } + + #[tokio::test] + async fn test_update_group() { + let fixture = TestFixture::new().await; + fixture + .handler + .update_group(UpdateGroupRequest { + group_id: fixture.groups[0], + display_name: Some("Awesomest Group".to_string()), + }) + .await + .unwrap(); + let details = fixture + .handler + .get_group_details(fixture.groups[0]) + .await + .unwrap(); + assert_eq!(details.display_name, "Awesomest Group"); + } + + #[tokio::test] + async fn test_delete_group() { + let fixture = TestFixture::new().await; + fixture + .handler + .delete_group(fixture.groups[0]) + .await + .unwrap(); + assert_eq!( + get_group_ids(&fixture.handler, None).await, + vec![fixture.groups[2], fixture.groups[1]] + ); + } +} diff --git a/server/src/domain/sql_opaque_handler.rs b/server/src/domain/sql_opaque_handler.rs index 6b199aa..3df4aed 100644 --- a/server/src/domain/sql_opaque_handler.rs +++ b/server/src/domain/sql_opaque_handler.rs @@ -1,9 +1,9 @@ use super::{ - error::*, + error::{DomainError, Result}, handler::{BindRequest, LoginHandler, UserId}, - opaque_handler::*, + opaque_handler::{login, registration, OpaqueHandler}, sql_backend_handler::SqlBackendHandler, - sql_tables::*, + sql_tables::{DbQueryBuilder, Users}, }; use async_trait::async_trait; use lldap_auth::opaque; @@ -258,42 +258,7 @@ pub(crate) async fn register_password( #[cfg(test)] mod tests { use super::*; - use crate::{ - domain::{ - handler::{BackendHandler, CreateUserRequest}, - sql_backend_handler::SqlBackendHandler, - sql_tables::init_table, - }, - infra::configuration::{Configuration, ConfigurationBuilder}, - }; - - fn get_default_config() -> Configuration { - ConfigurationBuilder::default() - .verbose(true) - .build() - .unwrap() - } - - async fn get_in_memory_db() -> Pool { - PoolOptions::new().connect("sqlite::memory:").await.unwrap() - } - - async fn get_initialized_db() -> Pool { - let sql_pool = get_in_memory_db().await; - init_table(&sql_pool).await.unwrap(); - sql_pool - } - - async fn insert_user_no_password(handler: &SqlBackendHandler, name: &str) { - handler - .create_user(CreateUserRequest { - user_id: UserId::new(name), - email: "bob@bob.bob".to_string(), - ..Default::default() - }) - .await - .unwrap(); - } + use crate::domain::sql_backend_handler::tests::*; async fn attempt_login( opaque_handler: &SqlOpaqueHandler, @@ -323,7 +288,7 @@ mod tests { } #[tokio::test] - async fn test_flow() -> Result<()> { + async fn test_opaque_flow() -> Result<()> { let sql_pool = get_initialized_db().await; let config = get_default_config(); let backend_handler = SqlBackendHandler::new(config.clone(), sql_pool.clone()); @@ -344,4 +309,50 @@ mod tests { attempt_login(&opaque_handler, "bob", "bob00").await?; Ok(()) } + + #[tokio::test] + async fn test_bind_user() { + let sql_pool = get_initialized_db().await; + let config = get_default_config(); + let handler = SqlOpaqueHandler::new(config, sql_pool.clone()); + insert_user(&handler, "bob", "bob00").await; + + handler + .bind(BindRequest { + name: UserId::new("bob"), + password: "bob00".to_string(), + }) + .await + .unwrap(); + handler + .bind(BindRequest { + name: UserId::new("andrew"), + password: "bob00".to_string(), + }) + .await + .unwrap_err(); + handler + .bind(BindRequest { + name: UserId::new("bob"), + password: "wrong_password".to_string(), + }) + .await + .unwrap_err(); + } + + #[tokio::test] + async fn test_user_no_password() { + let sql_pool = get_initialized_db().await; + let config = get_default_config(); + let handler = SqlBackendHandler::new(config, sql_pool.clone()); + insert_user_no_password(&handler, "bob").await; + + handler + .bind(BindRequest { + name: UserId::new("bob"), + password: "bob00".to_string(), + }) + .await + .unwrap_err(); + } } diff --git a/server/src/domain/sql_user_backend_handler.rs b/server/src/domain/sql_user_backend_handler.rs new file mode 100644 index 0000000..92e4fa3 --- /dev/null +++ b/server/src/domain/sql_user_backend_handler.rs @@ -0,0 +1,737 @@ +use super::{ + error::Result, + handler::{ + CreateUserRequest, GroupDetails, GroupId, UpdateUserRequest, User, UserAndGroups, + UserBackendHandler, UserId, UserRequestFilter, Uuid, + }, + sql_backend_handler::SqlBackendHandler, + sql_tables::{DbQueryBuilder, Groups, Memberships, Users}, +}; +use async_trait::async_trait; +use sea_query::{Alias, Cond, Expr, Iden, Order, Query, SimpleExpr}; +use sea_query_binder::{SqlxBinder, SqlxValues}; +use sqlx::{query_as_with, query_with, FromRow, Row}; +use std::collections::HashSet; +use tracing::{debug, instrument}; + +struct RequiresGroup(bool); + +// Returns the condition for the SQL query, and whether it requires joining with the groups table. +fn get_user_filter_expr(filter: UserRequestFilter) -> (RequiresGroup, Cond) { + use sea_query::IntoCondition; + use UserRequestFilter::*; + fn get_repeated_filter( + fs: Vec, + condition: Cond, + default_value: bool, + ) -> (RequiresGroup, Cond) { + if fs.is_empty() { + return ( + RequiresGroup(false), + SimpleExpr::Value(default_value.into()).into_condition(), + ); + } + let mut requires_group = false; + let filter = fs.into_iter().fold(condition, |c, f| { + let (group, filters) = get_user_filter_expr(f); + requires_group |= group.0; + c.add(filters) + }); + (RequiresGroup(requires_group), filter) + } + match filter { + And(fs) => get_repeated_filter(fs, Cond::all(), true), + Or(fs) => get_repeated_filter(fs, Cond::any(), false), + Not(f) => { + let (requires_group, filters) = get_user_filter_expr(*f); + (requires_group, filters.not()) + } + UserId(user_id) => ( + RequiresGroup(false), + Expr::col((Users::Table, Users::UserId)) + .eq(user_id) + .into_condition(), + ), + Equality(s1, s2) => ( + RequiresGroup(false), + if s1 == Users::DisplayName.to_string() { + Expr::col((Users::Table, Users::DisplayName)) + .eq(s2) + .into_condition() + } else if s1 == Users::UserId.to_string() { + panic!("User id should be wrapped") + } else { + Expr::expr(Expr::cust(&s1)).eq(s2).into_condition() + }, + ), + MemberOf(group) => ( + RequiresGroup(true), + Expr::col((Groups::Table, Groups::DisplayName)) + .eq(group) + .into_condition(), + ), + MemberOfId(group_id) => ( + RequiresGroup(true), + Expr::col((Groups::Table, Groups::GroupId)) + .eq(group_id) + .into_condition(), + ), + } +} + +fn get_list_users_query( + filters: Option, + get_groups: bool, +) -> (String, SqlxValues) { + let mut query_builder = Query::select() + .column((Users::Table, Users::UserId)) + .column(Users::Email) + .column((Users::Table, Users::DisplayName)) + .column(Users::FirstName) + .column(Users::LastName) + .column(Users::Avatar) + .column((Users::Table, Users::CreationDate)) + .column((Users::Table, Users::Uuid)) + .from(Users::Table) + .order_by((Users::Table, Users::UserId), Order::Asc) + .to_owned(); + let add_join_group_tables = |builder: &mut sea_query::SelectStatement| { + builder + .left_join( + Memberships::Table, + Expr::tbl(Users::Table, Users::UserId) + .equals(Memberships::Table, Memberships::UserId), + ) + .left_join( + Groups::Table, + Expr::tbl(Memberships::Table, Memberships::GroupId) + .equals(Groups::Table, Groups::GroupId), + ); + }; + if get_groups { + add_join_group_tables(&mut query_builder); + query_builder + .column((Groups::Table, Groups::GroupId)) + .expr_as( + Expr::col((Groups::Table, Groups::DisplayName)), + Alias::new("group_display_name"), + ) + .expr_as( + Expr::col((Groups::Table, Groups::CreationDate)), + sea_query::Alias::new("group_creation_date"), + ) + .expr_as( + Expr::col((Groups::Table, Groups::Uuid)), + sea_query::Alias::new("group_uuid"), + ) + .order_by(Alias::new("group_display_name"), Order::Asc); + } + if let Some(filter) = filters { + let (RequiresGroup(requires_group), condition) = get_user_filter_expr(filter); + query_builder.cond_where(condition); + if requires_group && !get_groups { + add_join_group_tables(&mut query_builder); + } + } + + query_builder.build_sqlx(DbQueryBuilder {}) +} + +#[async_trait] +impl UserBackendHandler for SqlBackendHandler { + #[instrument(skip_all, level = "debug", ret, err)] + async fn list_users( + &self, + filters: Option, + get_groups: bool, + ) -> Result> { + debug!(?filters, get_groups); + let (query, values) = get_list_users_query(filters, get_groups); + + debug!(%query); + + // For group_by. + use itertools::Itertools; + let mut users = Vec::new(); + // The rows are returned sorted by user_id. We group them by + // this key which gives us one element (`rows`) per group. + for (_, rows) in &query_with(&query, values) + .fetch_all(&self.sql_pool) + .await? + .into_iter() + .group_by(|row| row.get::(&*Users::UserId.to_string())) + { + let mut rows = rows.peekable(); + users.push(UserAndGroups { + user: User::from_row(rows.peek().unwrap()).unwrap(), + groups: if get_groups { + Some( + rows.filter_map(|row| { + let display_name = row.get::("group_display_name"); + if display_name.is_empty() { + None + } else { + Some(GroupDetails { + group_id: row.get::(&*Groups::GroupId.to_string()), + display_name, + creation_date: row.get::, _>( + "group_creation_date", + ), + uuid: row.get::("group_uuid"), + }) + } + }) + .collect(), + ) + } else { + None + }, + }); + } + Ok(users) + } + + #[instrument(skip_all, level = "debug", ret)] + async fn get_user_details(&self, user_id: &UserId) -> Result { + debug!(?user_id); + let (query, values) = Query::select() + .column(Users::UserId) + .column(Users::Email) + .column(Users::DisplayName) + .column(Users::FirstName) + .column(Users::LastName) + .column(Users::Avatar) + .column(Users::CreationDate) + .column(Users::Uuid) + .from(Users::Table) + .cond_where(Expr::col(Users::UserId).eq(user_id)) + .build_sqlx(DbQueryBuilder {}); + debug!(%query); + + Ok(query_as_with::<_, User, _>(query.as_str(), values) + .fetch_one(&self.sql_pool) + .await?) + } + + #[instrument(skip_all, level = "debug", ret, err)] + async fn get_user_groups(&self, user_id: &UserId) -> Result> { + debug!(?user_id); + let (query, values) = Query::select() + .column((Groups::Table, Groups::GroupId)) + .column(Groups::DisplayName) + .column(Groups::CreationDate) + .column(Groups::Uuid) + .from(Groups::Table) + .inner_join( + Memberships::Table, + Expr::tbl(Groups::Table, Groups::GroupId) + .equals(Memberships::Table, Memberships::GroupId), + ) + .cond_where(Expr::col(Memberships::UserId).eq(user_id)) + .build_sqlx(DbQueryBuilder {}); + debug!(%query); + + Ok(HashSet::from_iter( + query_as_with::<_, GroupDetails, _>(&query, values) + .fetch_all(&self.sql_pool) + .await?, + )) + } + + #[instrument(skip_all, level = "debug", err)] + async fn create_user(&self, request: CreateUserRequest) -> Result<()> { + debug!(user_id = ?request.user_id); + let columns = vec![ + Users::UserId, + Users::Email, + Users::DisplayName, + Users::FirstName, + Users::LastName, + Users::Avatar, + Users::CreationDate, + Users::Uuid, + ]; + let now = chrono::Utc::now(); + let uuid = Uuid::from_name_and_date(request.user_id.as_str(), &now); + let values = vec![ + request.user_id.into(), + request.email.into(), + request.display_name.unwrap_or_default().into(), + request.first_name.unwrap_or_default().into(), + request.last_name.unwrap_or_default().into(), + request.avatar.unwrap_or_default().into(), + now.naive_utc().into(), + uuid.into(), + ]; + let (query, values) = Query::insert() + .into_table(Users::Table) + .columns(columns) + .values_panic(values) + .build_sqlx(DbQueryBuilder {}); + debug!(%query); + query_with(query.as_str(), values) + .execute(&self.sql_pool) + .await?; + Ok(()) + } + + #[instrument(skip_all, level = "debug", err)] + async fn update_user(&self, request: UpdateUserRequest) -> Result<()> { + debug!(user_id = ?request.user_id); + let mut values = Vec::new(); + if let Some(email) = request.email { + values.push((Users::Email, email.into())); + } + if let Some(display_name) = request.display_name { + values.push((Users::DisplayName, display_name.into())); + } + if let Some(first_name) = request.first_name { + values.push((Users::FirstName, first_name.into())); + } + if let Some(last_name) = request.last_name { + values.push((Users::LastName, last_name.into())); + } + if let Some(avatar) = request.avatar { + values.push((Users::Avatar, avatar.into())); + } + if values.is_empty() { + return Ok(()); + } + let (query, values) = Query::update() + .table(Users::Table) + .values(values) + .cond_where(Expr::col(Users::UserId).eq(request.user_id)) + .build_sqlx(DbQueryBuilder {}); + debug!(%query); + query_with(query.as_str(), values) + .execute(&self.sql_pool) + .await?; + Ok(()) + } + + #[instrument(skip_all, level = "debug", err)] + async fn delete_user(&self, user_id: &UserId) -> Result<()> { + debug!(?user_id); + let (query, values) = Query::delete() + .from_table(Users::Table) + .cond_where(Expr::col(Users::UserId).eq(user_id)) + .build_sqlx(DbQueryBuilder {}); + debug!(%query); + query_with(query.as_str(), values) + .execute(&self.sql_pool) + .await?; + Ok(()) + } + + #[instrument(skip_all, level = "debug", err)] + async fn add_user_to_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()> { + debug!(?user_id, ?group_id); + let (query, values) = Query::insert() + .into_table(Memberships::Table) + .columns(vec![Memberships::UserId, Memberships::GroupId]) + .values_panic(vec![user_id.into(), group_id.into()]) + .build_sqlx(DbQueryBuilder {}); + debug!(%query); + query_with(query.as_str(), values) + .execute(&self.sql_pool) + .await?; + Ok(()) + } + + #[instrument(skip_all, level = "debug", err)] + async fn remove_user_from_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()> { + debug!(?user_id, ?group_id); + let (query, values) = Query::delete() + .from_table(Memberships::Table) + .cond_where( + Cond::all() + .add(Expr::col(Memberships::GroupId).eq(group_id)) + .add(Expr::col(Memberships::UserId).eq(user_id)), + ) + .build_sqlx(DbQueryBuilder {}); + debug!(%query); + query_with(query.as_str(), values) + .execute(&self.sql_pool) + .await?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::domain::{handler::JpegPhoto, sql_backend_handler::tests::*}; + + #[tokio::test] + async fn test_list_users_no_filter() { + let fixture = TestFixture::new().await; + let users = get_user_names(&fixture.handler, None).await; + assert_eq!(users, vec!["bob", "john", "nogroup", "patrick"]); + } + + #[tokio::test] + async fn test_list_users_user_id_filter() { + let fixture = TestFixture::new().await; + let users = get_user_names( + &fixture.handler, + Some(UserRequestFilter::UserId(UserId::new("bob"))), + ) + .await; + assert_eq!(users, vec!["bob"]); + } + + #[tokio::test] + async fn test_list_users_display_name_filter() { + let fixture = TestFixture::new().await; + let users = get_user_names( + &fixture.handler, + Some(UserRequestFilter::Equality( + "display_name".to_string(), + "display bob".to_string(), + )), + ) + .await; + assert_eq!(users, vec!["bob"]); + } + + #[tokio::test] + async fn test_list_users_other_filter() { + let fixture = TestFixture::new().await; + let users = get_user_names( + &fixture.handler, + Some(UserRequestFilter::Equality( + "first_name".to_string(), + "first bob".to_string(), + )), + ) + .await; + assert_eq!(users, vec!["bob"]); + } + + #[tokio::test] + async fn test_list_users_false_filter() { + let fixture = TestFixture::new().await; + let users = get_user_names( + &fixture.handler, + Some(UserRequestFilter::Not(Box::new(UserRequestFilter::And( + vec![], + )))), + ) + .await; + assert_eq!(users, Vec::::new()); + } + + #[tokio::test] + async fn test_list_users_member_of() { + let fixture = TestFixture::new().await; + let users = get_user_names( + &fixture.handler, + Some(UserRequestFilter::MemberOf("Best Group".to_string())), + ) + .await; + assert_eq!(users, vec!["bob", "patrick"]); + } + + #[tokio::test] + async fn test_list_users_member_of_id() { + let fixture = TestFixture::new().await; + let users = get_user_names( + &fixture.handler, + Some(UserRequestFilter::MemberOfId(fixture.groups[0])), + ) + .await; + assert_eq!(users, vec!["bob", "patrick"]); + } + + #[tokio::test] + #[should_panic] + async fn test_list_users_invalid_userid_filter() { + let fixture = TestFixture::new().await; + get_user_names( + &fixture.handler, + Some(UserRequestFilter::Equality( + "user_id".to_string(), + "first bob".to_string(), + )), + ) + .await; + } + + #[tokio::test] + async fn test_list_users_filter_or() { + let fixture = TestFixture::new().await; + let users = get_user_names( + &fixture.handler, + Some(UserRequestFilter::Or(vec![ + UserRequestFilter::UserId(UserId::new("bob")), + UserRequestFilter::UserId(UserId::new("John")), + ])), + ) + .await; + assert_eq!(users, vec!["bob", "john"]); + } + + #[tokio::test] + async fn test_list_users_filter_many_or() { + let fixture = TestFixture::new().await; + let users = get_user_names( + &fixture.handler, + Some(UserRequestFilter::Or(vec![ + UserRequestFilter::Or(vec![]), + UserRequestFilter::Or(vec![ + UserRequestFilter::UserId(UserId::new("bob")), + UserRequestFilter::UserId(UserId::new("John")), + UserRequestFilter::UserId(UserId::new("random")), + ]), + ])), + ) + .await; + assert_eq!(users, vec!["bob", "john"]); + } + + #[tokio::test] + async fn test_list_users_filter_not() { + let fixture = TestFixture::new().await; + let users = get_user_names( + &fixture.handler, + Some(UserRequestFilter::Not(Box::new(UserRequestFilter::UserId( + UserId::new("bob"), + )))), + ) + .await; + assert_eq!(users, vec!["john", "nogroup", "patrick"]); + } + + #[tokio::test] + async fn test_list_users_with_groups() { + let fixture = TestFixture::new().await; + let users = fixture + .handler + .list_users(None, true) + .await + .unwrap() + .into_iter() + .map(|u| { + ( + u.user.user_id.to_string(), + u.user.display_name.to_string(), + u.groups + .unwrap() + .into_iter() + .map(|g| g.group_id) + .collect::>(), + ) + }) + .collect::>(); + assert_eq!( + users, + vec![ + ( + "bob".to_string(), + "display bob".to_string(), + vec![fixture.groups[0]] + ), + ( + "john".to_string(), + "display John".to_string(), + vec![fixture.groups[1]] + ), + ("nogroup".to_string(), "display NoGroup".to_string(), vec![]), + ( + "patrick".to_string(), + "display patrick".to_string(), + vec![fixture.groups[0], fixture.groups[1]] + ), + ] + ); + } + + #[tokio::test] + async fn test_list_users_groups_have_different_creation_date_than_users() { + let fixture = TestFixture::new().await; + let users = fixture + .handler + .list_users(None, true) + .await + .unwrap() + .into_iter() + .map(|u| { + ( + u.user.creation_date, + u.groups + .unwrap() + .into_iter() + .map(|g| g.creation_date) + .collect::>(), + ) + }) + .collect::>(); + for (user_date, groups) in users { + for group_date in groups { + assert_ne!(user_date, group_date); + } + } + } + + #[tokio::test] + async fn test_get_user_details() { + let handler = SqlBackendHandler::new(get_default_config(), get_initialized_db().await); + insert_user_no_password(&handler, "bob").await; + { + let user = handler.get_user_details(&UserId::new("bob")).await.unwrap(); + assert_eq!(user.user_id.as_str(), "bob"); + } + { + handler + .get_user_details(&UserId::new("John")) + .await + .unwrap_err(); + } + } + + #[tokio::test] + async fn test_user_lowercase() { + let handler = SqlBackendHandler::new(get_default_config(), get_initialized_db().await); + insert_user_no_password(&handler, "Bob").await; + { + let user = handler.get_user_details(&UserId::new("bOb")).await.unwrap(); + assert_eq!(user.user_id.as_str(), "bob"); + } + { + handler + .get_user_details(&UserId::new("John")) + .await + .unwrap_err(); + } + } + + #[tokio::test] + async fn test_delete_user() { + let fixture = TestFixture::new().await; + fixture + .handler + .delete_user(&UserId::new("bob")) + .await + .unwrap(); + + assert_eq!( + get_user_names(&fixture.handler, None).await, + vec!["john", "nogroup", "patrick"] + ); + + // Insert new user and remove two + insert_user_no_password(&fixture.handler, "NewBoi").await; + fixture + .handler + .delete_user(&UserId::new("nogroup")) + .await + .unwrap(); + fixture + .handler + .delete_user(&UserId::new("NewBoi")) + .await + .unwrap(); + + assert_eq!( + get_user_names(&fixture.handler, None).await, + vec!["john", "patrick"] + ); + } + + #[tokio::test] + async fn test_get_user_groups() { + let fixture = TestFixture::new().await; + let get_group_ids = |user: &'static str| async { + let mut groups = fixture + .handler + .get_user_groups(&UserId::new(user)) + .await + .unwrap() + .into_iter() + .map(|g| g.group_id) + .collect::>(); + groups.sort_by(|g1, g2| g1.0.cmp(&g2.0)); + groups + }; + assert_eq!(get_group_ids("bob").await, vec![fixture.groups[0]]); + assert_eq!( + get_group_ids("patrick").await, + vec![fixture.groups[0], fixture.groups[1]] + ); + assert_eq!(get_group_ids("nogroup").await, vec![]); + } + + #[tokio::test] + async fn test_update_user_all_values() { + let fixture = TestFixture::new().await; + + fixture + .handler + .update_user(UpdateUserRequest { + user_id: UserId::new("bob"), + email: Some("email".to_string()), + display_name: Some("display_name".to_string()), + first_name: Some("first_name".to_string()), + last_name: Some("last_name".to_string()), + avatar: Some(JpegPhoto::default()), + }) + .await + .unwrap(); + + let user = fixture + .handler + .get_user_details(&UserId::new("bob")) + .await + .unwrap(); + assert_eq!(user.email, "email"); + assert_eq!(user.display_name, "display_name"); + assert_eq!(user.first_name, "first_name"); + assert_eq!(user.last_name, "last_name"); + assert_eq!(user.avatar, JpegPhoto::default()); + } + + #[tokio::test] + async fn test_update_user_some_values() { + let fixture = TestFixture::new().await; + + fixture + .handler + .update_user(UpdateUserRequest { + user_id: UserId::new("bob"), + first_name: Some("first_name".to_string()), + last_name: Some(String::new()), + ..Default::default() + }) + .await + .unwrap(); + + let user = fixture + .handler + .get_user_details(&UserId::new("bob")) + .await + .unwrap(); + assert_eq!(user.display_name, "display bob"); + assert_eq!(user.first_name, "first_name"); + assert_eq!(user.last_name, ""); + } + + #[tokio::test] + async fn test_remove_user_from_group() { + let fixture = TestFixture::new().await; + + fixture + .handler + .remove_user_from_group(&UserId::new("bob"), fixture.groups[0]) + .await + .unwrap(); + + assert_eq!( + get_user_names( + &fixture.handler, + Some(UserRequestFilter::MemberOfId(fixture.groups[0])), + ) + .await, + vec!["patrick"] + ); + } +} diff --git a/server/src/infra/ldap_handler.rs b/server/src/infra/ldap_handler.rs index e6e15d1..4038fce 100644 --- a/server/src/infra/ldap_handler.rs +++ b/server/src/infra/ldap_handler.rs @@ -467,22 +467,27 @@ mod tests { async fn bind(&self, request: BindRequest) -> Result<()>; } #[async_trait] - impl BackendHandler for TestBackendHandler { - async fn list_users(&self, filters: Option, get_groups: bool) -> Result>; + impl GroupBackendHandler for TestBackendHandler { async fn list_groups(&self, filters: Option) -> Result>; - async fn get_user_details(&self, user_id: &UserId) -> Result; async fn get_group_details(&self, group_id: GroupId) -> Result; - async fn get_user_groups(&self, user: &UserId) -> Result>; - async fn create_user(&self, request: CreateUserRequest) -> Result<()>; - async fn update_user(&self, request: UpdateUserRequest) -> Result<()>; async fn update_group(&self, request: UpdateGroupRequest) -> Result<()>; - async fn delete_user(&self, user_id: &UserId) -> Result<()>; async fn create_group(&self, group_name: &str) -> Result; async fn delete_group(&self, group_id: GroupId) -> Result<()>; + } + #[async_trait] + impl UserBackendHandler for TestBackendHandler { + async fn list_users(&self, filters: Option, get_groups: bool) -> Result>; + async fn get_user_details(&self, user_id: &UserId) -> Result; + async fn create_user(&self, request: CreateUserRequest) -> Result<()>; + async fn update_user(&self, request: UpdateUserRequest) -> Result<()>; + async fn delete_user(&self, user_id: &UserId) -> Result<()>; + async fn get_user_groups(&self, user_id: &UserId) -> Result>; async fn add_user_to_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()>; async fn remove_user_from_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()>; } #[async_trait] + impl BackendHandler for TestBackendHandler {} + #[async_trait] impl OpaqueHandler for TestBackendHandler { async fn login_start( &self, diff --git a/server/src/infra/tcp_backend_handler.rs b/server/src/infra/tcp_backend_handler.rs index 093ec2b..8bdf97d 100644 --- a/server/src/infra/tcp_backend_handler.rs +++ b/server/src/infra/tcp_backend_handler.rs @@ -34,30 +34,24 @@ mockall::mock! { async fn bind(&self, request: BindRequest) -> Result<()>; } #[async_trait] - impl BackendHandler for TestTcpBackendHandler { - async fn list_users(&self, filters: Option, get_groups: bool) -> Result>; + impl GroupBackendHandler for TestTcpBackendHandler { async fn list_groups(&self, filters: Option) -> Result>; - async fn get_user_details(&self, user_id: &UserId) -> Result; async fn get_group_details(&self, group_id: GroupId) -> Result; - async fn get_user_groups(&self, user: &UserId) -> Result>; - async fn create_user(&self, request: CreateUserRequest) -> Result<()>; - async fn update_user(&self, request: UpdateUserRequest) -> Result<()>; async fn update_group(&self, request: UpdateGroupRequest) -> Result<()>; - async fn delete_user(&self, user_id: &UserId) -> Result<()>; async fn create_group(&self, group_name: &str) -> Result; async fn delete_group(&self, group_id: GroupId) -> Result<()>; + } + #[async_trait] + impl UserBackendHandler for TestBackendHandler { + async fn list_users(&self, filters: Option, get_groups: bool) -> Result>; + async fn get_user_details(&self, user_id: &UserId) -> Result; + async fn create_user(&self, request: CreateUserRequest) -> Result<()>; + async fn update_user(&self, request: UpdateUserRequest) -> Result<()>; + async fn delete_user(&self, user_id: &UserId) -> Result<()>; + async fn get_user_groups(&self, user_id: &UserId) -> Result>; async fn add_user_to_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()>; async fn remove_user_from_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()>; } #[async_trait] - impl TcpBackendHandler for TestTcpBackendHandler { - async fn get_jwt_blacklist(&self) -> anyhow::Result>; - async fn create_refresh_token(&self, user: &UserId) -> Result<(String, chrono::Duration)>; - async fn check_token(&self, refresh_token_hash: u64, user: &UserId) -> Result; - async fn blacklist_jwts(&self, user: &UserId) -> Result>; - async fn delete_refresh_token(&self, refresh_token_hash: u64) -> Result<()>; - async fn start_password_reset(&self, user: &UserId) -> Result>; - async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result; - async fn delete_password_reset_token(&self, token: &str) -> Result<()>; - } + impl BackendHandler for TestTcpBackendHandler {} } diff --git a/server/src/main.rs b/server/src/main.rs index d315799..87735dc 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -6,7 +6,7 @@ use std::time::Duration; use crate::{ domain::{ - handler::{BackendHandler, CreateUserRequest, GroupRequestFilter}, + handler::{CreateUserRequest, GroupBackendHandler, GroupRequestFilter, UserBackendHandler}, sql_backend_handler::SqlBackendHandler, sql_opaque_handler::register_password, sql_tables::PoolOptions,