From c9997d4c17db165e89d9b8a86b32fc4642ca93c1 Mon Sep 17 00:00:00 2001 From: Valentin Tolmer Date: Fri, 17 Feb 2023 15:59:32 +0100 Subject: [PATCH] server: statically enforce access control --- server/src/domain/handler.rs | 34 +- server/src/domain/ldap/group.rs | 30 +- server/src/domain/ldap/user.rs | 20 +- server/src/domain/opaque_handler.rs | 2 +- server/src/domain/sql_backend_handler.rs | 5 +- .../src/domain/sql_group_backend_handler.rs | 9 +- server/src/domain/sql_user_backend_handler.rs | 12 +- server/src/infra/access_control.rs | 317 ++++++++++++++++++ server/src/infra/auth_service.rs | 112 ++----- server/src/infra/graphql/api.rs | 68 +++- server/src/infra/graphql/mutation.rs | 113 +++---- server/src/infra/graphql/query.rs | 107 +++--- server/src/infra/ldap_handler.rs | 185 +++++----- server/src/infra/ldap_server.rs | 9 +- server/src/infra/mod.rs | 1 + server/src/infra/tcp_backend_handler.rs | 12 +- server/src/infra/tcp_server.rs | 30 +- server/src/main.rs | 5 +- 18 files changed, 712 insertions(+), 359 deletions(-) create mode 100644 server/src/infra/access_control.rs diff --git a/server/src/domain/handler.rs b/server/src/domain/handler.rs index 1b2cb4f..7ba11bb 100644 --- a/server/src/domain/handler.rs +++ b/server/src/domain/handler.rs @@ -122,13 +122,17 @@ pub struct UpdateGroupRequest { } #[async_trait] -pub trait LoginHandler: Clone + Send { +pub trait LoginHandler: Send + Sync { async fn bind(&self, request: BindRequest) -> Result<()>; } #[async_trait] -pub trait GroupBackendHandler { +pub trait GroupListerBackendHandler { async fn list_groups(&self, filters: Option) -> Result>; +} + +#[async_trait] +pub trait GroupBackendHandler { async fn get_group_details(&self, group_id: GroupId) -> Result; async fn update_group(&self, request: UpdateGroupRequest) -> Result<()>; async fn create_group(&self, group_name: &str) -> Result; @@ -136,12 +140,16 @@ pub trait GroupBackendHandler { } #[async_trait] -pub trait UserBackendHandler { +pub trait UserListerBackendHandler { async fn list_users( &self, filters: Option, get_groups: bool, ) -> Result>; +} + +#[async_trait] +pub trait UserBackendHandler { async fn get_user_details(&self, user_id: &UserId) -> Result; async fn create_user(&self, request: CreateUserRequest) -> Result<()>; async fn update_user(&self, request: UpdateUserRequest) -> Result<()>; @@ -152,7 +160,15 @@ pub trait UserBackendHandler { } #[async_trait] -pub trait BackendHandler: Clone + Send + GroupBackendHandler + UserBackendHandler {} +pub trait BackendHandler: + Send + + Sync + + GroupBackendHandler + + UserBackendHandler + + UserListerBackendHandler + + GroupListerBackendHandler +{ +} #[cfg(test)] mockall::mock! { @@ -161,16 +177,22 @@ mockall::mock! { fn clone(&self) -> Self; } #[async_trait] - impl GroupBackendHandler for TestBackendHandler { + impl GroupListerBackendHandler for TestBackendHandler { async fn list_groups(&self, filters: Option) -> Result>; + } + #[async_trait] + impl GroupBackendHandler for TestBackendHandler { async fn get_group_details(&self, group_id: GroupId) -> Result; async fn update_group(&self, request: UpdateGroupRequest) -> Result<()>; async fn create_group(&self, group_name: &str) -> Result; async fn delete_group(&self, group_id: GroupId) -> Result<()>; } #[async_trait] - impl UserBackendHandler for TestBackendHandler { + impl UserListerBackendHandler for TestBackendHandler { async fn list_users(&self, filters: Option, get_groups: bool) -> Result>; + } + #[async_trait] + impl UserBackendHandler for TestBackendHandler { async fn get_user_details(&self, user_id: &UserId) -> Result; async fn create_user(&self, request: CreateUserRequest) -> Result<()>; async fn update_user(&self, request: UpdateUserRequest) -> Result<()>; diff --git a/server/src/domain/ldap/group.rs b/server/src/domain/ldap/group.rs index 00bde72..4f29b7d 100644 --- a/server/src/domain/ldap/group.rs +++ b/server/src/domain/ldap/group.rs @@ -1,10 +1,10 @@ use ldap3_proto::{ proto::LdapOp, LdapFilter, LdapPartialAttribute, LdapResultCode, LdapSearchResultEntry, }; -use tracing::{debug, info, instrument, warn}; +use tracing::{debug, instrument, warn}; use crate::domain::{ - handler::{BackendHandler, GroupRequestFilter}, + handler::{GroupListerBackendHandler, GroupRequestFilter}, ldap::error::LdapError, types::{Group, GroupColumn, UserId, Uuid}, }; @@ -21,7 +21,7 @@ pub fn get_group_attribute( group: &Group, base_dn_str: &str, attribute: &str, - user_filter: &Option<&UserId>, + user_filter: &Option, ignored_group_attributes: &[String], ) -> Option>> { let attribute = attribute.to_ascii_lowercase(); @@ -34,7 +34,7 @@ pub fn get_group_attribute( "member" | "uniquemember" => group .users .iter() - .filter(|u| user_filter.map(|f| *u == f).unwrap_or(true)) + .filter(|u| user_filter.as_ref().map(|f| *u == f).unwrap_or(true)) .map(|u| format!("uid={},ou=people,{}", u, base_dn_str).into_bytes()) .collect(), "1.1" => return None, @@ -81,7 +81,7 @@ fn make_ldap_search_group_result_entry( group: Group, base_dn_str: &str, attributes: &[String], - user_filter: &Option<&UserId>, + user_filter: &Option, ignored_group_attributes: &[String], ) -> LdapSearchResultEntry { let expanded_attributes = expand_group_attribute_wildcards(attributes); @@ -201,25 +201,17 @@ fn convert_group_filter( } #[instrument(skip_all, level = "debug")] -pub async fn get_groups_list( +pub async fn get_groups_list( ldap_info: &LdapInfo, ldap_filter: &LdapFilter, base: &str, - user_filter: &Option<&UserId>, - backend: &mut Backend, + backend: &Backend, ) -> LdapResult> { debug!(?ldap_filter); - let filter = convert_group_filter(ldap_info, ldap_filter)?; - let parsed_filters = match user_filter { - None => filter, - Some(u) => { - info!("Unprivileged search, limiting results"); - GroupRequestFilter::And(vec![filter, GroupRequestFilter::Member((*u).clone())]) - } - }; - debug!(?parsed_filters); + let filters = convert_group_filter(ldap_info, ldap_filter)?; + debug!(?filters); backend - .list_groups(Some(parsed_filters)) + .list_groups(Some(filters)) .await .map_err(|e| LdapError { code: LdapResultCode::Other, @@ -231,7 +223,7 @@ pub fn convert_groups_to_ldap_op<'a>( groups: Vec, attributes: &'a [String], ldap_info: &'a LdapInfo, - user_filter: &'a Option<&'a UserId>, + user_filter: &'a Option, ) -> impl Iterator + 'a { groups.into_iter().map(move |g| { LdapOp::SearchResultEntry(make_ldap_search_group_result_entry( diff --git a/server/src/domain/ldap/user.rs b/server/src/domain/ldap/user.rs index 603bb05..bad6764 100644 --- a/server/src/domain/ldap/user.rs +++ b/server/src/domain/ldap/user.rs @@ -2,10 +2,10 @@ use chrono::TimeZone; use ldap3_proto::{ proto::LdapOp, LdapFilter, LdapPartialAttribute, LdapResultCode, LdapSearchResultEntry, }; -use tracing::{debug, info, instrument, warn}; +use tracing::{debug, instrument, warn}; use crate::domain::{ - handler::{BackendHandler, UserRequestFilter}, + handler::{UserListerBackendHandler, UserRequestFilter}, ldap::{ error::LdapError, utils::{expand_attribute_wildcards, get_user_id_from_distinguished_name}, @@ -217,26 +217,18 @@ fn expand_user_attribute_wildcards(attributes: &[String]) -> Vec<&str> { } #[instrument(skip_all, level = "debug")] -pub async fn get_user_list( +pub async fn get_user_list( ldap_info: &LdapInfo, ldap_filter: &LdapFilter, request_groups: bool, base: &str, - user_filter: &Option<&UserId>, - backend: &mut Backend, + backend: &Backend, ) -> LdapResult> { debug!(?ldap_filter); let filters = convert_user_filter(ldap_info, ldap_filter)?; - let parsed_filters = match user_filter { - None => filters, - Some(u) => { - info!("Unprivileged search, limiting results"); - UserRequestFilter::And(vec![filters, UserRequestFilter::UserId((*u).clone())]) - } - }; - debug!(?parsed_filters); + debug!(?filters); backend - .list_users(Some(parsed_filters), request_groups) + .list_users(Some(filters), request_groups) .await .map_err(|e| LdapError { code: LdapResultCode::Other, diff --git a/server/src/domain/opaque_handler.rs b/server/src/domain/opaque_handler.rs index d5f71dd..13ed81a 100644 --- a/server/src/domain/opaque_handler.rs +++ b/server/src/domain/opaque_handler.rs @@ -4,7 +4,7 @@ use async_trait::async_trait; pub use lldap_auth::{login, registration}; #[async_trait] -pub trait OpaqueHandler: Clone + Send { +pub trait OpaqueHandler: Send + Sync { async fn login_start( &self, request: login::ClientLoginStartRequest, diff --git a/server/src/domain/sql_backend_handler.rs b/server/src/domain/sql_backend_handler.rs index 86181f7..2b0f757 100644 --- a/server/src/domain/sql_backend_handler.rs +++ b/server/src/domain/sql_backend_handler.rs @@ -1,4 +1,4 @@ -use super::{handler::BackendHandler, sql_tables::DbConnection}; +use crate::domain::{handler::BackendHandler, sql_tables::DbConnection}; use crate::infra::configuration::Configuration; use async_trait::async_trait; @@ -23,7 +23,8 @@ pub mod tests { use crate::{ domain::{ handler::{ - CreateUserRequest, GroupBackendHandler, UserBackendHandler, UserRequestFilter, + CreateUserRequest, GroupBackendHandler, UserBackendHandler, + UserListerBackendHandler, UserRequestFilter, }, sql_tables::init_table, types::{GroupId, UserId}, diff --git a/server/src/domain/sql_group_backend_handler.rs b/server/src/domain/sql_group_backend_handler.rs index afffb5c..6ab2cb3 100644 --- a/server/src/domain/sql_group_backend_handler.rs +++ b/server/src/domain/sql_group_backend_handler.rs @@ -1,6 +1,8 @@ use crate::domain::{ error::{DomainError, Result}, - handler::{GroupBackendHandler, GroupRequestFilter, UpdateGroupRequest}, + handler::{ + GroupBackendHandler, GroupListerBackendHandler, GroupRequestFilter, UpdateGroupRequest, + }, model::{self, GroupColumn, MembershipColumn}, sql_backend_handler::SqlBackendHandler, types::{Group, GroupDetails, GroupId, Uuid}, @@ -57,7 +59,7 @@ fn get_group_filter_expr(filter: GroupRequestFilter) -> Cond { } #[async_trait] -impl GroupBackendHandler for SqlBackendHandler { +impl GroupListerBackendHandler for SqlBackendHandler { #[instrument(skip_all, level = "debug", ret, err)] async fn list_groups(&self, filters: Option) -> Result> { debug!(?filters); @@ -94,7 +96,10 @@ impl GroupBackendHandler for SqlBackendHandler { }) .collect()) } +} +#[async_trait] +impl GroupBackendHandler for SqlBackendHandler { #[instrument(skip_all, level = "debug", ret, err)] async fn get_group_details(&self, group_id: GroupId) -> Result { debug!(?group_id); diff --git a/server/src/domain/sql_user_backend_handler.rs b/server/src/domain/sql_user_backend_handler.rs index 529c3e9..a53de3d 100644 --- a/server/src/domain/sql_user_backend_handler.rs +++ b/server/src/domain/sql_user_backend_handler.rs @@ -1,6 +1,9 @@ -use super::{ +use crate::domain::{ error::{DomainError, Result}, - handler::{CreateUserRequest, UpdateUserRequest, UserBackendHandler, UserRequestFilter}, + handler::{ + CreateUserRequest, UpdateUserRequest, UserBackendHandler, UserListerBackendHandler, + UserRequestFilter, + }, model::{self, GroupColumn, UserColumn}, sql_backend_handler::SqlBackendHandler, types::{GroupDetails, GroupId, User, UserAndGroups, UserId, Uuid}, @@ -70,7 +73,7 @@ fn to_value(opt_name: &Option) -> ActiveValue> { } #[async_trait] -impl UserBackendHandler for SqlBackendHandler { +impl UserListerBackendHandler for SqlBackendHandler { #[instrument(skip_all, level = "debug", ret, err)] async fn list_users( &self, @@ -135,7 +138,10 @@ impl UserBackendHandler for SqlBackendHandler { .collect()) } } +} +#[async_trait] +impl UserBackendHandler for SqlBackendHandler { #[instrument(skip_all, level = "debug", ret)] async fn get_user_details(&self, user_id: &UserId) -> Result { debug!(?user_id); diff --git a/server/src/infra/access_control.rs b/server/src/infra/access_control.rs new file mode 100644 index 0000000..e4196cc --- /dev/null +++ b/server/src/infra/access_control.rs @@ -0,0 +1,317 @@ +use std::collections::HashSet; + +use async_trait::async_trait; +use tracing::info; + +use crate::domain::{ + error::Result, + handler::{ + BackendHandler, CreateUserRequest, GroupListerBackendHandler, GroupRequestFilter, + UpdateGroupRequest, UpdateUserRequest, UserListerBackendHandler, UserRequestFilter, + }, + types::{Group, GroupDetails, GroupId, User, UserAndGroups, UserId}, +}; + +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +pub enum Permission { + Admin, + PasswordManager, + Readonly, + Regular, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ValidationResults { + pub user: UserId, + pub permission: Permission, +} + +impl ValidationResults { + #[cfg(test)] + pub fn admin() -> Self { + Self { + user: UserId::new("admin"), + permission: Permission::Admin, + } + } + + #[must_use] + pub fn is_admin(&self) -> bool { + self.permission == Permission::Admin + } + + #[must_use] + pub fn can_read_all(&self) -> bool { + self.permission == Permission::Admin + || self.permission == Permission::Readonly + || self.permission == Permission::PasswordManager + } + + #[must_use] + pub fn can_read(&self, user: &UserId) -> bool { + self.permission == Permission::Admin + || self.permission == Permission::PasswordManager + || self.permission == Permission::Readonly + || &self.user == user + } + + #[must_use] + pub fn can_change_password(&self, user: &UserId, user_is_admin: bool) -> bool { + self.permission == Permission::Admin + || (self.permission == Permission::PasswordManager && !user_is_admin) + || &self.user == user + } + + #[must_use] + pub fn can_write(&self, user: &UserId) -> bool { + self.permission == Permission::Admin || &self.user == user + } +} + +#[async_trait] +pub trait UserReadableBackendHandler { + async fn get_user_details(&self, user_id: &UserId) -> Result; + async fn get_user_groups(&self, user_id: &UserId) -> Result>; +} + +#[async_trait] +pub trait ReadonlyBackendHandler: UserReadableBackendHandler { + async fn list_users( + &self, + filters: Option, + get_groups: bool, + ) -> Result>; + async fn list_groups(&self, filters: Option) -> Result>; + async fn get_group_details(&self, group_id: GroupId) -> Result; +} + +#[async_trait] +pub trait UserWriteableBackendHandler: UserReadableBackendHandler { + async fn update_user(&self, request: UpdateUserRequest) -> Result<()>; +} + +#[async_trait] +pub trait AdminBackendHandler: + UserWriteableBackendHandler + ReadonlyBackendHandler + UserWriteableBackendHandler +{ + async fn create_user(&self, request: CreateUserRequest) -> Result<()>; + async fn delete_user(&self, user_id: &UserId) -> Result<()>; + async fn add_user_to_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()>; + async fn remove_user_from_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()>; + async fn update_group(&self, request: UpdateGroupRequest) -> Result<()>; + async fn create_group(&self, group_name: &str) -> Result; + async fn delete_group(&self, group_id: GroupId) -> Result<()>; +} + +#[async_trait] +impl UserReadableBackendHandler for Handler { + async fn get_user_details(&self, user_id: &UserId) -> Result { + self.get_user_details(user_id).await + } + async fn get_user_groups(&self, user_id: &UserId) -> Result> { + self.get_user_groups(user_id).await + } +} + +#[async_trait] +impl ReadonlyBackendHandler for Handler { + async fn list_users( + &self, + filters: Option, + get_groups: bool, + ) -> Result> { + self.list_users(filters, get_groups).await + } + async fn list_groups(&self, filters: Option) -> Result> { + self.list_groups(filters).await + } + async fn get_group_details(&self, group_id: GroupId) -> Result { + self.get_group_details(group_id).await + } +} + +#[async_trait] +impl UserWriteableBackendHandler for Handler { + async fn update_user(&self, request: UpdateUserRequest) -> Result<()> { + self.update_user(request).await + } +} +#[async_trait] +impl AdminBackendHandler for Handler { + async fn create_user(&self, request: CreateUserRequest) -> Result<()> { + self.create_user(request).await + } + async fn delete_user(&self, user_id: &UserId) -> Result<()> { + self.delete_user(user_id).await + } + async fn add_user_to_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()> { + self.add_user_to_group(user_id, group_id).await + } + async fn remove_user_from_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()> { + self.remove_user_from_group(user_id, group_id).await + } + async fn update_group(&self, request: UpdateGroupRequest) -> Result<()> { + self.update_group(request).await + } + async fn create_group(&self, group_name: &str) -> Result { + self.create_group(group_name).await + } + async fn delete_group(&self, group_id: GroupId) -> Result<()> { + self.delete_group(group_id).await + } +} + +pub struct AccessControlledBackendHandler { + handler: Handler, +} + +impl Clone for AccessControlledBackendHandler { + fn clone(&self) -> Self { + Self { + handler: self.handler.clone(), + } + } +} + +impl AccessControlledBackendHandler { + pub fn unsafe_get_handler(&self) -> &Handler { + &self.handler + } +} + +impl AccessControlledBackendHandler { + pub fn new(handler: Handler) -> Self { + Self { handler } + } + + pub fn get_admin_handler( + &self, + validation_result: &ValidationResults, + ) -> Option<&impl AdminBackendHandler> { + validation_result.is_admin().then_some(&self.handler) + } + + pub fn get_readonly_handler( + &self, + validation_result: &ValidationResults, + ) -> Option<&impl ReadonlyBackendHandler> { + validation_result.can_read_all().then_some(&self.handler) + } + + pub fn get_writeable_handler( + &self, + validation_result: &ValidationResults, + user_id: &UserId, + ) -> Option<&impl UserWriteableBackendHandler> { + validation_result + .can_write(user_id) + .then_some(&self.handler) + } + + pub fn get_readable_handler( + &self, + validation_result: &ValidationResults, + user_id: &UserId, + ) -> Option<&impl UserReadableBackendHandler> { + validation_result.can_read(user_id).then_some(&self.handler) + } + + pub fn get_user_restricted_lister_handler( + &self, + validation_result: &ValidationResults, + ) -> UserRestrictedListerBackendHandler<'_, Handler> { + UserRestrictedListerBackendHandler { + handler: &self.handler, + user_filter: if validation_result.can_read_all() { + None + } else { + info!("Unprivileged search, limiting results"); + Some(validation_result.user.clone()) + }, + } + } + + pub async fn get_permissions_for_user(&self, user_id: UserId) -> Result { + let user_groups = self.handler.get_user_groups(&user_id).await?; + Ok(self.get_permissions_from_groups(user_id, user_groups.iter().map(|g| &g.display_name))) + } + + pub fn get_permissions_from_groups<'a, Groups: Iterator + Clone + 'a>( + &self, + user_id: UserId, + groups: Groups, + ) -> ValidationResults { + let is_in_group = |name| groups.clone().any(|g| g == name); + ValidationResults { + user: user_id, + permission: if is_in_group("lldap_admin") { + Permission::Admin + } else if is_in_group("lldap_password_manager") { + Permission::PasswordManager + } else if is_in_group("lldap_strict_readonly") { + Permission::Readonly + } else { + Permission::Regular + }, + } + } +} + +pub struct UserRestrictedListerBackendHandler<'a, Handler> { + handler: &'a Handler, + pub user_filter: Option, +} + +#[async_trait] +impl<'a, Handler: UserListerBackendHandler + Sync> UserListerBackendHandler + for UserRestrictedListerBackendHandler<'a, Handler> +{ + async fn list_users( + &self, + filters: Option, + get_groups: bool, + ) -> Result> { + let user_filter = self + .user_filter + .as_ref() + .map(|u| UserRequestFilter::UserId(u.clone())); + let filters = match (filters, user_filter) { + (None, None) => None, + (None, u) => u, + (f, None) => f, + (Some(f), Some(u)) => Some(UserRequestFilter::And(vec![f, u])), + }; + self.handler.list_users(filters, get_groups).await + } +} + +#[async_trait] +impl<'a, Handler: GroupListerBackendHandler + Sync> GroupListerBackendHandler + for UserRestrictedListerBackendHandler<'a, Handler> +{ + async fn list_groups(&self, filters: Option) -> Result> { + let group_filter = self + .user_filter + .as_ref() + .map(|u| GroupRequestFilter::Member(u.clone())); + let filters = match (filters, group_filter) { + (None, None) => None, + (None, u) => u, + (f, None) => f, + (Some(f), Some(u)) => Some(GroupRequestFilter::And(vec![f, u])), + }; + self.handler.list_groups(filters).await + } +} + +#[async_trait] +pub trait UserAndGroupListerBackendHandler: + UserListerBackendHandler + GroupListerBackendHandler +{ +} + +#[async_trait] +impl<'a, Handler: GroupListerBackendHandler + UserListerBackendHandler + Sync> + UserAndGroupListerBackendHandler for UserRestrictedListerBackendHandler<'a, Handler> +{ +} diff --git a/server/src/infra/auth_service.rs b/server/src/infra/auth_service.rs index 5dc0cc1..4b350e0 100644 --- a/server/src/infra/auth_service.rs +++ b/server/src/infra/auth_service.rs @@ -30,6 +30,7 @@ use crate::{ types::{GroupDetails, UserColumn, UserId}, }, infra::{ + access_control::{ReadonlyBackendHandler, UserReadableBackendHandler, ValidationResults}, tcp_backend_handler::*, tcp_server::{error_to_http_response, AppState, TcpError, TcpResult}, }, @@ -87,11 +88,10 @@ async fn get_refresh( where Backend: TcpBackendHandler + BackendHandler + 'static, { - let backend_handler = &data.backend_handler; let jwt_key = &data.jwt_key; let (refresh_token_hash, user) = get_refresh_token(request)?; let found = data - .backend_handler + .get_tcp_handler() .check_token(refresh_token_hash, &user) .await?; if !found { @@ -99,7 +99,8 @@ where "Invalid refresh token".to_string(), ))); } - Ok(backend_handler + Ok(data + .get_readonly_handler() .get_user_groups(&user) .await .map(|groups| create_jwt(jwt_key, user.to_string(), groups)) @@ -145,7 +146,7 @@ where .get("user_id") .ok_or_else(|| TcpError::BadRequest("Missing user ID".to_string()))?; let user_results = data - .backend_handler + .get_readonly_handler() .list_users( Some(UserRequestFilter::Or(vec![ UserRequestFilter::UserId(UserId::new(user_string)), @@ -163,7 +164,7 @@ where } let user = &user_results[0].user; let token = match data - .backend_handler + .get_tcp_handler() .start_password_reset(&user.user_id) .await? { @@ -216,7 +217,7 @@ where .get("token") .ok_or_else(|| TcpError::BadRequest("Missing reset token".to_owned()))?; let user_id = data - .backend_handler + .get_tcp_handler() .get_user_id_for_password_reset_token(token) .await .map_err(|e| { @@ -224,7 +225,7 @@ where TcpError::NotFoundError("Wrong or expired reset token".to_owned()) })?; let _ = data - .backend_handler + .get_tcp_handler() .delete_password_reset_token(token) .await; let groups = HashSet::new(); @@ -266,10 +267,10 @@ where Backend: TcpBackendHandler + BackendHandler + 'static, { let (refresh_token_hash, user) = get_refresh_token(request)?; - data.backend_handler + data.get_tcp_handler() .delete_refresh_token(refresh_token_hash) .await?; - let new_blacklisted_jwts = data.backend_handler.blacklist_jwts(&user).await?; + let new_blacklisted_jwts = data.get_tcp_handler().blacklist_jwts(&user).await?; let mut jwt_blacklist = data.jwt_blacklist.write().unwrap(); for jwt in new_blacklisted_jwts { jwt_blacklist.insert(jwt); @@ -320,7 +321,7 @@ async fn opaque_login_start( where Backend: OpaqueHandler + 'static, { - data.backend_handler + data.get_opaque_handler() .login_start(request.into_inner()) .await .map(|res| ApiResult::Left(web::Json(res))) @@ -337,8 +338,8 @@ where { // The authentication was successful, we need to fetch the groups to create the JWT // token. - let groups = data.backend_handler.get_user_groups(name).await?; - let (refresh_token, max_age) = data.backend_handler.create_refresh_token(name).await?; + let groups = data.get_readonly_handler().get_user_groups(name).await?; + let (refresh_token, max_age) = data.get_tcp_handler().create_refresh_token(name).await?; let token = create_jwt(&data.jwt_key, name.to_string(), groups); let refresh_token_plus_name = refresh_token + "+" + name.as_str(); @@ -374,7 +375,7 @@ where Backend: TcpBackendHandler + BackendHandler + OpaqueHandler + 'static, { let name = data - .backend_handler + .get_opaque_handler() .login_finish(request.into_inner()) .await?; get_login_successful_response(&data, &name).await @@ -405,7 +406,7 @@ where name: user_id.clone(), password: request.password.clone(), }; - data.backend_handler.bind(bind_request).await?; + data.get_login_handler().bind(bind_request).await?; get_login_successful_response(&data, &user_id).await } @@ -431,7 +432,7 @@ where { let name = request.name.clone(); debug!(%name); - data.backend_handler.bind(request.into_inner()).await?; + data.get_login_handler().bind(request.into_inner()).await?; get_login_successful_response(&data, &name).await } @@ -474,7 +475,7 @@ where .into_inner(); let user_id = UserId::new(®istration_start_request.username); let user_is_admin = data - .backend_handler + .get_readonly_handler() .get_user_groups(&user_id) .await? .iter() @@ -485,7 +486,7 @@ where )); } Ok(data - .backend_handler + .get_opaque_handler() .registration_start(registration_start_request) .await?) } @@ -512,7 +513,7 @@ async fn opaque_register_finish( where Backend: TcpBackendHandler + BackendHandler + OpaqueHandler + 'static, { - data.backend_handler + data.get_opaque_handler() .registration_finish(request.into_inner()) .await?; Ok(HttpResponse::Ok().finish()) @@ -586,64 +587,8 @@ where } } -#[derive(Clone, Copy, PartialEq, Eq, Debug)] -pub enum Permission { - Admin, - PasswordManager, - Readonly, - Regular, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct ValidationResults { - pub user: UserId, - pub permission: Permission, -} - -impl ValidationResults { - #[cfg(test)] - pub fn admin() -> Self { - Self { - user: UserId::new("admin"), - permission: Permission::Admin, - } - } - - #[must_use] - pub fn is_admin(&self) -> bool { - self.permission == Permission::Admin - } - - #[must_use] - pub fn is_admin_or_readonly(&self) -> bool { - self.permission == Permission::Admin - || self.permission == Permission::Readonly - || self.permission == Permission::PasswordManager - } - - #[must_use] - pub fn can_read(&self, user: &UserId) -> bool { - self.permission == Permission::Admin - || self.permission == Permission::PasswordManager - || self.permission == Permission::Readonly - || &self.user == user - } - - #[must_use] - pub fn can_change_password(&self, user: &UserId, user_is_admin: bool) -> bool { - self.permission == Permission::Admin - || (self.permission == Permission::PasswordManager && !user_is_admin) - || &self.user == user - } - - #[must_use] - pub fn can_write(&self, user: &UserId) -> bool { - self.permission == Permission::Admin || &self.user == user - } -} - #[instrument(skip_all, level = "debug", err, ret)] -pub(crate) fn check_if_token_is_valid( +pub(crate) fn check_if_token_is_valid( state: &AppState, token_str: &str, ) -> Result { @@ -666,19 +611,10 @@ pub(crate) fn check_if_token_is_valid( if state.jwt_blacklist.read().unwrap().contains(&jwt_hash) { return Err(ErrorUnauthorized("JWT was logged out")); } - let is_in_group = |name| token.claims().groups.contains(name); - Ok(ValidationResults { - user: UserId::new(&token.claims().user), - permission: if is_in_group("lldap_admin") { - Permission::Admin - } else if is_in_group("lldap_password_manager") { - Permission::PasswordManager - } else if is_in_group("lldap_strict_readonly") { - Permission::Readonly - } else { - Permission::Regular - }, - }) + Ok(state.backend_handler.get_permissions_from_groups( + UserId::new(&token.claims().user), + token.claims().groups.iter(), + )) } pub fn configure_server(cfg: &mut web::ServiceConfig, enable_password_reset: bool) diff --git a/server/src/infra/graphql/api.rs b/server/src/infra/graphql/api.rs index b228993..eda2c25 100644 --- a/server/src/infra/graphql/api.rs +++ b/server/src/infra/graphql/api.rs @@ -1,29 +1,77 @@ use crate::{ - domain::handler::BackendHandler, + domain::{handler::BackendHandler, types::UserId}, infra::{ - auth_service::{check_if_token_is_valid, ValidationResults}, + access_control::{ + AccessControlledBackendHandler, AdminBackendHandler, ReadonlyBackendHandler, + UserReadableBackendHandler, UserWriteableBackendHandler, ValidationResults, + }, + auth_service::check_if_token_is_valid, cli::ExportGraphQLSchemaOpts, + graphql::{mutation::Mutation, query::Query}, tcp_server::AppState, }, }; use actix_web::{web, Error, HttpResponse}; use actix_web_httpauth::extractors::bearer::BearerAuth; -use juniper::{EmptySubscription, RootNode}; +use juniper::{EmptySubscription, FieldError, RootNode}; use juniper_actix::{graphiql_handler, graphql_handler, playground_handler}; - -use super::{mutation::Mutation, query::Query}; +use tracing::debug; pub struct Context { - pub handler: Box, + pub handler: AccessControlledBackendHandler, pub validation_result: ValidationResults, } +pub fn field_error_callback<'a>( + span: &'a tracing::Span, + error_message: &'a str, +) -> impl 'a + FnOnce() -> FieldError { + move || { + span.in_scope(|| debug!("Unauthorized")); + FieldError::from(error_message) + } +} + +impl Context { + #[cfg(test)] + pub fn new_for_tests(handler: Handler, validation_result: ValidationResults) -> Self { + Self { + handler: AccessControlledBackendHandler::new(handler), + validation_result, + } + } + + pub fn get_admin_handler(&self) -> Option<&impl AdminBackendHandler> { + self.handler.get_admin_handler(&self.validation_result) + } + + pub fn get_readonly_handler(&self) -> Option<&impl ReadonlyBackendHandler> { + self.handler.get_readonly_handler(&self.validation_result) + } + + pub fn get_writeable_handler( + &self, + user_id: &UserId, + ) -> Option<&impl UserWriteableBackendHandler> { + self.handler + .get_writeable_handler(&self.validation_result, user_id) + } + + pub fn get_readable_handler( + &self, + user_id: &UserId, + ) -> Option<&impl UserReadableBackendHandler> { + self.handler + .get_readable_handler(&self.validation_result, user_id) + } +} + impl juniper::Context for Context {} type Schema = RootNode<'static, Query, Mutation, EmptySubscription>>; -fn schema() -> Schema { +fn schema() -> Schema { Schema::new( Query::::new(), Mutation::::new(), @@ -58,7 +106,7 @@ async fn playground_route() -> Result { playground_handler("/api/graphql", None).await } -async fn graphql_route( +async fn graphql_route( req: actix_web::HttpRequest, mut payload: actix_web::web::Payload, data: web::Data>, @@ -67,7 +115,7 @@ async fn graphql_route( let bearer = BearerAuth::from_request(&req, &mut payload.0).await?; let validation_result = check_if_token_is_valid(&data, bearer.token())?; let context = Context:: { - handler: Box::new(data.backend_handler.clone()), + handler: data.backend_handler.clone(), validation_result, }; graphql_handler(&schema(), &context, req, payload).await @@ -75,7 +123,7 @@ async fn graphql_route( pub fn configure_endpoint(cfg: &mut web::ServiceConfig) where - Backend: BackendHandler + Sync + 'static, + Backend: BackendHandler + Clone + 'static, { let json_config = web::JsonConfig::default() .limit(4096) diff --git a/server/src/infra/graphql/mutation.rs b/server/src/infra/graphql/mutation.rs index 9a25013..5134334 100644 --- a/server/src/infra/graphql/mutation.rs +++ b/server/src/infra/graphql/mutation.rs @@ -1,6 +1,15 @@ -use crate::domain::{ - handler::{BackendHandler, CreateUserRequest, UpdateGroupRequest, UpdateUserRequest}, - types::{GroupId, JpegPhoto, UserId}, +use crate::{ + domain::{ + handler::{BackendHandler, CreateUserRequest, UpdateGroupRequest, UpdateUserRequest}, + types::{GroupId, JpegPhoto, UserId}, + }, + infra::{ + access_control::{ + AdminBackendHandler, ReadonlyBackendHandler, UserReadableBackendHandler, + UserWriteableBackendHandler, + }, + graphql::api::field_error_callback, + }, }; use anyhow::Context as AnyhowContext; use juniper::{graphql_object, FieldResult, GraphQLInputObject, GraphQLObject}; @@ -65,19 +74,18 @@ impl Success { } #[graphql_object(context = Context)] -impl Mutation { +impl Mutation { async fn create_user( context: &Context, user: CreateUserInput, ) -> FieldResult> { let span = debug_span!("[GraphQL mutation] create_user"); span.in_scope(|| { - debug!(?user.id); + debug!("{:?}", &user.id); }); - if !context.validation_result.is_admin() { - span.in_scope(|| debug!("Unauthorized")); - return Err("Unauthorized user creation".into()); - } + let handler = context + .get_admin_handler() + .ok_or_else(field_error_callback(&span, "Unauthorized user creation"))?; let user_id = UserId::new(&user.id); let avatar = user .avatar @@ -87,8 +95,7 @@ impl Mutation { .map(JpegPhoto::try_from) .transpose() .context("Provided image is not a valid JPEG")?; - context - .handler + handler .create_user(CreateUserRequest { user_id: user_id.clone(), email: user.email, @@ -99,8 +106,7 @@ impl Mutation { }) .instrument(span.clone()) .await?; - Ok(context - .handler + Ok(handler .get_user_details(&user_id) .instrument(span) .await @@ -115,13 +121,11 @@ impl Mutation { span.in_scope(|| { debug!(?name); }); - if !context.validation_result.is_admin() { - span.in_scope(|| debug!("Unauthorized")); - return Err("Unauthorized group creation".into()); - } - let group_id = context.handler.create_group(&name).await?; - Ok(context - .handler + let handler = context + .get_admin_handler() + .ok_or_else(field_error_callback(&span, "Unauthorized group creation"))?; + let group_id = handler.create_group(&name).await?; + Ok(handler .get_group_details(group_id) .instrument(span) .await @@ -137,10 +141,9 @@ impl Mutation { debug!(?user.id); }); let user_id = UserId::new(&user.id); - if !context.validation_result.can_write(&user_id) { - span.in_scope(|| debug!("Unauthorized")); - return Err("Unauthorized user update".into()); - } + let handler = context + .get_writeable_handler(&user_id) + .ok_or_else(field_error_callback(&span, "Unauthorized user update"))?; let avatar = user .avatar .map(base64::decode) @@ -149,8 +152,7 @@ impl Mutation { .map(JpegPhoto::try_from) .transpose() .context("Provided image is not a valid JPEG")?; - context - .handler + handler .update_user(UpdateUserRequest { user_id, email: user.email, @@ -172,16 +174,14 @@ impl Mutation { span.in_scope(|| { debug!(?group.id); }); - if !context.validation_result.is_admin() { - span.in_scope(|| debug!("Unauthorized")); - return Err("Unauthorized group update".into()); - } + let handler = context + .get_admin_handler() + .ok_or_else(field_error_callback(&span, "Unauthorized group update"))?; if group.id == 1 { span.in_scope(|| debug!("Cannot change admin group details")); return Err("Cannot change admin group details".into()); } - context - .handler + handler .update_group(UpdateGroupRequest { group_id: GroupId(group.id), display_name: group.display_name, @@ -200,12 +200,13 @@ impl Mutation { span.in_scope(|| { debug!(?user_id, ?group_id); }); - if !context.validation_result.is_admin() { - span.in_scope(|| debug!("Unauthorized")); - return Err("Unauthorized group membership modification".into()); - } - context - .handler + let handler = context + .get_admin_handler() + .ok_or_else(field_error_callback( + &span, + "Unauthorized group membership modification", + ))?; + handler .add_user_to_group(&UserId::new(&user_id), GroupId(group_id)) .instrument(span) .await?; @@ -221,17 +222,18 @@ impl Mutation { span.in_scope(|| { debug!(?user_id, ?group_id); }); - if !context.validation_result.is_admin() { - span.in_scope(|| debug!("Unauthorized")); - return Err("Unauthorized group membership modification".into()); - } + let handler = context + .get_admin_handler() + .ok_or_else(field_error_callback( + &span, + "Unauthorized group membership modification", + ))?; let user_id = UserId::new(&user_id); if context.validation_result.user == user_id && group_id == 1 { span.in_scope(|| debug!("Cannot remove admin rights for current user")); return Err("Cannot remove admin rights for current user".into()); } - context - .handler + handler .remove_user_from_group(&user_id, GroupId(group_id)) .instrument(span) .await?; @@ -244,19 +246,14 @@ impl Mutation { debug!(?user_id); }); let user_id = UserId::new(&user_id); - if !context.validation_result.is_admin() { - span.in_scope(|| debug!("Unauthorized")); - return Err("Unauthorized user deletion".into()); - } + let handler = context + .get_admin_handler() + .ok_or_else(field_error_callback(&span, "Unauthorized user deletion"))?; if context.validation_result.user == user_id { span.in_scope(|| debug!("Cannot delete current user")); return Err("Cannot delete current user".into()); } - context - .handler - .delete_user(&user_id) - .instrument(span) - .await?; + handler.delete_user(&user_id).instrument(span).await?; Ok(Success::new()) } @@ -265,16 +262,14 @@ impl Mutation { span.in_scope(|| { debug!(?group_id); }); - if !context.validation_result.is_admin() { - span.in_scope(|| debug!("Unauthorized")); - return Err("Unauthorized group deletion".into()); - } + let handler = context + .get_admin_handler() + .ok_or_else(field_error_callback(&span, "Unauthorized group deletion"))?; if group_id == 1 { span.in_scope(|| debug!("Cannot delete admin group")); return Err("Cannot delete admin group".into()); } - context - .handler + handler .delete_group(GroupId(group_id)) .instrument(span) .await?; diff --git a/server/src/infra/graphql/query.rs b/server/src/infra/graphql/query.rs index 7c97050..6422844 100644 --- a/server/src/infra/graphql/query.rs +++ b/server/src/infra/graphql/query.rs @@ -1,7 +1,13 @@ -use crate::domain::{ - handler::BackendHandler, - ldap::utils::map_user_field, - types::{GroupDetails, GroupId, UserColumn, UserId}, +use crate::{ + domain::{ + handler::BackendHandler, + ldap::utils::map_user_field, + types::{GroupDetails, GroupId, UserColumn, UserId}, + }, + infra::{ + access_control::{ReadonlyBackendHandler, UserReadableBackendHandler}, + graphql::api::field_error_callback, + }, }; use chrono::TimeZone; use juniper::{graphql_object, FieldResult, GraphQLInputObject}; @@ -112,7 +118,7 @@ impl Query { } #[graphql_object(context = Context)] -impl Query { +impl Query { fn api_version() -> &'static str { "1.0" } @@ -123,12 +129,13 @@ impl Query { debug!(?user_id); }); let user_id = UserId::new(&user_id); - if !context.validation_result.can_read(&user_id) { - span.in_scope(|| debug!("Unauthorized")); - return Err("Unauthorized access to user data".into()); - } - Ok(context - .handler + let handler = context + .get_readable_handler(&user_id) + .ok_or_else(field_error_callback( + &span, + "Unauthorized access to user data", + ))?; + Ok(handler .get_user_details(&user_id) .instrument(span) .await @@ -143,12 +150,13 @@ impl Query { span.in_scope(|| { debug!(?filters); }); - if !context.validation_result.is_admin_or_readonly() { - span.in_scope(|| debug!("Unauthorized")); - return Err("Unauthorized access to user list".into()); - } - Ok(context - .handler + let handler = context + .get_readonly_handler() + .ok_or_else(field_error_callback( + &span, + "Unauthorized access to user list", + ))?; + Ok(handler .list_users(filters.map(TryInto::try_into).transpose()?, false) .instrument(span) .await @@ -157,12 +165,13 @@ impl Query { async fn groups(context: &Context) -> FieldResult>> { let span = debug_span!("[GraphQL query] groups"); - if !context.validation_result.is_admin_or_readonly() { - span.in_scope(|| debug!("Unauthorized")); - return Err("Unauthorized access to group list".into()); - } - Ok(context - .handler + let handler = context + .get_readonly_handler() + .ok_or_else(field_error_callback( + &span, + "Unauthorized access to group list", + ))?; + Ok(handler .list_groups(None) .instrument(span) .await @@ -174,12 +183,13 @@ impl Query { span.in_scope(|| { debug!(?group_id); }); - if !context.validation_result.is_admin_or_readonly() { - span.in_scope(|| debug!("Unauthorized")); - return Err("Unauthorized access to group data".into()); - } - Ok(context - .handler + let handler = context + .get_readonly_handler() + .ok_or_else(field_error_callback( + &span, + "Unauthorized access to group data", + ))?; + Ok(handler .get_group_details(GroupId(group_id)) .instrument(span) .await @@ -205,7 +215,7 @@ impl Default for User { } #[graphql_object(context = Context)] -impl User { +impl User { fn id(&self) -> &str { self.user.user_id.as_str() } @@ -244,8 +254,10 @@ impl User { span.in_scope(|| { debug!(user_id = ?self.user.user_id); }); - Ok(context - .handler + let handler = context + .get_readable_handler(&self.user.user_id) + .expect("We shouldn't be able to get there without readable permission"); + Ok(handler .get_user_groups(&self.user.user_id) .instrument(span) .await @@ -283,7 +295,7 @@ pub struct Group { } #[graphql_object(context = Context)] -impl Group { +impl Group { fn id(&self) -> i32 { self.group_id } @@ -302,12 +314,13 @@ impl Group { span.in_scope(|| { debug!(name = %self.display_name); }); - if !context.validation_result.is_admin_or_readonly() { - span.in_scope(|| debug!("Unauthorized")); - return Err("Unauthorized access to group data".into()); - } - Ok(context - .handler + let handler = context + .get_readonly_handler() + .ok_or_else(field_error_callback( + &span, + "Unauthorized access to group data", + ))?; + Ok(handler .list_users( Some(DomainRequestFilter::MemberOfId(GroupId(self.group_id))), false, @@ -347,7 +360,9 @@ impl From for Group { #[cfg(test)] mod tests { use super::*; - use crate::{domain::handler::MockTestBackendHandler, infra::auth_service::ValidationResults}; + use crate::{ + domain::handler::MockTestBackendHandler, infra::access_control::ValidationResults, + }; use chrono::TimeZone; use juniper::{ execute, graphql_value, DefaultScalarValue, EmptyMutation, EmptySubscription, GraphQLType, @@ -406,10 +421,8 @@ mod tests { .with(eq(UserId::new("bob"))) .return_once(|_| Ok(groups)); - let context = Context:: { - handler: Box::new(mock), - validation_result: ValidationResults::admin(), - }; + let context = + Context::::new_for_tests(mock, ValidationResults::admin()); let schema = schema(Query::::new()); assert_eq!( @@ -486,10 +499,8 @@ mod tests { ]) }); - let context = Context:: { - handler: Box::new(mock), - validation_result: ValidationResults::admin(), - }; + let context = + Context::::new_for_tests(mock, ValidationResults::admin()); let schema = schema(Query::::new()); assert_eq!( diff --git a/server/src/infra/ldap_handler.rs b/server/src/infra/ldap_handler.rs index e24bcf0..1478ec8 100644 --- a/server/src/infra/ldap_handler.rs +++ b/server/src/infra/ldap_handler.rs @@ -12,7 +12,10 @@ use crate::{ opaque_handler::OpaqueHandler, types::{Group, JpegPhoto, UserAndGroups, UserId}, }, - infra::auth_service::{Permission, ValidationResults}, + infra::access_control::{ + AccessControlledBackendHandler, AdminBackendHandler, UserAndGroupListerBackendHandler, + UserReadableBackendHandler, ValidationResults, + }, }; use anyhow::Result; use ldap3_proto::proto::{ @@ -175,15 +178,27 @@ fn root_dse_response(base_dn: &str) -> LdapOp { }) } -pub struct LdapHandler { +pub struct LdapHandler { user_info: Option, - backend_handler: Backend, + backend_handler: AccessControlledBackendHandler, ldap_info: LdapInfo, } +impl LdapHandler { + pub fn get_login_handler(&self) -> &impl LoginHandler { + self.backend_handler.unsafe_get_handler() + } +} + +impl LdapHandler { + pub fn get_opaque_handler(&self) -> &impl OpaqueHandler { + self.backend_handler.unsafe_get_handler() + } +} + impl LdapHandler { pub fn new( - backend_handler: Backend, + backend_handler: AccessControlledBackendHandler, mut ldap_base_dn: String, ignored_user_attributes: Vec, ignored_group_attributes: Vec, @@ -206,6 +221,16 @@ impl LdapHandler Self { + Self::new( + AccessControlledBackendHandler::new(backend_handler), + ldap_base_dn.to_string(), + vec![], + vec![], + ) + } + #[instrument(skip_all, level = "debug")] pub async fn do_bind(&mut self, request: &LdapBindRequest) -> (LdapResultCode, String) { debug!("DN: {}", &request.dn); @@ -219,7 +244,7 @@ impl LdapHandler LdapHandler { - let user_groups = self.backend_handler.get_user_groups(&user_id).await; - let is_in_group = |name| { - user_groups - .as_ref() - .map(|groups| groups.iter().any(|g| g.display_name == name)) - .unwrap_or(false) - }; - self.user_info = Some(ValidationResults { - user: user_id, - permission: if is_in_group("lldap_admin") { - Permission::Admin - } else if is_in_group("lldap_password_manager") { - Permission::PasswordManager - } else if is_in_group("lldap_strict_readonly") { - Permission::Readonly - } else { - Permission::Regular - }, - }); + self.user_info = self + .backend_handler + .get_permissions_for_user(user_id) + .await + .ok(); debug!("Success!"); (LdapResultCode::Success, "".to_string()) } @@ -253,7 +264,12 @@ impl LdapHandler Result<()> { + async fn change_password( + &self, + backend_handler: &B, + user: &UserId, + password: &str, + ) -> Result<()> { use lldap_auth::*; let mut rng = rand::rngs::OsRng; let registration_start_request = @@ -262,7 +278,7 @@ impl LdapHandler LdapHandler LdapHandler { let user_is_admin = self .backend_handler + .get_readable_handler(credentials, &uid) + .expect("Unexpected permission error") .get_user_groups(&uid) .await .map_err(|e| LdapError { @@ -313,7 +331,10 @@ impl LdapHandler LdapHandler LdapResult> { - let user_info = self.user_info.as_ref().ok_or_else(|| LdapError { - code: LdapResultCode::InsufficentAccessRights, - message: "No user currently bound".to_string(), - })?; - Ok(if user_info.is_admin_or_readonly() { - None - } else { - Some(user_info.user.clone()) - }) - } - pub async fn do_search_or_dse( &mut self, request: &LdapSearchRequest, @@ -382,22 +391,22 @@ impl LdapHandler, ) -> LdapResult<(Option>, Option>)> { let dn_parts = parse_distinguished_name(&request.base.to_ascii_lowercase())?; let scope = get_search_scope(&self.ldap_info.base_dn, &dn_parts); debug!(?request.base, ?scope); // Disambiguate the lifetimes. - fn cast<'a, T, R, B: 'a>(x: T) -> T + fn cast<'a, T, R>(x: T) -> T where - T: Fn(&'a mut B, &'a LdapFilter) -> R + 'a, + T: Fn(&'a LdapFilter) -> R + 'a, { x } - let get_user_list = cast(|backend_handler: &mut Backend, filter: &LdapFilter| async { + let get_user_list = cast(|filter: &LdapFilter| async { let need_groups = request .attrs .iter() @@ -407,47 +416,27 @@ impl LdapHandler ( - Some(get_user_list(&mut self.backend_handler, &request.filter).await?), - Some(get_group_list(&mut self.backend_handler, &request.filter).await?), - ), - SearchScope::Users => ( - Some(get_user_list(&mut self.backend_handler, &request.filter).await?), - None, - ), - SearchScope::Groups => ( - None, - Some(get_group_list(&mut self.backend_handler, &request.filter).await?), + Some(get_user_list(&request.filter).await?), + Some(get_group_list(&request.filter).await?), ), + SearchScope::Users => (Some(get_user_list(&request.filter).await?), None), + SearchScope::Groups => (None, Some(get_group_list(&request.filter).await?)), SearchScope::User(filter) => { let filter = LdapFilter::And(vec![request.filter.clone(), filter]); - ( - Some(get_user_list(&mut self.backend_handler, &filter).await?), - None, - ) + (Some(get_user_list(&filter).await?), None) } SearchScope::Group(filter) => { let filter = LdapFilter::And(vec![request.filter.clone(), filter]); - ( - None, - Some(get_group_list(&mut self.backend_handler, &filter).await?), - ) + (None, Some(get_group_list(&filter).await?)) } SearchScope::Unknown => { warn!( @@ -468,10 +457,15 @@ impl LdapHandler LdapResult> { - let user_filter = self.get_user_permission_filter()?; - let user_filter = user_filter.as_ref(); - let (users, groups) = self.do_search_internal(request, &user_filter).await?; + pub async fn do_search(&self, request: &LdapSearchRequest) -> LdapResult> { + let user_info = self.user_info.as_ref().ok_or_else(|| LdapError { + code: LdapResultCode::InsufficentAccessRights, + message: "No user currently bound".to_string(), + })?; + let backend_handler = self + .backend_handler + .get_user_restricted_lister_handler(user_info); + let (users, groups) = self.do_search_internal(&backend_handler, request).await?; let mut results = Vec::new(); if let Some(users) = users { @@ -486,7 +480,7 @@ impl LdapHandler LdapHandler LdapResult> { - if !self + let backend_handler = self .user_info .as_ref() - .map(|u| u.is_admin()) - .unwrap_or(false) - { - return Err(LdapError { + .and_then(|u| self.backend_handler.get_admin_handler(u)) + .ok_or_else(|| LdapError { code: LdapResultCode::InsufficentAccessRights, message: "Unauthorized write".to_string(), - }); - } + })?; let user_id = get_user_id_from_distinguished_name( &request.dn, &self.ldap_info.base_dn, @@ -552,7 +543,7 @@ impl LdapHandler Result<()>; } #[async_trait] - impl GroupBackendHandler for TestBackendHandler { + impl GroupListerBackendHandler for TestBackendHandler { async fn list_groups(&self, filters: Option) -> Result>; + } + #[async_trait] + impl GroupBackendHandler for TestBackendHandler { async fn get_group_details(&self, group_id: GroupId) -> Result; async fn update_group(&self, request: UpdateGroupRequest) -> Result<()>; async fn create_group(&self, group_name: &str) -> Result; async fn delete_group(&self, group_id: GroupId) -> Result<()>; } #[async_trait] - impl UserBackendHandler for TestBackendHandler { + impl UserListerBackendHandler for TestBackendHandler { async fn list_users(&self, filters: Option, get_groups: bool) -> Result>; + } + #[async_trait] + impl UserBackendHandler for TestBackendHandler { async fn get_user_details(&self, user_id: &UserId) -> Result; async fn create_user(&self, request: CreateUserRequest) -> Result<()>; async fn update_user(&self, request: UpdateUserRequest) -> Result<()>; @@ -768,8 +765,7 @@ mod tests { }); Ok(set) }); - let mut ldap_handler = - LdapHandler::new(mock, "dc=Example,dc=com".to_string(), vec![], vec![]); + let mut ldap_handler = LdapHandler::new_for_tests(mock, "dc=Example,dc=com"); let request = LdapBindRequest { dn: "uid=test,ou=people,dc=example,dc=coM".to_string(), cred: LdapBindCred::Simple("pass".to_string()), @@ -812,8 +808,7 @@ mod tests { mock.expect_get_user_groups() .with(eq(UserId::new("bob"))) .return_once(|_| Ok(HashSet::new())); - let mut ldap_handler = - LdapHandler::new(mock, "dc=eXample,dc=com".to_string(), vec![], vec![]); + let mut ldap_handler = LdapHandler::new_for_tests(mock, "dc=eXample,dc=com"); let request = LdapOp::BindRequest(LdapBindRequest { dn: "uid=bob,ou=people,dc=example,dc=com".to_string(), @@ -855,8 +850,7 @@ mod tests { }); Ok(set) }); - let mut ldap_handler = - LdapHandler::new(mock, "dc=example,dc=com".to_string(), vec![], vec![]); + let mut ldap_handler = LdapHandler::new_for_tests(mock, "dc=example,dc=com"); let request = LdapBindRequest { dn: "uid=test,ou=people,dc=example,dc=com".to_string(), @@ -997,8 +991,7 @@ mod tests { #[tokio::test] async fn test_bind_invalid_dn() { let mock = MockTestBackendHandler::new(); - let mut ldap_handler = - LdapHandler::new(mock, "dc=example,dc=com".to_string(), vec![], vec![]); + let mut ldap_handler = LdapHandler::new_for_tests(mock, "dc=example,dc=com"); let request = LdapBindRequest { dn: "cn=bob,dc=example,dc=com".to_string(), diff --git a/server/src/infra/ldap_server.rs b/server/src/infra/ldap_server.rs index 80b8cf4..31551f9 100644 --- a/server/src/infra/ldap_server.rs +++ b/server/src/infra/ldap_server.rs @@ -3,7 +3,10 @@ use crate::{ handler::{BackendHandler, LoginHandler}, opaque_handler::OpaqueHandler, }, - infra::{configuration::Configuration, ldap_handler::LdapHandler}, + infra::{ + access_control::AccessControlledBackendHandler, configuration::Configuration, + ldap_handler::LdapHandler, + }, }; use actix_rt::net::TcpStream; use actix_server::ServerBuilder; @@ -73,7 +76,7 @@ where let mut resp = FramedWrite::new(w, LdapCodec); let mut session = LdapHandler::new( - backend_handler, + AccessControlledBackendHandler::new(backend_handler), ldap_base_dn, ignored_user_attributes, ignored_group_attributes, @@ -145,7 +148,7 @@ pub fn build_ldap_server( server_builder: ServerBuilder, ) -> Result where - Backend: BackendHandler + LoginHandler + OpaqueHandler + 'static, + Backend: BackendHandler + LoginHandler + OpaqueHandler + Clone + 'static, { let context = ( backend_handler, diff --git a/server/src/infra/mod.rs b/server/src/infra/mod.rs index f0b85f9..33a2e58 100644 --- a/server/src/infra/mod.rs +++ b/server/src/infra/mod.rs @@ -1,3 +1,4 @@ +pub mod access_control; pub mod auth_service; pub mod cli; pub mod configuration; diff --git a/server/src/infra/tcp_backend_handler.rs b/server/src/infra/tcp_backend_handler.rs index f4930af..c153a96 100644 --- a/server/src/infra/tcp_backend_handler.rs +++ b/server/src/infra/tcp_backend_handler.rs @@ -4,7 +4,7 @@ use std::collections::HashSet; use crate::domain::{error::Result, types::UserId}; #[async_trait] -pub trait TcpBackendHandler { +pub trait TcpBackendHandler: Sync { async fn get_jwt_blacklist(&self) -> anyhow::Result>; async fn create_refresh_token(&self, user: &UserId) -> Result<(String, chrono::Duration)>; async fn check_token(&self, refresh_token_hash: u64, user: &UserId) -> Result; @@ -34,16 +34,22 @@ mockall::mock! { async fn bind(&self, request: BindRequest) -> Result<()>; } #[async_trait] - impl GroupBackendHandler for TestTcpBackendHandler { + impl GroupListerBackendHandler for TestTcpBackendHandler { async fn list_groups(&self, filters: Option) -> Result>; + } + #[async_trait] + impl GroupBackendHandler for TestTcpBackendHandler { async fn get_group_details(&self, group_id: GroupId) -> Result; async fn update_group(&self, request: UpdateGroupRequest) -> Result<()>; async fn create_group(&self, group_name: &str) -> Result; async fn delete_group(&self, group_id: GroupId) -> Result<()>; } #[async_trait] - impl UserBackendHandler for TestBackendHandler { + impl UserListerBackendHandler for TestBackendHandler { async fn list_users(&self, filters: Option, get_groups: bool) -> Result>; + } + #[async_trait] + impl UserBackendHandler for TestBackendHandler { async fn get_user_details(&self, user_id: &UserId) -> Result; async fn create_user(&self, request: CreateUserRequest) -> Result<()>; async fn update_user(&self, request: UpdateUserRequest) -> Result<()>; diff --git a/server/src/infra/tcp_server.rs b/server/src/infra/tcp_server.rs index 166b65e..7179e7d 100644 --- a/server/src/infra/tcp_server.rs +++ b/server/src/infra/tcp_server.rs @@ -5,6 +5,7 @@ use crate::{ opaque_handler::OpaqueHandler, }, infra::{ + access_control::{AccessControlledBackendHandler, ReadonlyBackendHandler}, auth_service, configuration::{Configuration, MailOptions}, logging::CustomRootSpanBuilder, @@ -74,11 +75,11 @@ fn http_config( server_url: String, mail_options: MailOptions, ) where - Backend: TcpBackendHandler + BackendHandler + LoginHandler + OpaqueHandler + Sync + 'static, + Backend: TcpBackendHandler + BackendHandler + LoginHandler + OpaqueHandler + Clone + 'static, { let enable_password_reset = mail_options.enable_password_reset; cfg.app_data(web::Data::new(AppState:: { - backend_handler, + backend_handler: AccessControlledBackendHandler::new(backend_handler), jwt_key: Hmac::new_varkey(jwt_secret.unsecure().as_bytes()).unwrap(), jwt_blacklist: RwLock::new(jwt_blacklist), server_url, @@ -110,20 +111,41 @@ fn http_config( } pub(crate) struct AppState { - pub backend_handler: Backend, + pub backend_handler: AccessControlledBackendHandler, pub jwt_key: Hmac, pub jwt_blacklist: RwLock>, pub server_url: String, pub mail_options: MailOptions, } +impl AppState { + pub fn get_readonly_handler(&self) -> &impl ReadonlyBackendHandler { + self.backend_handler.unsafe_get_handler() + } +} +impl AppState { + pub fn get_tcp_handler(&self) -> &impl TcpBackendHandler { + self.backend_handler.unsafe_get_handler() + } +} +impl AppState { + pub fn get_opaque_handler(&self) -> &impl OpaqueHandler { + self.backend_handler.unsafe_get_handler() + } +} +impl AppState { + pub fn get_login_handler(&self) -> &impl LoginHandler { + self.backend_handler.unsafe_get_handler() + } +} + pub async fn build_tcp_server( config: &Configuration, backend_handler: Backend, server_builder: ServerBuilder, ) -> Result where - Backend: TcpBackendHandler + BackendHandler + LoginHandler + OpaqueHandler + Sync + 'static, + Backend: TcpBackendHandler + BackendHandler + LoginHandler + OpaqueHandler + Clone + 'static, { let jwt_secret = config.jwt_secret.clone(); let jwt_blacklist = backend_handler diff --git a/server/src/main.rs b/server/src/main.rs index 712d0a9..904487c 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -7,7 +7,10 @@ use std::time::Duration; use crate::{ domain::{ - handler::{CreateUserRequest, GroupBackendHandler, GroupRequestFilter, UserBackendHandler}, + handler::{ + CreateUserRequest, GroupBackendHandler, GroupListerBackendHandler, GroupRequestFilter, + UserBackendHandler, + }, sql_backend_handler::SqlBackendHandler, sql_opaque_handler::register_password, },