domain: introduce UserId to make uid case insensitive

Note that if there was a non-lowercase user already in the DB, it cannot
be found again. To fix this, run in the DB:

sqlite> UPDATE users SET user_id = LOWER(user_id);
This commit is contained in:
Valentin Tolmer 2022-03-26 18:00:37 +01:00 committed by nitnelave
parent 26cedcb621
commit ca19e61f50
13 changed files with 299 additions and 181 deletions

View File

@ -3,7 +3,7 @@ use thiserror::Error;
#[allow(clippy::enum_variant_names)] #[allow(clippy::enum_variant_names)]
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum DomainError { pub enum DomainError {
#[error("Authentication error for `{0}`")] #[error("Authentication error: `{0}`")]
AuthenticationError(String), AuthenticationError(String),
#[error("Database error: `{0}`")] #[error("Database error: `{0}`")]
DatabaseError(#[from] sqlx::Error), DatabaseError(#[from] sqlx::Error),

View File

@ -3,10 +3,41 @@ use async_trait::async_trait;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashSet; use std::collections::HashSet;
#[derive(PartialEq, Eq, Clone, Debug, Default, Serialize, Deserialize)]
#[cfg_attr(not(target_arch = "wasm32"), derive(sqlx::FromRow))]
#[serde(from = "String")]
pub struct UserId(String);
impl UserId {
pub fn new(user_id: &str) -> Self {
Self(user_id.to_lowercase())
}
pub fn as_str(&self) -> &str {
self.0.as_str()
}
pub fn into_string(self) -> String {
self.0
}
}
impl std::fmt::Display for UserId {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl From<String> for UserId {
fn from(s: String) -> Self {
Self::new(&s)
}
}
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize)] #[derive(PartialEq, Eq, Debug, Serialize, Deserialize)]
#[cfg_attr(not(target_arch = "wasm32"), derive(sqlx::FromRow))] #[cfg_attr(not(target_arch = "wasm32"), derive(sqlx::FromRow))]
pub struct User { pub struct User {
pub user_id: String, pub user_id: UserId,
pub email: String, pub email: String,
pub display_name: String, pub display_name: String,
pub first_name: String, pub first_name: String,
@ -19,7 +50,7 @@ impl Default for User {
fn default() -> Self { fn default() -> Self {
use chrono::TimeZone; use chrono::TimeZone;
User { User {
user_id: String::new(), user_id: UserId::default(),
email: String::new(), email: String::new(),
display_name: String::new(), display_name: String::new(),
first_name: String::new(), first_name: String::new(),
@ -33,12 +64,12 @@ impl Default for User {
pub struct Group { pub struct Group {
pub id: GroupId, pub id: GroupId,
pub display_name: String, pub display_name: String,
pub users: Vec<String>, pub users: Vec<UserId>,
} }
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Clone)] #[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Clone)]
pub struct BindRequest { pub struct BindRequest {
pub name: String, pub name: UserId,
pub password: String, pub password: String,
} }
@ -47,6 +78,7 @@ pub enum UserRequestFilter {
And(Vec<UserRequestFilter>), And(Vec<UserRequestFilter>),
Or(Vec<UserRequestFilter>), Or(Vec<UserRequestFilter>),
Not(Box<UserRequestFilter>), Not(Box<UserRequestFilter>),
UserId(UserId),
Equality(String, String), Equality(String, String),
// Check if a user belongs to a group identified by name. // Check if a user belongs to a group identified by name.
MemberOf(String), MemberOf(String),
@ -62,13 +94,13 @@ pub enum GroupRequestFilter {
DisplayName(String), DisplayName(String),
GroupId(GroupId), GroupId(GroupId),
// Check if the group contains a user identified by uid. // Check if the group contains a user identified by uid.
Member(String), Member(UserId),
} }
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Clone, Default)] #[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Clone, Default)]
pub struct CreateUserRequest { pub struct CreateUserRequest {
// Same fields as User, but no creation_date, and with password. // Same fields as User, but no creation_date, and with password.
pub user_id: String, pub user_id: UserId,
pub email: String, pub email: String,
pub display_name: Option<String>, pub display_name: Option<String>,
pub first_name: Option<String>, pub first_name: Option<String>,
@ -78,7 +110,7 @@ pub struct CreateUserRequest {
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Clone, Default)] #[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Clone, Default)]
pub struct UpdateUserRequest { pub struct UpdateUserRequest {
// Same fields as CreateUserRequest, but no with an extra layer of Option. // Same fields as CreateUserRequest, but no with an extra layer of Option.
pub user_id: String, pub user_id: UserId,
pub email: Option<String>, pub email: Option<String>,
pub display_name: Option<String>, pub display_name: Option<String>,
pub first_name: Option<String>, pub first_name: Option<String>,
@ -106,17 +138,17 @@ pub struct GroupIdAndName(pub GroupId, pub String);
pub trait BackendHandler: Clone + Send { pub trait BackendHandler: Clone + Send {
async fn list_users(&self, filters: Option<UserRequestFilter>) -> Result<Vec<User>>; async fn list_users(&self, filters: Option<UserRequestFilter>) -> Result<Vec<User>>;
async fn list_groups(&self, filters: Option<GroupRequestFilter>) -> Result<Vec<Group>>; async fn list_groups(&self, filters: Option<GroupRequestFilter>) -> Result<Vec<Group>>;
async fn get_user_details(&self, user_id: &str) -> Result<User>; async fn get_user_details(&self, user_id: &UserId) -> Result<User>;
async fn get_group_details(&self, group_id: GroupId) -> Result<GroupIdAndName>; async fn get_group_details(&self, group_id: GroupId) -> Result<GroupIdAndName>;
async fn create_user(&self, request: CreateUserRequest) -> Result<()>; async fn create_user(&self, request: CreateUserRequest) -> Result<()>;
async fn update_user(&self, request: UpdateUserRequest) -> Result<()>; async fn update_user(&self, request: UpdateUserRequest) -> Result<()>;
async fn update_group(&self, request: UpdateGroupRequest) -> Result<()>; async fn update_group(&self, request: UpdateGroupRequest) -> Result<()>;
async fn delete_user(&self, user_id: &str) -> Result<()>; async fn delete_user(&self, user_id: &UserId) -> Result<()>;
async fn create_group(&self, group_name: &str) -> Result<GroupId>; async fn create_group(&self, group_name: &str) -> Result<GroupId>;
async fn delete_group(&self, group_id: GroupId) -> Result<()>; async fn delete_group(&self, group_id: GroupId) -> Result<()>;
async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> Result<()>; async fn add_user_to_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()>;
async fn remove_user_from_group(&self, user_id: &str, group_id: GroupId) -> Result<()>; async fn remove_user_from_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()>;
async fn get_user_groups(&self, user: &str) -> Result<HashSet<GroupIdAndName>>; async fn get_user_groups(&self, user_id: &UserId) -> Result<HashSet<GroupIdAndName>>;
} }
#[cfg(test)] #[cfg(test)]
@ -129,17 +161,17 @@ mockall::mock! {
impl BackendHandler for TestBackendHandler { impl BackendHandler for TestBackendHandler {
async fn list_users(&self, filters: Option<UserRequestFilter>) -> Result<Vec<User>>; async fn list_users(&self, filters: Option<UserRequestFilter>) -> Result<Vec<User>>;
async fn list_groups(&self, filters: Option<GroupRequestFilter>) -> Result<Vec<Group>>; async fn list_groups(&self, filters: Option<GroupRequestFilter>) -> Result<Vec<Group>>;
async fn get_user_details(&self, user_id: &str) -> Result<User>; async fn get_user_details(&self, user_id: &UserId) -> Result<User>;
async fn get_group_details(&self, group_id: GroupId) -> Result<GroupIdAndName>; async fn get_group_details(&self, group_id: GroupId) -> Result<GroupIdAndName>;
async fn create_user(&self, request: CreateUserRequest) -> Result<()>; async fn create_user(&self, request: CreateUserRequest) -> Result<()>;
async fn update_user(&self, request: UpdateUserRequest) -> Result<()>; async fn update_user(&self, request: UpdateUserRequest) -> Result<()>;
async fn update_group(&self, request: UpdateGroupRequest) -> Result<()>; async fn update_group(&self, request: UpdateGroupRequest) -> Result<()>;
async fn delete_user(&self, user_id: &str) -> Result<()>; async fn delete_user(&self, user_id: &UserId) -> Result<()>;
async fn create_group(&self, group_name: &str) -> Result<GroupId>; async fn create_group(&self, group_name: &str) -> Result<GroupId>;
async fn delete_group(&self, group_id: GroupId) -> Result<()>; async fn delete_group(&self, group_id: GroupId) -> Result<()>;
async fn get_user_groups(&self, user: &str) -> Result<HashSet<GroupIdAndName>>; async fn get_user_groups(&self, user_id: &UserId) -> Result<HashSet<GroupIdAndName>>;
async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> Result<()>; async fn add_user_to_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()>;
async fn remove_user_from_group(&self, user_id: &str, group_id: GroupId) -> Result<()>; async fn remove_user_from_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()>;
} }
#[async_trait] #[async_trait]
impl LoginHandler for TestBackendHandler { impl LoginHandler for TestBackendHandler {

View File

@ -1,4 +1,4 @@
use super::error::*; use crate::domain::{error::*, handler::UserId};
use async_trait::async_trait; use async_trait::async_trait;
pub use lldap_auth::{login, registration}; pub use lldap_auth::{login, registration};
@ -9,7 +9,7 @@ pub trait OpaqueHandler: Clone + Send {
&self, &self,
request: login::ClientLoginStartRequest, request: login::ClientLoginStartRequest,
) -> Result<login::ServerLoginStartResponse>; ) -> Result<login::ServerLoginStartResponse>;
async fn login_finish(&self, request: login::ClientLoginFinishRequest) -> Result<String>; async fn login_finish(&self, request: login::ClientLoginFinishRequest) -> Result<UserId>;
async fn registration_start( async fn registration_start(
&self, &self,
request: registration::ClientRegistrationStartRequest, request: registration::ClientRegistrationStartRequest,
@ -32,7 +32,7 @@ mockall::mock! {
&self, &self,
request: login::ClientLoginStartRequest request: login::ClientLoginStartRequest
) -> Result<login::ServerLoginStartResponse>; ) -> Result<login::ServerLoginStartResponse>;
async fn login_finish(&self, request: login::ClientLoginFinishRequest ) -> Result<String>; async fn login_finish(&self, request: login::ClientLoginFinishRequest ) -> Result<UserId>;
async fn registration_start( async fn registration_start(
&self, &self,
request: registration::ClientRegistrationStartRequest request: registration::ClientRegistrationStartRequest

View File

@ -51,12 +51,16 @@ fn get_user_filter_expr(filter: UserRequestFilter) -> (RequiresGroup, SimpleExpr
let (requires_group, filters) = get_user_filter_expr(*f); let (requires_group, filters) = get_user_filter_expr(*f);
(requires_group, Expr::not(Expr::expr(filters))) (requires_group, Expr::not(Expr::expr(filters)))
} }
UserId(user_id) => (
RequiresGroup(false),
Expr::col((Users::Table, Users::UserId)).eq(user_id),
),
Equality(s1, s2) => ( Equality(s1, s2) => (
RequiresGroup(false), RequiresGroup(false),
if s1 == Users::DisplayName.to_string() { if s1 == Users::DisplayName.to_string() {
Expr::col((Users::Table, Users::DisplayName)).eq(s2) Expr::col((Users::Table, Users::DisplayName)).eq(s2)
} else if s1 == Users::UserId.to_string() { } else if s1 == Users::UserId.to_string() {
Expr::col((Users::Table, Users::UserId)).eq(s2) panic!("User id should be wrapped")
} else { } else {
Expr::expr(Expr::cust(&s1)).eq(s2) Expr::expr(Expr::cust(&s1)).eq(s2)
}, },
@ -205,17 +209,17 @@ impl BackendHandler for SqlBackendHandler {
id: group_id, id: group_id,
display_name, display_name,
users: rows users: rows
.map(|row| row.get::<String, _>(&*Memberships::UserId.to_string())) .map(|row| row.get::<UserId, _>(&*Memberships::UserId.to_string()))
// If a group has no users, an empty string is returned because of the left // If a group has no users, an empty string is returned because of the left
// join. // join.
.filter(|s| !s.is_empty()) .filter(|s| !s.as_str().is_empty())
.collect(), .collect(),
}); });
} }
Ok(groups) Ok(groups)
} }
async fn get_user_details(&self, user_id: &str) -> Result<User> { async fn get_user_details(&self, user_id: &UserId) -> Result<User> {
let query = Query::select() let query = Query::select()
.column(Users::UserId) .column(Users::UserId)
.column(Users::Email) .column(Users::Email)
@ -246,8 +250,8 @@ impl BackendHandler for SqlBackendHandler {
.await?) .await?)
} }
async fn get_user_groups(&self, user: &str) -> Result<HashSet<GroupIdAndName>> { async fn get_user_groups(&self, user_id: &UserId) -> Result<HashSet<GroupIdAndName>> {
if user == self.config.ldap_user_dn { if *user_id == self.config.ldap_user_dn {
let mut groups = HashSet::new(); let mut groups = HashSet::new();
groups.insert(GroupIdAndName(GroupId(1), "lldap_admin".to_string())); groups.insert(GroupIdAndName(GroupId(1), "lldap_admin".to_string()));
return Ok(groups); return Ok(groups);
@ -261,7 +265,7 @@ impl BackendHandler for SqlBackendHandler {
Expr::tbl(Groups::Table, Groups::GroupId) Expr::tbl(Groups::Table, Groups::GroupId)
.equals(Memberships::Table, Memberships::GroupId), .equals(Memberships::Table, Memberships::GroupId),
) )
.and_where(Expr::col(Memberships::UserId).eq(user)) .and_where(Expr::col(Memberships::UserId).eq(user_id))
.to_string(DbQueryBuilder {}); .to_string(DbQueryBuilder {});
sqlx::query(&query) sqlx::query(&query)
@ -294,7 +298,7 @@ impl BackendHandler for SqlBackendHandler {
Users::CreationDate, Users::CreationDate,
]; ];
let values = vec![ let values = vec![
request.user_id.clone().into(), request.user_id.into(),
request.email.into(), request.email.into(),
request.display_name.unwrap_or_default().into(), request.display_name.unwrap_or_default().into(),
request.first_name.unwrap_or_default().into(), request.first_name.unwrap_or_default().into(),
@ -353,7 +357,7 @@ impl BackendHandler for SqlBackendHandler {
Ok(()) Ok(())
} }
async fn delete_user(&self, user_id: &str) -> Result<()> { async fn delete_user(&self, user_id: &UserId) -> Result<()> {
let delete_query = Query::delete() let delete_query = Query::delete()
.from_table(Users::Table) .from_table(Users::Table)
.and_where(Expr::col(Users::UserId).eq(user_id)) .and_where(Expr::col(Users::UserId).eq(user_id))
@ -387,7 +391,7 @@ impl BackendHandler for SqlBackendHandler {
Ok(()) Ok(())
} }
async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> Result<()> { async fn add_user_to_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()> {
let query = Query::insert() let query = Query::insert()
.into_table(Memberships::Table) .into_table(Memberships::Table)
.columns(vec![Memberships::UserId, Memberships::GroupId]) .columns(vec![Memberships::UserId, Memberships::GroupId])
@ -397,7 +401,7 @@ impl BackendHandler for SqlBackendHandler {
Ok(()) Ok(())
} }
async fn remove_user_from_group(&self, user_id: &str, group_id: GroupId) -> Result<()> { async fn remove_user_from_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()> {
let query = Query::delete() let query = Query::delete()
.from_table(Memberships::Table) .from_table(Memberships::Table)
.and_where(Expr::col(Memberships::GroupId).eq(group_id)) .and_where(Expr::col(Memberships::GroupId).eq(group_id))
@ -463,7 +467,7 @@ mod tests {
async fn insert_user_no_password(handler: &SqlBackendHandler, name: &str) { async fn insert_user_no_password(handler: &SqlBackendHandler, name: &str) {
handler handler
.create_user(CreateUserRequest { .create_user(CreateUserRequest {
user_id: name.to_string(), user_id: UserId::new(name),
email: "bob@bob.bob".to_string(), email: "bob@bob.bob".to_string(),
..Default::default() ..Default::default()
}) })
@ -476,21 +480,24 @@ mod tests {
} }
async fn insert_membership(handler: &SqlBackendHandler, group_id: GroupId, user_id: &str) { async fn insert_membership(handler: &SqlBackendHandler, group_id: GroupId, user_id: &str) {
handler.add_user_to_group(user_id, group_id).await.unwrap(); handler
.add_user_to_group(&UserId::new(user_id), group_id)
.await
.unwrap();
} }
#[tokio::test] #[tokio::test]
async fn test_bind_admin() { async fn test_bind_admin() {
let sql_pool = get_in_memory_db().await; let sql_pool = get_in_memory_db().await;
let config = ConfigurationBuilder::default() let config = ConfigurationBuilder::default()
.ldap_user_dn("admin".to_string()) .ldap_user_dn(UserId::new("admin"))
.ldap_user_pass(secstr::SecUtf8::from("test")) .ldap_user_pass(secstr::SecUtf8::from("test"))
.build() .build()
.unwrap(); .unwrap();
let handler = SqlBackendHandler::new(config, sql_pool); let handler = SqlBackendHandler::new(config, sql_pool);
handler handler
.bind(BindRequest { .bind(BindRequest {
name: "admin".to_string(), name: UserId::new("admin"),
password: "test".to_string(), password: "test".to_string(),
}) })
.await .await
@ -506,21 +513,21 @@ mod tests {
handler handler
.bind(BindRequest { .bind(BindRequest {
name: "bob".to_string(), name: UserId::new("bob"),
password: "bob00".to_string(), password: "bob00".to_string(),
}) })
.await .await
.unwrap(); .unwrap();
handler handler
.bind(BindRequest { .bind(BindRequest {
name: "andrew".to_string(), name: UserId::new("andrew"),
password: "bob00".to_string(), password: "bob00".to_string(),
}) })
.await .await
.unwrap_err(); .unwrap_err();
handler handler
.bind(BindRequest { .bind(BindRequest {
name: "bob".to_string(), name: UserId::new("bob"),
password: "wrong_password".to_string(), password: "wrong_password".to_string(),
}) })
.await .await
@ -536,7 +543,7 @@ mod tests {
handler handler
.bind(BindRequest { .bind(BindRequest {
name: "bob".to_string(), name: UserId::new("bob"),
password: "bob00".to_string(), password: "bob00".to_string(),
}) })
.await .await
@ -557,47 +564,44 @@ mod tests {
.await .await
.unwrap() .unwrap()
.into_iter() .into_iter()
.map(|u| u.user_id) .map(|u| u.user_id.to_string())
.collect::<Vec<_>>(); .collect::<Vec<_>>();
assert_eq!(users, vec!["John", "bob", "patrick"]); assert_eq!(users, vec!["bob", "john", "patrick"]);
} }
{ {
let users = handler let users = handler
.list_users(Some(UserRequestFilter::Equality( .list_users(Some(UserRequestFilter::UserId(UserId::new("bob"))))
"user_id".to_string(),
"bob".to_string(),
)))
.await .await
.unwrap() .unwrap()
.into_iter() .into_iter()
.map(|u| u.user_id) .map(|u| u.user_id.to_string())
.collect::<Vec<_>>(); .collect::<Vec<_>>();
assert_eq!(users, vec!["bob"]); assert_eq!(users, vec!["bob"]);
} }
{ {
let users = handler let users = handler
.list_users(Some(UserRequestFilter::Or(vec![ .list_users(Some(UserRequestFilter::Or(vec![
UserRequestFilter::Equality("user_id".to_string(), "bob".to_string()), UserRequestFilter::UserId(UserId::new("bob")),
UserRequestFilter::Equality("user_id".to_string(), "John".to_string()), UserRequestFilter::UserId(UserId::new("John")),
]))) ])))
.await .await
.unwrap() .unwrap()
.into_iter() .into_iter()
.map(|u| u.user_id) .map(|u| u.user_id.to_string())
.collect::<Vec<_>>(); .collect::<Vec<_>>();
assert_eq!(users, vec!["John", "bob"]); assert_eq!(users, vec!["bob", "john"]);
} }
{ {
let users = handler let users = handler
.list_users(Some(UserRequestFilter::Not(Box::new( .list_users(Some(UserRequestFilter::Not(Box::new(
UserRequestFilter::Equality("user_id".to_string(), "bob".to_string()), UserRequestFilter::UserId(UserId::new("bob")),
)))) ))))
.await .await
.unwrap() .unwrap()
.into_iter() .into_iter()
.map(|u| u.user_id) .map(|u| u.user_id.to_string())
.collect::<Vec<_>>(); .collect::<Vec<_>>();
assert_eq!(users, vec!["John", "patrick"]); assert_eq!(users, vec!["john", "patrick"]);
} }
} }
@ -622,7 +626,7 @@ mod tests {
Group { Group {
id: group_1, id: group_1,
display_name: "Best Group".to_string(), display_name: "Best Group".to_string(),
users: vec!["bob".to_string(), "patrick".to_string()] users: vec![UserId::new("bob"), UserId::new("patrick")]
}, },
Group { Group {
id: group_3, id: group_3,
@ -632,7 +636,7 @@ mod tests {
Group { Group {
id: group_2, id: group_2,
display_name: "Worst Group".to_string(), display_name: "Worst Group".to_string(),
users: vec!["John".to_string(), "patrick".to_string()] users: vec![UserId::new("john"), UserId::new("patrick")]
}, },
] ]
); );
@ -640,7 +644,7 @@ mod tests {
handler handler
.list_groups(Some(GroupRequestFilter::Or(vec![ .list_groups(Some(GroupRequestFilter::Or(vec![
GroupRequestFilter::DisplayName("Empty Group".to_string()), GroupRequestFilter::DisplayName("Empty Group".to_string()),
GroupRequestFilter::Member("bob".to_string()), GroupRequestFilter::Member(UserId::new("bob")),
]))) ])))
.await .await
.unwrap(), .unwrap(),
@ -648,7 +652,7 @@ mod tests {
Group { Group {
id: group_1, id: group_1,
display_name: "Best Group".to_string(), display_name: "Best Group".to_string(),
users: vec!["bob".to_string(), "patrick".to_string()] users: vec![UserId::new("bob"), UserId::new("patrick")]
}, },
Group { Group {
id: group_3, id: group_3,
@ -670,7 +674,7 @@ mod tests {
vec![Group { vec![Group {
id: group_1, id: group_1,
display_name: "Best Group".to_string(), display_name: "Best Group".to_string(),
users: vec!["bob".to_string(), "patrick".to_string()] users: vec![UserId::new("bob"), UserId::new("patrick")]
}] }]
); );
} }
@ -682,13 +686,35 @@ mod tests {
let handler = SqlBackendHandler::new(config, sql_pool); let handler = SqlBackendHandler::new(config, sql_pool);
insert_user(&handler, "bob", "bob00").await; insert_user(&handler, "bob", "bob00").await;
{ {
let user = handler.get_user_details("bob").await.unwrap(); let user = handler.get_user_details(&UserId::new("bob")).await.unwrap();
assert_eq!(user.user_id, "bob".to_string()); assert_eq!(user.user_id.as_str(), "bob");
} }
{ {
handler.get_user_details("John").await.unwrap_err(); handler
.get_user_details(&UserId::new("John"))
.await
.unwrap_err();
} }
} }
#[tokio::test]
async fn test_user_lowercase() {
let sql_pool = get_initialized_db().await;
let config = get_default_config();
let handler = SqlBackendHandler::new(config, sql_pool);
insert_user(&handler, "Bob", "bob00").await;
{
let user = handler.get_user_details(&UserId::new("bOb")).await.unwrap();
assert_eq!(user.user_id.as_str(), "bob");
}
{
handler
.get_user_details(&UserId::new("John"))
.await
.unwrap_err();
}
}
#[tokio::test] #[tokio::test]
async fn test_get_user_groups() { async fn test_get_user_groups() {
let sql_pool = get_initialized_db().await; let sql_pool = get_initialized_db().await;
@ -707,13 +733,19 @@ mod tests {
let mut patrick_groups = HashSet::new(); let mut patrick_groups = HashSet::new();
patrick_groups.insert(GroupIdAndName(group_1, "Group1".to_string())); patrick_groups.insert(GroupIdAndName(group_1, "Group1".to_string()));
patrick_groups.insert(GroupIdAndName(group_2, "Group2".to_string())); patrick_groups.insert(GroupIdAndName(group_2, "Group2".to_string()));
assert_eq!(handler.get_user_groups("bob").await.unwrap(), bob_groups);
assert_eq!( assert_eq!(
handler.get_user_groups("patrick").await.unwrap(), handler.get_user_groups(&UserId::new("bob")).await.unwrap(),
bob_groups
);
assert_eq!(
handler
.get_user_groups(&UserId::new("patrick"))
.await
.unwrap(),
patrick_groups patrick_groups
); );
assert_eq!( assert_eq!(
handler.get_user_groups("John").await.unwrap(), handler.get_user_groups(&UserId::new("John")).await.unwrap(),
HashSet::new() HashSet::new()
); );
} }
@ -729,29 +761,29 @@ mod tests {
insert_user(&handler, "Jennz", "boupBoup").await; insert_user(&handler, "Jennz", "boupBoup").await;
// Remove a user // Remove a user
let _request_result = handler.delete_user("Jennz").await.unwrap(); let _request_result = handler.delete_user(&UserId::new("Jennz")).await.unwrap();
let users = handler let users = handler
.list_users(None) .list_users(None)
.await .await
.unwrap() .unwrap()
.into_iter() .into_iter()
.map(|u| u.user_id) .map(|u| u.user_id.to_string())
.collect::<Vec<_>>(); .collect::<Vec<_>>();
assert_eq!(users, vec!["Hector", "val"]); assert_eq!(users, vec!["hector", "val"]);
// Insert new user and remove two // Insert new user and remove two
insert_user(&handler, "NewBoi", "Joni").await; insert_user(&handler, "NewBoi", "Joni").await;
let _request_result = handler.delete_user("Hector").await.unwrap(); let _request_result = handler.delete_user(&UserId::new("Hector")).await.unwrap();
let _request_result = handler.delete_user("NewBoi").await.unwrap(); let _request_result = handler.delete_user(&UserId::new("NewBoi")).await.unwrap();
let users = handler let users = handler
.list_users(None) .list_users(None)
.await .await
.unwrap() .unwrap()
.into_iter() .into_iter()
.map(|u| u.user_id) .map(|u| u.user_id.to_string())
.collect::<Vec<_>>(); .collect::<Vec<_>>();
assert_eq!(users, vec!["val"]); assert_eq!(users, vec!["val"]);

View File

@ -1,6 +1,6 @@
use super::{ use super::{
error::*, error::*,
handler::{BindRequest, LoginHandler}, handler::{BindRequest, LoginHandler, UserId},
opaque_handler::*, opaque_handler::*,
sql_backend_handler::SqlBackendHandler, sql_backend_handler::SqlBackendHandler,
sql_tables::*, sql_tables::*,
@ -18,7 +18,7 @@ fn passwords_match(
password_file_bytes: &[u8], password_file_bytes: &[u8],
clear_password: &str, clear_password: &str,
server_setup: &opaque::server::ServerSetup, server_setup: &opaque::server::ServerSetup,
username: &str, username: &UserId,
) -> Result<()> { ) -> Result<()> {
use opaque::{client, server}; use opaque::{client, server};
let mut rng = rand::rngs::OsRng; let mut rng = rand::rngs::OsRng;
@ -31,7 +31,7 @@ fn passwords_match(
server_setup, server_setup,
Some(password_file), Some(password_file),
client_login_start_result.message, client_login_start_result.message,
username, username.as_str(),
)?; )?;
client::login::finish_login( client::login::finish_login(
client_login_start_result.state, client_login_start_result.state,
@ -88,13 +88,16 @@ impl LoginHandler for SqlBackendHandler {
return Ok(()); return Ok(());
} else { } else {
debug!(r#"Invalid password for LDAP bind user"#); debug!(r#"Invalid password for LDAP bind user"#);
return Err(DomainError::AuthenticationError(request.name)); return Err(DomainError::AuthenticationError(format!(
" for user '{}'",
request.name
)));
} }
} }
let query = Query::select() let query = Query::select()
.column(Users::PasswordHash) .column(Users::PasswordHash)
.from(Users::Table) .from(Users::Table)
.and_where(Expr::col(Users::UserId).eq(request.name.as_str())) .and_where(Expr::col(Users::UserId).eq(&request.name))
.to_string(DbQueryBuilder {}); .to_string(DbQueryBuilder {});
if let Ok(row) = sqlx::query(&query).fetch_one(&self.sql_pool).await { if let Ok(row) = sqlx::query(&query).fetch_one(&self.sql_pool).await {
if let Some(password_hash) = if let Some(password_hash) =
@ -106,17 +109,20 @@ impl LoginHandler for SqlBackendHandler {
self.config.get_server_setup(), self.config.get_server_setup(),
&request.name, &request.name,
) { ) {
debug!(r#"Invalid password for "{}": {}"#, request.name, e); debug!(r#"Invalid password for "{}": {}"#, &request.name, e);
} else { } else {
return Ok(()); return Ok(());
} }
} else { } else {
debug!(r#"User "{}" has no password"#, request.name); debug!(r#"User "{}" has no password"#, &request.name);
} }
} else { } else {
debug!(r#"No user found for "{}""#, request.name); debug!(r#"No user found for "{}""#, &request.name);
} }
Err(DomainError::AuthenticationError(request.name)) Err(DomainError::AuthenticationError(format!(
" for user '{}'",
request.name
)))
} }
} }
@ -150,7 +156,7 @@ impl OpaqueHandler for SqlOpaqueHandler {
}) })
} }
async fn login_finish(&self, request: login::ClientLoginFinishRequest) -> Result<String> { async fn login_finish(&self, request: login::ClientLoginFinishRequest) -> Result<UserId> {
let secret_key = self.get_orion_secret_key()?; let secret_key = self.get_orion_secret_key()?;
let login::ServerData { let login::ServerData {
username, username,
@ -165,7 +171,7 @@ impl OpaqueHandler for SqlOpaqueHandler {
opaque::server::login::finish_login(server_login, request.credential_finalization)? opaque::server::login::finish_login(server_login, request.credential_finalization)?
.session_key; .session_key;
Ok(username) Ok(UserId::new(&username))
} }
async fn registration_start( async fn registration_start(
@ -220,7 +226,7 @@ impl OpaqueHandler for SqlOpaqueHandler {
/// Convenience function to set a user's password. /// Convenience function to set a user's password.
pub(crate) async fn register_password( pub(crate) async fn register_password(
opaque_handler: &SqlOpaqueHandler, opaque_handler: &SqlOpaqueHandler,
username: &str, username: &UserId,
password: &SecUtf8, password: &SecUtf8,
) -> Result<()> { ) -> Result<()> {
let mut rng = rand::rngs::OsRng; let mut rng = rand::rngs::OsRng;
@ -278,7 +284,7 @@ mod tests {
async fn insert_user_no_password(handler: &SqlBackendHandler, name: &str) { async fn insert_user_no_password(handler: &SqlBackendHandler, name: &str) {
handler handler
.create_user(CreateUserRequest { .create_user(CreateUserRequest {
user_id: name.to_string(), user_id: UserId::new(name),
email: "bob@bob.bob".to_string(), email: "bob@bob.bob".to_string(),
..Default::default() ..Default::default()
}) })
@ -323,7 +329,12 @@ mod tests {
attempt_login(&opaque_handler, "bob", "bob00") attempt_login(&opaque_handler, "bob", "bob00")
.await .await
.unwrap_err(); .unwrap_err();
register_password(&opaque_handler, "bob", &secstr::SecUtf8::from("bob00")).await?; register_password(
&opaque_handler,
&UserId::new("bob"),
&secstr::SecUtf8::from("bob00"),
)
.await?;
attempt_login(&opaque_handler, "bob", "wrong_password") attempt_login(&opaque_handler, "bob", "wrong_password")
.await .await
.unwrap_err(); .unwrap_err();

View File

@ -1,4 +1,4 @@
use super::handler::GroupId; use super::handler::{GroupId, UserId};
use sea_query::*; use sea_query::*;
pub type Pool = sqlx::sqlite::SqlitePool; pub type Pool = sqlx::sqlite::SqlitePool;
@ -37,6 +37,43 @@ where
} }
} }
impl<DB> sqlx::Type<DB> for UserId
where
DB: sqlx::Database,
String: sqlx::Type<DB>,
{
fn type_info() -> <DB as sqlx::Database>::TypeInfo {
<String as sqlx::Type<DB>>::type_info()
}
fn compatible(ty: &<DB as sqlx::Database>::TypeInfo) -> bool {
<String as sqlx::Type<DB>>::compatible(ty)
}
}
impl<'r, DB> sqlx::Decode<'r, DB> for UserId
where
DB: sqlx::Database,
String: sqlx::Decode<'r, DB>,
{
fn decode(
value: <DB as sqlx::database::HasValueRef<'r>>::ValueRef,
) -> Result<Self, Box<dyn std::error::Error + Sync + Send + 'static>> {
<String as sqlx::Decode<'r, DB>>::decode(value).map(|s| UserId::new(&s))
}
}
impl From<UserId> for sea_query::Value {
fn from(user_id: UserId) -> Self {
user_id.into_string().into()
}
}
impl From<&UserId> for sea_query::Value {
fn from(user_id: &UserId) -> Self {
user_id.as_str().into()
}
}
#[derive(Iden)] #[derive(Iden)]
pub enum Users { pub enum Users {
Table, Table,

View File

@ -25,7 +25,7 @@ use lldap_auth::{login, opaque, password_reset, registration, JWTClaims};
use crate::{ use crate::{
domain::{ domain::{
error::DomainError, error::DomainError,
handler::{BackendHandler, BindRequest, GroupIdAndName, LoginHandler}, handler::{BackendHandler, BindRequest, GroupIdAndName, LoginHandler, UserId},
opaque_handler::OpaqueHandler, opaque_handler::OpaqueHandler,
}, },
infra::{ infra::{
@ -51,7 +51,7 @@ fn create_jwt(key: &Hmac<Sha512>, user: String, groups: HashSet<GroupIdAndName>)
jwt::Token::new(header, claims).sign_with_key(key).unwrap() jwt::Token::new(header, claims).sign_with_key(key).unwrap()
} }
fn parse_refresh_token(token: &str) -> std::result::Result<(u64, String), HttpResponse> { fn parse_refresh_token(token: &str) -> std::result::Result<(u64, UserId), HttpResponse> {
match token.split_once('+') { match token.split_once('+') {
None => Err(HttpResponse::Unauthorized().body("Invalid refresh token")), None => Err(HttpResponse::Unauthorized().body("Invalid refresh token")),
Some((token, u)) => { Some((token, u)) => {
@ -60,12 +60,12 @@ fn parse_refresh_token(token: &str) -> std::result::Result<(u64, String), HttpRe
token.hash(&mut s); token.hash(&mut s);
s.finish() s.finish()
}; };
Ok((refresh_token_hash, u.to_string())) Ok((refresh_token_hash, UserId::new(u)))
} }
} }
} }
fn get_refresh_token(request: HttpRequest) -> std::result::Result<(u64, String), HttpResponse> { fn get_refresh_token(request: HttpRequest) -> std::result::Result<(u64, UserId), HttpResponse> {
match ( match (
request.cookie("refresh_token"), request.cookie("refresh_token"),
request.headers().get("refresh-token"), request.headers().get("refresh-token"),
@ -134,14 +134,14 @@ where
{ {
let user_id = match request.match_info().get("user_id") { let user_id = match request.match_info().get("user_id") {
None => return HttpResponse::BadRequest().body("Missing user ID"), None => return HttpResponse::BadRequest().body("Missing user ID"),
Some(id) => id, Some(id) => UserId::new(id),
}; };
let token = match data.backend_handler.start_password_reset(user_id).await { let token = match data.backend_handler.start_password_reset(&user_id).await {
Err(e) => return HttpResponse::InternalServerError().body(e.to_string()), Err(e) => return HttpResponse::InternalServerError().body(e.to_string()),
Ok(None) => return HttpResponse::Ok().finish(), Ok(None) => return HttpResponse::Ok().finish(),
Ok(Some(token)) => token, Ok(Some(token)) => token,
}; };
let user = match data.backend_handler.get_user_details(user_id).await { let user = match data.backend_handler.get_user_details(&user_id).await {
Err(e) => { Err(e) => {
warn!("Error getting used details: {:#?}", e); warn!("Error getting used details: {:#?}", e);
return HttpResponse::Ok().finish(); return HttpResponse::Ok().finish();
@ -196,7 +196,7 @@ where
.finish(), .finish(),
) )
.json(&password_reset::ServerPasswordResetResponse { .json(&password_reset::ServerPasswordResetResponse {
user_id, user_id: user_id.to_string(),
token: token.as_str().to_owned(), token: token.as_str().to_owned(),
}) })
} }
@ -276,7 +276,7 @@ where
async fn get_login_successful_response<Backend>( async fn get_login_successful_response<Backend>(
data: &web::Data<AppState<Backend>>, data: &web::Data<AppState<Backend>>,
name: &str, name: &UserId,
) -> HttpResponse ) -> HttpResponse
where where
Backend: TcpBackendHandler + BackendHandler, Backend: TcpBackendHandler + BackendHandler,
@ -289,7 +289,7 @@ where
.await .await
.map(|(groups, (refresh_token, max_age))| { .map(|(groups, (refresh_token, max_age))| {
let token = create_jwt(&data.jwt_key, name.to_string(), groups); let token = create_jwt(&data.jwt_key, name.to_string(), groups);
let refresh_token_plus_name = refresh_token + "+" + name; let refresh_token_plus_name = refresh_token + "+" + name.as_str();
HttpResponse::Ok() HttpResponse::Ok()
.cookie( .cookie(

View File

@ -1,4 +1,7 @@
use crate::infra::cli::{GeneralConfigOpts, RunOpts, SmtpOpts, TestEmailOpts}; use crate::{
domain::handler::UserId,
infra::cli::{GeneralConfigOpts, RunOpts, SmtpOpts, TestEmailOpts},
};
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use figment::{ use figment::{
providers::{Env, Format, Serialized, Toml}, providers::{Env, Format, Serialized, Toml},
@ -49,8 +52,8 @@ pub struct Configuration {
pub jwt_secret: SecUtf8, pub jwt_secret: SecUtf8,
#[builder(default = r#"String::from("dc=example,dc=com")"#)] #[builder(default = r#"String::from("dc=example,dc=com")"#)]
pub ldap_base_dn: String, pub ldap_base_dn: String,
#[builder(default = r#"String::from("admin")"#)] #[builder(default = r#"UserId::new("admin")"#)]
pub ldap_user_dn: String, pub ldap_user_dn: UserId,
#[builder(default = r#"SecUtf8::from("password")"#)] #[builder(default = r#"SecUtf8::from("password")"#)]
pub ldap_user_pass: SecUtf8, pub ldap_user_pass: SecUtf8,
#[builder(default = r#"String::from("sqlite://users.db?mode=rwc")"#)] #[builder(default = r#"String::from("sqlite://users.db?mode=rwc")"#)]

View File

@ -1,5 +1,5 @@
use crate::domain::handler::{ use crate::domain::handler::{
BackendHandler, CreateUserRequest, GroupId, UpdateGroupRequest, UpdateUserRequest, BackendHandler, CreateUserRequest, GroupId, UpdateGroupRequest, UpdateUserRequest, UserId,
}; };
use juniper::{graphql_object, FieldResult, GraphQLInputObject, GraphQLObject}; use juniper::{graphql_object, FieldResult, GraphQLInputObject, GraphQLObject};
@ -66,10 +66,11 @@ impl<Handler: BackendHandler + Sync> Mutation<Handler> {
if !context.validation_result.is_admin { if !context.validation_result.is_admin {
return Err("Unauthorized user creation".into()); return Err("Unauthorized user creation".into());
} }
let user_id = UserId::new(&user.id);
context context
.handler .handler
.create_user(CreateUserRequest { .create_user(CreateUserRequest {
user_id: user.id.clone(), user_id: user_id.clone(),
email: user.email, email: user.email,
display_name: user.display_name, display_name: user.display_name,
first_name: user.first_name, first_name: user.first_name,
@ -78,7 +79,7 @@ impl<Handler: BackendHandler + Sync> Mutation<Handler> {
.await?; .await?;
Ok(context Ok(context
.handler .handler
.get_user_details(&user.id) .get_user_details(&user_id)
.await .await
.map(Into::into)?) .map(Into::into)?)
} }
@ -108,7 +109,7 @@ impl<Handler: BackendHandler + Sync> Mutation<Handler> {
context context
.handler .handler
.update_user(UpdateUserRequest { .update_user(UpdateUserRequest {
user_id: user.id, user_id: UserId::new(&user.id),
email: user.email, email: user.email,
display_name: user.display_name, display_name: user.display_name,
first_name: user.first_name, first_name: user.first_name,
@ -148,7 +149,7 @@ impl<Handler: BackendHandler + Sync> Mutation<Handler> {
} }
context context
.handler .handler
.add_user_to_group(&user_id, GroupId(group_id)) .add_user_to_group(&UserId::new(&user_id), GroupId(group_id))
.await?; .await?;
Ok(Success::new()) Ok(Success::new())
} }
@ -166,7 +167,7 @@ impl<Handler: BackendHandler + Sync> Mutation<Handler> {
} }
context context
.handler .handler
.remove_user_from_group(&user_id, GroupId(group_id)) .remove_user_from_group(&UserId::new(&user_id), GroupId(group_id))
.await?; .await?;
Ok(Success::new()) Ok(Success::new())
} }
@ -178,7 +179,7 @@ impl<Handler: BackendHandler + Sync> Mutation<Handler> {
if context.validation_result.user == user_id { if context.validation_result.user == user_id {
return Err("Cannot delete current user".into()); return Err("Cannot delete current user".into());
} }
context.handler.delete_user(&user_id).await?; context.handler.delete_user(&UserId::new(&user_id)).await?;
Ok(Success::new()) Ok(Success::new())
} }

View File

@ -1,4 +1,4 @@
use crate::domain::handler::{BackendHandler, GroupId, GroupIdAndName}; use crate::domain::handler::{BackendHandler, GroupId, GroupIdAndName, UserId};
use juniper::{graphql_object, FieldResult, GraphQLInputObject}; use juniper::{graphql_object, FieldResult, GraphQLInputObject};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -48,6 +48,9 @@ impl TryInto<DomainRequestFilter> for RequestFilter {
return Err("Multiple fields specified in request filter".to_string()); return Err("Multiple fields specified in request filter".to_string());
} }
if let Some(e) = self.eq { if let Some(e) = self.eq {
if e.field.to_lowercase() == "uid" {
return Ok(DomainRequestFilter::UserId(UserId::new(&e.value)));
}
return Ok(DomainRequestFilter::Equality(e.field, e.value)); return Ok(DomainRequestFilter::Equality(e.field, e.value));
} }
if let Some(c) = self.any { if let Some(c) = self.any {
@ -109,7 +112,7 @@ impl<Handler: BackendHandler + Sync> Query<Handler> {
} }
Ok(context Ok(context
.handler .handler
.get_user_details(&user_id) .get_user_details(&UserId::new(&user_id))
.await .await
.map(Into::into)?) .map(Into::into)?)
} }
@ -170,7 +173,7 @@ impl<Handler: BackendHandler> Default for User<Handler> {
#[graphql_object(context = Context<Handler>)] #[graphql_object(context = Context<Handler>)]
impl<Handler: BackendHandler + Sync> User<Handler> { impl<Handler: BackendHandler + Sync> User<Handler> {
fn id(&self) -> &str { fn id(&self) -> &str {
&self.user.user_id self.user.user_id.as_str()
} }
fn email(&self) -> &str { fn email(&self) -> &str {
@ -260,7 +263,7 @@ impl<Handler: BackendHandler> From<DomainGroup> for Group<Handler> {
Self { Self {
group_id: group.id.0, group_id: group.id.0,
display_name: group.display_name, display_name: group.display_name,
members: Some(group.users.into_iter().map(Into::into).collect()), members: Some(group.users.into_iter().map(UserId::into_string).collect()),
_phantom: std::marker::PhantomData, _phantom: std::marker::PhantomData,
} }
} }
@ -305,10 +308,10 @@ mod tests {
let mut mock = MockTestBackendHandler::new(); let mut mock = MockTestBackendHandler::new();
mock.expect_get_user_details() mock.expect_get_user_details()
.with(eq("bob")) .with(eq(UserId::new("bob")))
.return_once(|_| { .return_once(|_| {
Ok(DomainUser { Ok(DomainUser {
user_id: "bob".to_string(), user_id: UserId::new("bob"),
email: "bob@bobbers.on".to_string(), email: "bob@bobbers.on".to_string(),
..Default::default() ..Default::default()
}) })
@ -316,7 +319,7 @@ mod tests {
let mut groups = HashSet::new(); let mut groups = HashSet::new();
groups.insert(GroupIdAndName(GroupId(3), "Bobbersons".to_string())); groups.insert(GroupIdAndName(GroupId(3), "Bobbersons".to_string()));
mock.expect_get_user_groups() mock.expect_get_user_groups()
.with(eq("bob")) .with(eq(UserId::new("bob")))
.return_once(|_| Ok(groups)); .return_once(|_| Ok(groups));
let context = Context::<MockTestBackendHandler> { let context = Context::<MockTestBackendHandler> {
@ -369,12 +372,12 @@ mod tests {
.return_once(|_| { .return_once(|_| {
Ok(vec![ Ok(vec![
DomainUser { DomainUser {
user_id: "bob".to_string(), user_id: UserId::new("bob"),
email: "bob@bobbers.on".to_string(), email: "bob@bobbers.on".to_string(),
..Default::default() ..Default::default()
}, },
DomainUser { DomainUser {
user_id: "robert".to_string(), user_id: UserId::new("robert"),
email: "robert@bobbers.on".to_string(), email: "robert@bobbers.on".to_string(),
..Default::default() ..Default::default()
}, },

View File

@ -1,6 +1,6 @@
use crate::domain::{ use crate::domain::{
handler::{ handler::{
BackendHandler, BindRequest, Group, GroupRequestFilter, LoginHandler, User, BackendHandler, BindRequest, Group, GroupRequestFilter, LoginHandler, User, UserId,
UserRequestFilter, UserRequestFilter,
}, },
opaque_handler::OpaqueHandler, opaque_handler::OpaqueHandler,
@ -71,7 +71,7 @@ fn get_user_id_from_distinguished_name(
dn: &str, dn: &str,
base_tree: &[(String, String)], base_tree: &[(String, String)],
base_dn_str: &str, base_dn_str: &str,
) -> Result<String> { ) -> Result<UserId> {
let parts = parse_distinguished_name(dn).context("while parsing a user ID")?; let parts = parse_distinguished_name(dn).context("while parsing a user ID")?;
if !is_subtree(&parts, base_tree) { if !is_subtree(&parts, base_tree) {
bail!("Not a subtree of the base tree"); bail!("Not a subtree of the base tree");
@ -84,7 +84,7 @@ fn get_user_id_from_distinguished_name(
base_dn_str base_dn_str
); );
} }
Ok(parts[0].1.to_string()) Ok(UserId::new(&parts[0].1))
} else { } else {
bail!( bail!(
r#"Unexpected user DN format. Got "{}", expected: "cn=username,ou=people,{}""#, r#"Unexpected user DN format. Got "{}", expected: "cn=username,ou=people,{}""#,
@ -103,7 +103,7 @@ fn get_user_attribute(user: &User, attribute: &str, dn: &str) -> Result<Vec<Stri
"person".to_string(), "person".to_string(),
]), ]),
"dn" => Ok(vec![dn.to_string()]), "dn" => Ok(vec![dn.to_string()]),
"uid" => Ok(vec![user.user_id.clone()]), "uid" => Ok(vec![user.user_id.to_string()]),
"mail" => Ok(vec![user.email.clone()]), "mail" => Ok(vec![user.email.clone()]),
"givenname" => Ok(vec![user.first_name.clone()]), "givenname" => Ok(vec![user.first_name.clone()]),
"sn" => Ok(vec![user.last_name.clone()]), "sn" => Ok(vec![user.last_name.clone()]),
@ -118,7 +118,7 @@ fn make_ldap_search_user_result_entry(
base_dn_str: &str, base_dn_str: &str,
attributes: &[String], attributes: &[String],
) -> Result<LdapSearchResultEntry> { ) -> Result<LdapSearchResultEntry> {
let dn = format!("cn={},ou=people,{}", user.user_id, base_dn_str); let dn = format!("cn={},ou=people,{}", user.user_id.as_str(), base_dn_str);
Ok(LdapSearchResultEntry { Ok(LdapSearchResultEntry {
dn: dn.clone(), dn: dn.clone(),
attributes: attributes attributes: attributes
@ -264,17 +264,17 @@ fn root_dse_response(base_dn: &str) -> LdapOp {
} }
pub struct LdapHandler<Backend: BackendHandler + LoginHandler + OpaqueHandler> { pub struct LdapHandler<Backend: BackendHandler + LoginHandler + OpaqueHandler> {
dn: String, dn: UserId,
backend_handler: Backend, backend_handler: Backend,
pub base_dn: Vec<(String, String)>, pub base_dn: Vec<(String, String)>,
base_dn_str: String, base_dn_str: String,
ldap_user_dn: String, ldap_user_dn: UserId,
} }
impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend> { impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend> {
pub fn new(backend_handler: Backend, ldap_base_dn: String, ldap_user_dn: String) -> Self { pub fn new(backend_handler: Backend, ldap_base_dn: String, ldap_user_dn: UserId) -> Self {
Self { Self {
dn: "Unauthenticated".to_string(), dn: UserId::new("unauthenticated"),
backend_handler, backend_handler,
base_dn: parse_distinguished_name(&ldap_base_dn).unwrap_or_else(|_| { base_dn: parse_distinguished_name(&ldap_base_dn).unwrap_or_else(|_| {
panic!( panic!(
@ -282,7 +282,7 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
ldap_base_dn ldap_base_dn
) )
}), }),
ldap_user_dn: format!("cn={},ou=people,{}", ldap_user_dn, &ldap_base_dn), ldap_user_dn: UserId::new(&format!("cn={},ou=people,{}", ldap_user_dn, &ldap_base_dn)),
base_dn_str: ldap_base_dn, base_dn_str: ldap_base_dn,
} }
} }
@ -307,14 +307,14 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
.await .await
{ {
Ok(()) => { Ok(()) => {
self.dn = request.dn.clone(); self.dn = UserId::new(&request.dn);
(LdapResultCode::Success, "".to_string()) (LdapResultCode::Success, "".to_string())
} }
Err(_) => (LdapResultCode::InvalidCredentials, "".to_string()), Err(_) => (LdapResultCode::InvalidCredentials, "".to_string()),
} }
} }
async fn change_password(&mut self, user: &str, password: &str) -> Result<()> { async fn change_password(&mut self, user: &UserId, password: &str) -> Result<()> {
use lldap_auth::*; use lldap_auth::*;
let mut rng = rand::rngs::OsRng; let mut rng = rand::rngs::OsRng;
let registration_start_request = let registration_start_request =
@ -527,7 +527,7 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
} }
LdapOp::SearchRequest(request) => self.do_search(&request).await, LdapOp::SearchRequest(request) => self.do_search(&request).await,
LdapOp::UnbindRequest => { LdapOp::UnbindRequest => {
self.dn = "Unauthenticated".to_string(); self.dn = UserId::new("unauthenticated");
// No need to notify on unbind (per rfc4511) // No need to notify on unbind (per rfc4511)
return None; return None;
} }
@ -617,10 +617,12 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
)))) ))))
} }
} else { } else {
Ok(UserRequestFilter::Equality( let field = map_field(field)?;
map_field(field)?, if field == "user_id" {
value.clone(), Ok(UserRequestFilter::UserId(UserId::new(value)))
)) } else {
Ok(UserRequestFilter::Equality(field, value.clone()))
}
} }
} }
LdapFilter::Present(field) => { LdapFilter::Present(field) => {
@ -661,17 +663,17 @@ mod tests {
impl BackendHandler for TestBackendHandler { impl BackendHandler for TestBackendHandler {
async fn list_users(&self, filters: Option<UserRequestFilter>) -> Result<Vec<User>>; async fn list_users(&self, filters: Option<UserRequestFilter>) -> Result<Vec<User>>;
async fn list_groups(&self, filters: Option<GroupRequestFilter>) -> Result<Vec<Group>>; async fn list_groups(&self, filters: Option<GroupRequestFilter>) -> Result<Vec<Group>>;
async fn get_user_details(&self, user_id: &str) -> Result<User>; async fn get_user_details(&self, user_id: &UserId) -> Result<User>;
async fn get_group_details(&self, group_id: GroupId) -> Result<GroupIdAndName>; async fn get_group_details(&self, group_id: GroupId) -> Result<GroupIdAndName>;
async fn get_user_groups(&self, user: &str) -> Result<HashSet<GroupIdAndName>>; async fn get_user_groups(&self, user: &UserId) -> Result<HashSet<GroupIdAndName>>;
async fn create_user(&self, request: CreateUserRequest) -> Result<()>; async fn create_user(&self, request: CreateUserRequest) -> Result<()>;
async fn update_user(&self, request: UpdateUserRequest) -> Result<()>; async fn update_user(&self, request: UpdateUserRequest) -> Result<()>;
async fn update_group(&self, request: UpdateGroupRequest) -> Result<()>; async fn update_group(&self, request: UpdateGroupRequest) -> Result<()>;
async fn delete_user(&self, user_id: &str) -> Result<()>; async fn delete_user(&self, user_id: &UserId) -> Result<()>;
async fn create_group(&self, group_name: &str) -> Result<GroupId>; async fn create_group(&self, group_name: &str) -> Result<GroupId>;
async fn delete_group(&self, group_id: GroupId) -> Result<()>; async fn delete_group(&self, group_id: GroupId) -> Result<()>;
async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> Result<()>; async fn add_user_to_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()>;
async fn remove_user_from_group(&self, user_id: &str, group_id: GroupId) -> Result<()>; async fn remove_user_from_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()>;
} }
#[async_trait] #[async_trait]
impl OpaqueHandler for TestBackendHandler { impl OpaqueHandler for TestBackendHandler {
@ -679,7 +681,7 @@ mod tests {
&self, &self,
request: login::ClientLoginStartRequest request: login::ClientLoginStartRequest
) -> Result<login::ServerLoginStartResponse>; ) -> Result<login::ServerLoginStartResponse>;
async fn login_finish(&self, request: login::ClientLoginFinishRequest) -> Result<String>; async fn login_finish(&self, request: login::ClientLoginFinishRequest) -> Result<UserId>;
async fn registration_start( async fn registration_start(
&self, &self,
request: registration::ClientRegistrationStartRequest request: registration::ClientRegistrationStartRequest
@ -720,12 +722,12 @@ mod tests {
) -> LdapHandler<MockTestBackendHandler> { ) -> LdapHandler<MockTestBackendHandler> {
mock.expect_bind() mock.expect_bind()
.with(eq(BindRequest { .with(eq(BindRequest {
name: "test".to_string(), name: UserId::new("test"),
password: "pass".to_string(), password: "pass".to_string(),
})) }))
.return_once(|_| Ok(())); .return_once(|_| Ok(()));
let mut ldap_handler = let mut ldap_handler =
LdapHandler::new(mock, "dc=example,dc=com".to_string(), "test".to_string()); LdapHandler::new(mock, "dc=example,dc=com".to_string(), UserId::new("test"));
let request = LdapBindRequest { let request = LdapBindRequest {
dn: "cn=test,ou=people,dc=example,dc=com".to_string(), dn: "cn=test,ou=people,dc=example,dc=com".to_string(),
cred: LdapBindCred::Simple("pass".to_string()), cred: LdapBindCred::Simple("pass".to_string()),
@ -742,13 +744,13 @@ mod tests {
let mut mock = MockTestBackendHandler::new(); let mut mock = MockTestBackendHandler::new();
mock.expect_bind() mock.expect_bind()
.with(eq(crate::domain::handler::BindRequest { .with(eq(crate::domain::handler::BindRequest {
name: "bob".to_string(), name: UserId::new("bob"),
password: "pass".to_string(), password: "pass".to_string(),
})) }))
.times(1) .times(1)
.return_once(|_| Ok(())); .return_once(|_| Ok(()));
let mut ldap_handler = let mut ldap_handler =
LdapHandler::new(mock, "dc=example,dc=com".to_string(), "test".to_string()); LdapHandler::new(mock, "dc=example,dc=com".to_string(), UserId::new("test"));
let request = LdapOp::BindRequest(LdapBindRequest { let request = LdapOp::BindRequest(LdapBindRequest {
dn: "cn=bob,ou=people,dc=example,dc=com".to_string(), dn: "cn=bob,ou=people,dc=example,dc=com".to_string(),
@ -773,13 +775,13 @@ mod tests {
let mut mock = MockTestBackendHandler::new(); let mut mock = MockTestBackendHandler::new();
mock.expect_bind() mock.expect_bind()
.with(eq(crate::domain::handler::BindRequest { .with(eq(crate::domain::handler::BindRequest {
name: "test".to_string(), name: UserId::new("test"),
password: "pass".to_string(), password: "pass".to_string(),
})) }))
.times(1) .times(1)
.return_once(|_| Ok(())); .return_once(|_| Ok(()));
let mut ldap_handler = let mut ldap_handler =
LdapHandler::new(mock, "dc=example,dc=com".to_string(), "test".to_string()); LdapHandler::new(mock, "dc=example,dc=com".to_string(), UserId::new("test"));
let request = LdapBindRequest { let request = LdapBindRequest {
dn: "cn=test,ou=people,dc=example,dc=com".to_string(), dn: "cn=test,ou=people,dc=example,dc=com".to_string(),
@ -796,13 +798,13 @@ mod tests {
let mut mock = MockTestBackendHandler::new(); let mut mock = MockTestBackendHandler::new();
mock.expect_bind() mock.expect_bind()
.with(eq(crate::domain::handler::BindRequest { .with(eq(crate::domain::handler::BindRequest {
name: "test".to_string(), name: UserId::new("test"),
password: "pass".to_string(), password: "pass".to_string(),
})) }))
.times(1) .times(1)
.return_once(|_| Ok(())); .return_once(|_| Ok(()));
let mut ldap_handler = let mut ldap_handler =
LdapHandler::new(mock, "dc=example,dc=com".to_string(), "admin".to_string()); LdapHandler::new(mock, "dc=example,dc=com".to_string(), UserId::new("admin"));
let request = LdapBindRequest { let request = LdapBindRequest {
dn: "cn=test,ou=people,dc=example,dc=com".to_string(), dn: "cn=test,ou=people,dc=example,dc=com".to_string(),
@ -827,7 +829,7 @@ mod tests {
async fn test_bind_invalid_dn() { async fn test_bind_invalid_dn() {
let mock = MockTestBackendHandler::new(); let mock = MockTestBackendHandler::new();
let mut ldap_handler = let mut ldap_handler =
LdapHandler::new(mock, "dc=example,dc=com".to_string(), "admin".to_string()); LdapHandler::new(mock, "dc=example,dc=com".to_string(), UserId::new("admin"));
let request = LdapBindRequest { let request = LdapBindRequest {
dn: "cn=bob,dc=example,dc=com".to_string(), dn: "cn=bob,dc=example,dc=com".to_string(),
@ -903,7 +905,7 @@ mod tests {
mock.expect_list_users().times(1).return_once(|_| { mock.expect_list_users().times(1).return_once(|_| {
Ok(vec![ Ok(vec![
User { User {
user_id: "bob_1".to_string(), user_id: UserId::new("bob_1"),
email: "bob@bobmail.bob".to_string(), email: "bob@bobmail.bob".to_string(),
display_name: "Bôb Böbberson".to_string(), display_name: "Bôb Böbberson".to_string(),
first_name: "Bôb".to_string(), first_name: "Bôb".to_string(),
@ -911,7 +913,7 @@ mod tests {
..Default::default() ..Default::default()
}, },
User { User {
user_id: "jim".to_string(), user_id: UserId::new("jim"),
email: "jim@cricket.jim".to_string(), email: "jim@cricket.jim".to_string(),
display_name: "Jimminy Cricket".to_string(), display_name: "Jimminy Cricket".to_string(),
first_name: "Jim".to_string(), first_name: "Jim".to_string(),
@ -1037,12 +1039,12 @@ mod tests {
Group { Group {
id: GroupId(1), id: GroupId(1),
display_name: "group_1".to_string(), display_name: "group_1".to_string(),
users: vec!["bob".to_string(), "john".to_string()], users: vec![UserId::new("bob"), UserId::new("john")],
}, },
Group { Group {
id: GroupId(3), id: GroupId(3),
display_name: "bestgroup".to_string(), display_name: "bestgroup".to_string(),
users: vec!["john".to_string()], users: vec![UserId::new("john")],
}, },
]) ])
}); });
@ -1111,7 +1113,7 @@ mod tests {
mock.expect_list_groups() mock.expect_list_groups()
.with(eq(Some(GroupRequestFilter::And(vec![ .with(eq(Some(GroupRequestFilter::And(vec![
GroupRequestFilter::DisplayName("group_1".to_string()), GroupRequestFilter::DisplayName("group_1".to_string()),
GroupRequestFilter::Member("bob".to_string()), GroupRequestFilter::Member(UserId::new("bob")),
GroupRequestFilter::And(vec![]), GroupRequestFilter::And(vec![]),
])))) ]))))
.times(1) .times(1)
@ -1250,10 +1252,7 @@ mod tests {
mock.expect_list_users() mock.expect_list_users()
.with(eq(Some(UserRequestFilter::And(vec![ .with(eq(Some(UserRequestFilter::And(vec![
UserRequestFilter::Or(vec![ UserRequestFilter::Or(vec![
UserRequestFilter::Not(Box::new(UserRequestFilter::Equality( UserRequestFilter::Not(Box::new(UserRequestFilter::UserId(UserId::new("bob")))),
"user_id".to_string(),
"bob".to_string(),
))),
UserRequestFilter::And(vec![]), UserRequestFilter::And(vec![]),
UserRequestFilter::Not(Box::new(UserRequestFilter::And(vec![]))), UserRequestFilter::Not(Box::new(UserRequestFilter::And(vec![]))),
UserRequestFilter::And(vec![]), UserRequestFilter::And(vec![]),
@ -1342,7 +1341,7 @@ mod tests {
.times(1) .times(1)
.return_once(|_| { .return_once(|_| {
Ok(vec![User { Ok(vec![User {
user_id: "bob_1".to_string(), user_id: UserId::new("bob_1"),
..Default::default() ..Default::default()
}]) }])
}); });
@ -1378,7 +1377,7 @@ mod tests {
let mut mock = MockTestBackendHandler::new(); let mut mock = MockTestBackendHandler::new();
mock.expect_list_users().times(1).return_once(|_| { mock.expect_list_users().times(1).return_once(|_| {
Ok(vec![User { Ok(vec![User {
user_id: "bob_1".to_string(), user_id: UserId::new("bob_1"),
email: "bob@bobmail.bob".to_string(), email: "bob@bobmail.bob".to_string(),
display_name: "Bôb Böbberson".to_string(), display_name: "Bôb Böbberson".to_string(),
first_name: "Bôb".to_string(), first_name: "Bôb".to_string(),
@ -1393,7 +1392,7 @@ mod tests {
Ok(vec![Group { Ok(vec![Group {
id: GroupId(1), id: GroupId(1),
display_name: "group_1".to_string(), display_name: "group_1".to_string(),
users: vec!["bob".to_string(), "john".to_string()], users: vec![UserId::new("bob"), UserId::new("john")],
}]) }])
}); });
let mut ldap_handler = setup_bound_handler(mock).await; let mut ldap_handler = setup_bound_handler(mock).await;

View File

@ -1,5 +1,5 @@
use super::{jwt_sql_tables::*, tcp_backend_handler::*}; use super::{jwt_sql_tables::*, tcp_backend_handler::*};
use crate::domain::{error::*, sql_backend_handler::SqlBackendHandler}; use crate::domain::{error::*, handler::UserId, sql_backend_handler::SqlBackendHandler};
use async_trait::async_trait; use async_trait::async_trait;
use futures_util::StreamExt; use futures_util::StreamExt;
use sea_query::{Expr, Iden, Query, SimpleExpr}; use sea_query::{Expr, Iden, Query, SimpleExpr};
@ -34,7 +34,7 @@ impl TcpBackendHandler for SqlBackendHandler {
.map_err(|e| anyhow::anyhow!(e)) .map_err(|e| anyhow::anyhow!(e))
} }
async fn create_refresh_token(&self, user: &str) -> Result<(String, chrono::Duration)> { async fn create_refresh_token(&self, user: &UserId) -> Result<(String, chrono::Duration)> {
use std::collections::hash_map::DefaultHasher; use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher}; use std::hash::{Hash, Hasher};
// TODO: Initialize the rng only once. Maybe Arc<Cell>? // TODO: Initialize the rng only once. Maybe Arc<Cell>?
@ -62,7 +62,7 @@ impl TcpBackendHandler for SqlBackendHandler {
Ok((refresh_token, duration)) Ok((refresh_token, duration))
} }
async fn check_token(&self, refresh_token_hash: u64, user: &str) -> Result<bool> { async fn check_token(&self, refresh_token_hash: u64, user: &UserId) -> Result<bool> {
let query = Query::select() let query = Query::select()
.expr(SimpleExpr::Value(1.into())) .expr(SimpleExpr::Value(1.into()))
.from(JwtRefreshStorage::Table) .from(JwtRefreshStorage::Table)
@ -74,7 +74,7 @@ impl TcpBackendHandler for SqlBackendHandler {
.await? .await?
.is_some()) .is_some())
} }
async fn blacklist_jwts(&self, user: &str) -> Result<HashSet<u64>> { async fn blacklist_jwts(&self, user: &UserId) -> Result<HashSet<u64>> {
use sqlx::Result; use sqlx::Result;
let query = Query::select() let query = Query::select()
.column(JwtStorage::JwtHash) .column(JwtStorage::JwtHash)
@ -106,7 +106,7 @@ impl TcpBackendHandler for SqlBackendHandler {
Ok(()) Ok(())
} }
async fn start_password_reset(&self, user: &str) -> Result<Option<String>> { async fn start_password_reset(&self, user: &UserId) -> Result<Option<String>> {
let query = Query::select() let query = Query::select()
.column(Users::UserId) .column(Users::UserId)
.from(Users::Table) .from(Users::Table)
@ -138,7 +138,7 @@ impl TcpBackendHandler for SqlBackendHandler {
Ok(Some(token)) Ok(Some(token))
} }
async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result<String> { async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result<UserId> {
let query = Query::select() let query = Query::select()
.column(PasswordResetTokens::UserId) .column(PasswordResetTokens::UserId)
.from(PasswordResetTokens::Table) .from(PasswordResetTokens::Table)

View File

@ -1,22 +1,22 @@
use async_trait::async_trait; use async_trait::async_trait;
use std::collections::HashSet; use std::collections::HashSet;
use crate::domain::error::Result; use crate::domain::{error::Result, handler::UserId};
#[async_trait] #[async_trait]
pub trait TcpBackendHandler { pub trait TcpBackendHandler {
async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>>; async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>>;
async fn create_refresh_token(&self, user: &str) -> Result<(String, chrono::Duration)>; async fn create_refresh_token(&self, user: &UserId) -> Result<(String, chrono::Duration)>;
async fn check_token(&self, refresh_token_hash: u64, user: &str) -> Result<bool>; async fn check_token(&self, refresh_token_hash: u64, user: &UserId) -> Result<bool>;
async fn blacklist_jwts(&self, user: &str) -> Result<HashSet<u64>>; async fn blacklist_jwts(&self, user: &UserId) -> Result<HashSet<u64>>;
async fn delete_refresh_token(&self, refresh_token_hash: u64) -> Result<()>; async fn delete_refresh_token(&self, refresh_token_hash: u64) -> Result<()>;
/// Request a token to reset a user's password. /// Request a token to reset a user's password.
/// If the user doesn't exist, returns `Ok(None)`, otherwise `Ok(Some(token))`. /// If the user doesn't exist, returns `Ok(None)`, otherwise `Ok(Some(token))`.
async fn start_password_reset(&self, user: &str) -> Result<Option<String>>; async fn start_password_reset(&self, user: &UserId) -> Result<Option<String>>;
/// Get the user ID associated with a password reset token. /// Get the user ID associated with a password reset token.
async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result<String>; async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result<UserId>;
async fn delete_password_reset_token(&self, token: &str) -> Result<()>; async fn delete_password_reset_token(&self, token: &str) -> Result<()>;
} }
@ -37,27 +37,27 @@ mockall::mock! {
impl BackendHandler for TestTcpBackendHandler { impl BackendHandler for TestTcpBackendHandler {
async fn list_users(&self, filters: Option<UserRequestFilter>) -> Result<Vec<User>>; async fn list_users(&self, filters: Option<UserRequestFilter>) -> Result<Vec<User>>;
async fn list_groups(&self, filters: Option<GroupRequestFilter>) -> Result<Vec<Group>>; async fn list_groups(&self, filters: Option<GroupRequestFilter>) -> Result<Vec<Group>>;
async fn get_user_details(&self, user_id: &str) -> Result<User>; async fn get_user_details(&self, user_id: &UserId) -> Result<User>;
async fn get_group_details(&self, group_id: GroupId) -> Result<GroupIdAndName>; async fn get_group_details(&self, group_id: GroupId) -> Result<GroupIdAndName>;
async fn get_user_groups(&self, user: &str) -> Result<HashSet<GroupIdAndName>>; async fn get_user_groups(&self, user: &UserId) -> Result<HashSet<GroupIdAndName>>;
async fn create_user(&self, request: CreateUserRequest) -> Result<()>; async fn create_user(&self, request: CreateUserRequest) -> Result<()>;
async fn update_user(&self, request: UpdateUserRequest) -> Result<()>; async fn update_user(&self, request: UpdateUserRequest) -> Result<()>;
async fn update_group(&self, request: UpdateGroupRequest) -> Result<()>; async fn update_group(&self, request: UpdateGroupRequest) -> Result<()>;
async fn delete_user(&self, user_id: &str) -> Result<()>; async fn delete_user(&self, user_id: &UserId) -> Result<()>;
async fn create_group(&self, group_name: &str) -> Result<GroupId>; async fn create_group(&self, group_name: &str) -> Result<GroupId>;
async fn delete_group(&self, group_id: GroupId) -> Result<()>; async fn delete_group(&self, group_id: GroupId) -> Result<()>;
async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> Result<()>; async fn add_user_to_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()>;
async fn remove_user_from_group(&self, user_id: &str, group_id: GroupId) -> Result<()>; async fn remove_user_from_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()>;
} }
#[async_trait] #[async_trait]
impl TcpBackendHandler for TestTcpBackendHandler { impl TcpBackendHandler for TestTcpBackendHandler {
async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>>; async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>>;
async fn create_refresh_token(&self, user: &str) -> Result<(String, chrono::Duration)>; async fn create_refresh_token(&self, user: &UserId) -> Result<(String, chrono::Duration)>;
async fn check_token(&self, refresh_token_hash: u64, user: &str) -> Result<bool>; async fn check_token(&self, refresh_token_hash: u64, user: &UserId) -> Result<bool>;
async fn blacklist_jwts(&self, user: &str) -> Result<HashSet<u64>>; async fn blacklist_jwts(&self, user: &UserId) -> Result<HashSet<u64>>;
async fn delete_refresh_token(&self, refresh_token_hash: u64) -> Result<()>; async fn delete_refresh_token(&self, refresh_token_hash: u64) -> Result<()>;
async fn start_password_reset(&self, user: &str) -> Result<Option<String>>; async fn start_password_reset(&self, user: &UserId) -> Result<Option<String>>;
async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result<String>; async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result<UserId>;
async fn delete_password_reset_token(&self, token: &str) -> Result<()>; async fn delete_password_reset_token(&self, token: &str) -> Result<()>;
} }
} }