Add support for non-admin bind

This commit is contained in:
Valentin Tolmer 2021-04-07 20:14:21 +02:00
parent 31e8998ac3
commit 6abe94af13
4 changed files with 116 additions and 32 deletions

View File

@ -27,6 +27,7 @@ tracing = "*"
tracing-actix-web = "0.3.0-beta.2"
tracing-log = "*"
tracing-subscriber = "*"
async-trait = "0.1.48"
[dependencies.figment]
features = ["toml", "env"]

View File

@ -1,6 +1,8 @@
use crate::infra::configuration::Configuration;
use anyhow::{bail, Result};
use async_trait::async_trait;
use sqlx::any::AnyPool;
use sqlx::Row;
#[cfg_attr(test, derive(PartialEq, Eq, Debug))]
pub struct BindRequest {
@ -24,9 +26,10 @@ pub struct User {
pub creation_date: chrono::NaiveDateTime,
}
#[async_trait]
pub trait BackendHandler: Clone + Send {
fn bind(&mut self, request: BindRequest) -> Result<()>;
fn list_users(&mut self, request: ListUsersRequest) -> Result<Vec<User>>;
async fn bind(&mut self, request: BindRequest) -> Result<()>;
async fn list_users(&mut self, request: ListUsersRequest) -> Result<Vec<User>>;
}
#[derive(Debug, Clone)]
@ -46,19 +49,34 @@ impl SqlBackendHandler {
}
}
fn passwords_match(encrypted_password: &str, clear_password: &str) -> bool {
encrypted_password == clear_password
}
#[async_trait]
impl BackendHandler for SqlBackendHandler {
fn bind(&mut self, request: BindRequest) -> Result<()> {
if request.name == self.config.ldap_user_dn
&& request.password == self.config.ldap_user_pass
{
self.authenticated = true;
Ok(())
} else {
bail!(r#"Authentication error for "{}""#, request.name)
async fn bind(&mut self, request: BindRequest) -> Result<()> {
if request.name == self.config.ldap_user_dn {
if request.password == self.config.ldap_user_pass {
self.authenticated = true;
return Ok(());
} else {
bail!(r#"Authentication error for "{}""#, request.name)
}
}
if let Ok(row) = sqlx::query("SELECT password FROM users WHERE user_id = ?")
.bind(&request.name)
.fetch_one(&self.sql_pool)
.await
{
if passwords_match(&request.password, &row.get::<String, _>("password")) {
return Ok(());
}
}
bail!(r#"Authentication error for "{}""#, request.name)
}
fn list_users(&mut self, request: ListUsersRequest) -> Result<Vec<User>> {
async fn list_users(&mut self, request: ListUsersRequest) -> Result<Vec<User>> {
Ok(Vec::new())
}
}
@ -69,8 +87,9 @@ mockall::mock! {
impl Clone for TestBackendHandler {
fn clone(&self) -> Self;
}
#[async_trait]
impl BackendHandler for TestBackendHandler {
fn bind(&mut self, request: BindRequest) -> Result<()>;
fn list_users(&mut self, request: ListUsersRequest) -> Result<Vec<User>>;
async fn bind(&mut self, request: BindRequest) -> Result<()>;
async fn list_users(&mut self, request: ListUsersRequest) -> Result<Vec<User>>;
}
}

View File

@ -82,10 +82,11 @@ pub struct LdapHandler<Backend: BackendHandler> {
backend_handler: Backend,
pub base_dn: Vec<(String, String)>,
base_dn_str: String,
ldap_user_dn: String,
}
impl<Backend: BackendHandler> LdapHandler<Backend> {
pub fn new(backend_handler: Backend, ldap_base_dn: String) -> Self {
pub fn new(backend_handler: Backend, ldap_base_dn: String, ldap_user_dn: String) -> Self {
Self {
dn: "Unauthenticated".to_string(),
backend_handler,
@ -96,16 +97,19 @@ impl<Backend: BackendHandler> LdapHandler<Backend> {
)
}),
base_dn_str: ldap_base_dn,
ldap_user_dn,
}
}
pub fn do_bind(&mut self, sbr: &SimpleBindRequest) -> LdapMsg {
pub async fn do_bind(&mut self, sbr: &SimpleBindRequest) -> LdapMsg {
match self
.backend_handler
.bind(crate::domain::handler::BindRequest {
name: sbr.dn.clone(),
password: sbr.pw.clone(),
}) {
})
.await
{
Ok(()) => {
self.dn = sbr.dn.clone();
sbr.gen_success()
@ -114,7 +118,13 @@ impl<Backend: BackendHandler> LdapHandler<Backend> {
}
}
pub fn do_search(&mut self, lsr: &SearchRequest) -> Vec<LdapMsg> {
pub async fn do_search(&mut self, lsr: &SearchRequest) -> Vec<LdapMsg> {
if self.dn != self.ldap_user_dn {
return vec![lsr.gen_error(
LdapResultCode::InsufficentAccessRights,
r#"Current user is not allowed to query LDAP"#.to_string(),
)];
}
let dn_parts = match parse_distinguished_name(&lsr.base) {
Ok(dn) => dn,
Err(_) => {
@ -128,7 +138,7 @@ impl<Backend: BackendHandler> LdapHandler<Backend> {
// Search path is not in our tree, just return an empty success.
return vec![lsr.gen_success()];
}
let users = match self.backend_handler.list_users(ListUsersRequest {}) {
let users = match self.backend_handler.list_users(ListUsersRequest {}).await {
Ok(users) => users,
Err(e) => {
return vec![lsr.gen_error(
@ -156,10 +166,10 @@ impl<Backend: BackendHandler> LdapHandler<Backend> {
}
}
pub fn handle_ldap_message(&mut self, server_op: ServerOps) -> Option<Vec<LdapMsg>> {
pub async fn handle_ldap_message(&mut self, server_op: ServerOps) -> Option<Vec<LdapMsg>> {
let result = match server_op {
ServerOps::SimpleBind(sbr) => vec![self.do_bind(&sbr)],
ServerOps::Search(sr) => self.do_search(&sr),
ServerOps::SimpleBind(sbr) => vec![self.do_bind(&sbr).await],
ServerOps::Search(sr) => self.do_search(&sr).await,
ServerOps::Unbind(_) => {
// No need to notify on unbind (per rfc4511)
return None;
@ -176,9 +186,10 @@ mod tests {
use crate::domain::handler::MockTestBackendHandler;
use chrono::NaiveDateTime;
use mockall::predicate::eq;
use tokio;
#[test]
fn test_bind() {
#[tokio::test]
async fn test_bind() {
let mut mock = MockTestBackendHandler::new();
mock.expect_bind()
.with(eq(crate::domain::handler::BindRequest {
@ -187,7 +198,8 @@ mod tests {
}))
.times(1)
.return_once(|_| Ok(()));
let mut ldap_handler = LdapHandler::new(mock, "dc=example,dc=com".to_string());
let mut ldap_handler =
LdapHandler::new(mock, "dc=example,dc=com".to_string(), "test".to_string());
let request = WhoamiRequest { msgid: 1 };
assert_eq!(
@ -200,7 +212,7 @@ mod tests {
dn: "test".to_string(),
pw: "pass".to_string(),
};
assert_eq!(ldap_handler.do_bind(&request), request.gen_success());
assert_eq!(ldap_handler.do_bind(&request).await, request.gen_success());
let request = WhoamiRequest { msgid: 3 };
assert_eq!(
@ -209,6 +221,54 @@ mod tests {
);
}
#[tokio::test]
async fn test_bind_invalid_credentials() {
let mut mock = MockTestBackendHandler::new();
mock.expect_bind()
.with(eq(crate::domain::handler::BindRequest {
name: "test".to_string(),
password: "pass".to_string(),
}))
.times(1)
.return_once(|_| Ok(()));
let mut ldap_handler =
LdapHandler::new(mock, "dc=example,dc=com".to_string(), "admin".to_string());
let request = WhoamiRequest { msgid: 1 };
assert_eq!(
ldap_handler.do_whoami(&request),
request.gen_operror("Unauthenticated")
);
let request = SimpleBindRequest {
msgid: 2,
dn: "test".to_string(),
pw: "pass".to_string(),
};
assert_eq!(ldap_handler.do_bind(&request).await, request.gen_success());
let request = WhoamiRequest { msgid: 3 };
assert_eq!(
ldap_handler.do_whoami(&request),
request.gen_success("dn: test")
);
let request = SearchRequest {
msgid: 2,
base: "ou=people,dc=example,dc=com".to_string(),
scope: LdapSearchScope::Base,
filter: LdapFilter::And(vec![]),
attrs: vec![],
};
assert_eq!(
ldap_handler.do_search(&request).await,
vec![request.gen_error(
LdapResultCode::InsufficentAccessRights,
r#"Current user is not allowed to query LDAP"#.to_string()
)]
);
}
#[test]
fn test_is_subtree() {
let subtree1 = &[
@ -237,8 +297,8 @@ mod tests {
);
}
#[test]
fn test_search() {
#[tokio::test]
async fn test_search() {
let mut mock = MockTestBackendHandler::new();
mock.expect_bind().return_once(|_| Ok(()));
mock.expect_list_users()
@ -264,13 +324,14 @@ mod tests {
},
])
});
let mut ldap_handler = LdapHandler::new(mock, "dc=example,dc=com".to_string());
let mut ldap_handler =
LdapHandler::new(mock, "dc=example,dc=com".to_string(), "test".to_string());
let request = SimpleBindRequest {
msgid: 1,
dn: "test".to_string(),
pw: "pass".to_string(),
};
assert_eq!(ldap_handler.do_bind(&request), request.gen_success());
assert_eq!(ldap_handler.do_bind(&request).await, request.gen_success());
let request = SearchRequest {
msgid: 2,
base: "ou=people,dc=example,dc=com".to_string(),
@ -286,7 +347,7 @@ mod tests {
],
};
assert_eq!(
ldap_handler.do_search(&request),
ldap_handler.do_search(&request).await,
vec![
request.gen_result_entry(LdapSearchResultEntry {
dn: "cn=bob_1,dc=example,dc=com".to_string(),

View File

@ -34,7 +34,7 @@ async fn handle_incoming_message<Backend: BackendHandler>(
}
};
match session.handle_ldap_message(server_op) {
match session.handle_ldap_message(server_op).await {
None => return Ok(false),
Some(result) => {
for rmsg in result.into_iter() {
@ -62,20 +62,23 @@ where
use futures_util::StreamExt;
let ldap_base_dn = config.ldap_base_dn.clone();
let ldap_user_dn = config.ldap_user_dn.clone();
Ok(
server_builder.bind("ldap", ("0.0.0.0", config.ldap_port), move || {
let backend_handler = backend_handler.clone();
let ldap_base_dn = ldap_base_dn.clone();
let ldap_user_dn = ldap_user_dn.clone();
pipeline_factory(fn_service(move |mut stream: TcpStream| {
let backend_handler = backend_handler.clone();
let ldap_base_dn = ldap_base_dn.clone();
let ldap_user_dn = ldap_user_dn.clone();
async move {
// Configure the codec etc.
let (r, w) = stream.split();
let mut requests = FramedRead::new(r, LdapCodec);
let mut resp = FramedWrite::new(w, LdapCodec);
let mut session = LdapHandler::new(backend_handler, ldap_base_dn);
let mut session = LdapHandler::new(backend_handler, ldap_base_dn, ldap_user_dn);
while let Some(msg) = requests.next().await {
if !handle_incoming_message(msg, &mut resp, &mut session).await? {