diff --git a/src/domain/handler.rs b/src/domain/handler.rs index 3dd676b..877e6f0 100644 --- a/src/domain/handler.rs +++ b/src/domain/handler.rs @@ -160,6 +160,11 @@ impl BackendHandler for SqlBackendHandler { } async fn get_user_groups(&self, user: String) -> Result> { + if user == self.config.ldap_user_dn { + let mut groups = HashSet::new(); + groups.insert("lldap_admin".to_string()); + return Ok(groups); + } let query: String = Query::select() .column(Groups::DisplayName) .from(Groups::Table) diff --git a/src/infra/configuration.rs b/src/infra/configuration.rs index 1ca0f9f..04b60cf 100644 --- a/src/infra/configuration.rs +++ b/src/infra/configuration.rs @@ -30,7 +30,8 @@ impl Default for Configuration { secret_pepper: String::from("secretsecretpepper"), jwt_secret: String::from("secretjwtsecret"), ldap_base_dn: String::from("dc=example,dc=com"), - ldap_user_dn: String::from("cn=admin,dc=example,dc=com"), + // cn=admin,dc=example,dc=com + ldap_user_dn: String::from("admin"), ldap_user_pass: String::from("password"), database_url: String::from("sqlite://users.db?mode=rwc"), verbose: false, diff --git a/src/infra/ldap_handler.rs b/src/infra/ldap_handler.rs index 47cd13f..c45648a 100644 --- a/src/infra/ldap_handler.rs +++ b/src/infra/ldap_handler.rs @@ -29,6 +29,37 @@ fn parse_distinguished_name(dn: &str) -> Result> { .collect() } +fn get_user_id_from_distinguished_name( + dn: &str, + base_tree: &[(String, String)], + base_dn_str: &str, + ldap_user_dn: &str, +) -> Result { + let parts = parse_distinguished_name(dn)?; + if !is_subtree(&parts, base_tree) { + bail!("Not a subtree of the base tree"); + } + if parts.len() == base_tree.len() + 1 { + if dn != ldap_user_dn { + bail!(r#"Wrong admin DN. Expected: "{}""#, ldap_user_dn); + } + Ok(parts[0].1.to_string()) + } else if parts.len() == base_tree.len() + 2 { + if parts[1].0 != "ou" || parts[1].1 != "people" || parts[0].0 != "cn" { + bail!( + r#"Unexpected user DN format. Expected: "cn=username,ou=people,{}""#, + base_dn_str + ); + } + Ok(parts[0].1.to_string()) + } else { + bail!( + r#"Unexpected user DN format. Expected: "cn=username,ou=people,{}""#, + base_dn_str + ); + } +} + fn get_attribute(user: &User, attribute: &str) -> Result> { match attribute { "objectClass" => Ok(vec![ @@ -132,16 +163,25 @@ impl LdapHandler { ldap_base_dn ) }), + ldap_user_dn: format!("cn={},{}", ldap_user_dn, &ldap_base_dn), base_dn_str: ldap_base_dn, - ldap_user_dn, } } pub async fn do_bind(&mut self, sbr: &SimpleBindRequest) -> LdapMsg { + let user_id = match get_user_id_from_distinguished_name( + &sbr.dn, + &self.base_dn, + &self.base_dn_str, + &self.ldap_user_dn, + ) { + Ok(s) => s, + Err(e) => return sbr.gen_error(LdapResultCode::NamingViolation, e.to_string()), + }; match self .backend_handler .bind(crate::domain::handler::BindRequest { - name: sbr.dn.clone(), + name: user_id, password: sbr.pw.clone(), }) .await @@ -232,6 +272,7 @@ impl LdapHandler { #[cfg(test)] mod tests { use super::*; + use crate::domain::handler::BindRequest; use crate::domain::handler::MockTestBackendHandler; use chrono::NaiveDateTime; use mockall::predicate::eq; @@ -240,12 +281,17 @@ mod tests { async fn setup_bound_handler( mut mock: MockTestBackendHandler, ) -> LdapHandler { - mock.expect_bind().return_once(|_| Ok(())); + mock.expect_bind() + .with(eq(BindRequest { + name: "test".to_string(), + password: "pass".to_string(), + })) + .return_once(|_| Ok(())); let mut ldap_handler = LdapHandler::new(mock, "dc=example,dc=com".to_string(), "test".to_string()); let request = SimpleBindRequest { msgid: 1, - dn: "test".to_string(), + dn: "cn=test,dc=example,dc=com".to_string(), pw: "pass".to_string(), }; ldap_handler.do_bind(&request).await; @@ -254,6 +300,39 @@ mod tests { #[tokio::test] async fn test_bind() { + let mut mock = MockTestBackendHandler::new(); + mock.expect_bind() + .with(eq(crate::domain::handler::BindRequest { + name: "bob".to_string(), + password: "pass".to_string(), + })) + .times(1) + .return_once(|_| Ok(())); + let mut ldap_handler = + LdapHandler::new(mock, "dc=example,dc=com".to_string(), "test".to_string()); + + let request = WhoamiRequest { msgid: 1 }; + assert_eq!( + ldap_handler.do_whoami(&request), + request.gen_operror("Unauthenticated") + ); + + let request = SimpleBindRequest { + msgid: 2, + dn: "cn=bob,ou=people,dc=example,dc=com".to_string(), + pw: "pass".to_string(), + }; + assert_eq!(ldap_handler.do_bind(&request).await, request.gen_success()); + + let request = WhoamiRequest { msgid: 3 }; + assert_eq!( + ldap_handler.do_whoami(&request), + request.gen_success("dn: cn=bob,ou=people,dc=example,dc=com") + ); + } + + #[tokio::test] + async fn test_admin_bind() { let mut mock = MockTestBackendHandler::new(); mock.expect_bind() .with(eq(crate::domain::handler::BindRequest { @@ -273,7 +352,7 @@ mod tests { let request = SimpleBindRequest { msgid: 2, - dn: "test".to_string(), + dn: "cn=test,dc=example,dc=com".to_string(), pw: "pass".to_string(), }; assert_eq!(ldap_handler.do_bind(&request).await, request.gen_success()); @@ -281,7 +360,7 @@ mod tests { let request = WhoamiRequest { msgid: 3 }; assert_eq!( ldap_handler.do_whoami(&request), - request.gen_success("dn: test") + request.gen_success("dn: cn=test,dc=example,dc=com") ); } @@ -306,7 +385,7 @@ mod tests { let request = SimpleBindRequest { msgid: 2, - dn: "test".to_string(), + dn: "cn=test,ou=people,dc=example,dc=com".to_string(), pw: "pass".to_string(), }; assert_eq!(ldap_handler.do_bind(&request).await, request.gen_success()); @@ -314,7 +393,7 @@ mod tests { let request = WhoamiRequest { msgid: 3 }; assert_eq!( ldap_handler.do_whoami(&request), - request.gen_success("dn: test") + request.gen_success("dn: cn=test,ou=people,dc=example,dc=com") ); let request = SearchRequest { @@ -333,6 +412,39 @@ mod tests { ); } + #[tokio::test] + async fn test_bind_invalid_dn() { + let mock = MockTestBackendHandler::new(); + let mut ldap_handler = + LdapHandler::new(mock, "dc=example,dc=com".to_string(), "admin".to_string()); + + let request = SimpleBindRequest { + msgid: 2, + dn: "cn=bob,dc=example,dc=com".to_string(), + pw: "pass".to_string(), + }; + assert_eq!( + ldap_handler.do_bind(&request).await, + request.gen_error( + LdapResultCode::NamingViolation, + r#"Wrong admin DN. Expected: "cn=admin,dc=example,dc=com""#.to_string() + ) + ); + let request = SimpleBindRequest { + msgid: 2, + dn: "cn=bob,ou=groups,dc=example,dc=com".to_string(), + pw: "pass".to_string(), + }; + assert_eq!( + ldap_handler.do_bind(&request).await, + request.gen_error( + LdapResultCode::NamingViolation, + r#"Unexpected user DN format. Expected: "cn=username,ou=people,dc=example,dc=com""# + .to_string() + ) + ); + } + #[test] fn test_is_subtree() { let subtree1 = &[