diff --git a/server/src/domain/handler.rs b/server/src/domain/handler.rs index bbb3315..8862dc3 100644 --- a/server/src/domain/handler.rs +++ b/server/src/domain/handler.rs @@ -54,6 +54,17 @@ pub enum UserRequestFilter { MemberOfId(GroupId), } +#[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Clone)] +pub enum GroupRequestFilter { + And(Vec), + Or(Vec), + Not(Box), + DisplayName(String), + GroupId(GroupId), + // Check if the group contains a user identified by uid. + Member(String), +} + #[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Clone, Default)] pub struct CreateUserRequest { // Same fields as User, but no creation_date, and with password. @@ -94,7 +105,7 @@ pub struct GroupIdAndName(pub GroupId, pub String); #[async_trait] pub trait BackendHandler: Clone + Send { async fn list_users(&self, filters: Option) -> Result>; - async fn list_groups(&self) -> Result>; + async fn list_groups(&self, filters: Option) -> Result>; async fn get_user_details(&self, user_id: &str) -> Result; async fn get_group_details(&self, group_id: GroupId) -> Result; async fn create_user(&self, request: CreateUserRequest) -> Result<()>; @@ -117,7 +128,7 @@ mockall::mock! { #[async_trait] impl BackendHandler for TestBackendHandler { async fn list_users(&self, filters: Option) -> Result>; - async fn list_groups(&self) -> Result>; + async fn list_groups(&self, filters: Option) -> Result>; async fn get_user_details(&self, user_id: &str) -> Result; async fn get_group_details(&self, group_id: GroupId) -> Result; async fn create_user(&self, request: CreateUserRequest) -> Result<()>; diff --git a/server/src/domain/sql_backend_handler.rs b/server/src/domain/sql_backend_handler.rs index 8b5312d..2211434 100644 --- a/server/src/domain/sql_backend_handler.rs +++ b/server/src/domain/sql_backend_handler.rs @@ -21,7 +21,7 @@ impl SqlBackendHandler { struct RequiresGroup(bool); // Returns the condition for the SQL query, and whether it requires joining with the groups table. -fn get_filter_expr(filter: UserRequestFilter) -> (RequiresGroup, SimpleExpr) { +fn get_user_filter_expr(filter: UserRequestFilter) -> (RequiresGroup, SimpleExpr) { use UserRequestFilter::*; fn get_repeated_filter( fs: Vec, @@ -32,13 +32,13 @@ fn get_filter_expr(filter: UserRequestFilter) -> (RequiresGroup, SimpleExpr) { let first_expr = match it.next() { None => return (RequiresGroup(false), Expr::value(true)), Some(f) => { - let (group, filter) = get_filter_expr(f); + let (group, filter) = get_user_filter_expr(f); requires_group |= group.0; filter } }; let filter = it.fold(first_expr, |e, f| { - let (group, filters) = get_filter_expr(f); + let (group, filters) = get_user_filter_expr(f); requires_group |= group.0; field(e, filters) }); @@ -48,7 +48,7 @@ fn get_filter_expr(filter: UserRequestFilter) -> (RequiresGroup, SimpleExpr) { And(fs) => get_repeated_filter(fs, &SimpleExpr::and), Or(fs) => get_repeated_filter(fs, &SimpleExpr::or), Not(f) => { - let (requires_group, filters) = get_filter_expr(*f); + let (requires_group, filters) = get_user_filter_expr(*f); (requires_group, Expr::not(Expr::expr(filters))) } Equality(s1, s2) => ( @@ -72,6 +72,37 @@ fn get_filter_expr(filter: UserRequestFilter) -> (RequiresGroup, SimpleExpr) { } } +// Returns the condition for the SQL query, and whether it requires joining with the groups table. +fn get_group_filter_expr(filter: GroupRequestFilter) -> SimpleExpr { + use GroupRequestFilter::*; + fn get_repeated_filter( + fs: Vec, + field: &dyn Fn(SimpleExpr, SimpleExpr) -> SimpleExpr, + ) -> SimpleExpr { + let mut it = fs.into_iter(); + let first_expr = match it.next() { + None => return Expr::value(true), + Some(f) => get_group_filter_expr(f), + }; + it.fold(first_expr, |e, f| field(e, get_group_filter_expr(f))) + } + match filter { + And(fs) => get_repeated_filter(fs, &SimpleExpr::and), + Or(fs) => get_repeated_filter(fs, &SimpleExpr::or), + Not(f) => Expr::not(Expr::expr(get_group_filter_expr(*f))), + DisplayName(name) => Expr::col((Groups::Table, Groups::DisplayName)).eq(name), + GroupId(id) => Expr::col((Groups::Table, Groups::GroupId)).eq(id.0), + // WHERE (group_id in (SELECT group_id FROM memberships WHERE user_id = user)) + Member(user) => Expr::col((Memberships::Table, Memberships::GroupId)).in_subquery( + Query::select() + .column(Memberships::GroupId) + .from(Memberships::Table) + .and_where(Expr::col(Memberships::UserId).eq(user)) + .take(), + ), + } +} + #[async_trait] impl BackendHandler for SqlBackendHandler { async fn list_users(&self, filters: Option) -> Result> { @@ -88,17 +119,13 @@ impl BackendHandler for SqlBackendHandler { .order_by((Users::Table, Users::UserId), Order::Asc) .to_owned(); if let Some(filter) = filters { - if filter - == UserRequestFilter::Not(Box::new( - UserRequestFilter::And(Vec::new()), - )) - { + if filter == UserRequestFilter::Not(Box::new(UserRequestFilter::And(Vec::new()))) { return Ok(Vec::new()); } if filter != UserRequestFilter::And(Vec::new()) && filter != UserRequestFilter::Or(Vec::new()) { - let (RequiresGroup(requires_group), condition) = get_filter_expr(filter); + let (RequiresGroup(requires_group), condition) = get_user_filter_expr(filter); query_builder.and_where(condition); if requires_group { query_builder @@ -127,20 +154,36 @@ impl BackendHandler for SqlBackendHandler { Ok(results.into_iter().collect::>>()?) } - async fn list_groups(&self) -> Result> { - let query: String = Query::select() - .column((Groups::Table, Groups::GroupId)) - .column(Groups::DisplayName) - .column(Memberships::UserId) - .from(Groups::Table) - .left_join( - Memberships::Table, - Expr::tbl(Groups::Table, Groups::GroupId) - .equals(Memberships::Table, Memberships::GroupId), - ) - .order_by(Groups::DisplayName, Order::Asc) - .order_by(Memberships::UserId, Order::Asc) - .to_string(DbQueryBuilder {}); + async fn list_groups(&self, filters: Option) -> Result> { + let query: String = { + let mut query_builder = Query::select() + .column((Groups::Table, Groups::GroupId)) + .column(Groups::DisplayName) + .column(Memberships::UserId) + .from(Groups::Table) + .left_join( + Memberships::Table, + Expr::tbl(Groups::Table, Groups::GroupId) + .equals(Memberships::Table, Memberships::GroupId), + ) + .order_by(Groups::DisplayName, Order::Asc) + .order_by(Memberships::UserId, Order::Asc) + .to_owned(); + + if let Some(filter) = filters { + if filter == GroupRequestFilter::Not(Box::new(GroupRequestFilter::And(Vec::new()))) + { + return Ok(Vec::new()); + } + if filter != GroupRequestFilter::And(Vec::new()) + && filter != GroupRequestFilter::Or(Vec::new()) + { + query_builder.and_where(get_group_filter_expr(filter)); + } + } + + query_builder.to_string(DbQueryBuilder {}) + }; // For group_by. use itertools::Itertools; @@ -546,10 +589,9 @@ mod tests { } { let users = handler - .list_users(Some(UserRequestFilter::Not(Box::new(UserRequestFilter::Equality( - "user_id".to_string(), - "bob".to_string(), - ))))) + .list_users(Some(UserRequestFilter::Not(Box::new( + UserRequestFilter::Equality("user_id".to_string(), "bob".to_string()), + )))) .await .unwrap() .into_iter() @@ -575,7 +617,7 @@ mod tests { insert_membership(&handler, group_2, "patrick").await; insert_membership(&handler, group_2, "John").await; assert_eq!( - handler.list_groups().await.unwrap(), + handler.list_groups(None).await.unwrap(), vec![ Group { id: group_1, @@ -594,6 +636,43 @@ mod tests { }, ] ); + assert_eq!( + handler + .list_groups(Some(GroupRequestFilter::Or(vec![ + GroupRequestFilter::DisplayName("Empty Group".to_string()), + GroupRequestFilter::Member("bob".to_string()), + ]))) + .await + .unwrap(), + vec![ + Group { + id: group_1, + display_name: "Best Group".to_string(), + users: vec!["bob".to_string(), "patrick".to_string()] + }, + Group { + id: group_3, + display_name: "Empty Group".to_string(), + users: vec![] + }, + ] + ); + assert_eq!( + handler + .list_groups(Some(GroupRequestFilter::And(vec![ + GroupRequestFilter::Not(Box::new(GroupRequestFilter::DisplayName( + "value".to_string() + ))), + GroupRequestFilter::GroupId(group_1), + ]))) + .await + .unwrap(), + vec![Group { + id: group_1, + display_name: "Best Group".to_string(), + users: vec!["bob".to_string(), "patrick".to_string()] + }] + ); } #[tokio::test] diff --git a/server/src/infra/graphql/query.rs b/server/src/infra/graphql/query.rs index 26593b1..3804127 100644 --- a/server/src/infra/graphql/query.rs +++ b/server/src/infra/graphql/query.rs @@ -134,7 +134,7 @@ impl Query { } Ok(context .handler - .list_groups() + .list_groups(None) .await .map(|v| v.into_iter().map(Into::into).collect())?) } diff --git a/server/src/infra/ldap_handler.rs b/server/src/infra/ldap_handler.rs index ad707ed..9e47659 100644 --- a/server/src/infra/ldap_handler.rs +++ b/server/src/infra/ldap_handler.rs @@ -1,12 +1,11 @@ use crate::domain::{ handler::{ - BackendHandler, BindRequest, Group, GroupIdAndName, LoginHandler, UserRequestFilter, User, + BackendHandler, BindRequest, Group, GroupRequestFilter, LoginHandler, User, + UserRequestFilter, }, opaque_handler::OpaqueHandler, }; use anyhow::{bail, Context, Result}; -use futures::stream::StreamExt; -use futures_util::TryStreamExt; use ldap3_server::proto::{ LdapBindCred, LdapBindRequest, LdapBindResponse, LdapExtendedRequest, LdapExtendedResponse, LdapFilter, LdapOp, LdapPartialAttribute, LdapPasswordModifyRequest, LdapResult, @@ -479,8 +478,8 @@ impl LdapHandler Vec { - let for_user = match self.get_group_filter(&request.filter) { - Ok(u) => u, + let filter = match self.convert_group_filter(&request.filter) { + Ok(f) => f, Err(e) => { return vec![make_search_error( LdapResultCode::UnwillingToPerform, @@ -489,55 +488,13 @@ impl LdapHandler( - backend_handler: &Backend, - g: &GroupIdAndName, - ) -> Result { - let users = backend_handler - .list_users(Some(UserRequestFilter::MemberOfId(g.0))) - .await?; - Ok(Group { - id: g.0, - display_name: g.1.clone(), - users: users.into_iter().map(|u| u.user_id).collect(), - }) - } - - let groups: Vec = if let Some(user) = for_user { - let groups_without_users = match self.backend_handler.get_user_groups(&user).await { - Ok(groups) => groups, - Err(e) => { - return vec![make_search_error( - LdapResultCode::Other, - format!( - r#"Error while listing user groups: "{}": {:#}"#, - request.base, e - ), - )] - } - }; - match tokio_stream::iter(groups_without_users.iter()) - .then(|g| async move { get_users_for_group::(&self.backend_handler, g).await }) - .try_collect::>() - .await - { - Ok(groups) => groups, - Err(e) => { - return vec![make_search_error( - LdapResultCode::Other, - format!(r#"Error while listing user groups: "{}": {:#}"#, request.base, e), - )] - } - } - } else { - match self.backend_handler.list_groups().await { - Ok(groups) => groups, - Err(e) => { - return vec![make_search_error( - LdapResultCode::Other, - format!(r#"Error while listing groups "{}": {:#}"#, request.base, e), - )] - } + let groups = match self.backend_handler.list_groups(Some(filter)).await { + Ok(groups) => groups, + Err(e) => { + return vec![make_search_error( + LdapResultCode::Other, + format!(r#"Error while listing groups "{}": {:#}"#, request.base, e), + )] } }; @@ -582,7 +539,7 @@ impl LdapHandler Result> { + fn convert_group_filter(&self, filter: &LdapFilter) -> Result { match filter { LdapFilter::Equality(field, value) => { if field == "member" || field.to_lowercase() == "uniquemember" { @@ -591,16 +548,33 @@ impl LdapHandler v - .iter() - .fold(Ok(None), |o, f| Ok(o?.xor(self.get_group_filter(f)?))), + LdapFilter::And(filters) => Ok(GroupRequestFilter::And( + filters + .iter() + .map(|f| self.convert_group_filter(f)) + .collect::>()?, + )), + LdapFilter::Or(filters) => Ok(GroupRequestFilter::Or( + filters + .iter() + .map(|f| self.convert_group_filter(f)) + .collect::>()?, + )), + LdapFilter::Not(filter) => Ok(GroupRequestFilter::Not(Box::new( + self.convert_group_filter(&*filter)?, + ))), _ => bail!("Unsupported group filter: {:?}", filter), } } @@ -638,10 +612,15 @@ impl LdapHandler { @@ -649,7 +628,9 @@ impl LdapHandler bail!("Unsupported user filter: {:?}", filter), @@ -679,7 +660,7 @@ mod tests { #[async_trait] impl BackendHandler for TestBackendHandler { async fn list_users(&self, filters: Option) -> Result>; - async fn list_groups(&self) -> Result>; + async fn list_groups(&self, filters: Option) -> Result>; async fn get_user_details(&self, user_id: &str) -> Result; async fn get_group_details(&self, group_id: GroupId) -> Result; async fn get_user_groups(&self, user: &str) -> Result>; @@ -1048,20 +1029,23 @@ mod tests { #[tokio::test] async fn test_search_groups() { let mut mock = MockTestBackendHandler::new(); - mock.expect_list_groups().times(1).return_once(|| { - Ok(vec![ - Group { - id: GroupId(1), - display_name: "group_1".to_string(), - users: vec!["bob".to_string(), "john".to_string()], - }, - Group { - id: GroupId(3), - display_name: "bestgroup".to_string(), - users: vec!["john".to_string()], - }, - ]) - }); + mock.expect_list_groups() + .with(eq(Some(GroupRequestFilter::And(vec![])))) + .times(1) + .return_once(|_| { + Ok(vec![ + Group { + id: GroupId(1), + display_name: "group_1".to_string(), + users: vec!["bob".to_string(), "john".to_string()], + }, + Group { + id: GroupId(3), + display_name: "bestgroup".to_string(), + users: vec!["john".to_string()], + }, + ]) + }); let mut ldap_handler = setup_bound_handler(mock).await; let request = make_search_request( "ou=groups,dc=example,dc=com", @@ -1124,27 +1108,25 @@ mod tests { #[tokio::test] async fn test_search_groups_filter() { let mut mock = MockTestBackendHandler::new(); - mock.expect_get_user_groups() - .with(eq("bob")) + mock.expect_list_groups() + .with(eq(Some(GroupRequestFilter::And(vec![ + GroupRequestFilter::DisplayName("group_1".to_string()), + GroupRequestFilter::Member("bob".to_string()), + GroupRequestFilter::And(vec![]), + ])))) .times(1) .return_once(|_| { - let mut set = HashSet::new(); - set.insert(GroupIdAndName(GroupId(1), "group_1".to_string())); - Ok(set) - }); - mock.expect_list_users() - .with(eq(Some(UserRequestFilter::MemberOfId(GroupId(1))))) - .times(1) - .return_once(|_| { - Ok(vec![User { - user_id: "bob".to_string(), - ..Default::default() + Ok(vec![Group { + display_name: "group_1".to_string(), + id: GroupId(1), + users: vec![], }]) }); let mut ldap_handler = setup_bound_handler(mock).await; let request = make_search_request( "ou=groups,dc=example,dc=com", LdapFilter::And(vec![ + LdapFilter::Equality("cn".to_string(), "group_1".to_string()), LdapFilter::Equality( "uniqueMember".to_string(), "cn=bob,ou=people,dc=example,dc=com".to_string(), @@ -1168,21 +1150,117 @@ mod tests { ); } + #[tokio::test] + async fn test_search_groups_filter_2() { + let mut mock = MockTestBackendHandler::new(); + mock.expect_list_groups() + .with(eq(Some(GroupRequestFilter::Or(vec![ + GroupRequestFilter::Not(Box::new(GroupRequestFilter::DisplayName( + "group_2".to_string(), + ))), + ])))) + .times(1) + .return_once(|_| { + Ok(vec![Group { + display_name: "group_1".to_string(), + id: GroupId(1), + users: vec![], + }]) + }); + let mut ldap_handler = setup_bound_handler(mock).await; + let request = make_search_request( + "ou=groups,dc=example,dc=com", + LdapFilter::Or(vec![LdapFilter::Not(Box::new(LdapFilter::Equality( + "displayname".to_string(), + "group_2".to_string(), + )))]), + vec!["cn"], + ); + assert_eq!( + ldap_handler.do_search(&request).await, + vec![ + LdapOp::SearchResultEntry(LdapSearchResultEntry { + dn: "cn=group_1,ou=groups,dc=example,dc=com".to_string(), + attributes: vec![LdapPartialAttribute { + atype: "cn".to_string(), + vals: vec!["group_1".to_string()] + },], + }), + make_search_success(), + ] + ); + } + + #[tokio::test] + async fn test_search_groups_error() { + let mut mock = MockTestBackendHandler::new(); + mock.expect_list_groups() + .with(eq(Some(GroupRequestFilter::Or(vec![ + GroupRequestFilter::Not(Box::new(GroupRequestFilter::DisplayName( + "group_2".to_string(), + ))), + ])))) + .times(1) + .return_once(|_| { + Err(crate::domain::error::DomainError::InternalError( + "Error getting groups".to_string(), + )) + }); + let mut ldap_handler = setup_bound_handler(mock).await; + let request = make_search_request( + "ou=groups,dc=example,dc=com", + LdapFilter::Or(vec![LdapFilter::Not(Box::new(LdapFilter::Equality( + "displayname".to_string(), + "group_2".to_string(), + )))]), + vec!["cn"], + ); + assert_eq!( + ldap_handler.do_search(&request).await, + vec![make_search_error( + LdapResultCode::Other, + r#"Error while listing groups "ou=groups,dc=example,dc=com": Internal error: `Error getting groups`"#.to_string() + )] + ); + } + + #[tokio::test] + async fn test_search_groups_filter_error() { + let mut ldap_handler = setup_bound_handler(MockTestBackendHandler::new()).await; + let request = make_search_request( + "ou=groups,dc=example,dc=com", + LdapFilter::And(vec![LdapFilter::Equality( + "whatever".to_string(), + "group_1".to_string(), + )]), + vec!["cn"], + ); + assert_eq!( + ldap_handler.do_search(&request).await, + vec![make_search_error( + LdapResultCode::UnwillingToPerform, + "Unsupported group filter: Unknown field: whatever".to_string() + )] + ); + } + #[tokio::test] async fn test_search_filters() { let mut mock = MockTestBackendHandler::new(); mock.expect_list_users() - .with(eq(Some(UserRequestFilter::And(vec![UserRequestFilter::Or(vec![ - UserRequestFilter::Not(Box::new(UserRequestFilter::Equality( - "user_id".to_string(), - "bob".to_string(), - ))), - UserRequestFilter::And(vec![]), - UserRequestFilter::Not(Box::new(UserRequestFilter::And(vec![]))), - UserRequestFilter::And(vec![]), - UserRequestFilter::And(vec![]), - UserRequestFilter::Not(Box::new(UserRequestFilter::And(vec![]))), - ])])))) + .with(eq(Some(UserRequestFilter::And(vec![ + UserRequestFilter::Or(vec![ + UserRequestFilter::Not(Box::new(UserRequestFilter::Equality( + "user_id".to_string(), + "bob".to_string(), + ))), + UserRequestFilter::And(vec![]), + UserRequestFilter::Not(Box::new(UserRequestFilter::And(vec![]))), + UserRequestFilter::And(vec![]), + UserRequestFilter::And(vec![]), + UserRequestFilter::Not(Box::new(UserRequestFilter::And(vec![]))), + ]), + ])))) .times(1) .return_once(|_| Ok(vec![])); let mut ldap_handler = setup_bound_handler(mock).await; @@ -1256,12 +1334,11 @@ mod tests { async fn test_search_filters_lowercase() { let mut mock = MockTestBackendHandler::new(); mock.expect_list_users() - .with(eq(Some(UserRequestFilter::And(vec![UserRequestFilter::Or(vec![ - UserRequestFilter::Not(Box::new(UserRequestFilter::Equality( - "first_name".to_string(), - "bob".to_string(), - ))), - ])])))) + .with(eq(Some(UserRequestFilter::And(vec![ + UserRequestFilter::Or(vec![UserRequestFilter::Not(Box::new( + UserRequestFilter::Equality("first_name".to_string(), "bob".to_string()), + ))]), + ])))) .times(1) .return_once(|_| { Ok(vec![User { @@ -1309,13 +1386,16 @@ mod tests { ..Default::default() }]) }); - mock.expect_list_groups().times(1).return_once(|| { - Ok(vec![Group { - id: GroupId(1), - display_name: "group_1".to_string(), - users: vec!["bob".to_string(), "john".to_string()], - }]) - }); + mock.expect_list_groups() + .with(eq(Some(GroupRequestFilter::And(vec![])))) + .times(1) + .return_once(|_| { + Ok(vec![Group { + id: GroupId(1), + display_name: "group_1".to_string(), + users: vec!["bob".to_string(), "john".to_string()], + }]) + }); let mut ldap_handler = setup_bound_handler(mock).await; let request = make_search_request( "dc=example,dc=com", diff --git a/server/src/infra/tcp_backend_handler.rs b/server/src/infra/tcp_backend_handler.rs index f0dfa02..79fa102 100644 --- a/server/src/infra/tcp_backend_handler.rs +++ b/server/src/infra/tcp_backend_handler.rs @@ -36,7 +36,7 @@ mockall::mock! { #[async_trait] impl BackendHandler for TestTcpBackendHandler { async fn list_users(&self, filters: Option) -> Result>; - async fn list_groups(&self) -> Result>; + async fn list_groups(&self, filters: Option) -> Result>; async fn get_user_details(&self, user_id: &str) -> Result; async fn get_group_details(&self, group_id: GroupId) -> Result; async fn get_user_groups(&self, user: &str) -> Result>;