From ce034fbc747c44f88d02fef541354850564ab7de Mon Sep 17 00:00:00 2001
From: Valentin Tolmer <valentin@tolmer.fr>
Date: Fri, 24 Sep 2021 22:27:07 +0200
Subject: [PATCH] ldap: Add support for memberOf and wildcards

---
 server/src/domain/handler.rs             |  2 +
 server/src/domain/sql_backend_handler.rs | 68 +++++++++++++++---
 server/src/infra/ldap_handler.rs         | 88 +++++++++++++++++++-----
 3 files changed, 128 insertions(+), 30 deletions(-)

diff --git a/server/src/domain/handler.rs b/server/src/domain/handler.rs
index 1eff98f..2f1f198 100644
--- a/server/src/domain/handler.rs
+++ b/server/src/domain/handler.rs
@@ -48,6 +48,8 @@ pub enum RequestFilter {
     Or(Vec<RequestFilter>),
     Not(Box<RequestFilter>),
     Equality(String, String),
+    // Check if a user belongs to a group.
+    MemberOf(String),
 }
 
 #[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Clone, Default)]
diff --git a/server/src/domain/sql_backend_handler.rs b/server/src/domain/sql_backend_handler.rs
index a73998a..b398546 100644
--- a/server/src/domain/sql_backend_handler.rs
+++ b/server/src/domain/sql_backend_handler.rs
@@ -18,24 +18,53 @@ impl SqlBackendHandler {
     }
 }
 
