diff --git a/server/src/domain/handler.rs b/server/src/domain/handler.rs index d7a5b68..b394494 100644 --- a/server/src/domain/handler.rs +++ b/server/src/domain/handler.rs @@ -1,4 +1,4 @@ -use super::error::*; +use super::{error::*, sql_tables::UserColumn}; use async_trait::async_trait; use serde::{Deserialize, Serialize}; use std::collections::HashSet; @@ -201,7 +201,7 @@ pub enum UserRequestFilter { Or(Vec), Not(Box), UserId(UserId), - Equality(String, String), + Equality(UserColumn, String), // Check if a user belongs to a group identified by name. MemberOf(String), // Same, by id. diff --git a/server/src/domain/ldap/group.rs b/server/src/domain/ldap/group.rs index 23198e6..e41bd7f 100644 --- a/server/src/domain/ldap/group.rs +++ b/server/src/domain/ldap/group.rs @@ -6,11 +6,14 @@ use tracing::{debug, info, instrument, warn}; use crate::domain::{ handler::{BackendHandler, Group, GroupRequestFilter, UserId, Uuid}, ldap::error::LdapError, + sql_tables::GroupColumn, }; use super::{ error::LdapResult, - utils::{expand_attribute_wildcards, get_user_id_from_distinguished_name, map_field, LdapInfo}, + utils::{ + expand_attribute_wildcards, get_user_id_from_distinguished_name, map_group_field, LdapInfo, + }, }; fn get_group_attribute( @@ -123,11 +126,11 @@ fn convert_group_filter( vec![], )))), }, - _ => match map_field(field) { - Some("display_name") | Some("user_id") => { + _ => match map_group_field(field) { + Some(GroupColumn::DisplayName) => { Ok(GroupRequestFilter::DisplayName(value.to_string())) } - Some("uuid") => Ok(GroupRequestFilter::Uuid( + Some(GroupColumn::Uuid) => Ok(GroupRequestFilter::Uuid( Uuid::try_from(value.as_str()).map_err(|e| LdapError { code: LdapResultCode::InappropriateMatching, message: format!("Invalid UUID: {:#}", e), diff --git a/server/src/domain/ldap/user.rs b/server/src/domain/ldap/user.rs index 49e1c56..5419600 100644 --- a/server/src/domain/ldap/user.rs +++ b/server/src/domain/ldap/user.rs @@ -6,11 +6,12 @@ use tracing::{debug, info, instrument, warn}; use crate::domain::{ handler::{BackendHandler, GroupDetails, User, UserId, UserRequestFilter}, ldap::{error::LdapError, utils::expand_attribute_wildcards}, + sql_tables::UserColumn, }; use super::{ error::LdapResult, - utils::{get_group_id_from_distinguished_name, map_field, LdapInfo}, + utils::{get_group_id_from_distinguished_name, map_user_field, LdapInfo}, }; fn get_user_attribute( @@ -142,17 +143,9 @@ fn convert_user_filter(ldap_info: &LdapInfo, filter: &LdapFilter) -> LdapResult< vec![], )))), }, - _ => match map_field(field) { - Some(field) => { - if field == "user_id" { - Ok(UserRequestFilter::UserId(UserId::new(value))) - } else { - Ok(UserRequestFilter::Equality( - field.to_string(), - value.clone(), - )) - } - } + _ => match map_user_field(field) { + Some(UserColumn::UserId) => Ok(UserRequestFilter::UserId(UserId::new(value))), + Some(field) => Ok(UserRequestFilter::Equality(field, value.clone())), None => { if !ldap_info.ignored_user_attributes.contains(field) { warn!( diff --git a/server/src/domain/ldap/utils.rs b/server/src/domain/ldap/utils.rs index 62db574..05fefb4 100644 --- a/server/src/domain/ldap/utils.rs +++ b/server/src/domain/ldap/utils.rs @@ -2,7 +2,10 @@ use itertools::Itertools; use ldap3_proto::LdapResultCode; use tracing::{debug, instrument, warn}; -use crate::domain::handler::UserId; +use crate::domain::{ + handler::UserId, + sql_tables::{GroupColumn, UserColumn}, +}; use super::error::{LdapError, LdapResult}; @@ -134,17 +137,31 @@ pub fn is_subtree(subtree: &[(String, String)], base_tree: &[(String, String)]) true } -pub fn map_field(field: &str) -> Option<&'static str> { +pub fn map_user_field(field: &str) -> Option { assert!(field == field.to_ascii_lowercase()); Some(match field { - "uid" => "user_id", - "mail" => "email", - "cn" | "displayname" => "display_name", - "givenname" => "first_name", - "sn" => "last_name", - "avatar" => "avatar", - "creationdate" | "createtimestamp" | "modifytimestamp" => "creation_date", - "entryuuid" => "uuid", + "uid" | "user_id" | "id" => UserColumn::UserId, + "mail" | "email" => UserColumn::Email, + "cn" | "displayname" | "display_name" => UserColumn::DisplayName, + "givenname" | "first_name" => UserColumn::FirstName, + "sn" | "last_name" => UserColumn::LastName, + "avatar" => UserColumn::Avatar, + "creationdate" | "createtimestamp" | "modifytimestamp" | "creation_date" => { + UserColumn::CreationDate + } + "entryuuid" | "uuid" => UserColumn::Uuid, + _ => return None, + }) +} + +pub fn map_group_field(field: &str) -> Option { + assert!(field == field.to_ascii_lowercase()); + Some(match field { + "cn" | "displayname" | "uid" | "display_name" => GroupColumn::DisplayName, + "creationdate" | "createtimestamp" | "modifytimestamp" | "creation_date" => { + GroupColumn::CreationDate + } + "entryuuid" | "uuid" => GroupColumn::Uuid, _ => return None, }) } diff --git a/server/src/domain/sql_tables.rs b/server/src/domain/sql_tables.rs index 3ff1d0c..b409a1b 100644 --- a/server/src/domain/sql_tables.rs +++ b/server/src/domain/sql_tables.rs @@ -3,6 +3,7 @@ use super::{ sql_migrations::{get_schema_version, migrate_from_version, upgrade_to_v1}, }; use sea_query::*; +use serde::{Deserialize, Serialize}; pub use super::sql_migrations::create_group; @@ -51,7 +52,7 @@ impl From for Value { } } -#[derive(Iden)] +#[derive(Iden, PartialEq, Eq, Debug, Serialize, Deserialize, Clone)] pub enum Users { Table, UserId, @@ -67,7 +68,9 @@ pub enum Users { Uuid, } -#[derive(Iden)] +pub type UserColumn = Users; + +#[derive(Iden, PartialEq, Eq, Debug, Serialize, Deserialize, Clone)] pub enum Groups { Table, GroupId, @@ -76,6 +79,8 @@ pub enum Groups { Uuid, } +pub type GroupColumn = Groups; + #[derive(Iden)] pub enum Memberships { Table, diff --git a/server/src/domain/sql_user_backend_handler.rs b/server/src/domain/sql_user_backend_handler.rs index 92e4fa3..ef3298f 100644 --- a/server/src/domain/sql_user_backend_handler.rs +++ b/server/src/domain/sql_user_backend_handler.rs @@ -54,14 +54,10 @@ fn get_user_filter_expr(filter: UserRequestFilter) -> (RequiresGroup, Cond) { ), Equality(s1, s2) => ( RequiresGroup(false), - if s1 == Users::DisplayName.to_string() { - Expr::col((Users::Table, Users::DisplayName)) - .eq(s2) - .into_condition() - } else if s1 == Users::UserId.to_string() { + if s1 == Users::UserId { panic!("User id should be wrapped") } else { - Expr::expr(Expr::cust(&s1)).eq(s2).into_condition() + Expr::col((Users::Table, s1)).eq(s2).into_condition() }, ), MemberOf(group) => ( @@ -360,7 +356,9 @@ impl UserBackendHandler for SqlBackendHandler { #[cfg(test)] mod tests { use super::*; - use crate::domain::{handler::JpegPhoto, sql_backend_handler::tests::*}; + use crate::domain::{ + handler::JpegPhoto, sql_backend_handler::tests::*, sql_tables::UserColumn, + }; #[tokio::test] async fn test_list_users_no_filter() { @@ -386,7 +384,7 @@ mod tests { let users = get_user_names( &fixture.handler, Some(UserRequestFilter::Equality( - "display_name".to_string(), + UserColumn::DisplayName, "display bob".to_string(), )), ) @@ -400,7 +398,7 @@ mod tests { let users = get_user_names( &fixture.handler, Some(UserRequestFilter::Equality( - "first_name".to_string(), + UserColumn::FirstName, "first bob".to_string(), )), ) @@ -432,6 +430,20 @@ mod tests { assert_eq!(users, vec!["bob", "patrick"]); } + #[tokio::test] + async fn test_list_users_member_of_and_uuid() { + let fixture = TestFixture::new().await; + let users = get_user_names( + &fixture.handler, + Some(UserRequestFilter::Or(vec![ + UserRequestFilter::MemberOf("Best Group".to_string()), + UserRequestFilter::Equality(UserColumn::Uuid, "abc".to_string()), + ])), + ) + .await; + assert_eq!(users, vec!["bob", "patrick"]); + } + #[tokio::test] async fn test_list_users_member_of_id() { let fixture = TestFixture::new().await; @@ -450,7 +462,7 @@ mod tests { get_user_names( &fixture.handler, Some(UserRequestFilter::Equality( - "user_id".to_string(), + UserColumn::UserId, "first bob".to_string(), )), ) diff --git a/server/src/infra/graphql/query.rs b/server/src/infra/graphql/query.rs index 3955dbe..d8e221a 100644 --- a/server/src/infra/graphql/query.rs +++ b/server/src/infra/graphql/query.rs @@ -1,4 +1,8 @@ -use crate::domain::handler::{BackendHandler, GroupDetails, GroupId, UserId}; +use crate::domain::{ + handler::{BackendHandler, GroupDetails, GroupId, UserId}, + ldap::utils::map_user_field, + sql_tables::UserColumn, +}; use juniper::{graphql_object, FieldResult, GraphQLInputObject}; use serde::{Deserialize, Serialize}; use tracing::{debug, debug_span, Instrument}; @@ -50,10 +54,14 @@ impl TryInto for RequestFilter { return Err("Multiple fields specified in request filter".to_string()); } if let Some(e) = self.eq { - if e.field.to_lowercase() == "uid" { - return Ok(DomainRequestFilter::UserId(UserId::new(&e.value))); + if let Some(column) = map_user_field(&e.field) { + if column == UserColumn::UserId { + return Ok(DomainRequestFilter::UserId(UserId::new(&e.value))); + } + return Ok(DomainRequestFilter::Equality(column, e.value)); + } else { + return Err(format!("Unknown request filter: {}", &e.field)); } - return Ok(DomainRequestFilter::Equality(e.field, e.value)); } if let Some(c) = self.any { return Ok(DomainRequestFilter::Or( @@ -451,11 +459,8 @@ mod tests { mock.expect_list_users() .with( eq(Some(UserRequestFilter::Or(vec![ - UserRequestFilter::Equality("id".to_string(), "bob".to_string()), - UserRequestFilter::Equality( - "email".to_string(), - "robert@bobbers.on".to_string(), - ), + UserRequestFilter::UserId(UserId::new("bob")), + UserRequestFilter::Equality(UserColumn::Email, "robert@bobbers.on".to_string()), ]))), eq(false), ) diff --git a/server/src/infra/ldap_handler.rs b/server/src/infra/ldap_handler.rs index 4038fce..bd1b27b 100644 --- a/server/src/infra/ldap_handler.rs +++ b/server/src/infra/ldap_handler.rs @@ -447,7 +447,7 @@ impl LdapHandler