diff --git a/schema.graphql b/schema.graphql index 7f9eb58..ec40c5c 100644 --- a/schema.graphql +++ b/schema.graphql @@ -29,6 +29,8 @@ input RequestFilter { all: [RequestFilter!] not: RequestFilter eq: EqualityConstraint + memberOf: String + memberOfId: Int } "DateTime" diff --git a/server/src/domain/handler.rs b/server/src/domain/handler.rs index 2f1f198..02d841a 100644 --- a/server/src/domain/handler.rs +++ b/server/src/domain/handler.rs @@ -48,8 +48,10 @@ pub enum RequestFilter { Or(Vec), Not(Box), Equality(String, String), - // Check if a user belongs to a group. + // Check if a user belongs to a group identified by name. MemberOf(String), + // Same, by id. + MemberOfId(GroupId), } #[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Clone, Default)] diff --git a/server/src/domain/sql_backend_handler.rs b/server/src/domain/sql_backend_handler.rs index b398546..89e0893 100644 --- a/server/src/domain/sql_backend_handler.rs +++ b/server/src/domain/sql_backend_handler.rs @@ -65,6 +65,10 @@ fn get_filter_expr(filter: RequestFilter) -> (RequiresGroup, SimpleExpr) { RequiresGroup(true), Expr::col((Groups::Table, Groups::DisplayName)).eq(group), ), + MemberOfId(group_id) => ( + RequiresGroup(true), + Expr::col((Groups::Table, Groups::GroupId)).eq(group_id), + ), } } diff --git a/server/src/infra/graphql/query.rs b/server/src/infra/graphql/query.rs index a8dbcb7..203c8d4 100644 --- a/server/src/infra/graphql/query.rs +++ b/server/src/infra/graphql/query.rs @@ -1,4 +1,4 @@ -use crate::domain::handler::{BackendHandler, GroupIdAndName}; +use crate::domain::handler::{BackendHandler, GroupId, GroupIdAndName}; use juniper::{graphql_object, FieldResult, GraphQLInputObject}; use serde::{Deserialize, Serialize}; use std::convert::TryInto; @@ -16,6 +16,8 @@ pub struct RequestFilter { all: Option>, not: Option>, eq: Option, + member_of: Option, + member_of_id: Option, } impl TryInto for RequestFilter { @@ -34,6 +36,12 @@ impl TryInto for RequestFilter { if self.eq.is_some() { field_count += 1; } + if self.member_of.is_some() { + field_count += 1; + } + if self.member_of_id.is_some() { + field_count += 1; + } if field_count == 0 { return Err("No field specified in request filter".to_string()); } @@ -60,6 +68,12 @@ impl TryInto for RequestFilter { if let Some(c) = self.not { return Ok(DomainRequestFilter::Not(Box::new((*c).try_into()?))); } + if let Some(group) = self.member_of { + return Ok(DomainRequestFilter::MemberOf(group)); + } + if let Some(group_id) = self.member_of_id { + return Ok(DomainRequestFilter::MemberOfId(GroupId(group_id))); + } unreachable!(); } } @@ -239,10 +253,7 @@ impl From for Group { #[cfg(test)] mod tests { use super::*; - use crate::{ - domain::handler::{GroupId, GroupIdAndName, MockTestBackendHandler}, - infra::auth_service::ValidationResults, - }; + use crate::{domain::handler::MockTestBackendHandler, infra::auth_service::ValidationResults}; use juniper::{ execute, graphql_value, DefaultScalarValue, EmptyMutation, EmptySubscription, GraphQLType, RootNode, Variables, diff --git a/server/src/infra/ldap_handler.rs b/server/src/infra/ldap_handler.rs index e862da9..1e44b3f 100644 --- a/server/src/infra/ldap_handler.rs +++ b/server/src/infra/ldap_handler.rs @@ -318,8 +318,7 @@ impl LdapHandler { #[cfg(test)] mod tests { use super::*; - use crate::domain::handler::BindRequest; - use crate::domain::handler::MockTestBackendHandler; + use crate::domain::handler::{BindRequest, MockTestBackendHandler}; use mockall::predicate::eq; use tokio; @@ -665,14 +664,17 @@ mod tests { msgid: 2, base: "ou=people,dc=example,dc=com".to_string(), scope: LdapSearchScope::Base, - filter: LdapFilter::Present("uid".to_string()), + filter: LdapFilter::Substring( + "uid".to_string(), + ldap3_server::proto::LdapSubstringFilter::default(), + ), attrs: vec!["objectClass".to_string()], }; assert_eq!( ldap_handler.do_search(&request).await, vec![request.gen_error( LdapResultCode::UnwillingToPerform, - "Unsupported filter".to_string() + "Unsupported filter: Unsupported filter: Substring(\"uid\", LdapSubstringFilter { initial: None, any: [], final_: None })".to_string() )] ); }