-fn get_filter_expr(filter: RequestFilter) -> SimpleExpr {
+struct RequiresGroup(bool);
+
+// Returns the condition for the SQL query, and whether it requires joining with the groups table.
+fn get_filter_expr(filter: RequestFilter) -> (RequiresGroup, SimpleExpr) {
     use RequestFilter::*;
     fn get_repeated_filter(
         fs: Vec<RequestFilter>,
         field: &dyn Fn(SimpleExpr, SimpleExpr) -> SimpleExpr,
-    ) -> SimpleExpr {
+    ) -> (RequiresGroup, SimpleExpr) {
+        let mut requires_group = false;
         let mut it = fs.into_iter();
         let first_expr = match it.next() {
-            None => return Expr::value(true),
-            Some(f) => get_filter_expr(f),
+            None => return (RequiresGroup(false), Expr::value(true)),
+            Some(f) => {
+                let (group, filter) = get_filter_expr(f);
+                requires_group |= group.0;
+                filter
+            }
         };
-        it.fold(first_expr, |e, f| field(e, get_filter_expr(f)))
+        let filter = it.fold(first_expr, |e, f| {
+            let (group, filters) = get_filter_expr(f);
+            requires_group |= group.0;
+            field(e, filters)
+        });
+        (RequiresGroup(requires_group), filter)
     }
     match filter {
         And(fs) => get_repeated_filter(fs, &SimpleExpr::and),
         Or(fs) => get_repeated_filter(fs, &SimpleExpr::or),
-        Not(f) => Expr::not(Expr::expr(get_filter_expr(*f))),
-        Equality(s1, s2) => Expr::expr(Expr::cust(&s1)).eq(s2),
+        Not(f) => {
+            let (requires_group, filters) = get_filter_expr(*f);
+            (requires_group, Expr::not(Expr::expr(filters)))
+        }
+        Equality(s1, s2) => (
+            RequiresGroup(false),
+            if s1 == Users::DisplayName.to_string() {
+                Expr::col((Users::Table, Users::DisplayName)).eq(s2)
+            } else if s1 == Users::UserId.to_string() {
+                Expr::col((Users::Table, Users::UserId)).eq(s2)
+            } else {
+                Expr::expr(Expr::cust(&s1)).eq(s2)
+            },
+        ),
+        MemberOf(group) => (
+            RequiresGroup(true),
+            Expr::col((Groups::Table, Groups::DisplayName)).eq(group),
+        ),
     }
 }
 
@@ -44,21 +73,38 @@ impl BackendHandler for SqlBackendHandler {
     async fn list_users(&self, filters: Option<RequestFilter>) -> Result<Vec<User>> {
         let query = {
             let mut query_builder = Query::select()
-                .column(Users::UserId)
+                .column((Users::Table, Users::UserId))
                 .column(Users::Email)
-                .column(Users::DisplayName)
+                .column((Users::Table, Users::DisplayName))
                 .column(Users::FirstName)
                 .column(Users::LastName)
                 .column(Users::Avatar)
                 .column(Users::CreationDate)
                 .from(Users::Table)
-                .order_by(Users::UserId, Order::Asc)
+                .order_by((Users::Table, Users::UserId), Order::Asc)
                 .to_owned();
             if let Some(filter) = filters {
+                if filter == RequestFilter::Not(Box::new(RequestFilter::And(Vec::new()))) {
+                    return Ok(Vec::new());
+                }
                 if filter != RequestFilter::And(Vec::new())
                     && filter != RequestFilter::Or(Vec::new())
                 {
-                    query_builder.and_where(get_filter_expr(filter));
+                    let (RequiresGroup(requires_group), condition) = get_filter_expr(filter);
+                    query_builder.and_where(condition);
+                    if requires_group {
+                        query_builder
+                            .left_join(
+                                Memberships::Table,
+                                Expr::tbl(Users::Table, Users::UserId)
+                                    .equals(Memberships::Table, Memberships::UserId),
+                            )
+                            .left_join(
+                                Groups::Table,
+                                Expr::tbl(Memberships::Table, Memberships::GroupId)
+                                    .equals(Groups::Table, Groups::GroupId),
+                            );
+                    }
                 }
             }
 
diff --git a/server/src/infra/ldap_handler.rs b/server/src/infra/ldap_handler.rs
index 6ac8608..e862da9 100644
--- a/server/src/infra/ldap_handler.rs
+++ b/server/src/infra/ldap_handler.rs
@@ -29,6 +29,31 @@ fn parse_distinguished_name(dn: &str) -> Result<Vec<(String, String)>> {
         .collect()
 }
 
+fn get_group_id_from_distinguished_name(
+    dn: &str,
+    base_tree: &[(String, String)],
+    base_dn_str: &str,
+) -> Result<String> {
+    let parts = parse_distinguished_name(dn)?;
+    if !is_subtree(&parts, base_tree) {
+        bail!("Not a subtree of the base tree");
+    }
+    if parts.len() == base_tree.len() + 2 {
+        if parts[1].0 != "ou" || parts[1].1 != "groups" || parts[0].0 != "cn" {
+            bail!(
+                r#"Unexpected user DN format. Expected: "cn=groupname,ou=groups,{}""#,
+                base_dn_str
+            );
+        }
+        Ok(parts[0].1.to_string())
+    } else {
+        bail!(
+            r#"Unexpected user DN format. Expected: "cn=groupname,ou=groups,{}""#,
+            base_dn_str
+        );
+    }
+}
+
 fn get_user_id_from_distinguished_name(
     dn: &str,
     base_tree: &[(String, String)],
@@ -128,22 +153,6 @@ fn map_field(field: &str) -> Result<String> {
     })
 }
 
-fn convert_filter(filter: &LdapFilter) -> Result<RequestFilter> {
-    match filter {
-        LdapFilter::And(filters) => Ok(RequestFilter::And(
-            filters.iter().map(convert_filter).collect::<Result<_>>()?,
-        )),
-        LdapFilter::Or(filters) => Ok(RequestFilter::Or(
-            filters.iter().map(convert_filter).collect::<Result<_>>()?,
-        )),
-        LdapFilter::Not(filter) => Ok(RequestFilter::Not(Box::new(convert_filter(&*filter)?))),
-        LdapFilter::Equality(field, value) => {
-            Ok(RequestFilter::Equality(map_field(field)?, value.clone()))
-        }
-        _ => bail!("Unsupported filter"),
-    }
-}
-
 pub struct LdapHandler<Backend: BackendHandler + LoginHandler> {
     dn: String,
     backend_handler: Backend,
@@ -214,12 +223,12 @@ impl<Backend: BackendHandler + LoginHandler> LdapHandler<Backend> {
             // Search path is not in our tree, just return an empty success.
             return vec![lsr.gen_success()];
         }
-        let filters = match convert_filter(&lsr.filter) {
+        let filters = match self.convert_filter(&lsr.filter) {
             Ok(f) => Some(f),
-            Err(_) => {
+            Err(e) => {
                 return vec![lsr.gen_error(
                     LdapResultCode::UnwillingToPerform,
-                    "Unsupported filter".to_string(),
+                    format!("Unsupported filter: {}", e),
                 )]
             }
         };
@@ -263,6 +272,47 @@ impl<Backend: BackendHandler + LoginHandler> LdapHandler<Backend> {
         };
         Some(result)
     }
+
+    fn convert_filter(&self, filter: &LdapFilter) -> Result<RequestFilter> {
+        match filter {
+            LdapFilter::And(filters) => Ok(RequestFilter::And(
+                filters
+                    .iter()
+                    .map(|f| self.convert_filter(f))
+                    .collect::<Result<_>>()?,
+            )),
+            LdapFilter::Or(filters) => Ok(RequestFilter::Or(
+                filters
+                    .iter()
+                    .map(|f| self.convert_filter(f))
+                    .collect::<Result<_>>()?,
+            )),
+            LdapFilter::Not(filter) => {
+                Ok(RequestFilter::Not(Box::new(self.convert_filter(&*filter)?)))
+            }
+            LdapFilter::Equality(field, value) => {
+                if field == "memberOf" {
+                    let group_name = get_group_id_from_distinguished_name(
+                        value,
+                        &self.base_dn,
+                        &self.base_dn_str,
+                    )?;
+                    Ok(RequestFilter::MemberOf(group_name))
+                } else {
+                    Ok(RequestFilter::Equality(map_field(field)?, value.clone()))
+                }
+            }
+            LdapFilter::Present(field) => {
+                // Check that it's a field we support.
+                if field == "objectclass" || map_field(field).is_ok() {
+                    Ok(RequestFilter::And(Vec::new()))
+                } else {
+                    Ok(RequestFilter::Not(Box::new(RequestFilter::And(Vec::new()))))
+                }
+            }
+            _ => bail!("Unsupported filter: {:?}", filter),
+        }
+    }
 }
 
 #[cfg(test)]