diff --git a/Cargo.toml b/Cargo.toml index 2d7bea5..980d206 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,7 @@ tracing = "*" tracing-actix-web = "0.3.0-beta.2" tracing-log = "*" tracing-subscriber = "*" +async-trait = "0.1.48" [dependencies.figment] features = ["toml", "env"] diff --git a/src/domain/handler.rs b/src/domain/handler.rs index 12d6a1a..ed463f4 100644 --- a/src/domain/handler.rs +++ b/src/domain/handler.rs @@ -1,6 +1,8 @@ use crate::infra::configuration::Configuration; use anyhow::{bail, Result}; +use async_trait::async_trait; use sqlx::any::AnyPool; +use sqlx::Row; #[cfg_attr(test, derive(PartialEq, Eq, Debug))] pub struct BindRequest { @@ -24,9 +26,10 @@ pub struct User { pub creation_date: chrono::NaiveDateTime, } +#[async_trait] pub trait BackendHandler: Clone + Send { - fn bind(&mut self, request: BindRequest) -> Result<()>; - fn list_users(&mut self, request: ListUsersRequest) -> Result>; + async fn bind(&mut self, request: BindRequest) -> Result<()>; + async fn list_users(&mut self, request: ListUsersRequest) -> Result>; } #[derive(Debug, Clone)] @@ -46,19 +49,34 @@ impl SqlBackendHandler { } } +fn passwords_match(encrypted_password: &str, clear_password: &str) -> bool { + encrypted_password == clear_password +} + +#[async_trait] impl BackendHandler for SqlBackendHandler { - fn bind(&mut self, request: BindRequest) -> Result<()> { - if request.name == self.config.ldap_user_dn - && request.password == self.config.ldap_user_pass - { - self.authenticated = true; - Ok(()) - } else { - bail!(r#"Authentication error for "{}""#, request.name) + async fn bind(&mut self, request: BindRequest) -> Result<()> { + if request.name == self.config.ldap_user_dn { + if request.password == self.config.ldap_user_pass { + self.authenticated = true; + return Ok(()); + } else { + bail!(r#"Authentication error for "{}""#, request.name) + } } + if let Ok(row) = sqlx::query("SELECT password FROM users WHERE user_id = ?") + .bind(&request.name) + .fetch_one(&self.sql_pool) + .await + { + if passwords_match(&request.password, &row.get::("password")) { + return Ok(()); + } + } + bail!(r#"Authentication error for "{}""#, request.name) } - fn list_users(&mut self, request: ListUsersRequest) -> Result> { + async fn list_users(&mut self, request: ListUsersRequest) -> Result> { Ok(Vec::new()) } } @@ -69,8 +87,9 @@ mockall::mock! { impl Clone for TestBackendHandler { fn clone(&self) -> Self; } + #[async_trait] impl BackendHandler for TestBackendHandler { - fn bind(&mut self, request: BindRequest) -> Result<()>; - fn list_users(&mut self, request: ListUsersRequest) -> Result>; + async fn bind(&mut self, request: BindRequest) -> Result<()>; + async fn list_users(&mut self, request: ListUsersRequest) -> Result>; } } diff --git a/src/infra/ldap_handler.rs b/src/infra/ldap_handler.rs index ba7d485..ccef962 100644 --- a/src/infra/ldap_handler.rs +++ b/src/infra/ldap_handler.rs @@ -82,10 +82,11 @@ pub struct LdapHandler { backend_handler: Backend, pub base_dn: Vec<(String, String)>, base_dn_str: String, + ldap_user_dn: String, } impl LdapHandler { - pub fn new(backend_handler: Backend, ldap_base_dn: String) -> Self { + pub fn new(backend_handler: Backend, ldap_base_dn: String, ldap_user_dn: String) -> Self { Self { dn: "Unauthenticated".to_string(), backend_handler, @@ -96,16 +97,19 @@ impl LdapHandler { ) }), base_dn_str: ldap_base_dn, + ldap_user_dn, } } - pub fn do_bind(&mut self, sbr: &SimpleBindRequest) -> LdapMsg { + pub async fn do_bind(&mut self, sbr: &SimpleBindRequest) -> LdapMsg { match self .backend_handler .bind(crate::domain::handler::BindRequest { name: sbr.dn.clone(), password: sbr.pw.clone(), - }) { + }) + .await + { Ok(()) => { self.dn = sbr.dn.clone(); sbr.gen_success() @@ -114,7 +118,13 @@ impl LdapHandler { } } - pub fn do_search(&mut self, lsr: &SearchRequest) -> Vec { + pub async fn do_search(&mut self, lsr: &SearchRequest) -> Vec { + if self.dn != self.ldap_user_dn { + return vec![lsr.gen_error( + LdapResultCode::InsufficentAccessRights, + r#"Current user is not allowed to query LDAP"#.to_string(), + )]; + } let dn_parts = match parse_distinguished_name(&lsr.base) { Ok(dn) => dn, Err(_) => { @@ -128,7 +138,7 @@ impl LdapHandler { // Search path is not in our tree, just return an empty success. return vec![lsr.gen_success()]; } - let users = match self.backend_handler.list_users(ListUsersRequest {}) { + let users = match self.backend_handler.list_users(ListUsersRequest {}).await { Ok(users) => users, Err(e) => { return vec![lsr.gen_error( @@ -156,10 +166,10 @@ impl LdapHandler { } } - pub fn handle_ldap_message(&mut self, server_op: ServerOps) -> Option> { + pub async fn handle_ldap_message(&mut self, server_op: ServerOps) -> Option> { let result = match server_op { - ServerOps::SimpleBind(sbr) => vec![self.do_bind(&sbr)], - ServerOps::Search(sr) => self.do_search(&sr), + ServerOps::SimpleBind(sbr) => vec![self.do_bind(&sbr).await], + ServerOps::Search(sr) => self.do_search(&sr).await, ServerOps::Unbind(_) => { // No need to notify on unbind (per rfc4511) return None; @@ -176,9 +186,10 @@ mod tests { use crate::domain::handler::MockTestBackendHandler; use chrono::NaiveDateTime; use mockall::predicate::eq; + use tokio; - #[test] - fn test_bind() { + #[tokio::test] + async fn test_bind() { let mut mock = MockTestBackendHandler::new(); mock.expect_bind() .with(eq(crate::domain::handler::BindRequest { @@ -187,7 +198,8 @@ mod tests { })) .times(1) .return_once(|_| Ok(())); - let mut ldap_handler = LdapHandler::new(mock, "dc=example,dc=com".to_string()); + let mut ldap_handler = + LdapHandler::new(mock, "dc=example,dc=com".to_string(), "test".to_string()); let request = WhoamiRequest { msgid: 1 }; assert_eq!( @@ -200,7 +212,7 @@ mod tests { dn: "test".to_string(), pw: "pass".to_string(), }; - assert_eq!(ldap_handler.do_bind(&request), request.gen_success()); + assert_eq!(ldap_handler.do_bind(&request).await, request.gen_success()); let request = WhoamiRequest { msgid: 3 }; assert_eq!( @@ -209,6 +221,54 @@ mod tests { ); } + #[tokio::test] + async fn test_bind_invalid_credentials() { + let mut mock = MockTestBackendHandler::new(); + mock.expect_bind() + .with(eq(crate::domain::handler::BindRequest { + name: "test".to_string(), + password: "pass".to_string(), + })) + .times(1) + .return_once(|_| Ok(())); + let mut ldap_handler = + LdapHandler::new(mock, "dc=example,dc=com".to_string(), "admin".to_string()); + + let request = WhoamiRequest { msgid: 1 }; + assert_eq!( + ldap_handler.do_whoami(&request), + request.gen_operror("Unauthenticated") + ); + + let request = SimpleBindRequest { + msgid: 2, + dn: "test".to_string(), + pw: "pass".to_string(), + }; + assert_eq!(ldap_handler.do_bind(&request).await, request.gen_success()); + + let request = WhoamiRequest { msgid: 3 }; + assert_eq!( + ldap_handler.do_whoami(&request), + request.gen_success("dn: test") + ); + + let request = SearchRequest { + msgid: 2, + base: "ou=people,dc=example,dc=com".to_string(), + scope: LdapSearchScope::Base, + filter: LdapFilter::And(vec![]), + attrs: vec![], + }; + assert_eq!( + ldap_handler.do_search(&request).await, + vec![request.gen_error( + LdapResultCode::InsufficentAccessRights, + r#"Current user is not allowed to query LDAP"#.to_string() + )] + ); + } + #[test] fn test_is_subtree() { let subtree1 = &[ @@ -237,8 +297,8 @@ mod tests { ); } - #[test] - fn test_search() { + #[tokio::test] + async fn test_search() { let mut mock = MockTestBackendHandler::new(); mock.expect_bind().return_once(|_| Ok(())); mock.expect_list_users() @@ -264,13 +324,14 @@ mod tests { }, ]) }); - let mut ldap_handler = LdapHandler::new(mock, "dc=example,dc=com".to_string()); + let mut ldap_handler = + LdapHandler::new(mock, "dc=example,dc=com".to_string(), "test".to_string()); let request = SimpleBindRequest { msgid: 1, dn: "test".to_string(), pw: "pass".to_string(), }; - assert_eq!(ldap_handler.do_bind(&request), request.gen_success()); + assert_eq!(ldap_handler.do_bind(&request).await, request.gen_success()); let request = SearchRequest { msgid: 2, base: "ou=people,dc=example,dc=com".to_string(), @@ -286,7 +347,7 @@ mod tests { ], }; assert_eq!( - ldap_handler.do_search(&request), + ldap_handler.do_search(&request).await, vec![ request.gen_result_entry(LdapSearchResultEntry { dn: "cn=bob_1,dc=example,dc=com".to_string(), diff --git a/src/infra/ldap_server.rs b/src/infra/ldap_server.rs index f8504a2..c7f3678 100644 --- a/src/infra/ldap_server.rs +++ b/src/infra/ldap_server.rs @@ -34,7 +34,7 @@ async fn handle_incoming_message( } }; - match session.handle_ldap_message(server_op) { + match session.handle_ldap_message(server_op).await { None => return Ok(false), Some(result) => { for rmsg in result.into_iter() { @@ -62,20 +62,23 @@ where use futures_util::StreamExt; let ldap_base_dn = config.ldap_base_dn.clone(); + let ldap_user_dn = config.ldap_user_dn.clone(); Ok( server_builder.bind("ldap", ("0.0.0.0", config.ldap_port), move || { let backend_handler = backend_handler.clone(); let ldap_base_dn = ldap_base_dn.clone(); + let ldap_user_dn = ldap_user_dn.clone(); pipeline_factory(fn_service(move |mut stream: TcpStream| { let backend_handler = backend_handler.clone(); let ldap_base_dn = ldap_base_dn.clone(); + let ldap_user_dn = ldap_user_dn.clone(); async move { // Configure the codec etc. let (r, w) = stream.split(); let mut requests = FramedRead::new(r, LdapCodec); let mut resp = FramedWrite::new(w, LdapCodec); - let mut session = LdapHandler::new(backend_handler, ldap_base_dn); + let mut session = LdapHandler::new(backend_handler, ldap_base_dn, ldap_user_dn); while let Some(msg) = requests.next().await { if !handle_incoming_message(msg, &mut resp, &mut session).await? {