From 31e8998ac326adf9f24d64a2f855c4b5a87006ca Mon Sep 17 00:00:00 2001 From: Valentin Tolmer Date: Mon, 22 Mar 2021 09:59:58 +0100 Subject: [PATCH] Add attribute list handling Also, fix various clippy warnings --- src/domain/handler.rs | 1 - src/infra/ldap_handler.rs | 140 +++++++++++++++++++++----------------- src/infra/ldap_server.rs | 5 +- src/infra/tcp_server.rs | 4 +- 4 files changed, 81 insertions(+), 69 deletions(-) diff --git a/src/domain/handler.rs b/src/domain/handler.rs index 5442389..12d6a1a 100644 --- a/src/domain/handler.rs +++ b/src/domain/handler.rs @@ -11,7 +11,6 @@ pub struct BindRequest { #[cfg_attr(test, derive(PartialEq, Eq, Debug))] pub struct ListUsersRequest { // filters - pub attrs: Vec, } #[cfg_attr(test, derive(PartialEq, Eq, Debug))] diff --git a/src/infra/ldap_handler.rs b/src/infra/ldap_handler.rs index 26ff7e5..ba7d485 100644 --- a/src/infra/ldap_handler.rs +++ b/src/infra/ldap_handler.rs @@ -2,17 +2,15 @@ use crate::domain::handler::{BackendHandler, ListUsersRequest, User}; use anyhow::{bail, Result}; use ldap3_server::simple::*; -fn make_dn_pair<'a, I>(mut iter: I) -> Result<(String, String)> +fn make_dn_pair(mut iter: I) -> Result<(String, String)> where I: Iterator, { let pair = ( iter.next() - .ok_or(anyhow::Error::msg("Empty DN element"))? - .clone(), + .ok_or_else(|| anyhow::Error::msg("Empty DN element"))?, iter.next() - .ok_or(anyhow::Error::msg("Missing DN value"))? - .clone(), + .ok_or_else(|| anyhow::Error::msg("Missing DN value"))?, ); if let Some(e) = iter.next() { bail!( @@ -26,48 +24,47 @@ where } fn parse_distinguished_name(dn: &str) -> Result> { - dn.split(",") - .map(|s| make_dn_pair(s.split("=").map(String::from))) + dn.split(',') + .map(|s| make_dn_pair(s.split('=').map(String::from))) .collect() } -fn make_ldap_search_result_entry(user: User, base_dn_str: &str) -> LdapSearchResultEntry { - LdapSearchResultEntry { - dn: format!("cn={},{}", user.user_id, base_dn_str), - attributes: vec![ - LdapPartialAttribute { - atype: "objectClass".to_string(), - vals: vec![ - "inetOrgPerson".to_string(), - "posixAccount".to_string(), - "mailAccount".to_string(), - ], - }, - LdapPartialAttribute { - atype: "uid".to_string(), - vals: vec![user.user_id], - }, - LdapPartialAttribute { - atype: "mail".to_string(), - vals: vec![user.email], - }, - LdapPartialAttribute { - atype: "givenName".to_string(), - vals: vec![user.first_name], - }, - LdapPartialAttribute { - atype: "sn".to_string(), - vals: vec![user.last_name], - }, - LdapPartialAttribute { - atype: "cn".to_string(), - vals: vec![user.display_name], - }, - ], +fn get_attribute(user: &User, attribute: &str) -> Result> { + match attribute { + "objectClass" => Ok(vec![ + "inetOrgPerson".to_string(), + "posixAccount".to_string(), + "mailAccount".to_string(), + ]), + "uid" => Ok(vec![user.user_id.to_string()]), + "mail" => Ok(vec![user.email.to_string()]), + "givenName" => Ok(vec![user.first_name.to_string()]), + "sn" => Ok(vec![user.last_name.to_string()]), + "cn" => Ok(vec![user.display_name.to_string()]), + _ => bail!("Unsupported attribute: {}", attribute), } } -fn is_subtree(subtree: &Vec<(String, String)>, base_tree: &Vec<(String, String)>) -> bool { +fn make_ldap_search_result_entry( + user: User, + base_dn_str: &str, + attributes: &[String], +) -> Result { + Ok(LdapSearchResultEntry { + dn: format!("cn={},{}", user.user_id, base_dn_str), + attributes: attributes + .iter() + .map(|a| { + Ok(LdapPartialAttribute { + atype: a.to_string(), + vals: get_attribute(&user, a)?, + }) + }) + .collect::>>()?, + }) +} + +fn is_subtree(subtree: &[(String, String)], base_tree: &[(String, String)]) -> bool { if subtree.len() < base_tree.len() { return false; } @@ -92,10 +89,12 @@ impl LdapHandler { Self { dn: "Unauthenticated".to_string(), backend_handler, - base_dn: parse_distinguished_name(&ldap_base_dn).expect(&format!( - "Invalid value for ldap_base_dn in configuration: {}", - ldap_base_dn - )), + base_dn: parse_distinguished_name(&ldap_base_dn).unwrap_or_else(|_| { + panic!( + "Invalid value for ldap_base_dn in configuration: {}", + ldap_base_dn + ) + }), base_dn_str: ldap_base_dn, } } @@ -126,11 +125,10 @@ impl LdapHandler { } }; if !is_subtree(&dn_parts, &self.base_dn) { + // Search path is not in our tree, just return an empty success. return vec![lsr.gen_success()]; } - let users = match self.backend_handler.list_users(ListUsersRequest { - attrs: lsr.attrs.clone(), - }) { + let users = match self.backend_handler.list_users(ListUsersRequest {}) { Ok(users) => users, Err(e) => { return vec![lsr.gen_error( @@ -139,12 +137,15 @@ impl LdapHandler { )] } }; - let mut res = users + + users .into_iter() - .map(|u| lsr.gen_result_entry(make_ldap_search_result_entry(u, &self.base_dn_str))) - .collect::>(); - res.push(lsr.gen_success()); - res + .map(|u| make_ldap_search_result_entry(u, &self.base_dn_str, &lsr.attrs)) + .map(|entry| Ok(lsr.gen_result_entry(entry?))) + // If the processing succeeds, add a success message at the end. + .chain(std::iter::once(Ok(lsr.gen_success()))) + .collect::>>() + .unwrap_or_else(|e| vec![lsr.gen_error(LdapResultCode::NoSuchAttribute, e.to_string())]) } pub fn do_whoami(&mut self, wr: &WhoamiRequest) -> LdapMsg { @@ -210,22 +211,22 @@ mod tests { #[test] fn test_is_subtree() { - let subtree1 = vec![ + let subtree1 = &[ ("ou".to_string(), "people".to_string()), ("dc".to_string(), "example".to_string()), ("dc".to_string(), "com".to_string()), ]; - let root = vec![ + let root = &[ ("dc".to_string(), "example".to_string()), ("dc".to_string(), "com".to_string()), ]; - assert!(is_subtree(&subtree1, &root)); - assert!(!is_subtree(&vec![], &root)); + assert!(is_subtree(subtree1, root)); + assert!(!is_subtree(&[], root)); } #[test] fn test_parse_distinguished_name() { - let parsed_dn = vec![ + let parsed_dn = &[ ("ou".to_string(), "people".to_string()), ("dc".to_string(), "example".to_string()), ("dc".to_string(), "com".to_string()), @@ -241,7 +242,7 @@ mod tests { let mut mock = MockTestBackendHandler::new(); mock.expect_bind().return_once(|_| Ok(())); mock.expect_list_users() - .with(eq(ListUsersRequest { attrs: vec![] })) + .with(eq(ListUsersRequest {})) .times(1) .return_once(|_| { Ok(vec![ @@ -275,7 +276,14 @@ mod tests { base: "ou=people,dc=example,dc=com".to_string(), scope: LdapSearchScope::Base, filter: LdapFilter::And(vec![]), - attrs: vec![], + attrs: vec![ + "objectClass".to_string(), + "uid".to_string(), + "mail".to_string(), + "givenName".to_string(), + "sn".to_string(), + "cn".to_string(), + ], }; assert_eq!( ldap_handler.do_search(&request), @@ -285,7 +293,11 @@ mod tests { attributes: vec![ LdapPartialAttribute { atype: "objectClass".to_string(), - vals: vec!["inetOrgPerson".to_string(), "posixAccount".to_string(), "mailAccount".to_string()] + vals: vec![ + "inetOrgPerson".to_string(), + "posixAccount".to_string(), + "mailAccount".to_string() + ] }, LdapPartialAttribute { atype: "uid".to_string(), @@ -314,7 +326,11 @@ mod tests { attributes: vec![ LdapPartialAttribute { atype: "objectClass".to_string(), - vals: vec!["inetOrgPerson".to_string(), "posixAccount".to_string(), "mailAccount".to_string()] + vals: vec![ + "inetOrgPerson".to_string(), + "posixAccount".to_string(), + "mailAccount".to_string() + ] }, LdapPartialAttribute { atype: "uid".to_string(), diff --git a/src/infra/ldap_server.rs b/src/infra/ldap_server.rs index 6560e04..f8504a2 100644 --- a/src/infra/ldap_server.rs +++ b/src/infra/ldap_server.rs @@ -20,10 +20,7 @@ async fn handle_incoming_message( ) -> Result { use futures_util::SinkExt; use std::convert::TryFrom; - let server_op = match msg - .map_err(|_e| ()) - .and_then(|msg| ServerOps::try_from(msg)) - { + let server_op = match msg.map_err(|_e| ()).and_then(ServerOps::try_from) { Ok(a_value) => a_value, Err(an_error) => { let _err = resp diff --git a/src/infra/tcp_server.rs b/src/infra/tcp_server.rs index 4fabfdd..8289e9e 100644 --- a/src/infra/tcp_server.rs +++ b/src/infra/tcp_server.rs @@ -23,7 +23,7 @@ where let count = Arc::new(AtomicUsize::new(0)); - Ok(server_builder + server_builder .bind("http", ("0.0.0.0", config.http_port), move || { let count = Arc::clone(&count); let num2 = Arc::clone(&count); @@ -73,5 +73,5 @@ where "While bringing up the TCP server with port {}", config.http_port ) - })?) + }) }