ldap: Add support for memberOf and wildcards

This commit is contained in:
Valentin Tolmer 2021-09-24 22:27:07 +02:00 committed by nitnelave
parent c0d866b77b
commit 09a23a1e59
3 changed files with 128 additions and 30 deletions

View File

@ -48,6 +48,8 @@ pub enum RequestFilter {
Or(Vec<RequestFilter>), Or(Vec<RequestFilter>),
Not(Box<RequestFilter>), Not(Box<RequestFilter>),
Equality(String, String), Equality(String, String),
// Check if a user belongs to a group.
MemberOf(String),
} }
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Clone, Default)] #[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Clone, Default)]

View File

@ -18,24 +18,53 @@ impl SqlBackendHandler {
} }
} }
fn get_filter_expr(filter: RequestFilter) -> SimpleExpr { struct RequiresGroup(bool);
// Returns the condition for the SQL query, and whether it requires joining with the groups table.
fn get_filter_expr(filter: RequestFilter) -> (RequiresGroup, SimpleExpr) {
use RequestFilter::*; use RequestFilter::*;
fn get_repeated_filter( fn get_repeated_filter(
fs: Vec<RequestFilter>, fs: Vec<RequestFilter>,
field: &dyn Fn(SimpleExpr, SimpleExpr) -> SimpleExpr, field: &dyn Fn(SimpleExpr, SimpleExpr) -> SimpleExpr,
) -> SimpleExpr { ) -> (RequiresGroup, SimpleExpr) {
let mut requires_group = false;
let mut it = fs.into_iter(); let mut it = fs.into_iter();
let first_expr = match it.next() { let first_expr = match it.next() {
None => return Expr::value(true), None => return (RequiresGroup(false), Expr::value(true)),
Some(f) => get_filter_expr(f), Some(f) => {
let (group, filter) = get_filter_expr(f);
requires_group |= group.0;
filter
}
}; };
it.fold(first_expr, |e, f| field(e, get_filter_expr(f))) let filter = it.fold(first_expr, |e, f| {
let (group, filters) = get_filter_expr(f);
requires_group |= group.0;
field(e, filters)
});
(RequiresGroup(requires_group), filter)
} }
match filter { match filter {
And(fs) => get_repeated_filter(fs, &SimpleExpr::and), And(fs) => get_repeated_filter(fs, &SimpleExpr::and),
Or(fs) => get_repeated_filter(fs, &SimpleExpr::or), Or(fs) => get_repeated_filter(fs, &SimpleExpr::or),
Not(f) => Expr::not(Expr::expr(get_filter_expr(*f))), Not(f) => {
Equality(s1, s2) => Expr::expr(Expr::cust(&s1)).eq(s2), let (requires_group, filters) = get_filter_expr(*f);
(requires_group, Expr::not(Expr::expr(filters)))
}
Equality(s1, s2) => (
RequiresGroup(false),
if s1 == Users::DisplayName.to_string() {
Expr::col((Users::Table, Users::DisplayName)).eq(s2)
} else if s1 == Users::UserId.to_string() {
Expr::col((Users::Table, Users::UserId)).eq(s2)
} else {
Expr::expr(Expr::cust(&s1)).eq(s2)
},
),
MemberOf(group) => (
RequiresGroup(true),
Expr::col((Groups::Table, Groups::DisplayName)).eq(group),
),
} }
} }
@ -44,21 +73,38 @@ impl BackendHandler for SqlBackendHandler {
async fn list_users(&self, filters: Option<RequestFilter>) -> Result<Vec<User>> { async fn list_users(&self, filters: Option<RequestFilter>) -> Result<Vec<User>> {
let query = { let query = {
let mut query_builder = Query::select() let mut query_builder = Query::select()
.column(Users::UserId) .column((Users::Table, Users::UserId))
.column(Users::Email) .column(Users::Email)
.column(Users::DisplayName) .column((Users::Table, Users::DisplayName))
.column(Users::FirstName) .column(Users::FirstName)
.column(Users::LastName) .column(Users::LastName)
.column(Users::Avatar) .column(Users::Avatar)
.column(Users::CreationDate) .column(Users::CreationDate)
.from(Users::Table) .from(Users::Table)
.order_by(Users::UserId, Order::Asc) .order_by((Users::Table, Users::UserId), Order::Asc)
.to_owned(); .to_owned();
if let Some(filter) = filters { if let Some(filter) = filters {
if filter == RequestFilter::Not(Box::new(RequestFilter::And(Vec::new()))) {
return Ok(Vec::new());
}
if filter != RequestFilter::And(Vec::new()) if filter != RequestFilter::And(Vec::new())
&& filter != RequestFilter::Or(Vec::new()) && filter != RequestFilter::Or(Vec::new())
{ {
query_builder.and_where(get_filter_expr(filter)); let (RequiresGroup(requires_group), condition) = get_filter_expr(filter);
query_builder.and_where(condition);
if requires_group {
query_builder
.left_join(
Memberships::Table,
Expr::tbl(Users::Table, Users::UserId)
.equals(Memberships::Table, Memberships::UserId),
)
.left_join(
Groups::Table,
Expr::tbl(Memberships::Table, Memberships::GroupId)
.equals(Groups::Table, Groups::GroupId),
);
}
} }
} }

View File

@ -29,6 +29,31 @@ fn parse_distinguished_name(dn: &str) -> Result<Vec<(String, String)>> {
.collect() .collect()
} }
fn get_group_id_from_distinguished_name(
dn: &str,
base_tree: &[(String, String)],
base_dn_str: &str,
) -> Result<String> {
let parts = parse_distinguished_name(dn)?;
if !is_subtree(&parts, base_tree) {
bail!("Not a subtree of the base tree");
}
if parts.len() == base_tree.len() + 2 {
if parts[1].0 != "ou" || parts[1].1 != "groups" || parts[0].0 != "cn" {
bail!(
r#"Unexpected user DN format. Expected: "cn=groupname,ou=groups,{}""#,
base_dn_str
);
}
Ok(parts[0].1.to_string())
} else {
bail!(
r#"Unexpected user DN format. Expected: "cn=groupname,ou=groups,{}""#,
base_dn_str
);
}
}
fn get_user_id_from_distinguished_name( fn get_user_id_from_distinguished_name(
dn: &str, dn: &str,
base_tree: &[(String, String)], base_tree: &[(String, String)],
@ -128,22 +153,6 @@ fn map_field(field: &str) -> Result<String> {
}) })
} }
fn convert_filter(filter: &LdapFilter) -> Result<RequestFilter> {
match filter {
LdapFilter::And(filters) => Ok(RequestFilter::And(
filters.iter().map(convert_filter).collect::<Result<_>>()?,
)),
LdapFilter::Or(filters) => Ok(RequestFilter::Or(
filters.iter().map(convert_filter).collect::<Result<_>>()?,
)),
LdapFilter::Not(filter) => Ok(RequestFilter::Not(Box::new(convert_filter(&*filter)?))),
LdapFilter::Equality(field, value) => {
Ok(RequestFilter::Equality(map_field(field)?, value.clone()))
}
_ => bail!("Unsupported filter"),
}
}
pub struct LdapHandler<Backend: BackendHandler + LoginHandler> { pub struct LdapHandler<Backend: BackendHandler + LoginHandler> {
dn: String, dn: String,
backend_handler: Backend, backend_handler: Backend,
@ -214,12 +223,12 @@ impl<Backend: BackendHandler + LoginHandler> LdapHandler<Backend> {
// Search path is not in our tree, just return an empty success. // Search path is not in our tree, just return an empty success.
return vec![lsr.gen_success()]; return vec![lsr.gen_success()];
} }
let filters = match convert_filter(&lsr.filter) { let filters = match self.convert_filter(&lsr.filter) {
Ok(f) => Some(f), Ok(f) => Some(f),
Err(_) => { Err(e) => {
return vec![lsr.gen_error( return vec![lsr.gen_error(
LdapResultCode::UnwillingToPerform, LdapResultCode::UnwillingToPerform,
"Unsupported filter".to_string(), format!("Unsupported filter: {}", e),
)] )]
} }
}; };
@ -263,6 +272,47 @@ impl<Backend: BackendHandler + LoginHandler> LdapHandler<Backend> {
}; };
Some(result) Some(result)
} }
fn convert_filter(&self, filter: &LdapFilter) -> Result<RequestFilter> {
match filter {
LdapFilter::And(filters) => Ok(RequestFilter::And(
filters
.iter()
.map(|f| self.convert_filter(f))
.collect::<Result<_>>()?,
)),
LdapFilter::Or(filters) => Ok(RequestFilter::Or(
filters
.iter()
.map(|f| self.convert_filter(f))
.collect::<Result<_>>()?,
)),
LdapFilter::Not(filter) => {
Ok(RequestFilter::Not(Box::new(self.convert_filter(&*filter)?)))
}
LdapFilter::Equality(field, value) => {
if field == "memberOf" {
let group_name = get_group_id_from_distinguished_name(
value,
&self.base_dn,
&self.base_dn_str,
)?;
Ok(RequestFilter::MemberOf(group_name))
} else {
Ok(RequestFilter::Equality(map_field(field)?, value.clone()))
}
}
LdapFilter::Present(field) => {
// Check that it's a field we support.
if field == "objectclass" || map_field(field).is_ok() {
Ok(RequestFilter::And(Vec::new()))
} else {
Ok(RequestFilter::Not(Box::new(RequestFilter::And(Vec::new()))))
}
}
_ => bail!("Unsupported filter: {:?}", filter),
}
}
} }
#[cfg(test)] #[cfg(test)]