diff --git a/server/src/domain/handler.rs b/server/src/domain/handler.rs index 1eff98f..2f1f198 100644 --- a/server/src/domain/handler.rs +++ b/server/src/domain/handler.rs @@ -48,6 +48,8 @@ pub enum RequestFilter { Or(Vec), Not(Box), Equality(String, String), + // Check if a user belongs to a group. + MemberOf(String), } #[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Clone, Default)] diff --git a/server/src/domain/sql_backend_handler.rs b/server/src/domain/sql_backend_handler.rs index a73998a..b398546 100644 --- a/server/src/domain/sql_backend_handler.rs +++ b/server/src/domain/sql_backend_handler.rs @@ -18,24 +18,53 @@ impl SqlBackendHandler { } } -fn get_filter_expr(filter: RequestFilter) -> SimpleExpr { +struct RequiresGroup(bool); + +// Returns the condition for the SQL query, and whether it requires joining with the groups table. +fn get_filter_expr(filter: RequestFilter) -> (RequiresGroup, SimpleExpr) { use RequestFilter::*; fn get_repeated_filter( fs: Vec, field: &dyn Fn(SimpleExpr, SimpleExpr) -> SimpleExpr, - ) -> SimpleExpr { + ) -> (RequiresGroup, SimpleExpr) { + let mut requires_group = false; let mut it = fs.into_iter(); let first_expr = match it.next() { - None => return Expr::value(true), - Some(f) => get_filter_expr(f), + None => return (RequiresGroup(false), Expr::value(true)), + Some(f) => { + let (group, filter) = get_filter_expr(f); + requires_group |= group.0; + filter + } }; - it.fold(first_expr, |e, f| field(e, get_filter_expr(f))) + let filter = it.fold(first_expr, |e, f| { + let (group, filters) = get_filter_expr(f); + requires_group |= group.0; + field(e, filters) + }); + (RequiresGroup(requires_group), filter) } match filter { And(fs) => get_repeated_filter(fs, &SimpleExpr::and), Or(fs) => get_repeated_filter(fs, &SimpleExpr::or), - Not(f) => Expr::not(Expr::expr(get_filter_expr(*f))), - Equality(s1, s2) => Expr::expr(Expr::cust(&s1)).eq(s2), + Not(f) => { + let (requires_group, filters) = get_filter_expr(*f); + (requires_group, Expr::not(Expr::expr(filters))) + } + Equality(s1, s2) => ( + RequiresGroup(false), + if s1 == Users::DisplayName.to_string() { + Expr::col((Users::Table, Users::DisplayName)).eq(s2) + } else if s1 == Users::UserId.to_string() { + Expr::col((Users::Table, Users::UserId)).eq(s2) + } else { + Expr::expr(Expr::cust(&s1)).eq(s2) + }, + ), + MemberOf(group) => ( + RequiresGroup(true), + Expr::col((Groups::Table, Groups::DisplayName)).eq(group), + ), } } @@ -44,21 +73,38 @@ impl BackendHandler for SqlBackendHandler { async fn list_users(&self, filters: Option) -> Result> { let query = { let mut query_builder = Query::select() - .column(Users::UserId) + .column((Users::Table, Users::UserId)) .column(Users::Email) - .column(Users::DisplayName) + .column((Users::Table, Users::DisplayName)) .column(Users::FirstName) .column(Users::LastName) .column(Users::Avatar) .column(Users::CreationDate) .from(Users::Table) - .order_by(Users::UserId, Order::Asc) + .order_by((Users::Table, Users::UserId), Order::Asc) .to_owned(); if let Some(filter) = filters { + if filter == RequestFilter::Not(Box::new(RequestFilter::And(Vec::new()))) { + return Ok(Vec::new()); + } if filter != RequestFilter::And(Vec::new()) && filter != RequestFilter::Or(Vec::new()) { - query_builder.and_where(get_filter_expr(filter)); + let (RequiresGroup(requires_group), condition) = get_filter_expr(filter); + query_builder.and_where(condition); + if requires_group { + query_builder + .left_join( + Memberships::Table, + Expr::tbl(Users::Table, Users::UserId) + .equals(Memberships::Table, Memberships::UserId), + ) + .left_join( + Groups::Table, + Expr::tbl(Memberships::Table, Memberships::GroupId) + .equals(Groups::Table, Groups::GroupId), + ); + } } } diff --git a/server/src/infra/ldap_handler.rs b/server/src/infra/ldap_handler.rs index 6ac8608..e862da9 100644 --- a/server/src/infra/ldap_handler.rs +++ b/server/src/infra/ldap_handler.rs @@ -29,6 +29,31 @@ fn parse_distinguished_name(dn: &str) -> Result> { .collect() } +fn get_group_id_from_distinguished_name( + dn: &str, + base_tree: &[(String, String)], + base_dn_str: &str, +) -> Result { + let parts = parse_distinguished_name(dn)?; + if !is_subtree(&parts, base_tree) { + bail!("Not a subtree of the base tree"); + } + if parts.len() == base_tree.len() + 2 { + if parts[1].0 != "ou" || parts[1].1 != "groups" || parts[0].0 != "cn" { + bail!( + r#"Unexpected user DN format. Expected: "cn=groupname,ou=groups,{}""#, + base_dn_str + ); + } + Ok(parts[0].1.to_string()) + } else { + bail!( + r#"Unexpected user DN format. Expected: "cn=groupname,ou=groups,{}""#, + base_dn_str + ); + } +} + fn get_user_id_from_distinguished_name( dn: &str, base_tree: &[(String, String)], @@ -128,22 +153,6 @@ fn map_field(field: &str) -> Result { }) } -fn convert_filter(filter: &LdapFilter) -> Result { - match filter { - LdapFilter::And(filters) => Ok(RequestFilter::And( - filters.iter().map(convert_filter).collect::>()?, - )), - LdapFilter::Or(filters) => Ok(RequestFilter::Or( - filters.iter().map(convert_filter).collect::>()?, - )), - LdapFilter::Not(filter) => Ok(RequestFilter::Not(Box::new(convert_filter(&*filter)?))), - LdapFilter::Equality(field, value) => { - Ok(RequestFilter::Equality(map_field(field)?, value.clone())) - } - _ => bail!("Unsupported filter"), - } -} - pub struct LdapHandler { dn: String, backend_handler: Backend, @@ -214,12 +223,12 @@ impl LdapHandler { // Search path is not in our tree, just return an empty success. return vec![lsr.gen_success()]; } - let filters = match convert_filter(&lsr.filter) { + let filters = match self.convert_filter(&lsr.filter) { Ok(f) => Some(f), - Err(_) => { + Err(e) => { return vec![lsr.gen_error( LdapResultCode::UnwillingToPerform, - "Unsupported filter".to_string(), + format!("Unsupported filter: {}", e), )] } }; @@ -263,6 +272,47 @@ impl LdapHandler { }; Some(result) } + + fn convert_filter(&self, filter: &LdapFilter) -> Result { + match filter { + LdapFilter::And(filters) => Ok(RequestFilter::And( + filters + .iter() + .map(|f| self.convert_filter(f)) + .collect::>()?, + )), + LdapFilter::Or(filters) => Ok(RequestFilter::Or( + filters + .iter() + .map(|f| self.convert_filter(f)) + .collect::>()?, + )), + LdapFilter::Not(filter) => { + Ok(RequestFilter::Not(Box::new(self.convert_filter(&*filter)?))) + } + LdapFilter::Equality(field, value) => { + if field == "memberOf" { + let group_name = get_group_id_from_distinguished_name( + value, + &self.base_dn, + &self.base_dn_str, + )?; + Ok(RequestFilter::MemberOf(group_name)) + } else { + Ok(RequestFilter::Equality(map_field(field)?, value.clone())) + } + } + LdapFilter::Present(field) => { + // Check that it's a field we support. + if field == "objectclass" || map_field(field).is_ok() { + Ok(RequestFilter::And(Vec::new())) + } else { + Ok(RequestFilter::Not(Box::new(RequestFilter::And(Vec::new())))) + } + } + _ => bail!("Unsupported filter: {:?}", filter), + } + } } #[cfg(test)]