Add support for non-admin bind

This commit is contained in:
Valentin Tolmer 2021-04-07 20:14:21 +02:00
parent 31e8998ac3
commit 6abe94af13
4 changed files with 116 additions and 32 deletions

View File

@ -27,6 +27,7 @@ tracing = "*"
tracing-actix-web = "0.3.0-beta.2" tracing-actix-web = "0.3.0-beta.2"
tracing-log = "*" tracing-log = "*"
tracing-subscriber = "*" tracing-subscriber = "*"
async-trait = "0.1.48"
[dependencies.figment] [dependencies.figment]
features = ["toml", "env"] features = ["toml", "env"]

View File

@ -1,6 +1,8 @@
use crate::infra::configuration::Configuration; use crate::infra::configuration::Configuration;
use anyhow::{bail, Result}; use anyhow::{bail, Result};
use async_trait::async_trait;
use sqlx::any::AnyPool; use sqlx::any::AnyPool;
use sqlx::Row;
#[cfg_attr(test, derive(PartialEq, Eq, Debug))] #[cfg_attr(test, derive(PartialEq, Eq, Debug))]
pub struct BindRequest { pub struct BindRequest {
@ -24,9 +26,10 @@ pub struct User {
pub creation_date: chrono::NaiveDateTime, pub creation_date: chrono::NaiveDateTime,
} }
#[async_trait]
pub trait BackendHandler: Clone + Send { pub trait BackendHandler: Clone + Send {
fn bind(&mut self, request: BindRequest) -> Result<()>; async fn bind(&mut self, request: BindRequest) -> Result<()>;
fn list_users(&mut self, request: ListUsersRequest) -> Result<Vec<User>>; async fn list_users(&mut self, request: ListUsersRequest) -> Result<Vec<User>>;
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -46,19 +49,34 @@ impl SqlBackendHandler {
} }
} }
fn passwords_match(encrypted_password: &str, clear_password: &str) -> bool {
encrypted_password == clear_password
}
#[async_trait]
impl BackendHandler for SqlBackendHandler { impl BackendHandler for SqlBackendHandler {
fn bind(&mut self, request: BindRequest) -> Result<()> { async fn bind(&mut self, request: BindRequest) -> Result<()> {
if request.name == self.config.ldap_user_dn if request.name == self.config.ldap_user_dn {
&& request.password == self.config.ldap_user_pass if request.password == self.config.ldap_user_pass {
{
self.authenticated = true; self.authenticated = true;
Ok(()) return Ok(());
} else { } else {
bail!(r#"Authentication error for "{}""#, request.name) bail!(r#"Authentication error for "{}""#, request.name)
} }
} }
if let Ok(row) = sqlx::query("SELECT password FROM users WHERE user_id = ?")
.bind(&request.name)
.fetch_one(&self.sql_pool)
.await
{
if passwords_match(&request.password, &row.get::<String, _>("password")) {
return Ok(());
}
}
bail!(r#"Authentication error for "{}""#, request.name)
}
fn list_users(&mut self, request: ListUsersRequest) -> Result<Vec<User>> { async fn list_users(&mut self, request: ListUsersRequest) -> Result<Vec<User>> {
Ok(Vec::new()) Ok(Vec::new())
} }
} }
@ -69,8 +87,9 @@ mockall::mock! {
impl Clone for TestBackendHandler { impl Clone for TestBackendHandler {
fn clone(&self) -> Self; fn clone(&self) -> Self;
} }
#[async_trait]
impl BackendHandler for TestBackendHandler { impl BackendHandler for TestBackendHandler {
fn bind(&mut self, request: BindRequest) -> Result<()>; async fn bind(&mut self, request: BindRequest) -> Result<()>;
fn list_users(&mut self, request: ListUsersRequest) -> Result<Vec<User>>; async fn list_users(&mut self, request: ListUsersRequest) -> Result<Vec<User>>;
} }
} }

View File

@ -82,10 +82,11 @@ pub struct LdapHandler<Backend: BackendHandler> {
backend_handler: Backend, backend_handler: Backend,
pub base_dn: Vec<(String, String)>, pub base_dn: Vec<(String, String)>,
base_dn_str: String, base_dn_str: String,
ldap_user_dn: String,
} }
impl<Backend: BackendHandler> LdapHandler<Backend> { impl<Backend: BackendHandler> LdapHandler<Backend> {
pub fn new(backend_handler: Backend, ldap_base_dn: String) -> Self { pub fn new(backend_handler: Backend, ldap_base_dn: String, ldap_user_dn: String) -> Self {
Self { Self {
dn: "Unauthenticated".to_string(), dn: "Unauthenticated".to_string(),
backend_handler, backend_handler,
@ -96,16 +97,19 @@ impl<Backend: BackendHandler> LdapHandler<Backend> {
) )
}), }),
base_dn_str: ldap_base_dn, base_dn_str: ldap_base_dn,
ldap_user_dn,
} }
} }
pub fn do_bind(&mut self, sbr: &SimpleBindRequest) -> LdapMsg { pub async fn do_bind(&mut self, sbr: &SimpleBindRequest) -> LdapMsg {
match self match self
.backend_handler .backend_handler
.bind(crate::domain::handler::BindRequest { .bind(crate::domain::handler::BindRequest {
name: sbr.dn.clone(), name: sbr.dn.clone(),
password: sbr.pw.clone(), password: sbr.pw.clone(),
}) { })
.await
{
Ok(()) => { Ok(()) => {
self.dn = sbr.dn.clone(); self.dn = sbr.dn.clone();
sbr.gen_success() sbr.gen_success()
@ -114,7 +118,13 @@ impl<Backend: BackendHandler> LdapHandler<Backend> {
} }
} }
pub fn do_search(&mut self, lsr: &SearchRequest) -> Vec<LdapMsg> { pub async fn do_search(&mut self, lsr: &SearchRequest) -> Vec<LdapMsg> {
if self.dn != self.ldap_user_dn {
return vec![lsr.gen_error(
LdapResultCode::InsufficentAccessRights,
r#"Current user is not allowed to query LDAP"#.to_string(),
)];
}
let dn_parts = match parse_distinguished_name(&lsr.base) { let dn_parts = match parse_distinguished_name(&lsr.base) {
Ok(dn) => dn, Ok(dn) => dn,
Err(_) => { Err(_) => {
@ -128,7 +138,7 @@ impl<Backend: BackendHandler> LdapHandler<Backend> {
// Search path is not in our tree, just return an empty success. // Search path is not in our tree, just return an empty success.
return vec![lsr.gen_success()]; return vec![lsr.gen_success()];
} }
let users = match self.backend_handler.list_users(ListUsersRequest {}) { let users = match self.backend_handler.list_users(ListUsersRequest {}).await {
Ok(users) => users, Ok(users) => users,
Err(e) => { Err(e) => {
return vec![lsr.gen_error( return vec![lsr.gen_error(
@ -156,10 +166,10 @@ impl<Backend: BackendHandler> LdapHandler<Backend> {
} }
} }
pub fn handle_ldap_message(&mut self, server_op: ServerOps) -> Option<Vec<LdapMsg>> { pub async fn handle_ldap_message(&mut self, server_op: ServerOps) -> Option<Vec<LdapMsg>> {
let result = match server_op { let result = match server_op {
ServerOps::SimpleBind(sbr) => vec![self.do_bind(&sbr)], ServerOps::SimpleBind(sbr) => vec![self.do_bind(&sbr).await],
ServerOps::Search(sr) => self.do_search(&sr), ServerOps::Search(sr) => self.do_search(&sr).await,
ServerOps::Unbind(_) => { ServerOps::Unbind(_) => {
// No need to notify on unbind (per rfc4511) // No need to notify on unbind (per rfc4511)
return None; return None;
@ -176,9 +186,10 @@ mod tests {
use crate::domain::handler::MockTestBackendHandler; use crate::domain::handler::MockTestBackendHandler;
use chrono::NaiveDateTime; use chrono::NaiveDateTime;
use mockall::predicate::eq; use mockall::predicate::eq;
use tokio;
#[test] #[tokio::test]
fn test_bind() { async fn test_bind() {
let mut mock = MockTestBackendHandler::new(); let mut mock = MockTestBackendHandler::new();
mock.expect_bind() mock.expect_bind()
.with(eq(crate::domain::handler::BindRequest { .with(eq(crate::domain::handler::BindRequest {
@ -187,7 +198,8 @@ mod tests {
})) }))
.times(1) .times(1)
.return_once(|_| Ok(())); .return_once(|_| Ok(()));
let mut ldap_handler = LdapHandler::new(mock, "dc=example,dc=com".to_string()); let mut ldap_handler =
LdapHandler::new(mock, "dc=example,dc=com".to_string(), "test".to_string());
let request = WhoamiRequest { msgid: 1 }; let request = WhoamiRequest { msgid: 1 };
assert_eq!( assert_eq!(
@ -200,7 +212,7 @@ mod tests {
dn: "test".to_string(), dn: "test".to_string(),
pw: "pass".to_string(), pw: "pass".to_string(),
}; };
assert_eq!(ldap_handler.do_bind(&request), request.gen_success()); assert_eq!(ldap_handler.do_bind(&request).await, request.gen_success());
let request = WhoamiRequest { msgid: 3 }; let request = WhoamiRequest { msgid: 3 };
assert_eq!( assert_eq!(
@ -209,6 +221,54 @@ mod tests {
); );
} }
#[tokio::test]
async fn test_bind_invalid_credentials() {
let mut mock = MockTestBackendHandler::new();
mock.expect_bind()
.with(eq(crate::domain::handler::BindRequest {
name: "test".to_string(),
password: "pass".to_string(),
}))
.times(1)
.return_once(|_| Ok(()));
let mut ldap_handler =
LdapHandler::new(mock, "dc=example,dc=com".to_string(), "admin".to_string());
let request = WhoamiRequest { msgid: 1 };
assert_eq!(
ldap_handler.do_whoami(&request),
request.gen_operror("Unauthenticated")
);
let request = SimpleBindRequest {
msgid: 2,
dn: "test".to_string(),
pw: "pass".to_string(),
};
assert_eq!(ldap_handler.do_bind(&request).await, request.gen_success());
let request = WhoamiRequest { msgid: 3 };
assert_eq!(
ldap_handler.do_whoami(&request),
request.gen_success("dn: test")
);
let request = SearchRequest {
msgid: 2,
base: "ou=people,dc=example,dc=com".to_string(),
scope: LdapSearchScope::Base,
filter: LdapFilter::And(vec![]),
attrs: vec![],
};
assert_eq!(
ldap_handler.do_search(&request).await,
vec![request.gen_error(
LdapResultCode::InsufficentAccessRights,
r#"Current user is not allowed to query LDAP"#.to_string()
)]
);
}
#[test] #[test]
fn test_is_subtree() { fn test_is_subtree() {
let subtree1 = &[ let subtree1 = &[
@ -237,8 +297,8 @@ mod tests {
); );
} }
#[test] #[tokio::test]
fn test_search() { async fn test_search() {
let mut mock = MockTestBackendHandler::new(); let mut mock = MockTestBackendHandler::new();
mock.expect_bind().return_once(|_| Ok(())); mock.expect_bind().return_once(|_| Ok(()));
mock.expect_list_users() mock.expect_list_users()
@ -264,13 +324,14 @@ mod tests {
}, },
]) ])
}); });
let mut ldap_handler = LdapHandler::new(mock, "dc=example,dc=com".to_string()); let mut ldap_handler =
LdapHandler::new(mock, "dc=example,dc=com".to_string(), "test".to_string());
let request = SimpleBindRequest { let request = SimpleBindRequest {
msgid: 1, msgid: 1,
dn: "test".to_string(), dn: "test".to_string(),
pw: "pass".to_string(), pw: "pass".to_string(),
}; };
assert_eq!(ldap_handler.do_bind(&request), request.gen_success()); assert_eq!(ldap_handler.do_bind(&request).await, request.gen_success());
let request = SearchRequest { let request = SearchRequest {
msgid: 2, msgid: 2,
base: "ou=people,dc=example,dc=com".to_string(), base: "ou=people,dc=example,dc=com".to_string(),
@ -286,7 +347,7 @@ mod tests {
], ],
}; };
assert_eq!( assert_eq!(
ldap_handler.do_search(&request), ldap_handler.do_search(&request).await,
vec![ vec![
request.gen_result_entry(LdapSearchResultEntry { request.gen_result_entry(LdapSearchResultEntry {
dn: "cn=bob_1,dc=example,dc=com".to_string(), dn: "cn=bob_1,dc=example,dc=com".to_string(),

View File

@ -34,7 +34,7 @@ async fn handle_incoming_message<Backend: BackendHandler>(
} }
}; };
match session.handle_ldap_message(server_op) { match session.handle_ldap_message(server_op).await {
None => return Ok(false), None => return Ok(false),
Some(result) => { Some(result) => {
for rmsg in result.into_iter() { for rmsg in result.into_iter() {
@ -62,20 +62,23 @@ where
use futures_util::StreamExt; use futures_util::StreamExt;
let ldap_base_dn = config.ldap_base_dn.clone(); let ldap_base_dn = config.ldap_base_dn.clone();
let ldap_user_dn = config.ldap_user_dn.clone();
Ok( Ok(
server_builder.bind("ldap", ("0.0.0.0", config.ldap_port), move || { server_builder.bind("ldap", ("0.0.0.0", config.ldap_port), move || {
let backend_handler = backend_handler.clone(); let backend_handler = backend_handler.clone();
let ldap_base_dn = ldap_base_dn.clone(); let ldap_base_dn = ldap_base_dn.clone();
let ldap_user_dn = ldap_user_dn.clone();
pipeline_factory(fn_service(move |mut stream: TcpStream| { pipeline_factory(fn_service(move |mut stream: TcpStream| {
let backend_handler = backend_handler.clone(); let backend_handler = backend_handler.clone();
let ldap_base_dn = ldap_base_dn.clone(); let ldap_base_dn = ldap_base_dn.clone();
let ldap_user_dn = ldap_user_dn.clone();
async move { async move {
// Configure the codec etc. // Configure the codec etc.
let (r, w) = stream.split(); let (r, w) = stream.split();
let mut requests = FramedRead::new(r, LdapCodec); let mut requests = FramedRead::new(r, LdapCodec);
let mut resp = FramedWrite::new(w, LdapCodec); let mut resp = FramedWrite::new(w, LdapCodec);
let mut session = LdapHandler::new(backend_handler, ldap_base_dn); let mut session = LdapHandler::new(backend_handler, ldap_base_dn, ldap_user_dn);
while let Some(msg) = requests.next().await { while let Some(msg) = requests.next().await {
if !handle_incoming_message(msg, &mut resp, &mut session).await? { if !handle_incoming_message(msg, &mut resp, &mut session).await? {