From 86b89a00cc78575a8ed6bf8889653037f65111da Mon Sep 17 00:00:00 2001 From: Valentin Tolmer Date: Thu, 11 Mar 2021 10:50:15 +0100 Subject: [PATCH] Separate ldap_handler, add tests --- src/domain/handler.rs | 15 +++++ src/infra/ldap_handler.rs | 123 ++++++++++++++++++++++++++++++++++++++ src/infra/ldap_server.rs | 82 ++----------------------- src/infra/mod.rs | 1 + 4 files changed, 143 insertions(+), 78 deletions(-) create mode 100644 src/infra/ldap_handler.rs diff --git a/src/domain/handler.rs b/src/domain/handler.rs index f17618a..587c843 100644 --- a/src/domain/handler.rs +++ b/src/domain/handler.rs @@ -2,13 +2,16 @@ use crate::infra::configuration::Configuration; use anyhow::{bail, Result}; use sqlx::any::AnyPool; +#[cfg_attr(test, derive(PartialEq, Eq, Debug))] pub struct BindRequest { pub name: String, pub password: String, } +#[cfg_attr(test, derive(PartialEq, Eq, Debug))] pub struct SearchRequest {} +#[cfg_attr(test, derive(PartialEq, Eq, Debug))] pub struct SearchResponse {} pub trait BackendHandler: Clone + Send { @@ -47,3 +50,15 @@ impl BackendHandler for SqlBackendHandler { Ok(SearchResponse {}) } } + +#[cfg(test)] +mockall::mock! { + pub TestBackendHandler{} + impl Clone for TestBackendHandler { + fn clone(&self) -> Self; + } + impl BackendHandler for TestBackendHandler { + fn bind(&mut self, request: BindRequest) -> Result<()>; + fn search(&mut self, request: SearchRequest) -> Result; + } +} diff --git a/src/infra/ldap_handler.rs b/src/infra/ldap_handler.rs new file mode 100644 index 0000000..2c9a29a --- /dev/null +++ b/src/infra/ldap_handler.rs @@ -0,0 +1,123 @@ +use crate::domain::handler::BackendHandler; +use ldap3_server::simple::*; + +pub struct LdapHandler { + dn: String, + backend_handler: Backend, +} + +impl LdapHandler { + pub fn new(backend_handler: Backend) -> Self { + Self { + dn: "Unauthenticated".to_string(), + backend_handler, + } + } + + pub fn do_bind(&mut self, sbr: &SimpleBindRequest) -> LdapMsg { + match self + .backend_handler + .bind(crate::domain::handler::BindRequest { + name: sbr.dn.clone(), + password: sbr.pw.clone(), + }) { + Ok(()) => { + self.dn = sbr.dn.clone(); + sbr.gen_success() + } + Err(_) => sbr.gen_invalid_cred(), + } + } + + pub fn do_search(&mut self, lsr: &SearchRequest) -> Vec { + vec![ + lsr.gen_result_entry(LdapSearchResultEntry { + dn: "cn=hello,dc=example,dc=com".to_string(), + attributes: vec![ + LdapPartialAttribute { + atype: "objectClass".to_string(), + vals: vec!["cursed".to_string()], + }, + LdapPartialAttribute { + atype: "cn".to_string(), + vals: vec!["hello".to_string()], + }, + ], + }), + lsr.gen_result_entry(LdapSearchResultEntry { + dn: "cn=world,dc=example,dc=com".to_string(), + attributes: vec![ + LdapPartialAttribute { + atype: "objectClass".to_string(), + vals: vec!["cursed".to_string()], + }, + LdapPartialAttribute { + atype: "cn".to_string(), + vals: vec!["world".to_string()], + }, + ], + }), + lsr.gen_success(), + ] + } + + pub fn do_whoami(&mut self, wr: &WhoamiRequest) -> LdapMsg { + if self.dn == "Unauthenticated" { + wr.gen_operror("Unauthenticated") + } else { + wr.gen_success(format!("dn: {}", self.dn).as_str()) + } + } + + pub fn handle_ldap_message(&mut self, server_op: ServerOps) -> Option> { + let result = match server_op { + ServerOps::SimpleBind(sbr) => vec![self.do_bind(&sbr)], + ServerOps::Search(sr) => self.do_search(&sr), + ServerOps::Unbind(_) => { + // No need to notify on unbind (per rfc4511) + return None; + } + ServerOps::Whoami(wr) => vec![self.do_whoami(&wr)], + }; + Some(result) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::domain::handler::MockTestBackendHandler; + use mockall::{mock, predicate::*}; + + #[test] + fn test_bind() { + let mut mock = MockTestBackendHandler::new(); + mock.expect_bind() + .with(eq(crate::domain::handler::BindRequest { + name: "test".to_string(), + password: "pass".to_string(), + })) + .times(1) + .return_once(|_| Ok(())); + let mut ldap_handler = LdapHandler::new(mock); + + let request = WhoamiRequest { msgid: 1 }; + assert_eq!( + ldap_handler.do_whoami(&request), + request.gen_operror("Unauthenticated") + ); + + let request = SimpleBindRequest { + msgid: 2, + dn: "test".to_string(), + pw: "pass".to_string(), + }; + assert_eq!(ldap_handler.do_bind(&request), request.gen_success()); + + let request = WhoamiRequest { msgid: 3 }; + assert_eq!( + ldap_handler.do_whoami(&request), + request.gen_success("dn: test") + ); + } +} diff --git a/src/infra/ldap_server.rs b/src/infra/ldap_server.rs index eae7770..41ad24b 100644 --- a/src/infra/ldap_server.rs +++ b/src/infra/ldap_server.rs @@ -1,89 +1,18 @@ use crate::domain::handler::BackendHandler; use crate::infra::configuration::Configuration; +use crate::infra::ldap_handler::LdapHandler; use actix_rt::net::TcpStream; use actix_server::ServerBuilder; use actix_service::{fn_service, pipeline_factory}; use anyhow::bail; use anyhow::Result; use futures_util::future::ok; +use ldap3_server::simple::*; +use ldap3_server::LdapCodec; use log::*; use tokio::net::tcp::WriteHalf; use tokio_util::codec::{FramedRead, FramedWrite}; -use ldap3_server::simple::*; -use ldap3_server::LdapCodec; - -pub struct LdapHandler { - dn: String, - backend_handler: Backend, -} - -impl LdapHandler { - pub fn do_bind(&mut self, sbr: &SimpleBindRequest) -> LdapMsg { - match self - .backend_handler - .bind(crate::domain::handler::BindRequest { - name: sbr.dn.clone(), - password: sbr.pw.clone(), - }) { - Ok(()) => { - self.dn = sbr.dn.clone(); - sbr.gen_success() - } - Err(_) => sbr.gen_invalid_cred(), - } - } - - pub fn do_search(&mut self, lsr: &SearchRequest) -> Vec { - vec![ - lsr.gen_result_entry(LdapSearchResultEntry { - dn: "cn=hello,dc=example,dc=com".to_string(), - attributes: vec![ - LdapPartialAttribute { - atype: "objectClass".to_string(), - vals: vec!["cursed".to_string()], - }, - LdapPartialAttribute { - atype: "cn".to_string(), - vals: vec!["hello".to_string()], - }, - ], - }), - lsr.gen_result_entry(LdapSearchResultEntry { - dn: "cn=world,dc=example,dc=com".to_string(), - attributes: vec![ - LdapPartialAttribute { - atype: "objectClass".to_string(), - vals: vec!["cursed".to_string()], - }, - LdapPartialAttribute { - atype: "cn".to_string(), - vals: vec!["world".to_string()], - }, - ], - }), - lsr.gen_success(), - ] - } - - pub fn do_whoami(&mut self, wr: &WhoamiRequest) -> LdapMsg { - wr.gen_success(format!("dn: {}", self.dn).as_str()) - } - - pub fn handle_ldap_message(&mut self, server_op: ServerOps) -> Option> { - let result = match server_op { - ServerOps::SimpleBind(sbr) => vec![self.do_bind(&sbr)], - ServerOps::Search(sr) => self.do_search(&sr), - ServerOps::Unbind(_) => { - // No need to notify on unbind (per rfc4511) - return None; - } - ServerOps::Whoami(wr) => vec![self.do_whoami(&wr)], - }; - Some(result) - } -} - async fn handle_incoming_message( msg: Result, resp: &mut FramedWrite, LdapCodec>, @@ -146,10 +75,7 @@ where let mut requests = FramedRead::new(r, LdapCodec); let mut resp = FramedWrite::new(w, LdapCodec); - let mut session = LdapHandler { - dn: "Unauthenticated".to_string(), - backend_handler, - }; + let mut session = LdapHandler::new(backend_handler); while let Some(msg) = requests.next().await { if !handle_incoming_message(msg, &mut resp, &mut session).await? { diff --git a/src/infra/mod.rs b/src/infra/mod.rs index 69ca5ca..177b2e5 100644 --- a/src/infra/mod.rs +++ b/src/infra/mod.rs @@ -1,5 +1,6 @@ pub mod cli; pub mod configuration; +pub mod ldap_handler; pub mod ldap_server; pub mod logging; pub mod tcp_server;