server: Improve equality handling in filters

Now the columns are checked and mapped to user columns, to avoid any
ambiguity.

Fixes #341.
This commit is contained in:
Valentin Tolmer 2022-10-19 08:34:00 +02:00 committed by nitnelave
parent 8d19678e39
commit 4c69f917e7
8 changed files with 86 additions and 51 deletions

View File

@ -1,4 +1,4 @@
use super::error::*; use super::{error::*, sql_tables::UserColumn};
use async_trait::async_trait; use async_trait::async_trait;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashSet; use std::collections::HashSet;
@ -201,7 +201,7 @@ pub enum UserRequestFilter {
Or(Vec<UserRequestFilter>), Or(Vec<UserRequestFilter>),
Not(Box<UserRequestFilter>), Not(Box<UserRequestFilter>),
UserId(UserId), UserId(UserId),
Equality(String, String), Equality(UserColumn, String),
// Check if a user belongs to a group identified by name. // Check if a user belongs to a group identified by name.
MemberOf(String), MemberOf(String),
// Same, by id. // Same, by id.

View File

@ -6,11 +6,14 @@ use tracing::{debug, info, instrument, warn};
use crate::domain::{ use crate::domain::{
handler::{BackendHandler, Group, GroupRequestFilter, UserId, Uuid}, handler::{BackendHandler, Group, GroupRequestFilter, UserId, Uuid},
ldap::error::LdapError, ldap::error::LdapError,
sql_tables::GroupColumn,
}; };
use super::{ use super::{
error::LdapResult, error::LdapResult,
utils::{expand_attribute_wildcards, get_user_id_from_distinguished_name, map_field, LdapInfo}, utils::{
expand_attribute_wildcards, get_user_id_from_distinguished_name, map_group_field, LdapInfo,
},
}; };
fn get_group_attribute( fn get_group_attribute(
@ -123,11 +126,11 @@ fn convert_group_filter(
vec![], vec![],
)))), )))),
}, },
_ => match map_field(field) { _ => match map_group_field(field) {
Some("display_name") | Some("user_id") => { Some(GroupColumn::DisplayName) => {
Ok(GroupRequestFilter::DisplayName(value.to_string())) Ok(GroupRequestFilter::DisplayName(value.to_string()))
} }
Some("uuid") => Ok(GroupRequestFilter::Uuid( Some(GroupColumn::Uuid) => Ok(GroupRequestFilter::Uuid(
Uuid::try_from(value.as_str()).map_err(|e| LdapError { Uuid::try_from(value.as_str()).map_err(|e| LdapError {
code: LdapResultCode::InappropriateMatching, code: LdapResultCode::InappropriateMatching,
message: format!("Invalid UUID: {:#}", e), message: format!("Invalid UUID: {:#}", e),

View File

@ -6,11 +6,12 @@ use tracing::{debug, info, instrument, warn};
use crate::domain::{ use crate::domain::{
handler::{BackendHandler, GroupDetails, User, UserId, UserRequestFilter}, handler::{BackendHandler, GroupDetails, User, UserId, UserRequestFilter},
ldap::{error::LdapError, utils::expand_attribute_wildcards}, ldap::{error::LdapError, utils::expand_attribute_wildcards},
sql_tables::UserColumn,
}; };
use super::{ use super::{
error::LdapResult, error::LdapResult,
utils::{get_group_id_from_distinguished_name, map_field, LdapInfo}, utils::{get_group_id_from_distinguished_name, map_user_field, LdapInfo},
}; };
fn get_user_attribute( fn get_user_attribute(
@ -142,17 +143,9 @@ fn convert_user_filter(ldap_info: &LdapInfo, filter: &LdapFilter) -> LdapResult<
vec![], vec![],
)))), )))),
}, },
_ => match map_field(field) { _ => match map_user_field(field) {
Some(field) => { Some(UserColumn::UserId) => Ok(UserRequestFilter::UserId(UserId::new(value))),
if field == "user_id" { Some(field) => Ok(UserRequestFilter::Equality(field, value.clone())),
Ok(UserRequestFilter::UserId(UserId::new(value)))
} else {
Ok(UserRequestFilter::Equality(
field.to_string(),
value.clone(),
))
}
}
None => { None => {
if !ldap_info.ignored_user_attributes.contains(field) { if !ldap_info.ignored_user_attributes.contains(field) {
warn!( warn!(

View File

@ -2,7 +2,10 @@ use itertools::Itertools;
use ldap3_proto::LdapResultCode; use ldap3_proto::LdapResultCode;
use tracing::{debug, instrument, warn}; use tracing::{debug, instrument, warn};
use crate::domain::handler::UserId; use crate::domain::{
handler::UserId,
sql_tables::{GroupColumn, UserColumn},
};
use super::error::{LdapError, LdapResult}; use super::error::{LdapError, LdapResult};
@ -134,17 +137,31 @@ pub fn is_subtree(subtree: &[(String, String)], base_tree: &[(String, String)])
true true
} }
pub fn map_field(field: &str) -> Option<&'static str> { pub fn map_user_field(field: &str) -> Option<UserColumn> {
assert!(field == field.to_ascii_lowercase()); assert!(field == field.to_ascii_lowercase());
Some(match field { Some(match field {
"uid" => "user_id", "uid" | "user_id" | "id" => UserColumn::UserId,
"mail" => "email", "mail" | "email" => UserColumn::Email,
"cn" | "displayname" => "display_name", "cn" | "displayname" | "display_name" => UserColumn::DisplayName,
"givenname" => "first_name", "givenname" | "first_name" => UserColumn::FirstName,
"sn" => "last_name", "sn" | "last_name" => UserColumn::LastName,
"avatar" => "avatar", "avatar" => UserColumn::Avatar,
"creationdate" | "createtimestamp" | "modifytimestamp" => "creation_date", "creationdate" | "createtimestamp" | "modifytimestamp" | "creation_date" => {
"entryuuid" => "uuid", UserColumn::CreationDate
}
"entryuuid" | "uuid" => UserColumn::Uuid,
_ => return None,
})
}
pub fn map_group_field(field: &str) -> Option<GroupColumn> {
assert!(field == field.to_ascii_lowercase());
Some(match field {
"cn" | "displayname" | "uid" | "display_name" => GroupColumn::DisplayName,
"creationdate" | "createtimestamp" | "modifytimestamp" | "creation_date" => {
GroupColumn::CreationDate
}
"entryuuid" | "uuid" => GroupColumn::Uuid,
_ => return None, _ => return None,
}) })
} }

View File

@ -3,6 +3,7 @@ use super::{
sql_migrations::{get_schema_version, migrate_from_version, upgrade_to_v1}, sql_migrations::{get_schema_version, migrate_from_version, upgrade_to_v1},
}; };
use sea_query::*; use sea_query::*;
use serde::{Deserialize, Serialize};
pub use super::sql_migrations::create_group; pub use super::sql_migrations::create_group;
@ -51,7 +52,7 @@ impl From<SchemaVersion> for Value {
} }
} }
#[derive(Iden)] #[derive(Iden, PartialEq, Eq, Debug, Serialize, Deserialize, Clone)]
pub enum Users { pub enum Users {
Table, Table,
UserId, UserId,
@ -67,7 +68,9 @@ pub enum Users {
Uuid, Uuid,
} }
#[derive(Iden)] pub type UserColumn = Users;
#[derive(Iden, PartialEq, Eq, Debug, Serialize, Deserialize, Clone)]
pub enum Groups { pub enum Groups {
Table, Table,
GroupId, GroupId,
@ -76,6 +79,8 @@ pub enum Groups {
Uuid, Uuid,
} }
pub type GroupColumn = Groups;
#[derive(Iden)] #[derive(Iden)]
pub enum Memberships { pub enum Memberships {
Table, Table,

View File

@ -54,14 +54,10 @@ fn get_user_filter_expr(filter: UserRequestFilter) -> (RequiresGroup, Cond) {
), ),
Equality(s1, s2) => ( Equality(s1, s2) => (
RequiresGroup(false), RequiresGroup(false),
if s1 == Users::DisplayName.to_string() { if s1 == Users::UserId {
Expr::col((Users::Table, Users::DisplayName))
.eq(s2)
.into_condition()
} else if s1 == Users::UserId.to_string() {
panic!("User id should be wrapped") panic!("User id should be wrapped")
} else { } else {
Expr::expr(Expr::cust(&s1)).eq(s2).into_condition() Expr::col((Users::Table, s1)).eq(s2).into_condition()
}, },
), ),
MemberOf(group) => ( MemberOf(group) => (
@ -360,7 +356,9 @@ impl UserBackendHandler for SqlBackendHandler {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::domain::{handler::JpegPhoto, sql_backend_handler::tests::*}; use crate::domain::{
handler::JpegPhoto, sql_backend_handler::tests::*, sql_tables::UserColumn,
};
#[tokio::test] #[tokio::test]
async fn test_list_users_no_filter() { async fn test_list_users_no_filter() {
@ -386,7 +384,7 @@ mod tests {
let users = get_user_names( let users = get_user_names(
&fixture.handler, &fixture.handler,
Some(UserRequestFilter::Equality( Some(UserRequestFilter::Equality(
"display_name".to_string(), UserColumn::DisplayName,
"display bob".to_string(), "display bob".to_string(),
)), )),
) )
@ -400,7 +398,7 @@ mod tests {
let users = get_user_names( let users = get_user_names(
&fixture.handler, &fixture.handler,
Some(UserRequestFilter::Equality( Some(UserRequestFilter::Equality(
"first_name".to_string(), UserColumn::FirstName,
"first bob".to_string(), "first bob".to_string(),
)), )),
) )
@ -432,6 +430,20 @@ mod tests {
assert_eq!(users, vec!["bob", "patrick"]); assert_eq!(users, vec!["bob", "patrick"]);
} }
#[tokio::test]
async fn test_list_users_member_of_and_uuid() {
let fixture = TestFixture::new().await;
let users = get_user_names(
&fixture.handler,
Some(UserRequestFilter::Or(vec![
UserRequestFilter::MemberOf("Best Group".to_string()),
UserRequestFilter::Equality(UserColumn::Uuid, "abc".to_string()),
])),
)
.await;
assert_eq!(users, vec!["bob", "patrick"]);
}
#[tokio::test] #[tokio::test]
async fn test_list_users_member_of_id() { async fn test_list_users_member_of_id() {
let fixture = TestFixture::new().await; let fixture = TestFixture::new().await;
@ -450,7 +462,7 @@ mod tests {
get_user_names( get_user_names(
&fixture.handler, &fixture.handler,
Some(UserRequestFilter::Equality( Some(UserRequestFilter::Equality(
"user_id".to_string(), UserColumn::UserId,
"first bob".to_string(), "first bob".to_string(),
)), )),
) )

View File

@ -1,4 +1,8 @@
use crate::domain::handler::{BackendHandler, GroupDetails, GroupId, UserId}; use crate::domain::{
handler::{BackendHandler, GroupDetails, GroupId, UserId},
ldap::utils::map_user_field,
sql_tables::UserColumn,
};
use juniper::{graphql_object, FieldResult, GraphQLInputObject}; use juniper::{graphql_object, FieldResult, GraphQLInputObject};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tracing::{debug, debug_span, Instrument}; use tracing::{debug, debug_span, Instrument};
@ -50,10 +54,14 @@ impl TryInto<DomainRequestFilter> for RequestFilter {
return Err("Multiple fields specified in request filter".to_string()); return Err("Multiple fields specified in request filter".to_string());
} }
if let Some(e) = self.eq { if let Some(e) = self.eq {
if e.field.to_lowercase() == "uid" { if let Some(column) = map_user_field(&e.field) {
return Ok(DomainRequestFilter::UserId(UserId::new(&e.value))); if column == UserColumn::UserId {
return Ok(DomainRequestFilter::UserId(UserId::new(&e.value)));
}
return Ok(DomainRequestFilter::Equality(column, e.value));
} else {
return Err(format!("Unknown request filter: {}", &e.field));
} }
return Ok(DomainRequestFilter::Equality(e.field, e.value));
} }
if let Some(c) = self.any { if let Some(c) = self.any {
return Ok(DomainRequestFilter::Or( return Ok(DomainRequestFilter::Or(
@ -451,11 +459,8 @@ mod tests {
mock.expect_list_users() mock.expect_list_users()
.with( .with(
eq(Some(UserRequestFilter::Or(vec![ eq(Some(UserRequestFilter::Or(vec![
UserRequestFilter::Equality("id".to_string(), "bob".to_string()), UserRequestFilter::UserId(UserId::new("bob")),
UserRequestFilter::Equality( UserRequestFilter::Equality(UserColumn::Email, "robert@bobbers.on".to_string()),
"email".to_string(),
"robert@bobbers.on".to_string(),
),
]))), ]))),
eq(false), eq(false),
) )

View File

@ -447,7 +447,7 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
mod tests { mod tests {
use super::*; use super::*;
use crate::{ use crate::{
domain::{error::Result, handler::*, opaque_handler::*}, domain::{error::Result, handler::*, opaque_handler::*, sql_tables::UserColumn},
uuid, uuid,
}; };
use async_trait::async_trait; use async_trait::async_trait;
@ -1370,7 +1370,7 @@ mod tests {
.with( .with(
eq(Some(UserRequestFilter::And(vec![UserRequestFilter::Or( eq(Some(UserRequestFilter::And(vec![UserRequestFilter::Or(
vec![UserRequestFilter::Not(Box::new( vec![UserRequestFilter::Not(Box::new(
UserRequestFilter::Equality("first_name".to_string(), "bob".to_string()), UserRequestFilter::Equality(UserColumn::FirstName, "bob".to_string()),
))], ))],
)]))), )]))),
eq(false), eq(false),