From ca19e61f5031bcebd9caf3323e2019f59fd628dd Mon Sep 17 00:00:00 2001 From: Valentin Tolmer Date: Sat, 26 Mar 2022 18:00:37 +0100 Subject: [PATCH] domain: introduce UserId to make uid case insensitive Note that if there was a non-lowercase user already in the DB, it cannot be found again. To fix this, run in the DB: sqlite> UPDATE users SET user_id = LOWER(user_id); --- server/src/domain/error.rs | 2 +- server/src/domain/handler.rs | 66 +++++++++--- server/src/domain/opaque_handler.rs | 6 +- server/src/domain/sql_backend_handler.rs | 132 ++++++++++++++--------- server/src/domain/sql_opaque_handler.rs | 39 ++++--- server/src/domain/sql_tables.rs | 39 ++++++- server/src/infra/auth_service.rs | 20 ++-- server/src/infra/configuration.rs | 9 +- server/src/infra/graphql/mutation.rs | 15 +-- server/src/infra/graphql/query.rs | 21 ++-- server/src/infra/ldap_handler.rs | 87 ++++++++------- server/src/infra/sql_backend_handler.rs | 12 +-- server/src/infra/tcp_backend_handler.rs | 32 +++--- 13 files changed, 299 insertions(+), 181 deletions(-) diff --git a/server/src/domain/error.rs b/server/src/domain/error.rs index 103bca1..5c9e38d 100644 --- a/server/src/domain/error.rs +++ b/server/src/domain/error.rs @@ -3,7 +3,7 @@ use thiserror::Error; #[allow(clippy::enum_variant_names)] #[derive(Error, Debug)] pub enum DomainError { - #[error("Authentication error for `{0}`")] + #[error("Authentication error: `{0}`")] AuthenticationError(String), #[error("Database error: `{0}`")] DatabaseError(#[from] sqlx::Error), diff --git a/server/src/domain/handler.rs b/server/src/domain/handler.rs index 8862dc3..bec8186 100644 --- a/server/src/domain/handler.rs +++ b/server/src/domain/handler.rs @@ -3,10 +3,41 @@ use async_trait::async_trait; use serde::{Deserialize, Serialize}; use std::collections::HashSet; +#[derive(PartialEq, Eq, Clone, Debug, Default, Serialize, Deserialize)] +#[cfg_attr(not(target_arch = "wasm32"), derive(sqlx::FromRow))] +#[serde(from = "String")] +pub struct UserId(String); + +impl UserId { + pub fn new(user_id: &str) -> Self { + Self(user_id.to_lowercase()) + } + + pub fn as_str(&self) -> &str { + self.0.as_str() + } + + pub fn into_string(self) -> String { + self.0 + } +} + +impl std::fmt::Display for UserId { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl From for UserId { + fn from(s: String) -> Self { + Self::new(&s) + } +} + #[derive(PartialEq, Eq, Debug, Serialize, Deserialize)] #[cfg_attr(not(target_arch = "wasm32"), derive(sqlx::FromRow))] pub struct User { - pub user_id: String, + pub user_id: UserId, pub email: String, pub display_name: String, pub first_name: String, @@ -19,7 +50,7 @@ impl Default for User { fn default() -> Self { use chrono::TimeZone; User { - user_id: String::new(), + user_id: UserId::default(), email: String::new(), display_name: String::new(), first_name: String::new(), @@ -33,12 +64,12 @@ impl Default for User { pub struct Group { pub id: GroupId, pub display_name: String, - pub users: Vec, + pub users: Vec, } #[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Clone)] pub struct BindRequest { - pub name: String, + pub name: UserId, pub password: String, } @@ -47,6 +78,7 @@ pub enum UserRequestFilter { And(Vec), Or(Vec), Not(Box), + UserId(UserId), Equality(String, String), // Check if a user belongs to a group identified by name. MemberOf(String), @@ -62,13 +94,13 @@ pub enum GroupRequestFilter { DisplayName(String), GroupId(GroupId), // Check if the group contains a user identified by uid. - Member(String), + Member(UserId), } #[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Clone, Default)] pub struct CreateUserRequest { // Same fields as User, but no creation_date, and with password. - pub user_id: String, + pub user_id: UserId, pub email: String, pub display_name: Option, pub first_name: Option, @@ -78,7 +110,7 @@ pub struct CreateUserRequest { #[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Clone, Default)] pub struct UpdateUserRequest { // Same fields as CreateUserRequest, but no with an extra layer of Option. - pub user_id: String, + pub user_id: UserId, pub email: Option, pub display_name: Option, pub first_name: Option, @@ -106,17 +138,17 @@ pub struct GroupIdAndName(pub GroupId, pub String); pub trait BackendHandler: Clone + Send { async fn list_users(&self, filters: Option) -> Result>; async fn list_groups(&self, filters: Option) -> Result>; - async fn get_user_details(&self, user_id: &str) -> Result; + async fn get_user_details(&self, user_id: &UserId) -> Result; async fn get_group_details(&self, group_id: GroupId) -> Result; async fn create_user(&self, request: CreateUserRequest) -> Result<()>; async fn update_user(&self, request: UpdateUserRequest) -> Result<()>; async fn update_group(&self, request: UpdateGroupRequest) -> Result<()>; - async fn delete_user(&self, user_id: &str) -> Result<()>; + async fn delete_user(&self, user_id: &UserId) -> Result<()>; async fn create_group(&self, group_name: &str) -> Result; async fn delete_group(&self, group_id: GroupId) -> Result<()>; - async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> Result<()>; - async fn remove_user_from_group(&self, user_id: &str, group_id: GroupId) -> Result<()>; - async fn get_user_groups(&self, user: &str) -> Result>; + async fn add_user_to_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()>; + async fn remove_user_from_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()>; + async fn get_user_groups(&self, user_id: &UserId) -> Result>; } #[cfg(test)] @@ -129,17 +161,17 @@ mockall::mock! { impl BackendHandler for TestBackendHandler { async fn list_users(&self, filters: Option) -> Result>; async fn list_groups(&self, filters: Option) -> Result>; - async fn get_user_details(&self, user_id: &str) -> Result; + async fn get_user_details(&self, user_id: &UserId) -> Result; async fn get_group_details(&self, group_id: GroupId) -> Result; async fn create_user(&self, request: CreateUserRequest) -> Result<()>; async fn update_user(&self, request: UpdateUserRequest) -> Result<()>; async fn update_group(&self, request: UpdateGroupRequest) -> Result<()>; - async fn delete_user(&self, user_id: &str) -> Result<()>; + async fn delete_user(&self, user_id: &UserId) -> Result<()>; async fn create_group(&self, group_name: &str) -> Result; async fn delete_group(&self, group_id: GroupId) -> Result<()>; - async fn get_user_groups(&self, user: &str) -> Result>; - async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> Result<()>; - async fn remove_user_from_group(&self, user_id: &str, group_id: GroupId) -> Result<()>; + async fn get_user_groups(&self, user_id: &UserId) -> Result>; + async fn add_user_to_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()>; + async fn remove_user_from_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()>; } #[async_trait] impl LoginHandler for TestBackendHandler { diff --git a/server/src/domain/opaque_handler.rs b/server/src/domain/opaque_handler.rs index dcdf083..7e78299 100644 --- a/server/src/domain/opaque_handler.rs +++ b/server/src/domain/opaque_handler.rs @@ -1,4 +1,4 @@ -use super::error::*; +use crate::domain::{error::*, handler::UserId}; use async_trait::async_trait; pub use lldap_auth::{login, registration}; @@ -9,7 +9,7 @@ pub trait OpaqueHandler: Clone + Send { &self, request: login::ClientLoginStartRequest, ) -> Result; - async fn login_finish(&self, request: login::ClientLoginFinishRequest) -> Result; + async fn login_finish(&self, request: login::ClientLoginFinishRequest) -> Result; async fn registration_start( &self, request: registration::ClientRegistrationStartRequest, @@ -32,7 +32,7 @@ mockall::mock! { &self, request: login::ClientLoginStartRequest ) -> Result; - async fn login_finish(&self, request: login::ClientLoginFinishRequest ) -> Result; + async fn login_finish(&self, request: login::ClientLoginFinishRequest ) -> Result; async fn registration_start( &self, request: registration::ClientRegistrationStartRequest diff --git a/server/src/domain/sql_backend_handler.rs b/server/src/domain/sql_backend_handler.rs index 2211434..e8ee667 100644 --- a/server/src/domain/sql_backend_handler.rs +++ b/server/src/domain/sql_backend_handler.rs @@ -51,12 +51,16 @@ fn get_user_filter_expr(filter: UserRequestFilter) -> (RequiresGroup, SimpleExpr let (requires_group, filters) = get_user_filter_expr(*f); (requires_group, Expr::not(Expr::expr(filters))) } + UserId(user_id) => ( + RequiresGroup(false), + Expr::col((Users::Table, Users::UserId)).eq(user_id), + ), Equality(s1, s2) => ( RequiresGroup(false), if s1 == Users::DisplayName.to_string() { Expr::col((Users::Table, Users::DisplayName)).eq(s2) } else if s1 == Users::UserId.to_string() { - Expr::col((Users::Table, Users::UserId)).eq(s2) + panic!("User id should be wrapped") } else { Expr::expr(Expr::cust(&s1)).eq(s2) }, @@ -205,17 +209,17 @@ impl BackendHandler for SqlBackendHandler { id: group_id, display_name, users: rows - .map(|row| row.get::(&*Memberships::UserId.to_string())) + .map(|row| row.get::(&*Memberships::UserId.to_string())) // If a group has no users, an empty string is returned because of the left // join. - .filter(|s| !s.is_empty()) + .filter(|s| !s.as_str().is_empty()) .collect(), }); } Ok(groups) } - async fn get_user_details(&self, user_id: &str) -> Result { + async fn get_user_details(&self, user_id: &UserId) -> Result { let query = Query::select() .column(Users::UserId) .column(Users::Email) @@ -246,8 +250,8 @@ impl BackendHandler for SqlBackendHandler { .await?) } - async fn get_user_groups(&self, user: &str) -> Result> { - if user == self.config.ldap_user_dn { + async fn get_user_groups(&self, user_id: &UserId) -> Result> { + if *user_id == self.config.ldap_user_dn { let mut groups = HashSet::new(); groups.insert(GroupIdAndName(GroupId(1), "lldap_admin".to_string())); return Ok(groups); @@ -261,7 +265,7 @@ impl BackendHandler for SqlBackendHandler { Expr::tbl(Groups::Table, Groups::GroupId) .equals(Memberships::Table, Memberships::GroupId), ) - .and_where(Expr::col(Memberships::UserId).eq(user)) + .and_where(Expr::col(Memberships::UserId).eq(user_id)) .to_string(DbQueryBuilder {}); sqlx::query(&query) @@ -294,7 +298,7 @@ impl BackendHandler for SqlBackendHandler { Users::CreationDate, ]; let values = vec![ - request.user_id.clone().into(), + request.user_id.into(), request.email.into(), request.display_name.unwrap_or_default().into(), request.first_name.unwrap_or_default().into(), @@ -353,7 +357,7 @@ impl BackendHandler for SqlBackendHandler { Ok(()) } - async fn delete_user(&self, user_id: &str) -> Result<()> { + async fn delete_user(&self, user_id: &UserId) -> Result<()> { let delete_query = Query::delete() .from_table(Users::Table) .and_where(Expr::col(Users::UserId).eq(user_id)) @@ -387,7 +391,7 @@ impl BackendHandler for SqlBackendHandler { Ok(()) } - async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> Result<()> { + async fn add_user_to_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()> { let query = Query::insert() .into_table(Memberships::Table) .columns(vec![Memberships::UserId, Memberships::GroupId]) @@ -397,7 +401,7 @@ impl BackendHandler for SqlBackendHandler { Ok(()) } - async fn remove_user_from_group(&self, user_id: &str, group_id: GroupId) -> Result<()> { + async fn remove_user_from_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()> { let query = Query::delete() .from_table(Memberships::Table) .and_where(Expr::col(Memberships::GroupId).eq(group_id)) @@ -463,7 +467,7 @@ mod tests { async fn insert_user_no_password(handler: &SqlBackendHandler, name: &str) { handler .create_user(CreateUserRequest { - user_id: name.to_string(), + user_id: UserId::new(name), email: "bob@bob.bob".to_string(), ..Default::default() }) @@ -476,21 +480,24 @@ mod tests { } async fn insert_membership(handler: &SqlBackendHandler, group_id: GroupId, user_id: &str) { - handler.add_user_to_group(user_id, group_id).await.unwrap(); + handler + .add_user_to_group(&UserId::new(user_id), group_id) + .await + .unwrap(); } #[tokio::test] async fn test_bind_admin() { let sql_pool = get_in_memory_db().await; let config = ConfigurationBuilder::default() - .ldap_user_dn("admin".to_string()) + .ldap_user_dn(UserId::new("admin")) .ldap_user_pass(secstr::SecUtf8::from("test")) .build() .unwrap(); let handler = SqlBackendHandler::new(config, sql_pool); handler .bind(BindRequest { - name: "admin".to_string(), + name: UserId::new("admin"), password: "test".to_string(), }) .await @@ -506,21 +513,21 @@ mod tests { handler .bind(BindRequest { - name: "bob".to_string(), + name: UserId::new("bob"), password: "bob00".to_string(), }) .await .unwrap(); handler .bind(BindRequest { - name: "andrew".to_string(), + name: UserId::new("andrew"), password: "bob00".to_string(), }) .await .unwrap_err(); handler .bind(BindRequest { - name: "bob".to_string(), + name: UserId::new("bob"), password: "wrong_password".to_string(), }) .await @@ -536,7 +543,7 @@ mod tests { handler .bind(BindRequest { - name: "bob".to_string(), + name: UserId::new("bob"), password: "bob00".to_string(), }) .await @@ -557,47 +564,44 @@ mod tests { .await .unwrap() .into_iter() - .map(|u| u.user_id) + .map(|u| u.user_id.to_string()) .collect::>(); - assert_eq!(users, vec!["John", "bob", "patrick"]); + assert_eq!(users, vec!["bob", "john", "patrick"]); } { let users = handler - .list_users(Some(UserRequestFilter::Equality( - "user_id".to_string(), - "bob".to_string(), - ))) + .list_users(Some(UserRequestFilter::UserId(UserId::new("bob")))) .await .unwrap() .into_iter() - .map(|u| u.user_id) + .map(|u| u.user_id.to_string()) .collect::>(); assert_eq!(users, vec!["bob"]); } { let users = handler .list_users(Some(UserRequestFilter::Or(vec![ - UserRequestFilter::Equality("user_id".to_string(), "bob".to_string()), - UserRequestFilter::Equality("user_id".to_string(), "John".to_string()), + UserRequestFilter::UserId(UserId::new("bob")), + UserRequestFilter::UserId(UserId::new("John")), ]))) .await .unwrap() .into_iter() - .map(|u| u.user_id) + .map(|u| u.user_id.to_string()) .collect::>(); - assert_eq!(users, vec!["John", "bob"]); + assert_eq!(users, vec!["bob", "john"]); } { let users = handler .list_users(Some(UserRequestFilter::Not(Box::new( - UserRequestFilter::Equality("user_id".to_string(), "bob".to_string()), + UserRequestFilter::UserId(UserId::new("bob")), )))) .await .unwrap() .into_iter() - .map(|u| u.user_id) + .map(|u| u.user_id.to_string()) .collect::>(); - assert_eq!(users, vec!["John", "patrick"]); + assert_eq!(users, vec!["john", "patrick"]); } } @@ -622,7 +626,7 @@ mod tests { Group { id: group_1, display_name: "Best Group".to_string(), - users: vec!["bob".to_string(), "patrick".to_string()] + users: vec![UserId::new("bob"), UserId::new("patrick")] }, Group { id: group_3, @@ -632,7 +636,7 @@ mod tests { Group { id: group_2, display_name: "Worst Group".to_string(), - users: vec!["John".to_string(), "patrick".to_string()] + users: vec![UserId::new("john"), UserId::new("patrick")] }, ] ); @@ -640,7 +644,7 @@ mod tests { handler .list_groups(Some(GroupRequestFilter::Or(vec![ GroupRequestFilter::DisplayName("Empty Group".to_string()), - GroupRequestFilter::Member("bob".to_string()), + GroupRequestFilter::Member(UserId::new("bob")), ]))) .await .unwrap(), @@ -648,7 +652,7 @@ mod tests { Group { id: group_1, display_name: "Best Group".to_string(), - users: vec!["bob".to_string(), "patrick".to_string()] + users: vec![UserId::new("bob"), UserId::new("patrick")] }, Group { id: group_3, @@ -670,7 +674,7 @@ mod tests { vec![Group { id: group_1, display_name: "Best Group".to_string(), - users: vec!["bob".to_string(), "patrick".to_string()] + users: vec![UserId::new("bob"), UserId::new("patrick")] }] ); } @@ -682,13 +686,35 @@ mod tests { let handler = SqlBackendHandler::new(config, sql_pool); insert_user(&handler, "bob", "bob00").await; { - let user = handler.get_user_details("bob").await.unwrap(); - assert_eq!(user.user_id, "bob".to_string()); + let user = handler.get_user_details(&UserId::new("bob")).await.unwrap(); + assert_eq!(user.user_id.as_str(), "bob"); } { - handler.get_user_details("John").await.unwrap_err(); + handler + .get_user_details(&UserId::new("John")) + .await + .unwrap_err(); } } + + #[tokio::test] + async fn test_user_lowercase() { + let sql_pool = get_initialized_db().await; + let config = get_default_config(); + let handler = SqlBackendHandler::new(config, sql_pool); + insert_user(&handler, "Bob", "bob00").await; + { + let user = handler.get_user_details(&UserId::new("bOb")).await.unwrap(); + assert_eq!(user.user_id.as_str(), "bob"); + } + { + handler + .get_user_details(&UserId::new("John")) + .await + .unwrap_err(); + } + } + #[tokio::test] async fn test_get_user_groups() { let sql_pool = get_initialized_db().await; @@ -707,13 +733,19 @@ mod tests { let mut patrick_groups = HashSet::new(); patrick_groups.insert(GroupIdAndName(group_1, "Group1".to_string())); patrick_groups.insert(GroupIdAndName(group_2, "Group2".to_string())); - assert_eq!(handler.get_user_groups("bob").await.unwrap(), bob_groups); assert_eq!( - handler.get_user_groups("patrick").await.unwrap(), + handler.get_user_groups(&UserId::new("bob")).await.unwrap(), + bob_groups + ); + assert_eq!( + handler + .get_user_groups(&UserId::new("patrick")) + .await + .unwrap(), patrick_groups ); assert_eq!( - handler.get_user_groups("John").await.unwrap(), + handler.get_user_groups(&UserId::new("John")).await.unwrap(), HashSet::new() ); } @@ -729,29 +761,29 @@ mod tests { insert_user(&handler, "Jennz", "boupBoup").await; // Remove a user - let _request_result = handler.delete_user("Jennz").await.unwrap(); + let _request_result = handler.delete_user(&UserId::new("Jennz")).await.unwrap(); let users = handler .list_users(None) .await .unwrap() .into_iter() - .map(|u| u.user_id) + .map(|u| u.user_id.to_string()) .collect::>(); - assert_eq!(users, vec!["Hector", "val"]); + assert_eq!(users, vec!["hector", "val"]); // Insert new user and remove two insert_user(&handler, "NewBoi", "Joni").await; - let _request_result = handler.delete_user("Hector").await.unwrap(); - let _request_result = handler.delete_user("NewBoi").await.unwrap(); + let _request_result = handler.delete_user(&UserId::new("Hector")).await.unwrap(); + let _request_result = handler.delete_user(&UserId::new("NewBoi")).await.unwrap(); let users = handler .list_users(None) .await .unwrap() .into_iter() - .map(|u| u.user_id) + .map(|u| u.user_id.to_string()) .collect::>(); assert_eq!(users, vec!["val"]); diff --git a/server/src/domain/sql_opaque_handler.rs b/server/src/domain/sql_opaque_handler.rs index 39c07e4..6208601 100644 --- a/server/src/domain/sql_opaque_handler.rs +++ b/server/src/domain/sql_opaque_handler.rs @@ -1,6 +1,6 @@ use super::{ error::*, - handler::{BindRequest, LoginHandler}, + handler::{BindRequest, LoginHandler, UserId}, opaque_handler::*, sql_backend_handler::SqlBackendHandler, sql_tables::*, @@ -18,7 +18,7 @@ fn passwords_match( password_file_bytes: &[u8], clear_password: &str, server_setup: &opaque::server::ServerSetup, - username: &str, + username: &UserId, ) -> Result<()> { use opaque::{client, server}; let mut rng = rand::rngs::OsRng; @@ -31,7 +31,7 @@ fn passwords_match( server_setup, Some(password_file), client_login_start_result.message, - username, + username.as_str(), )?; client::login::finish_login( client_login_start_result.state, @@ -88,13 +88,16 @@ impl LoginHandler for SqlBackendHandler { return Ok(()); } else { debug!(r#"Invalid password for LDAP bind user"#); - return Err(DomainError::AuthenticationError(request.name)); + return Err(DomainError::AuthenticationError(format!( + " for user '{}'", + request.name + ))); } } let query = Query::select() .column(Users::PasswordHash) .from(Users::Table) - .and_where(Expr::col(Users::UserId).eq(request.name.as_str())) + .and_where(Expr::col(Users::UserId).eq(&request.name)) .to_string(DbQueryBuilder {}); if let Ok(row) = sqlx::query(&query).fetch_one(&self.sql_pool).await { if let Some(password_hash) = @@ -106,17 +109,20 @@ impl LoginHandler for SqlBackendHandler { self.config.get_server_setup(), &request.name, ) { - debug!(r#"Invalid password for "{}": {}"#, request.name, e); + debug!(r#"Invalid password for "{}": {}"#, &request.name, e); } else { return Ok(()); } } else { - debug!(r#"User "{}" has no password"#, request.name); + debug!(r#"User "{}" has no password"#, &request.name); } } else { - debug!(r#"No user found for "{}""#, request.name); + debug!(r#"No user found for "{}""#, &request.name); } - Err(DomainError::AuthenticationError(request.name)) + Err(DomainError::AuthenticationError(format!( + " for user '{}'", + request.name + ))) } } @@ -150,7 +156,7 @@ impl OpaqueHandler for SqlOpaqueHandler { }) } - async fn login_finish(&self, request: login::ClientLoginFinishRequest) -> Result { + async fn login_finish(&self, request: login::ClientLoginFinishRequest) -> Result { let secret_key = self.get_orion_secret_key()?; let login::ServerData { username, @@ -165,7 +171,7 @@ impl OpaqueHandler for SqlOpaqueHandler { opaque::server::login::finish_login(server_login, request.credential_finalization)? .session_key; - Ok(username) + Ok(UserId::new(&username)) } async fn registration_start( @@ -220,7 +226,7 @@ impl OpaqueHandler for SqlOpaqueHandler { /// Convenience function to set a user's password. pub(crate) async fn register_password( opaque_handler: &SqlOpaqueHandler, - username: &str, + username: &UserId, password: &SecUtf8, ) -> Result<()> { let mut rng = rand::rngs::OsRng; @@ -278,7 +284,7 @@ mod tests { async fn insert_user_no_password(handler: &SqlBackendHandler, name: &str) { handler .create_user(CreateUserRequest { - user_id: name.to_string(), + user_id: UserId::new(name), email: "bob@bob.bob".to_string(), ..Default::default() }) @@ -323,7 +329,12 @@ mod tests { attempt_login(&opaque_handler, "bob", "bob00") .await .unwrap_err(); - register_password(&opaque_handler, "bob", &secstr::SecUtf8::from("bob00")).await?; + register_password( + &opaque_handler, + &UserId::new("bob"), + &secstr::SecUtf8::from("bob00"), + ) + .await?; attempt_login(&opaque_handler, "bob", "wrong_password") .await .unwrap_err(); diff --git a/server/src/domain/sql_tables.rs b/server/src/domain/sql_tables.rs index 67a5115..a48fad0 100644 --- a/server/src/domain/sql_tables.rs +++ b/server/src/domain/sql_tables.rs @@ -1,4 +1,4 @@ -use super::handler::GroupId; +use super::handler::{GroupId, UserId}; use sea_query::*; pub type Pool = sqlx::sqlite::SqlitePool; @@ -37,6 +37,43 @@ where } } +impl sqlx::Type for UserId +where + DB: sqlx::Database, + String: sqlx::Type, +{ + fn type_info() -> ::TypeInfo { + >::type_info() + } + fn compatible(ty: &::TypeInfo) -> bool { + >::compatible(ty) + } +} + +impl<'r, DB> sqlx::Decode<'r, DB> for UserId +where + DB: sqlx::Database, + String: sqlx::Decode<'r, DB>, +{ + fn decode( + value: >::ValueRef, + ) -> Result> { + >::decode(value).map(|s| UserId::new(&s)) + } +} + +impl From for sea_query::Value { + fn from(user_id: UserId) -> Self { + user_id.into_string().into() + } +} + +impl From<&UserId> for sea_query::Value { + fn from(user_id: &UserId) -> Self { + user_id.as_str().into() + } +} + #[derive(Iden)] pub enum Users { Table, diff --git a/server/src/infra/auth_service.rs b/server/src/infra/auth_service.rs index c3cf09f..a1d3fca 100644 --- a/server/src/infra/auth_service.rs +++ b/server/src/infra/auth_service.rs @@ -25,7 +25,7 @@ use lldap_auth::{login, opaque, password_reset, registration, JWTClaims}; use crate::{ domain::{ error::DomainError, - handler::{BackendHandler, BindRequest, GroupIdAndName, LoginHandler}, + handler::{BackendHandler, BindRequest, GroupIdAndName, LoginHandler, UserId}, opaque_handler::OpaqueHandler, }, infra::{ @@ -51,7 +51,7 @@ fn create_jwt(key: &Hmac, user: String, groups: HashSet) jwt::Token::new(header, claims).sign_with_key(key).unwrap() } -fn parse_refresh_token(token: &str) -> std::result::Result<(u64, String), HttpResponse> { +fn parse_refresh_token(token: &str) -> std::result::Result<(u64, UserId), HttpResponse> { match token.split_once('+') { None => Err(HttpResponse::Unauthorized().body("Invalid refresh token")), Some((token, u)) => { @@ -60,12 +60,12 @@ fn parse_refresh_token(token: &str) -> std::result::Result<(u64, String), HttpRe token.hash(&mut s); s.finish() }; - Ok((refresh_token_hash, u.to_string())) + Ok((refresh_token_hash, UserId::new(u))) } } } -fn get_refresh_token(request: HttpRequest) -> std::result::Result<(u64, String), HttpResponse> { +fn get_refresh_token(request: HttpRequest) -> std::result::Result<(u64, UserId), HttpResponse> { match ( request.cookie("refresh_token"), request.headers().get("refresh-token"), @@ -134,14 +134,14 @@ where { let user_id = match request.match_info().get("user_id") { None => return HttpResponse::BadRequest().body("Missing user ID"), - Some(id) => id, + Some(id) => UserId::new(id), }; - let token = match data.backend_handler.start_password_reset(user_id).await { + let token = match data.backend_handler.start_password_reset(&user_id).await { Err(e) => return HttpResponse::InternalServerError().body(e.to_string()), Ok(None) => return HttpResponse::Ok().finish(), Ok(Some(token)) => token, }; - let user = match data.backend_handler.get_user_details(user_id).await { + let user = match data.backend_handler.get_user_details(&user_id).await { Err(e) => { warn!("Error getting used details: {:#?}", e); return HttpResponse::Ok().finish(); @@ -196,7 +196,7 @@ where .finish(), ) .json(&password_reset::ServerPasswordResetResponse { - user_id, + user_id: user_id.to_string(), token: token.as_str().to_owned(), }) } @@ -276,7 +276,7 @@ where async fn get_login_successful_response( data: &web::Data>, - name: &str, + name: &UserId, ) -> HttpResponse where Backend: TcpBackendHandler + BackendHandler, @@ -289,7 +289,7 @@ where .await .map(|(groups, (refresh_token, max_age))| { let token = create_jwt(&data.jwt_key, name.to_string(), groups); - let refresh_token_plus_name = refresh_token + "+" + name; + let refresh_token_plus_name = refresh_token + "+" + name.as_str(); HttpResponse::Ok() .cookie( diff --git a/server/src/infra/configuration.rs b/server/src/infra/configuration.rs index 39b7f0c..ad9cc45 100644 --- a/server/src/infra/configuration.rs +++ b/server/src/infra/configuration.rs @@ -1,4 +1,7 @@ -use crate::infra::cli::{GeneralConfigOpts, RunOpts, SmtpOpts, TestEmailOpts}; +use crate::{ + domain::handler::UserId, + infra::cli::{GeneralConfigOpts, RunOpts, SmtpOpts, TestEmailOpts}, +}; use anyhow::{Context, Result}; use figment::{ providers::{Env, Format, Serialized, Toml}, @@ -49,8 +52,8 @@ pub struct Configuration { pub jwt_secret: SecUtf8, #[builder(default = r#"String::from("dc=example,dc=com")"#)] pub ldap_base_dn: String, - #[builder(default = r#"String::from("admin")"#)] - pub ldap_user_dn: String, + #[builder(default = r#"UserId::new("admin")"#)] + pub ldap_user_dn: UserId, #[builder(default = r#"SecUtf8::from("password")"#)] pub ldap_user_pass: SecUtf8, #[builder(default = r#"String::from("sqlite://users.db?mode=rwc")"#)] diff --git a/server/src/infra/graphql/mutation.rs b/server/src/infra/graphql/mutation.rs index f5089af..e2c1557 100644 --- a/server/src/infra/graphql/mutation.rs +++ b/server/src/infra/graphql/mutation.rs @@ -1,5 +1,5 @@ use crate::domain::handler::{ - BackendHandler, CreateUserRequest, GroupId, UpdateGroupRequest, UpdateUserRequest, + BackendHandler, CreateUserRequest, GroupId, UpdateGroupRequest, UpdateUserRequest, UserId, }; use juniper::{graphql_object, FieldResult, GraphQLInputObject, GraphQLObject}; @@ -66,10 +66,11 @@ impl Mutation { if !context.validation_result.is_admin { return Err("Unauthorized user creation".into()); } + let user_id = UserId::new(&user.id); context .handler .create_user(CreateUserRequest { - user_id: user.id.clone(), + user_id: user_id.clone(), email: user.email, display_name: user.display_name, first_name: user.first_name, @@ -78,7 +79,7 @@ impl Mutation { .await?; Ok(context .handler - .get_user_details(&user.id) + .get_user_details(&user_id) .await .map(Into::into)?) } @@ -108,7 +109,7 @@ impl Mutation { context .handler .update_user(UpdateUserRequest { - user_id: user.id, + user_id: UserId::new(&user.id), email: user.email, display_name: user.display_name, first_name: user.first_name, @@ -148,7 +149,7 @@ impl Mutation { } context .handler - .add_user_to_group(&user_id, GroupId(group_id)) + .add_user_to_group(&UserId::new(&user_id), GroupId(group_id)) .await?; Ok(Success::new()) } @@ -166,7 +167,7 @@ impl Mutation { } context .handler - .remove_user_from_group(&user_id, GroupId(group_id)) + .remove_user_from_group(&UserId::new(&user_id), GroupId(group_id)) .await?; Ok(Success::new()) } @@ -178,7 +179,7 @@ impl Mutation { if context.validation_result.user == user_id { return Err("Cannot delete current user".into()); } - context.handler.delete_user(&user_id).await?; + context.handler.delete_user(&UserId::new(&user_id)).await?; Ok(Success::new()) } diff --git a/server/src/infra/graphql/query.rs b/server/src/infra/graphql/query.rs index 3804127..8c592cb 100644 --- a/server/src/infra/graphql/query.rs +++ b/server/src/infra/graphql/query.rs @@ -1,4 +1,4 @@ -use crate::domain::handler::{BackendHandler, GroupId, GroupIdAndName}; +use crate::domain::handler::{BackendHandler, GroupId, GroupIdAndName, UserId}; use juniper::{graphql_object, FieldResult, GraphQLInputObject}; use serde::{Deserialize, Serialize}; @@ -48,6 +48,9 @@ impl TryInto for RequestFilter { return Err("Multiple fields specified in request filter".to_string()); } if let Some(e) = self.eq { + if e.field.to_lowercase() == "uid" { + return Ok(DomainRequestFilter::UserId(UserId::new(&e.value))); + } return Ok(DomainRequestFilter::Equality(e.field, e.value)); } if let Some(c) = self.any { @@ -109,7 +112,7 @@ impl Query { } Ok(context .handler - .get_user_details(&user_id) + .get_user_details(&UserId::new(&user_id)) .await .map(Into::into)?) } @@ -170,7 +173,7 @@ impl Default for User { #[graphql_object(context = Context)] impl User { fn id(&self) -> &str { - &self.user.user_id + self.user.user_id.as_str() } fn email(&self) -> &str { @@ -260,7 +263,7 @@ impl From for Group { Self { group_id: group.id.0, display_name: group.display_name, - members: Some(group.users.into_iter().map(Into::into).collect()), + members: Some(group.users.into_iter().map(UserId::into_string).collect()), _phantom: std::marker::PhantomData, } } @@ -305,10 +308,10 @@ mod tests { let mut mock = MockTestBackendHandler::new(); mock.expect_get_user_details() - .with(eq("bob")) + .with(eq(UserId::new("bob"))) .return_once(|_| { Ok(DomainUser { - user_id: "bob".to_string(), + user_id: UserId::new("bob"), email: "bob@bobbers.on".to_string(), ..Default::default() }) @@ -316,7 +319,7 @@ mod tests { let mut groups = HashSet::new(); groups.insert(GroupIdAndName(GroupId(3), "Bobbersons".to_string())); mock.expect_get_user_groups() - .with(eq("bob")) + .with(eq(UserId::new("bob"))) .return_once(|_| Ok(groups)); let context = Context:: { @@ -369,12 +372,12 @@ mod tests { .return_once(|_| { Ok(vec![ DomainUser { - user_id: "bob".to_string(), + user_id: UserId::new("bob"), email: "bob@bobbers.on".to_string(), ..Default::default() }, DomainUser { - user_id: "robert".to_string(), + user_id: UserId::new("robert"), email: "robert@bobbers.on".to_string(), ..Default::default() }, diff --git a/server/src/infra/ldap_handler.rs b/server/src/infra/ldap_handler.rs index 9e47659..9c5729f 100644 --- a/server/src/infra/ldap_handler.rs +++ b/server/src/infra/ldap_handler.rs @@ -1,6 +1,6 @@ use crate::domain::{ handler::{ - BackendHandler, BindRequest, Group, GroupRequestFilter, LoginHandler, User, + BackendHandler, BindRequest, Group, GroupRequestFilter, LoginHandler, User, UserId, UserRequestFilter, }, opaque_handler::OpaqueHandler, @@ -71,7 +71,7 @@ fn get_user_id_from_distinguished_name( dn: &str, base_tree: &[(String, String)], base_dn_str: &str, -) -> Result { +) -> Result { let parts = parse_distinguished_name(dn).context("while parsing a user ID")?; if !is_subtree(&parts, base_tree) { bail!("Not a subtree of the base tree"); @@ -84,7 +84,7 @@ fn get_user_id_from_distinguished_name( base_dn_str ); } - Ok(parts[0].1.to_string()) + Ok(UserId::new(&parts[0].1)) } else { bail!( r#"Unexpected user DN format. Got "{}", expected: "cn=username,ou=people,{}""#, @@ -103,7 +103,7 @@ fn get_user_attribute(user: &User, attribute: &str, dn: &str) -> Result Ok(vec![dn.to_string()]), - "uid" => Ok(vec![user.user_id.clone()]), + "uid" => Ok(vec![user.user_id.to_string()]), "mail" => Ok(vec![user.email.clone()]), "givenname" => Ok(vec![user.first_name.clone()]), "sn" => Ok(vec![user.last_name.clone()]), @@ -118,7 +118,7 @@ fn make_ldap_search_user_result_entry( base_dn_str: &str, attributes: &[String], ) -> Result { - let dn = format!("cn={},ou=people,{}", user.user_id, base_dn_str); + let dn = format!("cn={},ou=people,{}", user.user_id.as_str(), base_dn_str); Ok(LdapSearchResultEntry { dn: dn.clone(), attributes: attributes @@ -264,17 +264,17 @@ fn root_dse_response(base_dn: &str) -> LdapOp { } pub struct LdapHandler { - dn: String, + dn: UserId, backend_handler: Backend, pub base_dn: Vec<(String, String)>, base_dn_str: String, - ldap_user_dn: String, + ldap_user_dn: UserId, } impl LdapHandler { - pub fn new(backend_handler: Backend, ldap_base_dn: String, ldap_user_dn: String) -> Self { + pub fn new(backend_handler: Backend, ldap_base_dn: String, ldap_user_dn: UserId) -> Self { Self { - dn: "Unauthenticated".to_string(), + dn: UserId::new("unauthenticated"), backend_handler, base_dn: parse_distinguished_name(&ldap_base_dn).unwrap_or_else(|_| { panic!( @@ -282,7 +282,7 @@ impl LdapHandler LdapHandler { - self.dn = request.dn.clone(); + self.dn = UserId::new(&request.dn); (LdapResultCode::Success, "".to_string()) } Err(_) => (LdapResultCode::InvalidCredentials, "".to_string()), } } - async fn change_password(&mut self, user: &str, password: &str) -> Result<()> { + async fn change_password(&mut self, user: &UserId, password: &str) -> Result<()> { use lldap_auth::*; let mut rng = rand::rngs::OsRng; let registration_start_request = @@ -527,7 +527,7 @@ impl LdapHandler self.do_search(&request).await, LdapOp::UnbindRequest => { - self.dn = "Unauthenticated".to_string(); + self.dn = UserId::new("unauthenticated"); // No need to notify on unbind (per rfc4511) return None; } @@ -617,10 +617,12 @@ impl LdapHandler { @@ -661,17 +663,17 @@ mod tests { impl BackendHandler for TestBackendHandler { async fn list_users(&self, filters: Option) -> Result>; async fn list_groups(&self, filters: Option) -> Result>; - async fn get_user_details(&self, user_id: &str) -> Result; + async fn get_user_details(&self, user_id: &UserId) -> Result; async fn get_group_details(&self, group_id: GroupId) -> Result; - async fn get_user_groups(&self, user: &str) -> Result>; + async fn get_user_groups(&self, user: &UserId) -> Result>; async fn create_user(&self, request: CreateUserRequest) -> Result<()>; async fn update_user(&self, request: UpdateUserRequest) -> Result<()>; async fn update_group(&self, request: UpdateGroupRequest) -> Result<()>; - async fn delete_user(&self, user_id: &str) -> Result<()>; + async fn delete_user(&self, user_id: &UserId) -> Result<()>; async fn create_group(&self, group_name: &str) -> Result; async fn delete_group(&self, group_id: GroupId) -> Result<()>; - async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> Result<()>; - async fn remove_user_from_group(&self, user_id: &str, group_id: GroupId) -> Result<()>; + async fn add_user_to_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()>; + async fn remove_user_from_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()>; } #[async_trait] impl OpaqueHandler for TestBackendHandler { @@ -679,7 +681,7 @@ mod tests { &self, request: login::ClientLoginStartRequest ) -> Result; - async fn login_finish(&self, request: login::ClientLoginFinishRequest) -> Result; + async fn login_finish(&self, request: login::ClientLoginFinishRequest) -> Result; async fn registration_start( &self, request: registration::ClientRegistrationStartRequest @@ -720,12 +722,12 @@ mod tests { ) -> LdapHandler { mock.expect_bind() .with(eq(BindRequest { - name: "test".to_string(), + name: UserId::new("test"), password: "pass".to_string(), })) .return_once(|_| Ok(())); let mut ldap_handler = - LdapHandler::new(mock, "dc=example,dc=com".to_string(), "test".to_string()); + LdapHandler::new(mock, "dc=example,dc=com".to_string(), UserId::new("test")); let request = LdapBindRequest { dn: "cn=test,ou=people,dc=example,dc=com".to_string(), cred: LdapBindCred::Simple("pass".to_string()), @@ -742,13 +744,13 @@ mod tests { let mut mock = MockTestBackendHandler::new(); mock.expect_bind() .with(eq(crate::domain::handler::BindRequest { - name: "bob".to_string(), + name: UserId::new("bob"), password: "pass".to_string(), })) .times(1) .return_once(|_| Ok(())); let mut ldap_handler = - LdapHandler::new(mock, "dc=example,dc=com".to_string(), "test".to_string()); + LdapHandler::new(mock, "dc=example,dc=com".to_string(), UserId::new("test")); let request = LdapOp::BindRequest(LdapBindRequest { dn: "cn=bob,ou=people,dc=example,dc=com".to_string(), @@ -773,13 +775,13 @@ mod tests { let mut mock = MockTestBackendHandler::new(); mock.expect_bind() .with(eq(crate::domain::handler::BindRequest { - name: "test".to_string(), + name: UserId::new("test"), password: "pass".to_string(), })) .times(1) .return_once(|_| Ok(())); let mut ldap_handler = - LdapHandler::new(mock, "dc=example,dc=com".to_string(), "test".to_string()); + LdapHandler::new(mock, "dc=example,dc=com".to_string(), UserId::new("test")); let request = LdapBindRequest { dn: "cn=test,ou=people,dc=example,dc=com".to_string(), @@ -796,13 +798,13 @@ mod tests { let mut mock = MockTestBackendHandler::new(); mock.expect_bind() .with(eq(crate::domain::handler::BindRequest { - name: "test".to_string(), + name: UserId::new("test"), password: "pass".to_string(), })) .times(1) .return_once(|_| Ok(())); let mut ldap_handler = - LdapHandler::new(mock, "dc=example,dc=com".to_string(), "admin".to_string()); + LdapHandler::new(mock, "dc=example,dc=com".to_string(), UserId::new("admin")); let request = LdapBindRequest { dn: "cn=test,ou=people,dc=example,dc=com".to_string(), @@ -827,7 +829,7 @@ mod tests { async fn test_bind_invalid_dn() { let mock = MockTestBackendHandler::new(); let mut ldap_handler = - LdapHandler::new(mock, "dc=example,dc=com".to_string(), "admin".to_string()); + LdapHandler::new(mock, "dc=example,dc=com".to_string(), UserId::new("admin")); let request = LdapBindRequest { dn: "cn=bob,dc=example,dc=com".to_string(), @@ -903,7 +905,7 @@ mod tests { mock.expect_list_users().times(1).return_once(|_| { Ok(vec![ User { - user_id: "bob_1".to_string(), + user_id: UserId::new("bob_1"), email: "bob@bobmail.bob".to_string(), display_name: "Bôb Böbberson".to_string(), first_name: "Bôb".to_string(), @@ -911,7 +913,7 @@ mod tests { ..Default::default() }, User { - user_id: "jim".to_string(), + user_id: UserId::new("jim"), email: "jim@cricket.jim".to_string(), display_name: "Jimminy Cricket".to_string(), first_name: "Jim".to_string(), @@ -1037,12 +1039,12 @@ mod tests { Group { id: GroupId(1), display_name: "group_1".to_string(), - users: vec!["bob".to_string(), "john".to_string()], + users: vec![UserId::new("bob"), UserId::new("john")], }, Group { id: GroupId(3), display_name: "bestgroup".to_string(), - users: vec!["john".to_string()], + users: vec![UserId::new("john")], }, ]) }); @@ -1111,7 +1113,7 @@ mod tests { mock.expect_list_groups() .with(eq(Some(GroupRequestFilter::And(vec![ GroupRequestFilter::DisplayName("group_1".to_string()), - GroupRequestFilter::Member("bob".to_string()), + GroupRequestFilter::Member(UserId::new("bob")), GroupRequestFilter::And(vec![]), ])))) .times(1) @@ -1250,10 +1252,7 @@ mod tests { mock.expect_list_users() .with(eq(Some(UserRequestFilter::And(vec![ UserRequestFilter::Or(vec![ - UserRequestFilter::Not(Box::new(UserRequestFilter::Equality( - "user_id".to_string(), - "bob".to_string(), - ))), + UserRequestFilter::Not(Box::new(UserRequestFilter::UserId(UserId::new("bob")))), UserRequestFilter::And(vec![]), UserRequestFilter::Not(Box::new(UserRequestFilter::And(vec![]))), UserRequestFilter::And(vec![]), @@ -1342,7 +1341,7 @@ mod tests { .times(1) .return_once(|_| { Ok(vec![User { - user_id: "bob_1".to_string(), + user_id: UserId::new("bob_1"), ..Default::default() }]) }); @@ -1378,7 +1377,7 @@ mod tests { let mut mock = MockTestBackendHandler::new(); mock.expect_list_users().times(1).return_once(|_| { Ok(vec![User { - user_id: "bob_1".to_string(), + user_id: UserId::new("bob_1"), email: "bob@bobmail.bob".to_string(), display_name: "Bôb Böbberson".to_string(), first_name: "Bôb".to_string(), @@ -1393,7 +1392,7 @@ mod tests { Ok(vec![Group { id: GroupId(1), display_name: "group_1".to_string(), - users: vec!["bob".to_string(), "john".to_string()], + users: vec![UserId::new("bob"), UserId::new("john")], }]) }); let mut ldap_handler = setup_bound_handler(mock).await; diff --git a/server/src/infra/sql_backend_handler.rs b/server/src/infra/sql_backend_handler.rs index 91f8fb1..3a76467 100644 --- a/server/src/infra/sql_backend_handler.rs +++ b/server/src/infra/sql_backend_handler.rs @@ -1,5 +1,5 @@ use super::{jwt_sql_tables::*, tcp_backend_handler::*}; -use crate::domain::{error::*, sql_backend_handler::SqlBackendHandler}; +use crate::domain::{error::*, handler::UserId, sql_backend_handler::SqlBackendHandler}; use async_trait::async_trait; use futures_util::StreamExt; use sea_query::{Expr, Iden, Query, SimpleExpr}; @@ -34,7 +34,7 @@ impl TcpBackendHandler for SqlBackendHandler { .map_err(|e| anyhow::anyhow!(e)) } - async fn create_refresh_token(&self, user: &str) -> Result<(String, chrono::Duration)> { + async fn create_refresh_token(&self, user: &UserId) -> Result<(String, chrono::Duration)> { use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher}; // TODO: Initialize the rng only once. Maybe Arc? @@ -62,7 +62,7 @@ impl TcpBackendHandler for SqlBackendHandler { Ok((refresh_token, duration)) } - async fn check_token(&self, refresh_token_hash: u64, user: &str) -> Result { + async fn check_token(&self, refresh_token_hash: u64, user: &UserId) -> Result { let query = Query::select() .expr(SimpleExpr::Value(1.into())) .from(JwtRefreshStorage::Table) @@ -74,7 +74,7 @@ impl TcpBackendHandler for SqlBackendHandler { .await? .is_some()) } - async fn blacklist_jwts(&self, user: &str) -> Result> { + async fn blacklist_jwts(&self, user: &UserId) -> Result> { use sqlx::Result; let query = Query::select() .column(JwtStorage::JwtHash) @@ -106,7 +106,7 @@ impl TcpBackendHandler for SqlBackendHandler { Ok(()) } - async fn start_password_reset(&self, user: &str) -> Result> { + async fn start_password_reset(&self, user: &UserId) -> Result> { let query = Query::select() .column(Users::UserId) .from(Users::Table) @@ -138,7 +138,7 @@ impl TcpBackendHandler for SqlBackendHandler { Ok(Some(token)) } - async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result { + async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result { let query = Query::select() .column(PasswordResetTokens::UserId) .from(PasswordResetTokens::Table) diff --git a/server/src/infra/tcp_backend_handler.rs b/server/src/infra/tcp_backend_handler.rs index 79fa102..0dee16c 100644 --- a/server/src/infra/tcp_backend_handler.rs +++ b/server/src/infra/tcp_backend_handler.rs @@ -1,22 +1,22 @@ use async_trait::async_trait; use std::collections::HashSet; -use crate::domain::error::Result; +use crate::domain::{error::Result, handler::UserId}; #[async_trait] pub trait TcpBackendHandler { async fn get_jwt_blacklist(&self) -> anyhow::Result>; - async fn create_refresh_token(&self, user: &str) -> Result<(String, chrono::Duration)>; - async fn check_token(&self, refresh_token_hash: u64, user: &str) -> Result; - async fn blacklist_jwts(&self, user: &str) -> Result>; + async fn create_refresh_token(&self, user: &UserId) -> Result<(String, chrono::Duration)>; + async fn check_token(&self, refresh_token_hash: u64, user: &UserId) -> Result; + async fn blacklist_jwts(&self, user: &UserId) -> Result>; async fn delete_refresh_token(&self, refresh_token_hash: u64) -> Result<()>; /// Request a token to reset a user's password. /// If the user doesn't exist, returns `Ok(None)`, otherwise `Ok(Some(token))`. - async fn start_password_reset(&self, user: &str) -> Result>; + async fn start_password_reset(&self, user: &UserId) -> Result>; /// Get the user ID associated with a password reset token. - async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result; + async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result; async fn delete_password_reset_token(&self, token: &str) -> Result<()>; } @@ -37,27 +37,27 @@ mockall::mock! { impl BackendHandler for TestTcpBackendHandler { async fn list_users(&self, filters: Option) -> Result>; async fn list_groups(&self, filters: Option) -> Result>; - async fn get_user_details(&self, user_id: &str) -> Result; + async fn get_user_details(&self, user_id: &UserId) -> Result; async fn get_group_details(&self, group_id: GroupId) -> Result; - async fn get_user_groups(&self, user: &str) -> Result>; + async fn get_user_groups(&self, user: &UserId) -> Result>; async fn create_user(&self, request: CreateUserRequest) -> Result<()>; async fn update_user(&self, request: UpdateUserRequest) -> Result<()>; async fn update_group(&self, request: UpdateGroupRequest) -> Result<()>; - async fn delete_user(&self, user_id: &str) -> Result<()>; + async fn delete_user(&self, user_id: &UserId) -> Result<()>; async fn create_group(&self, group_name: &str) -> Result; async fn delete_group(&self, group_id: GroupId) -> Result<()>; - async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> Result<()>; - async fn remove_user_from_group(&self, user_id: &str, group_id: GroupId) -> Result<()>; + async fn add_user_to_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()>; + async fn remove_user_from_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()>; } #[async_trait] impl TcpBackendHandler for TestTcpBackendHandler { async fn get_jwt_blacklist(&self) -> anyhow::Result>; - async fn create_refresh_token(&self, user: &str) -> Result<(String, chrono::Duration)>; - async fn check_token(&self, refresh_token_hash: u64, user: &str) -> Result; - async fn blacklist_jwts(&self, user: &str) -> Result>; + async fn create_refresh_token(&self, user: &UserId) -> Result<(String, chrono::Duration)>; + async fn check_token(&self, refresh_token_hash: u64, user: &UserId) -> Result; + async fn blacklist_jwts(&self, user: &UserId) -> Result>; async fn delete_refresh_token(&self, refresh_token_hash: u64) -> Result<()>; - async fn start_password_reset(&self, user: &str) -> Result>; - async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result; + async fn start_password_reset(&self, user: &UserId) -> Result>; + async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result; async fn delete_password_reset_token(&self, token: &str) -> Result<()>; } }