diff --git a/src/domain/handler.rs b/src/domain/handler.rs index 1af5ff0..3dd676b 100644 --- a/src/domain/handler.rs +++ b/src/domain/handler.rs @@ -4,16 +4,19 @@ use crate::infra::configuration::Configuration; use async_trait::async_trait; use futures_util::StreamExt; use futures_util::TryStreamExt; -pub use lldap_model::*; use log::*; use sea_query::{Expr, Order, Query, SimpleExpr, SqliteQueryBuilder}; use sqlx::Row; +use std::collections::HashSet; + +pub use lldap_model::*; #[async_trait] pub trait BackendHandler: Clone + Send { async fn bind(&self, request: BindRequest) -> Result<()>; async fn list_users(&self, request: ListUsersRequest) -> Result>; async fn list_groups(&self) -> Result>; + async fn get_user_groups(&self, user: String) -> Result>; } #[derive(Debug, Clone)] @@ -155,6 +158,33 @@ impl BackendHandler for SqlBackendHandler { Ok(groups) } + + async fn get_user_groups(&self, user: String) -> Result> { + let query: String = Query::select() + .column(Groups::DisplayName) + .from(Groups::Table) + .inner_join( + Memberships::Table, + Expr::tbl(Groups::Table, Groups::GroupId) + .equals(Memberships::Table, Memberships::GroupId), + ) + .and_where(Expr::col(Memberships::UserId).eq(user)) + .to_string(SqliteQueryBuilder); + + sqlx::query(&query) + // Extract the group id from the row. + .map(|row: DbRow| row.get::("display_name")) + .fetch(&self.sql_pool) + // Collect the vector of rows, each potentially an error. + .collect::>>() + .await + .into_iter() + // Transform it into a single result (the first error if any), and group the group_ids + // into a HashSet. + .collect::>>() + // Map the sqlx::Error into a domain::Error. + .map_err(Error::DatabaseError) + } } #[cfg(test)] @@ -168,6 +198,7 @@ mockall::mock! { async fn bind(&self, request: BindRequest) -> Result<()>; async fn list_users(&self, request: ListUsersRequest) -> Result>; async fn list_groups(&self) -> Result>; + async fn get_user_groups(&self, user: String) -> Result>; } } @@ -369,4 +400,39 @@ mod tests { ] ); } + + #[tokio::test] + async fn test_get_user_groups() { + let sql_pool = get_initialized_db().await; + insert_user(&sql_pool, "bob", "bob00").await; + insert_user(&sql_pool, "patrick", "pass").await; + insert_user(&sql_pool, "John", "Pa33w0rd!").await; + insert_group(&sql_pool, 1, "Group1").await; + insert_group(&sql_pool, 2, "Group2").await; + insert_membership(&sql_pool, 1, "bob").await; + insert_membership(&sql_pool, 1, "patrick").await; + insert_membership(&sql_pool, 2, "patrick").await; + let config = Configuration::default(); + let handler = SqlBackendHandler::new(config, sql_pool); + let mut bob_groups = HashSet::new(); + bob_groups.insert("Group1".to_string()); + let mut patrick_groups = HashSet::new(); + patrick_groups.insert("Group1".to_string()); + patrick_groups.insert("Group2".to_string()); + assert_eq!( + handler.get_user_groups("bob".to_string()).await.unwrap(), + bob_groups + ); + assert_eq!( + handler + .get_user_groups("patrick".to_string()) + .await + .unwrap(), + patrick_groups + ); + assert_eq!( + handler.get_user_groups("John".to_string()).await.unwrap(), + HashSet::new() + ); + } } diff --git a/src/domain/sql_tables.rs b/src/domain/sql_tables.rs index ed53c70..cc74834 100644 --- a/src/domain/sql_tables.rs +++ b/src/domain/sql_tables.rs @@ -2,6 +2,7 @@ use sea_query::*; pub type Pool = sqlx::sqlite::SqlitePool; pub type PoolOptions = sqlx::sqlite::SqlitePoolOptions; +pub type DbRow = sqlx::sqlite::SqliteRow; #[derive(Iden)] pub enum Users {