mirror of
				https://github.com/nitnelave/lldap.git
				synced 2023-04-12 14:25:13 +00:00 
			
		
		
		
	domain: introduce UserId to make uid case insensitive
Note that if there was a non-lowercase user already in the DB, it cannot be found again. To fix this, run in the DB: sqlite> UPDATE users SET user_id = LOWER(user_id);
This commit is contained in:
		
							parent
							
								
									26cedcb621
								
							
						
					
					
						commit
						3e675be838
					
				@ -3,7 +3,7 @@ use thiserror::Error;
 | 
			
		||||
#[allow(clippy::enum_variant_names)]
 | 
			
		||||
#[derive(Error, Debug)]
 | 
			
		||||
pub enum DomainError {
 | 
			
		||||
    #[error("Authentication error for `{0}`")]
 | 
			
		||||
    #[error("Authentication error: `{0}`")]
 | 
			
		||||
    AuthenticationError(String),
 | 
			
		||||
    #[error("Database error: `{0}`")]
 | 
			
		||||
    DatabaseError(#[from] sqlx::Error),
 | 
			
		||||
 | 
			
		||||
@ -3,10 +3,41 @@ use async_trait::async_trait;
 | 
			
		||||
use serde::{Deserialize, Serialize};
 | 
			
		||||
use std::collections::HashSet;
 | 
			
		||||
 | 
			
		||||
#[derive(PartialEq, Eq, Clone, Debug, Default, Serialize, Deserialize)]
 | 
			
		||||
#[cfg_attr(not(target_arch = "wasm32"), derive(sqlx::FromRow))]
 | 
			
		||||
#[serde(from = "String")]
 | 
			
		||||
pub struct UserId(String);
 | 
			
		||||
 | 
			
		||||
impl UserId {
 | 
			
		||||
    pub fn new(user_id: &str) -> Self {
 | 
			
		||||
        Self(user_id.to_lowercase())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn as_str(&self) -> &str {
 | 
			
		||||
        self.0.as_str()
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn into_string(self) -> String {
 | 
			
		||||
        self.0
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl std::fmt::Display for UserId {
 | 
			
		||||
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
 | 
			
		||||
        write!(f, "{}", self.0)
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl From<String> for UserId {
 | 
			
		||||
    fn from(s: String) -> Self {
 | 
			
		||||
        Self::new(&s)
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize)]
 | 
			
		||||
#[cfg_attr(not(target_arch = "wasm32"), derive(sqlx::FromRow))]
 | 
			
		||||
pub struct User {
 | 
			
		||||
    pub user_id: String,
 | 
			
		||||
    pub user_id: UserId,
 | 
			
		||||
    pub email: String,
 | 
			
		||||
    pub display_name: String,
 | 
			
		||||
    pub first_name: String,
 | 
			
		||||
@ -19,7 +50,7 @@ impl Default for User {
 | 
			
		||||
    fn default() -> Self {
 | 
			
		||||
        use chrono::TimeZone;
 | 
			
		||||
        User {
 | 
			
		||||
            user_id: String::new(),
 | 
			
		||||
            user_id: UserId::default(),
 | 
			
		||||
            email: String::new(),
 | 
			
		||||
            display_name: String::new(),
 | 
			
		||||
            first_name: String::new(),
 | 
			
		||||
@ -33,12 +64,12 @@ impl Default for User {
 | 
			
		||||
pub struct Group {
 | 
			
		||||
    pub id: GroupId,
 | 
			
		||||
    pub display_name: String,
 | 
			
		||||
    pub users: Vec<String>,
 | 
			
		||||
    pub users: Vec<UserId>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Clone)]
 | 
			
		||||
pub struct BindRequest {
 | 
			
		||||
    pub name: String,
 | 
			
		||||
    pub name: UserId,
 | 
			
		||||
    pub password: String,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -47,6 +78,7 @@ pub enum UserRequestFilter {
 | 
			
		||||
    And(Vec<UserRequestFilter>),
 | 
			
		||||
    Or(Vec<UserRequestFilter>),
 | 
			
		||||
    Not(Box<UserRequestFilter>),
 | 
			
		||||
    UserId(UserId),
 | 
			
		||||
    Equality(String, String),
 | 
			
		||||
    // Check if a user belongs to a group identified by name.
 | 
			
		||||
    MemberOf(String),
 | 
			
		||||
@ -62,13 +94,13 @@ pub enum GroupRequestFilter {
 | 
			
		||||
    DisplayName(String),
 | 
			
		||||
    GroupId(GroupId),
 | 
			
		||||
    // Check if the group contains a user identified by uid.
 | 
			
		||||
    Member(String),
 | 
			
		||||
    Member(UserId),
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Clone, Default)]
 | 
			
		||||
pub struct CreateUserRequest {
 | 
			
		||||
    // Same fields as User, but no creation_date, and with password.
 | 
			
		||||
    pub user_id: String,
 | 
			
		||||
    pub user_id: UserId,
 | 
			
		||||
    pub email: String,
 | 
			
		||||
    pub display_name: Option<String>,
 | 
			
		||||
    pub first_name: Option<String>,
 | 
			
		||||
@ -78,7 +110,7 @@ pub struct CreateUserRequest {
 | 
			
		||||
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Clone, Default)]
 | 
			
		||||
pub struct UpdateUserRequest {
 | 
			
		||||
    // Same fields as CreateUserRequest, but no with an extra layer of Option.
 | 
			
		||||
    pub user_id: String,
 | 
			
		||||
    pub user_id: UserId,
 | 
			
		||||
    pub email: Option<String>,
 | 
			
		||||
    pub display_name: Option<String>,
 | 
			
		||||
    pub first_name: Option<String>,
 | 
			
		||||
@ -106,17 +138,17 @@ pub struct GroupIdAndName(pub GroupId, pub String);
 | 
			
		||||
pub trait BackendHandler: Clone + Send {
 | 
			
		||||
    async fn list_users(&self, filters: Option<UserRequestFilter>) -> Result<Vec<User>>;
 | 
			
		||||
    async fn list_groups(&self, filters: Option<GroupRequestFilter>) -> Result<Vec<Group>>;
 | 
			
		||||
    async fn get_user_details(&self, user_id: &str) -> Result<User>;
 | 
			
		||||
    async fn get_user_details(&self, user_id: &UserId) -> Result<User>;
 | 
			
		||||
    async fn get_group_details(&self, group_id: GroupId) -> Result<GroupIdAndName>;
 | 
			
		||||
    async fn create_user(&self, request: CreateUserRequest) -> Result<()>;
 | 
			
		||||
    async fn update_user(&self, request: UpdateUserRequest) -> Result<()>;
 | 
			
		||||
    async fn update_group(&self, request: UpdateGroupRequest) -> Result<()>;
 | 
			
		||||
    async fn delete_user(&self, user_id: &str) -> Result<()>;
 | 
			
		||||
    async fn delete_user(&self, user_id: &UserId) -> Result<()>;
 | 
			
		||||
    async fn create_group(&self, group_name: &str) -> Result<GroupId>;
 | 
			
		||||
    async fn delete_group(&self, group_id: GroupId) -> Result<()>;
 | 
			
		||||
    async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> Result<()>;
 | 
			
		||||
    async fn remove_user_from_group(&self, user_id: &str, group_id: GroupId) -> Result<()>;
 | 
			
		||||
    async fn get_user_groups(&self, user: &str) -> Result<HashSet<GroupIdAndName>>;
 | 
			
		||||
    async fn add_user_to_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()>;
 | 
			
		||||
    async fn remove_user_from_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()>;
 | 
			
		||||
    async fn get_user_groups(&self, user_id: &UserId) -> Result<HashSet<GroupIdAndName>>;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[cfg(test)]
 | 
			
		||||
@ -129,17 +161,17 @@ mockall::mock! {
 | 
			
		||||
    impl BackendHandler for TestBackendHandler {
 | 
			
		||||
        async fn list_users(&self, filters: Option<UserRequestFilter>) -> Result<Vec<User>>;
 | 
			
		||||
        async fn list_groups(&self, filters: Option<GroupRequestFilter>) -> Result<Vec<Group>>;
 | 
			
		||||
        async fn get_user_details(&self, user_id: &str) -> Result<User>;
 | 
			
		||||
        async fn get_user_details(&self, user_id: &UserId) -> Result<User>;
 | 
			
		||||
        async fn get_group_details(&self, group_id: GroupId) -> Result<GroupIdAndName>;
 | 
			
		||||
        async fn create_user(&self, request: CreateUserRequest) -> Result<()>;
 | 
			
		||||
        async fn update_user(&self, request: UpdateUserRequest) -> Result<()>;
 | 
			
		||||
        async fn update_group(&self, request: UpdateGroupRequest) -> Result<()>;
 | 
			
		||||
        async fn delete_user(&self, user_id: &str) -> Result<()>;
 | 
			
		||||
        async fn delete_user(&self, user_id: &UserId) -> Result<()>;
 | 
			
		||||
        async fn create_group(&self, group_name: &str) -> Result<GroupId>;
 | 
			
		||||
        async fn delete_group(&self, group_id: GroupId) -> Result<()>;
 | 
			
		||||
        async fn get_user_groups(&self, user: &str) -> Result<HashSet<GroupIdAndName>>;
 | 
			
		||||
        async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> Result<()>;
 | 
			
		||||
        async fn remove_user_from_group(&self, user_id: &str, group_id: GroupId) -> Result<()>;
 | 
			
		||||
        async fn get_user_groups(&self, user_id: &UserId) -> Result<HashSet<GroupIdAndName>>;
 | 
			
		||||
        async fn add_user_to_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()>;
 | 
			
		||||
        async fn remove_user_from_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()>;
 | 
			
		||||
    }
 | 
			
		||||
    #[async_trait]
 | 
			
		||||
    impl LoginHandler for TestBackendHandler {
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
use super::error::*;
 | 
			
		||||
use crate::domain::{error::*, handler::UserId};
 | 
			
		||||
use async_trait::async_trait;
 | 
			
		||||
 | 
			
		||||
pub use lldap_auth::{login, registration};
 | 
			
		||||
@ -9,7 +9,7 @@ pub trait OpaqueHandler: Clone + Send {
 | 
			
		||||
        &self,
 | 
			
		||||
        request: login::ClientLoginStartRequest,
 | 
			
		||||
    ) -> Result<login::ServerLoginStartResponse>;
 | 
			
		||||
    async fn login_finish(&self, request: login::ClientLoginFinishRequest) -> Result<String>;
 | 
			
		||||
    async fn login_finish(&self, request: login::ClientLoginFinishRequest) -> Result<UserId>;
 | 
			
		||||
    async fn registration_start(
 | 
			
		||||
        &self,
 | 
			
		||||
        request: registration::ClientRegistrationStartRequest,
 | 
			
		||||
@ -32,7 +32,7 @@ mockall::mock! {
 | 
			
		||||
            &self,
 | 
			
		||||
            request: login::ClientLoginStartRequest
 | 
			
		||||
        ) -> Result<login::ServerLoginStartResponse>;
 | 
			
		||||
        async fn login_finish(&self, request: login::ClientLoginFinishRequest ) -> Result<String>;
 | 
			
		||||
        async fn login_finish(&self, request: login::ClientLoginFinishRequest ) -> Result<UserId>;
 | 
			
		||||
        async fn registration_start(
 | 
			
		||||
            &self,
 | 
			
		||||
            request: registration::ClientRegistrationStartRequest
 | 
			
		||||
 | 
			
		||||
@ -51,12 +51,16 @@ fn get_user_filter_expr(filter: UserRequestFilter) -> (RequiresGroup, SimpleExpr
 | 
			
		||||
            let (requires_group, filters) = get_user_filter_expr(*f);
 | 
			
		||||
            (requires_group, Expr::not(Expr::expr(filters)))
 | 
			
		||||
        }
 | 
			
		||||
        UserId(user_id) => (
 | 
			
		||||
            RequiresGroup(false),
 | 
			
		||||
            Expr::col((Users::Table, Users::UserId)).eq(user_id),
 | 
			
		||||
        ),
 | 
			
		||||
        Equality(s1, s2) => (
 | 
			
		||||
            RequiresGroup(false),
 | 
			
		||||
            if s1 == Users::DisplayName.to_string() {
 | 
			
		||||
                Expr::col((Users::Table, Users::DisplayName)).eq(s2)
 | 
			
		||||
            } else if s1 == Users::UserId.to_string() {
 | 
			
		||||
                Expr::col((Users::Table, Users::UserId)).eq(s2)
 | 
			
		||||
                panic!("User id should be wrapped")
 | 
			
		||||
            } else {
 | 
			
		||||
                Expr::expr(Expr::cust(&s1)).eq(s2)
 | 
			
		||||
            },
 | 
			
		||||
@ -205,17 +209,17 @@ impl BackendHandler for SqlBackendHandler {
 | 
			
		||||
                id: group_id,
 | 
			
		||||
                display_name,
 | 
			
		||||
                users: rows
 | 
			
		||||
                    .map(|row| row.get::<String, _>(&*Memberships::UserId.to_string()))
 | 
			
		||||
                    .map(|row| row.get::<UserId, _>(&*Memberships::UserId.to_string()))
 | 
			
		||||
                    // If a group has no users, an empty string is returned because of the left
 | 
			
		||||
                    // join.
 | 
			
		||||
                    .filter(|s| !s.is_empty())
 | 
			
		||||
                    .filter(|s| !s.as_str().is_empty())
 | 
			
		||||
                    .collect(),
 | 
			
		||||
            });
 | 
			
		||||
        }
 | 
			
		||||
        Ok(groups)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn get_user_details(&self, user_id: &str) -> Result<User> {
 | 
			
		||||
    async fn get_user_details(&self, user_id: &UserId) -> Result<User> {
 | 
			
		||||
        let query = Query::select()
 | 
			
		||||
            .column(Users::UserId)
 | 
			
		||||
            .column(Users::Email)
 | 
			
		||||
@ -246,8 +250,8 @@ impl BackendHandler for SqlBackendHandler {
 | 
			
		||||
            .await?)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn get_user_groups(&self, user: &str) -> Result<HashSet<GroupIdAndName>> {
 | 
			
		||||
        if user == self.config.ldap_user_dn {
 | 
			
		||||
    async fn get_user_groups(&self, user_id: &UserId) -> Result<HashSet<GroupIdAndName>> {
 | 
			
		||||
        if *user_id == self.config.ldap_user_dn {
 | 
			
		||||
            let mut groups = HashSet::new();
 | 
			
		||||
            groups.insert(GroupIdAndName(GroupId(1), "lldap_admin".to_string()));
 | 
			
		||||
            return Ok(groups);
 | 
			
		||||
@ -261,7 +265,7 @@ impl BackendHandler for SqlBackendHandler {
 | 
			
		||||
                Expr::tbl(Groups::Table, Groups::GroupId)
 | 
			
		||||
                    .equals(Memberships::Table, Memberships::GroupId),
 | 
			
		||||
            )
 | 
			
		||||
            .and_where(Expr::col(Memberships::UserId).eq(user))
 | 
			
		||||
            .and_where(Expr::col(Memberships::UserId).eq(user_id))
 | 
			
		||||
            .to_string(DbQueryBuilder {});
 | 
			
		||||
 | 
			
		||||
        sqlx::query(&query)
 | 
			
		||||
@ -294,7 +298,7 @@ impl BackendHandler for SqlBackendHandler {
 | 
			
		||||
            Users::CreationDate,
 | 
			
		||||
        ];
 | 
			
		||||
        let values = vec![
 | 
			
		||||
            request.user_id.clone().into(),
 | 
			
		||||
            request.user_id.into(),
 | 
			
		||||
            request.email.into(),
 | 
			
		||||
            request.display_name.unwrap_or_default().into(),
 | 
			
		||||
            request.first_name.unwrap_or_default().into(),
 | 
			
		||||
@ -353,7 +357,7 @@ impl BackendHandler for SqlBackendHandler {
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn delete_user(&self, user_id: &str) -> Result<()> {
 | 
			
		||||
    async fn delete_user(&self, user_id: &UserId) -> Result<()> {
 | 
			
		||||
        let delete_query = Query::delete()
 | 
			
		||||
            .from_table(Users::Table)
 | 
			
		||||
            .and_where(Expr::col(Users::UserId).eq(user_id))
 | 
			
		||||
@ -387,7 +391,7 @@ impl BackendHandler for SqlBackendHandler {
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> Result<()> {
 | 
			
		||||
    async fn add_user_to_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()> {
 | 
			
		||||
        let query = Query::insert()
 | 
			
		||||
            .into_table(Memberships::Table)
 | 
			
		||||
            .columns(vec![Memberships::UserId, Memberships::GroupId])
 | 
			
		||||
@ -397,7 +401,7 @@ impl BackendHandler for SqlBackendHandler {
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn remove_user_from_group(&self, user_id: &str, group_id: GroupId) -> Result<()> {
 | 
			
		||||
    async fn remove_user_from_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()> {
 | 
			
		||||
        let query = Query::delete()
 | 
			
		||||
            .from_table(Memberships::Table)
 | 
			
		||||
            .and_where(Expr::col(Memberships::GroupId).eq(group_id))
 | 
			
		||||
@ -463,7 +467,7 @@ mod tests {
 | 
			
		||||
    async fn insert_user_no_password(handler: &SqlBackendHandler, name: &str) {
 | 
			
		||||
        handler
 | 
			
		||||
            .create_user(CreateUserRequest {
 | 
			
		||||
                user_id: name.to_string(),
 | 
			
		||||
                user_id: UserId::new(name),
 | 
			
		||||
                email: "bob@bob.bob".to_string(),
 | 
			
		||||
                ..Default::default()
 | 
			
		||||
            })
 | 
			
		||||
@ -476,21 +480,24 @@ mod tests {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn insert_membership(handler: &SqlBackendHandler, group_id: GroupId, user_id: &str) {
 | 
			
		||||
        handler.add_user_to_group(user_id, group_id).await.unwrap();
 | 
			
		||||
        handler
 | 
			
		||||
            .add_user_to_group(&UserId::new(user_id), group_id)
 | 
			
		||||
            .await
 | 
			
		||||
            .unwrap();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[tokio::test]
 | 
			
		||||
    async fn test_bind_admin() {
 | 
			
		||||
        let sql_pool = get_in_memory_db().await;
 | 
			
		||||
        let config = ConfigurationBuilder::default()
 | 
			
		||||
            .ldap_user_dn("admin".to_string())
 | 
			
		||||
            .ldap_user_dn(UserId::new("admin"))
 | 
			
		||||
            .ldap_user_pass(secstr::SecUtf8::from("test"))
 | 
			
		||||
            .build()
 | 
			
		||||
            .unwrap();
 | 
			
		||||
        let handler = SqlBackendHandler::new(config, sql_pool);
 | 
			
		||||
        handler
 | 
			
		||||
            .bind(BindRequest {
 | 
			
		||||
                name: "admin".to_string(),
 | 
			
		||||
                name: UserId::new("admin"),
 | 
			
		||||
                password: "test".to_string(),
 | 
			
		||||
            })
 | 
			
		||||
            .await
 | 
			
		||||
@ -506,21 +513,21 @@ mod tests {
 | 
			
		||||
 | 
			
		||||
        handler
 | 
			
		||||
            .bind(BindRequest {
 | 
			
		||||
                name: "bob".to_string(),
 | 
			
		||||
                name: UserId::new("bob"),
 | 
			
		||||
                password: "bob00".to_string(),
 | 
			
		||||
            })
 | 
			
		||||
            .await
 | 
			
		||||
            .unwrap();
 | 
			
		||||
        handler
 | 
			
		||||
            .bind(BindRequest {
 | 
			
		||||
                name: "andrew".to_string(),
 | 
			
		||||
                name: UserId::new("andrew"),
 | 
			
		||||
                password: "bob00".to_string(),
 | 
			
		||||
            })
 | 
			
		||||
            .await
 | 
			
		||||
            .unwrap_err();
 | 
			
		||||
        handler
 | 
			
		||||
            .bind(BindRequest {
 | 
			
		||||
                name: "bob".to_string(),
 | 
			
		||||
                name: UserId::new("bob"),
 | 
			
		||||
                password: "wrong_password".to_string(),
 | 
			
		||||
            })
 | 
			
		||||
            .await
 | 
			
		||||
@ -536,7 +543,7 @@ mod tests {
 | 
			
		||||
 | 
			
		||||
        handler
 | 
			
		||||
            .bind(BindRequest {
 | 
			
		||||
                name: "bob".to_string(),
 | 
			
		||||
                name: UserId::new("bob"),
 | 
			
		||||
                password: "bob00".to_string(),
 | 
			
		||||
            })
 | 
			
		||||
            .await
 | 
			
		||||
@ -557,47 +564,44 @@ mod tests {
 | 
			
		||||
                .await
 | 
			
		||||
                .unwrap()
 | 
			
		||||
                .into_iter()
 | 
			
		||||
                .map(|u| u.user_id)
 | 
			
		||||
                .map(|u| u.user_id.to_string())
 | 
			
		||||
                .collect::<Vec<_>>();
 | 
			
		||||
            assert_eq!(users, vec!["John", "bob", "patrick"]);
 | 
			
		||||
            assert_eq!(users, vec!["bob", "john", "patrick"]);
 | 
			
		||||
        }
 | 
			
		||||
        {
 | 
			
		||||
            let users = handler
 | 
			
		||||
                .list_users(Some(UserRequestFilter::Equality(
 | 
			
		||||
                    "user_id".to_string(),
 | 
			
		||||
                    "bob".to_string(),
 | 
			
		||||
                )))
 | 
			
		||||
                .list_users(Some(UserRequestFilter::UserId(UserId::new("bob"))))
 | 
			
		||||
                .await
 | 
			
		||||
                .unwrap()
 | 
			
		||||
                .into_iter()
 | 
			
		||||
                .map(|u| u.user_id)
 | 
			
		||||
                .map(|u| u.user_id.to_string())
 | 
			
		||||
                .collect::<Vec<_>>();
 | 
			
		||||
            assert_eq!(users, vec!["bob"]);
 | 
			
		||||
        }
 | 
			
		||||
        {
 | 
			
		||||
            let users = handler
 | 
			
		||||
                .list_users(Some(UserRequestFilter::Or(vec![
 | 
			
		||||
                    UserRequestFilter::Equality("user_id".to_string(), "bob".to_string()),
 | 
			
		||||
                    UserRequestFilter::Equality("user_id".to_string(), "John".to_string()),
 | 
			
		||||
                    UserRequestFilter::UserId(UserId::new("bob")),
 | 
			
		||||
                    UserRequestFilter::UserId(UserId::new("John")),
 | 
			
		||||
                ])))
 | 
			
		||||
                .await
 | 
			
		||||
                .unwrap()
 | 
			
		||||
                .into_iter()
 | 
			
		||||
                .map(|u| u.user_id)
 | 
			
		||||
                .map(|u| u.user_id.to_string())
 | 
			
		||||
                .collect::<Vec<_>>();
 | 
			
		||||
            assert_eq!(users, vec!["John", "bob"]);
 | 
			
		||||
            assert_eq!(users, vec!["bob", "john"]);
 | 
			
		||||
        }
 | 
			
		||||
        {
 | 
			
		||||
            let users = handler
 | 
			
		||||
                .list_users(Some(UserRequestFilter::Not(Box::new(
 | 
			
		||||
                    UserRequestFilter::Equality("user_id".to_string(), "bob".to_string()),
 | 
			
		||||
                    UserRequestFilter::UserId(UserId::new("bob")),
 | 
			
		||||
                ))))
 | 
			
		||||
                .await
 | 
			
		||||
                .unwrap()
 | 
			
		||||
                .into_iter()
 | 
			
		||||
                .map(|u| u.user_id)
 | 
			
		||||
                .map(|u| u.user_id.to_string())
 | 
			
		||||
                .collect::<Vec<_>>();
 | 
			
		||||
            assert_eq!(users, vec!["John", "patrick"]);
 | 
			
		||||
            assert_eq!(users, vec!["john", "patrick"]);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -622,7 +626,7 @@ mod tests {
 | 
			
		||||
                Group {
 | 
			
		||||
                    id: group_1,
 | 
			
		||||
                    display_name: "Best Group".to_string(),
 | 
			
		||||
                    users: vec!["bob".to_string(), "patrick".to_string()]
 | 
			
		||||
                    users: vec![UserId::new("bob"), UserId::new("patrick")]
 | 
			
		||||
                },
 | 
			
		||||
                Group {
 | 
			
		||||
                    id: group_3,
 | 
			
		||||
@ -632,7 +636,7 @@ mod tests {
 | 
			
		||||
                Group {
 | 
			
		||||
                    id: group_2,
 | 
			
		||||
                    display_name: "Worst Group".to_string(),
 | 
			
		||||
                    users: vec!["John".to_string(), "patrick".to_string()]
 | 
			
		||||
                    users: vec![UserId::new("john"), UserId::new("patrick")]
 | 
			
		||||
                },
 | 
			
		||||
            ]
 | 
			
		||||
        );
 | 
			
		||||
@ -640,7 +644,7 @@ mod tests {
 | 
			
		||||
            handler
 | 
			
		||||
                .list_groups(Some(GroupRequestFilter::Or(vec![
 | 
			
		||||
                    GroupRequestFilter::DisplayName("Empty Group".to_string()),
 | 
			
		||||
                    GroupRequestFilter::Member("bob".to_string()),
 | 
			
		||||
                    GroupRequestFilter::Member(UserId::new("bob")),
 | 
			
		||||
                ])))
 | 
			
		||||
                .await
 | 
			
		||||
                .unwrap(),
 | 
			
		||||
@ -648,7 +652,7 @@ mod tests {
 | 
			
		||||
                Group {
 | 
			
		||||
                    id: group_1,
 | 
			
		||||
                    display_name: "Best Group".to_string(),
 | 
			
		||||
                    users: vec!["bob".to_string(), "patrick".to_string()]
 | 
			
		||||
                    users: vec![UserId::new("bob"), UserId::new("patrick")]
 | 
			
		||||
                },
 | 
			
		||||
                Group {
 | 
			
		||||
                    id: group_3,
 | 
			
		||||
@ -670,7 +674,7 @@ mod tests {
 | 
			
		||||
            vec![Group {
 | 
			
		||||
                id: group_1,
 | 
			
		||||
                display_name: "Best Group".to_string(),
 | 
			
		||||
                users: vec!["bob".to_string(), "patrick".to_string()]
 | 
			
		||||
                users: vec![UserId::new("bob"), UserId::new("patrick")]
 | 
			
		||||
            }]
 | 
			
		||||
        );
 | 
			
		||||
    }
 | 
			
		||||
@ -682,13 +686,35 @@ mod tests {
 | 
			
		||||
        let handler = SqlBackendHandler::new(config, sql_pool);
 | 
			
		||||
        insert_user(&handler, "bob", "bob00").await;
 | 
			
		||||
        {
 | 
			
		||||
            let user = handler.get_user_details("bob").await.unwrap();
 | 
			
		||||
            assert_eq!(user.user_id, "bob".to_string());
 | 
			
		||||
            let user = handler.get_user_details(&UserId::new("bob")).await.unwrap();
 | 
			
		||||
            assert_eq!(user.user_id.as_str(), "bob");
 | 
			
		||||
        }
 | 
			
		||||
        {
 | 
			
		||||
            handler.get_user_details("John").await.unwrap_err();
 | 
			
		||||
            handler
 | 
			
		||||
                .get_user_details(&UserId::new("John"))
 | 
			
		||||
                .await
 | 
			
		||||
                .unwrap_err();
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[tokio::test]
 | 
			
		||||
    async fn test_user_lowercase() {
 | 
			
		||||
        let sql_pool = get_initialized_db().await;
 | 
			
		||||
        let config = get_default_config();
 | 
			
		||||
        let handler = SqlBackendHandler::new(config, sql_pool);
 | 
			
		||||
        insert_user(&handler, "Bob", "bob00").await;
 | 
			
		||||
        {
 | 
			
		||||
            let user = handler.get_user_details(&UserId::new("bOb")).await.unwrap();
 | 
			
		||||
            assert_eq!(user.user_id.as_str(), "bob");
 | 
			
		||||
        }
 | 
			
		||||
        {
 | 
			
		||||
            handler
 | 
			
		||||
                .get_user_details(&UserId::new("John"))
 | 
			
		||||
                .await
 | 
			
		||||
                .unwrap_err();
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[tokio::test]
 | 
			
		||||
    async fn test_get_user_groups() {
 | 
			
		||||
        let sql_pool = get_initialized_db().await;
 | 
			
		||||
@ -707,13 +733,19 @@ mod tests {
 | 
			
		||||
        let mut patrick_groups = HashSet::new();
 | 
			
		||||
        patrick_groups.insert(GroupIdAndName(group_1, "Group1".to_string()));
 | 
			
		||||
        patrick_groups.insert(GroupIdAndName(group_2, "Group2".to_string()));
 | 
			
		||||
        assert_eq!(handler.get_user_groups("bob").await.unwrap(), bob_groups);
 | 
			
		||||
        assert_eq!(
 | 
			
		||||
            handler.get_user_groups("patrick").await.unwrap(),
 | 
			
		||||
            handler.get_user_groups(&UserId::new("bob")).await.unwrap(),
 | 
			
		||||
            bob_groups
 | 
			
		||||
        );
 | 
			
		||||
        assert_eq!(
 | 
			
		||||
            handler
 | 
			
		||||
                .get_user_groups(&UserId::new("patrick"))
 | 
			
		||||
                .await
 | 
			
		||||
                .unwrap(),
 | 
			
		||||
            patrick_groups
 | 
			
		||||
        );
 | 
			
		||||
        assert_eq!(
 | 
			
		||||
            handler.get_user_groups("John").await.unwrap(),
 | 
			
		||||
            handler.get_user_groups(&UserId::new("John")).await.unwrap(),
 | 
			
		||||
            HashSet::new()
 | 
			
		||||
        );
 | 
			
		||||
    }
 | 
			
		||||
@ -729,29 +761,29 @@ mod tests {
 | 
			
		||||
        insert_user(&handler, "Jennz", "boupBoup").await;
 | 
			
		||||
 | 
			
		||||
        // Remove a user
 | 
			
		||||
        let _request_result = handler.delete_user("Jennz").await.unwrap();
 | 
			
		||||
        let _request_result = handler.delete_user(&UserId::new("Jennz")).await.unwrap();
 | 
			
		||||
 | 
			
		||||
        let users = handler
 | 
			
		||||
            .list_users(None)
 | 
			
		||||
            .await
 | 
			
		||||
            .unwrap()
 | 
			
		||||
            .into_iter()
 | 
			
		||||
            .map(|u| u.user_id)
 | 
			
		||||
            .map(|u| u.user_id.to_string())
 | 
			
		||||
            .collect::<Vec<_>>();
 | 
			
		||||
 | 
			
		||||
        assert_eq!(users, vec!["Hector", "val"]);
 | 
			
		||||
        assert_eq!(users, vec!["hector", "val"]);
 | 
			
		||||
 | 
			
		||||
        // Insert new user and remove two
 | 
			
		||||
        insert_user(&handler, "NewBoi", "Joni").await;
 | 
			
		||||
        let _request_result = handler.delete_user("Hector").await.unwrap();
 | 
			
		||||
        let _request_result = handler.delete_user("NewBoi").await.unwrap();
 | 
			
		||||
        let _request_result = handler.delete_user(&UserId::new("Hector")).await.unwrap();
 | 
			
		||||
        let _request_result = handler.delete_user(&UserId::new("NewBoi")).await.unwrap();
 | 
			
		||||
 | 
			
		||||
        let users = handler
 | 
			
		||||
            .list_users(None)
 | 
			
		||||
            .await
 | 
			
		||||
            .unwrap()
 | 
			
		||||
            .into_iter()
 | 
			
		||||
            .map(|u| u.user_id)
 | 
			
		||||
            .map(|u| u.user_id.to_string())
 | 
			
		||||
            .collect::<Vec<_>>();
 | 
			
		||||
 | 
			
		||||
        assert_eq!(users, vec!["val"]);
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,6 @@
 | 
			
		||||
use super::{
 | 
			
		||||
    error::*,
 | 
			
		||||
    handler::{BindRequest, LoginHandler},
 | 
			
		||||
    handler::{BindRequest, LoginHandler, UserId},
 | 
			
		||||
    opaque_handler::*,
 | 
			
		||||
    sql_backend_handler::SqlBackendHandler,
 | 
			
		||||
    sql_tables::*,
 | 
			
		||||
@ -18,7 +18,7 @@ fn passwords_match(
 | 
			
		||||
    password_file_bytes: &[u8],
 | 
			
		||||
    clear_password: &str,
 | 
			
		||||
    server_setup: &opaque::server::ServerSetup,
 | 
			
		||||
    username: &str,
 | 
			
		||||
    username: &UserId,
 | 
			
		||||
) -> Result<()> {
 | 
			
		||||
    use opaque::{client, server};
 | 
			
		||||
    let mut rng = rand::rngs::OsRng;
 | 
			
		||||
@ -31,7 +31,7 @@ fn passwords_match(
 | 
			
		||||
        server_setup,
 | 
			
		||||
        Some(password_file),
 | 
			
		||||
        client_login_start_result.message,
 | 
			
		||||
        username,
 | 
			
		||||
        username.as_str(),
 | 
			
		||||
    )?;
 | 
			
		||||
    client::login::finish_login(
 | 
			
		||||
        client_login_start_result.state,
 | 
			
		||||
@ -88,13 +88,16 @@ impl LoginHandler for SqlBackendHandler {
 | 
			
		||||
                return Ok(());
 | 
			
		||||
            } else {
 | 
			
		||||
                debug!(r#"Invalid password for LDAP bind user"#);
 | 
			
		||||
                return Err(DomainError::AuthenticationError(request.name));
 | 
			
		||||
                return Err(DomainError::AuthenticationError(format!(
 | 
			
		||||
                    " for user '{}'",
 | 
			
		||||
                    request.name
 | 
			
		||||
                )));
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        let query = Query::select()
 | 
			
		||||
            .column(Users::PasswordHash)
 | 
			
		||||
            .from(Users::Table)
 | 
			
		||||
            .and_where(Expr::col(Users::UserId).eq(request.name.as_str()))
 | 
			
		||||
            .and_where(Expr::col(Users::UserId).eq(&request.name))
 | 
			
		||||
            .to_string(DbQueryBuilder {});
 | 
			
		||||
        if let Ok(row) = sqlx::query(&query).fetch_one(&self.sql_pool).await {
 | 
			
		||||
            if let Some(password_hash) =
 | 
			
		||||
@ -106,17 +109,20 @@ impl LoginHandler for SqlBackendHandler {
 | 
			
		||||
                    self.config.get_server_setup(),
 | 
			
		||||
                    &request.name,
 | 
			
		||||
                ) {
 | 
			
		||||
                    debug!(r#"Invalid password for "{}": {}"#, request.name, e);
 | 
			
		||||
                    debug!(r#"Invalid password for "{}": {}"#, &request.name, e);
 | 
			
		||||
                } else {
 | 
			
		||||
                    return Ok(());
 | 
			
		||||
                }
 | 
			
		||||
            } else {
 | 
			
		||||
                debug!(r#"User "{}" has no password"#, request.name);
 | 
			
		||||
                debug!(r#"User "{}" has no password"#, &request.name);
 | 
			
		||||
            }
 | 
			
		||||
        } else {
 | 
			
		||||
            debug!(r#"No user found for "{}""#, request.name);
 | 
			
		||||
            debug!(r#"No user found for "{}""#, &request.name);
 | 
			
		||||
        }
 | 
			
		||||
        Err(DomainError::AuthenticationError(request.name))
 | 
			
		||||
        Err(DomainError::AuthenticationError(format!(
 | 
			
		||||
            " for user '{}'",
 | 
			
		||||
            request.name
 | 
			
		||||
        )))
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -150,7 +156,7 @@ impl OpaqueHandler for SqlOpaqueHandler {
 | 
			
		||||
        })
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn login_finish(&self, request: login::ClientLoginFinishRequest) -> Result<String> {
 | 
			
		||||
    async fn login_finish(&self, request: login::ClientLoginFinishRequest) -> Result<UserId> {
 | 
			
		||||
        let secret_key = self.get_orion_secret_key()?;
 | 
			
		||||
        let login::ServerData {
 | 
			
		||||
            username,
 | 
			
		||||
@ -165,7 +171,7 @@ impl OpaqueHandler for SqlOpaqueHandler {
 | 
			
		||||
            opaque::server::login::finish_login(server_login, request.credential_finalization)?
 | 
			
		||||
                .session_key;
 | 
			
		||||
 | 
			
		||||
        Ok(username)
 | 
			
		||||
        Ok(UserId::new(&username))
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn registration_start(
 | 
			
		||||
@ -220,7 +226,7 @@ impl OpaqueHandler for SqlOpaqueHandler {
 | 
			
		||||
/// Convenience function to set a user's password.
 | 
			
		||||
pub(crate) async fn register_password(
 | 
			
		||||
    opaque_handler: &SqlOpaqueHandler,
 | 
			
		||||
    username: &str,
 | 
			
		||||
    username: &UserId,
 | 
			
		||||
    password: &SecUtf8,
 | 
			
		||||
) -> Result<()> {
 | 
			
		||||
    let mut rng = rand::rngs::OsRng;
 | 
			
		||||
@ -278,7 +284,7 @@ mod tests {
 | 
			
		||||
    async fn insert_user_no_password(handler: &SqlBackendHandler, name: &str) {
 | 
			
		||||
        handler
 | 
			
		||||
            .create_user(CreateUserRequest {
 | 
			
		||||
                user_id: name.to_string(),
 | 
			
		||||
                user_id: UserId::new(name),
 | 
			
		||||
                email: "bob@bob.bob".to_string(),
 | 
			
		||||
                ..Default::default()
 | 
			
		||||
            })
 | 
			
		||||
@ -323,7 +329,12 @@ mod tests {
 | 
			
		||||
        attempt_login(&opaque_handler, "bob", "bob00")
 | 
			
		||||
            .await
 | 
			
		||||
            .unwrap_err();
 | 
			
		||||
        register_password(&opaque_handler, "bob", &secstr::SecUtf8::from("bob00")).await?;
 | 
			
		||||
        register_password(
 | 
			
		||||
            &opaque_handler,
 | 
			
		||||
            &UserId::new("bob"),
 | 
			
		||||
            &secstr::SecUtf8::from("bob00"),
 | 
			
		||||
        )
 | 
			
		||||
        .await?;
 | 
			
		||||
        attempt_login(&opaque_handler, "bob", "wrong_password")
 | 
			
		||||
            .await
 | 
			
		||||
            .unwrap_err();
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
use super::handler::GroupId;
 | 
			
		||||
use super::handler::{GroupId, UserId};
 | 
			
		||||
use sea_query::*;
 | 
			
		||||
 | 
			
		||||
pub type Pool = sqlx::sqlite::SqlitePool;
 | 
			
		||||
@ -37,6 +37,43 @@ where
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl<DB> sqlx::Type<DB> for UserId
 | 
			
		||||
where
 | 
			
		||||
    DB: sqlx::Database,
 | 
			
		||||
    String: sqlx::Type<DB>,
 | 
			
		||||
{
 | 
			
		||||
    fn type_info() -> <DB as sqlx::Database>::TypeInfo {
 | 
			
		||||
        <String as sqlx::Type<DB>>::type_info()
 | 
			
		||||
    }
 | 
			
		||||
    fn compatible(ty: &<DB as sqlx::Database>::TypeInfo) -> bool {
 | 
			
		||||
        <String as sqlx::Type<DB>>::compatible(ty)
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl<'r, DB> sqlx::Decode<'r, DB> for UserId
 | 
			
		||||
where
 | 
			
		||||
    DB: sqlx::Database,
 | 
			
		||||
    String: sqlx::Decode<'r, DB>,
 | 
			
		||||
{
 | 
			
		||||
    fn decode(
 | 
			
		||||
        value: <DB as sqlx::database::HasValueRef<'r>>::ValueRef,
 | 
			
		||||
    ) -> Result<Self, Box<dyn std::error::Error + Sync + Send + 'static>> {
 | 
			
		||||
        <String as sqlx::Decode<'r, DB>>::decode(value).map(|s| UserId::new(&s))
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl From<UserId> for sea_query::Value {
 | 
			
		||||
    fn from(user_id: UserId) -> Self {
 | 
			
		||||
        user_id.into_string().into()
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl From<&UserId> for sea_query::Value {
 | 
			
		||||
    fn from(user_id: &UserId) -> Self {
 | 
			
		||||
        user_id.as_str().into()
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Iden)]
 | 
			
		||||
pub enum Users {
 | 
			
		||||
    Table,
 | 
			
		||||
 | 
			
		||||
@ -25,7 +25,7 @@ use lldap_auth::{login, opaque, password_reset, registration, JWTClaims};
 | 
			
		||||
use crate::{
 | 
			
		||||
    domain::{
 | 
			
		||||
        error::DomainError,
 | 
			
		||||
        handler::{BackendHandler, BindRequest, GroupIdAndName, LoginHandler},
 | 
			
		||||
        handler::{BackendHandler, BindRequest, GroupIdAndName, LoginHandler, UserId},
 | 
			
		||||
        opaque_handler::OpaqueHandler,
 | 
			
		||||
    },
 | 
			
		||||
    infra::{
 | 
			
		||||
@ -51,7 +51,7 @@ fn create_jwt(key: &Hmac<Sha512>, user: String, groups: HashSet<GroupIdAndName>)
 | 
			
		||||
    jwt::Token::new(header, claims).sign_with_key(key).unwrap()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn parse_refresh_token(token: &str) -> std::result::Result<(u64, String), HttpResponse> {
 | 
			
		||||
fn parse_refresh_token(token: &str) -> std::result::Result<(u64, UserId), HttpResponse> {
 | 
			
		||||
    match token.split_once('+') {
 | 
			
		||||
        None => Err(HttpResponse::Unauthorized().body("Invalid refresh token")),
 | 
			
		||||
        Some((token, u)) => {
 | 
			
		||||
@ -60,12 +60,12 @@ fn parse_refresh_token(token: &str) -> std::result::Result<(u64, String), HttpRe
 | 
			
		||||
                token.hash(&mut s);
 | 
			
		||||
                s.finish()
 | 
			
		||||
            };
 | 
			
		||||
            Ok((refresh_token_hash, u.to_string()))
 | 
			
		||||
            Ok((refresh_token_hash, UserId::new(u)))
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn get_refresh_token(request: HttpRequest) -> std::result::Result<(u64, String), HttpResponse> {
 | 
			
		||||
fn get_refresh_token(request: HttpRequest) -> std::result::Result<(u64, UserId), HttpResponse> {
 | 
			
		||||
    match (
 | 
			
		||||
        request.cookie("refresh_token"),
 | 
			
		||||
        request.headers().get("refresh-token"),
 | 
			
		||||
@ -134,14 +134,14 @@ where
 | 
			
		||||
{
 | 
			
		||||
    let user_id = match request.match_info().get("user_id") {
 | 
			
		||||
        None => return HttpResponse::BadRequest().body("Missing user ID"),
 | 
			
		||||
        Some(id) => id,
 | 
			
		||||
        Some(id) => UserId::new(id),
 | 
			
		||||
    };
 | 
			
		||||
    let token = match data.backend_handler.start_password_reset(user_id).await {
 | 
			
		||||
    let token = match data.backend_handler.start_password_reset(&user_id).await {
 | 
			
		||||
        Err(e) => return HttpResponse::InternalServerError().body(e.to_string()),
 | 
			
		||||
        Ok(None) => return HttpResponse::Ok().finish(),
 | 
			
		||||
        Ok(Some(token)) => token,
 | 
			
		||||
    };
 | 
			
		||||
    let user = match data.backend_handler.get_user_details(user_id).await {
 | 
			
		||||
    let user = match data.backend_handler.get_user_details(&user_id).await {
 | 
			
		||||
        Err(e) => {
 | 
			
		||||
            warn!("Error getting used details: {:#?}", e);
 | 
			
		||||
            return HttpResponse::Ok().finish();
 | 
			
		||||
@ -196,7 +196,7 @@ where
 | 
			
		||||
                .finish(),
 | 
			
		||||
        )
 | 
			
		||||
        .json(&password_reset::ServerPasswordResetResponse {
 | 
			
		||||
            user_id,
 | 
			
		||||
            user_id: user_id.to_string(),
 | 
			
		||||
            token: token.as_str().to_owned(),
 | 
			
		||||
        })
 | 
			
		||||
}
 | 
			
		||||
@ -276,7 +276,7 @@ where
 | 
			
		||||
 | 
			
		||||
async fn get_login_successful_response<Backend>(
 | 
			
		||||
    data: &web::Data<AppState<Backend>>,
 | 
			
		||||
    name: &str,
 | 
			
		||||
    name: &UserId,
 | 
			
		||||
) -> HttpResponse
 | 
			
		||||
where
 | 
			
		||||
    Backend: TcpBackendHandler + BackendHandler,
 | 
			
		||||
@ -289,7 +289,7 @@ where
 | 
			
		||||
        .await
 | 
			
		||||
        .map(|(groups, (refresh_token, max_age))| {
 | 
			
		||||
            let token = create_jwt(&data.jwt_key, name.to_string(), groups);
 | 
			
		||||
            let refresh_token_plus_name = refresh_token + "+" + name;
 | 
			
		||||
            let refresh_token_plus_name = refresh_token + "+" + name.as_str();
 | 
			
		||||
 | 
			
		||||
            HttpResponse::Ok()
 | 
			
		||||
                .cookie(
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,7 @@
 | 
			
		||||
use crate::infra::cli::{GeneralConfigOpts, RunOpts, SmtpOpts, TestEmailOpts};
 | 
			
		||||
use crate::{
 | 
			
		||||
    domain::handler::UserId,
 | 
			
		||||
    infra::cli::{GeneralConfigOpts, RunOpts, SmtpOpts, TestEmailOpts},
 | 
			
		||||
};
 | 
			
		||||
use anyhow::{Context, Result};
 | 
			
		||||
use figment::{
 | 
			
		||||
    providers::{Env, Format, Serialized, Toml},
 | 
			
		||||
@ -49,8 +52,8 @@ pub struct Configuration {
 | 
			
		||||
    pub jwt_secret: SecUtf8,
 | 
			
		||||
    #[builder(default = r#"String::from("dc=example,dc=com")"#)]
 | 
			
		||||
    pub ldap_base_dn: String,
 | 
			
		||||
    #[builder(default = r#"String::from("admin")"#)]
 | 
			
		||||
    pub ldap_user_dn: String,
 | 
			
		||||
    #[builder(default = r#"UserId::new("admin")"#)]
 | 
			
		||||
    pub ldap_user_dn: UserId,
 | 
			
		||||
    #[builder(default = r#"SecUtf8::from("password")"#)]
 | 
			
		||||
    pub ldap_user_pass: SecUtf8,
 | 
			
		||||
    #[builder(default = r#"String::from("sqlite://users.db?mode=rwc")"#)]
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,5 @@
 | 
			
		||||
use crate::domain::handler::{
 | 
			
		||||
    BackendHandler, CreateUserRequest, GroupId, UpdateGroupRequest, UpdateUserRequest,
 | 
			
		||||
    BackendHandler, CreateUserRequest, GroupId, UpdateGroupRequest, UpdateUserRequest, UserId,
 | 
			
		||||
};
 | 
			
		||||
use juniper::{graphql_object, FieldResult, GraphQLInputObject, GraphQLObject};
 | 
			
		||||
 | 
			
		||||
@ -66,10 +66,11 @@ impl<Handler: BackendHandler + Sync> Mutation<Handler> {
 | 
			
		||||
        if !context.validation_result.is_admin {
 | 
			
		||||
            return Err("Unauthorized user creation".into());
 | 
			
		||||
        }
 | 
			
		||||
        let user_id = UserId::new(&user.id);
 | 
			
		||||
        context
 | 
			
		||||
            .handler
 | 
			
		||||
            .create_user(CreateUserRequest {
 | 
			
		||||
                user_id: user.id.clone(),
 | 
			
		||||
                user_id: user_id.clone(),
 | 
			
		||||
                email: user.email,
 | 
			
		||||
                display_name: user.display_name,
 | 
			
		||||
                first_name: user.first_name,
 | 
			
		||||
@ -78,7 +79,7 @@ impl<Handler: BackendHandler + Sync> Mutation<Handler> {
 | 
			
		||||
            .await?;
 | 
			
		||||
        Ok(context
 | 
			
		||||
            .handler
 | 
			
		||||
            .get_user_details(&user.id)
 | 
			
		||||
            .get_user_details(&user_id)
 | 
			
		||||
            .await
 | 
			
		||||
            .map(Into::into)?)
 | 
			
		||||
    }
 | 
			
		||||
@ -108,7 +109,7 @@ impl<Handler: BackendHandler + Sync> Mutation<Handler> {
 | 
			
		||||
        context
 | 
			
		||||
            .handler
 | 
			
		||||
            .update_user(UpdateUserRequest {
 | 
			
		||||
                user_id: user.id,
 | 
			
		||||
                user_id: UserId::new(&user.id),
 | 
			
		||||
                email: user.email,
 | 
			
		||||
                display_name: user.display_name,
 | 
			
		||||
                first_name: user.first_name,
 | 
			
		||||
@ -148,7 +149,7 @@ impl<Handler: BackendHandler + Sync> Mutation<Handler> {
 | 
			
		||||
        }
 | 
			
		||||
        context
 | 
			
		||||
            .handler
 | 
			
		||||
            .add_user_to_group(&user_id, GroupId(group_id))
 | 
			
		||||
            .add_user_to_group(&UserId::new(&user_id), GroupId(group_id))
 | 
			
		||||
            .await?;
 | 
			
		||||
        Ok(Success::new())
 | 
			
		||||
    }
 | 
			
		||||
@ -166,7 +167,7 @@ impl<Handler: BackendHandler + Sync> Mutation<Handler> {
 | 
			
		||||
        }
 | 
			
		||||
        context
 | 
			
		||||
            .handler
 | 
			
		||||
            .remove_user_from_group(&user_id, GroupId(group_id))
 | 
			
		||||
            .remove_user_from_group(&UserId::new(&user_id), GroupId(group_id))
 | 
			
		||||
            .await?;
 | 
			
		||||
        Ok(Success::new())
 | 
			
		||||
    }
 | 
			
		||||
@ -178,7 +179,7 @@ impl<Handler: BackendHandler + Sync> Mutation<Handler> {
 | 
			
		||||
        if context.validation_result.user == user_id {
 | 
			
		||||
            return Err("Cannot delete current user".into());
 | 
			
		||||
        }
 | 
			
		||||
        context.handler.delete_user(&user_id).await?;
 | 
			
		||||
        context.handler.delete_user(&UserId::new(&user_id)).await?;
 | 
			
		||||
        Ok(Success::new())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
use crate::domain::handler::{BackendHandler, GroupId, GroupIdAndName};
 | 
			
		||||
use crate::domain::handler::{BackendHandler, GroupId, GroupIdAndName, UserId};
 | 
			
		||||
use juniper::{graphql_object, FieldResult, GraphQLInputObject};
 | 
			
		||||
use serde::{Deserialize, Serialize};
 | 
			
		||||
 | 
			
		||||
@ -48,6 +48,9 @@ impl TryInto<DomainRequestFilter> for RequestFilter {
 | 
			
		||||
            return Err("Multiple fields specified in request filter".to_string());
 | 
			
		||||
        }
 | 
			
		||||
        if let Some(e) = self.eq {
 | 
			
		||||
            if e.field.to_lowercase() == "uid" {
 | 
			
		||||
                return Ok(DomainRequestFilter::UserId(UserId::new(&e.value)));
 | 
			
		||||
            }
 | 
			
		||||
            return Ok(DomainRequestFilter::Equality(e.field, e.value));
 | 
			
		||||
        }
 | 
			
		||||
        if let Some(c) = self.any {
 | 
			
		||||
@ -109,7 +112,7 @@ impl<Handler: BackendHandler + Sync> Query<Handler> {
 | 
			
		||||
        }
 | 
			
		||||
        Ok(context
 | 
			
		||||
            .handler
 | 
			
		||||
            .get_user_details(&user_id)
 | 
			
		||||
            .get_user_details(&UserId::new(&user_id))
 | 
			
		||||
            .await
 | 
			
		||||
            .map(Into::into)?)
 | 
			
		||||
    }
 | 
			
		||||
@ -170,7 +173,7 @@ impl<Handler: BackendHandler> Default for User<Handler> {
 | 
			
		||||
#[graphql_object(context = Context<Handler>)]
 | 
			
		||||
impl<Handler: BackendHandler + Sync> User<Handler> {
 | 
			
		||||
    fn id(&self) -> &str {
 | 
			
		||||
        &self.user.user_id
 | 
			
		||||
        self.user.user_id.as_str()
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn email(&self) -> &str {
 | 
			
		||||
@ -260,7 +263,7 @@ impl<Handler: BackendHandler> From<DomainGroup> for Group<Handler> {
 | 
			
		||||
        Self {
 | 
			
		||||
            group_id: group.id.0,
 | 
			
		||||
            display_name: group.display_name,
 | 
			
		||||
            members: Some(group.users.into_iter().map(Into::into).collect()),
 | 
			
		||||
            members: Some(group.users.into_iter().map(UserId::into_string).collect()),
 | 
			
		||||
            _phantom: std::marker::PhantomData,
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
@ -305,10 +308,10 @@ mod tests {
 | 
			
		||||
 | 
			
		||||
        let mut mock = MockTestBackendHandler::new();
 | 
			
		||||
        mock.expect_get_user_details()
 | 
			
		||||
            .with(eq("bob"))
 | 
			
		||||
            .with(eq(UserId::new("bob")))
 | 
			
		||||
            .return_once(|_| {
 | 
			
		||||
                Ok(DomainUser {
 | 
			
		||||
                    user_id: "bob".to_string(),
 | 
			
		||||
                    user_id: UserId::new("bob"),
 | 
			
		||||
                    email: "bob@bobbers.on".to_string(),
 | 
			
		||||
                    ..Default::default()
 | 
			
		||||
                })
 | 
			
		||||
@ -316,7 +319,7 @@ mod tests {
 | 
			
		||||
        let mut groups = HashSet::new();
 | 
			
		||||
        groups.insert(GroupIdAndName(GroupId(3), "Bobbersons".to_string()));
 | 
			
		||||
        mock.expect_get_user_groups()
 | 
			
		||||
            .with(eq("bob"))
 | 
			
		||||
            .with(eq(UserId::new("bob")))
 | 
			
		||||
            .return_once(|_| Ok(groups));
 | 
			
		||||
 | 
			
		||||
        let context = Context::<MockTestBackendHandler> {
 | 
			
		||||
@ -369,12 +372,12 @@ mod tests {
 | 
			
		||||
            .return_once(|_| {
 | 
			
		||||
                Ok(vec![
 | 
			
		||||
                    DomainUser {
 | 
			
		||||
                        user_id: "bob".to_string(),
 | 
			
		||||
                        user_id: UserId::new("bob"),
 | 
			
		||||
                        email: "bob@bobbers.on".to_string(),
 | 
			
		||||
                        ..Default::default()
 | 
			
		||||
                    },
 | 
			
		||||
                    DomainUser {
 | 
			
		||||
                        user_id: "robert".to_string(),
 | 
			
		||||
                        user_id: UserId::new("robert"),
 | 
			
		||||
                        email: "robert@bobbers.on".to_string(),
 | 
			
		||||
                        ..Default::default()
 | 
			
		||||
                    },
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,6 @@
 | 
			
		||||
use crate::domain::{
 | 
			
		||||
    handler::{
 | 
			
		||||
        BackendHandler, BindRequest, Group, GroupRequestFilter, LoginHandler, User,
 | 
			
		||||
        BackendHandler, BindRequest, Group, GroupRequestFilter, LoginHandler, User, UserId,
 | 
			
		||||
        UserRequestFilter,
 | 
			
		||||
    },
 | 
			
		||||
    opaque_handler::OpaqueHandler,
 | 
			
		||||
@ -71,7 +71,7 @@ fn get_user_id_from_distinguished_name(
 | 
			
		||||
    dn: &str,
 | 
			
		||||
    base_tree: &[(String, String)],
 | 
			
		||||
    base_dn_str: &str,
 | 
			
		||||
) -> Result<String> {
 | 
			
		||||
) -> Result<UserId> {
 | 
			
		||||
    let parts = parse_distinguished_name(dn).context("while parsing a user ID")?;
 | 
			
		||||
    if !is_subtree(&parts, base_tree) {
 | 
			
		||||
        bail!("Not a subtree of the base tree");
 | 
			
		||||
@ -84,7 +84,7 @@ fn get_user_id_from_distinguished_name(
 | 
			
		||||
                base_dn_str
 | 
			
		||||
            );
 | 
			
		||||
        }
 | 
			
		||||
        Ok(parts[0].1.to_string())
 | 
			
		||||
        Ok(UserId::new(&parts[0].1))
 | 
			
		||||
    } else {
 | 
			
		||||
        bail!(
 | 
			
		||||
            r#"Unexpected user DN format. Got "{}", expected: "cn=username,ou=people,{}""#,
 | 
			
		||||
@ -103,7 +103,7 @@ fn get_user_attribute(user: &User, attribute: &str, dn: &str) -> Result<Vec<Stri
 | 
			
		||||
            "person".to_string(),
 | 
			
		||||
        ]),
 | 
			
		||||
        "dn" => Ok(vec![dn.to_string()]),
 | 
			
		||||
        "uid" => Ok(vec![user.user_id.clone()]),
 | 
			
		||||
        "uid" => Ok(vec![user.user_id.to_string()]),
 | 
			
		||||
        "mail" => Ok(vec![user.email.clone()]),
 | 
			
		||||
        "givenname" => Ok(vec![user.first_name.clone()]),
 | 
			
		||||
        "sn" => Ok(vec![user.last_name.clone()]),
 | 
			
		||||
@ -118,7 +118,7 @@ fn make_ldap_search_user_result_entry(
 | 
			
		||||
    base_dn_str: &str,
 | 
			
		||||
    attributes: &[String],
 | 
			
		||||
) -> Result<LdapSearchResultEntry> {
 | 
			
		||||
    let dn = format!("cn={},ou=people,{}", user.user_id, base_dn_str);
 | 
			
		||||
    let dn = format!("cn={},ou=people,{}", user.user_id.as_str(), base_dn_str);
 | 
			
		||||
    Ok(LdapSearchResultEntry {
 | 
			
		||||
        dn: dn.clone(),
 | 
			
		||||
        attributes: attributes
 | 
			
		||||
@ -264,17 +264,17 @@ fn root_dse_response(base_dn: &str) -> LdapOp {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub struct LdapHandler<Backend: BackendHandler + LoginHandler + OpaqueHandler> {
 | 
			
		||||
    dn: String,
 | 
			
		||||
    dn: UserId,
 | 
			
		||||
    backend_handler: Backend,
 | 
			
		||||
    pub base_dn: Vec<(String, String)>,
 | 
			
		||||
    base_dn_str: String,
 | 
			
		||||
    ldap_user_dn: String,
 | 
			
		||||
    ldap_user_dn: UserId,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend> {
 | 
			
		||||
    pub fn new(backend_handler: Backend, ldap_base_dn: String, ldap_user_dn: String) -> Self {
 | 
			
		||||
    pub fn new(backend_handler: Backend, ldap_base_dn: String, ldap_user_dn: UserId) -> Self {
 | 
			
		||||
        Self {
 | 
			
		||||
            dn: "Unauthenticated".to_string(),
 | 
			
		||||
            dn: UserId::new("unauthenticated"),
 | 
			
		||||
            backend_handler,
 | 
			
		||||
            base_dn: parse_distinguished_name(&ldap_base_dn).unwrap_or_else(|_| {
 | 
			
		||||
                panic!(
 | 
			
		||||
@ -282,7 +282,7 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
 | 
			
		||||
                    ldap_base_dn
 | 
			
		||||
                )
 | 
			
		||||
            }),
 | 
			
		||||
            ldap_user_dn: format!("cn={},ou=people,{}", ldap_user_dn, &ldap_base_dn),
 | 
			
		||||
            ldap_user_dn: UserId::new(&format!("cn={},ou=people,{}", ldap_user_dn, &ldap_base_dn)),
 | 
			
		||||
            base_dn_str: ldap_base_dn,
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
@ -307,14 +307,14 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
 | 
			
		||||
            .await
 | 
			
		||||
        {
 | 
			
		||||
            Ok(()) => {
 | 
			
		||||
                self.dn = request.dn.clone();
 | 
			
		||||
                self.dn = UserId::new(&request.dn);
 | 
			
		||||
                (LdapResultCode::Success, "".to_string())
 | 
			
		||||
            }
 | 
			
		||||
            Err(_) => (LdapResultCode::InvalidCredentials, "".to_string()),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn change_password(&mut self, user: &str, password: &str) -> Result<()> {
 | 
			
		||||
    async fn change_password(&mut self, user: &UserId, password: &str) -> Result<()> {
 | 
			
		||||
        use lldap_auth::*;
 | 
			
		||||
        let mut rng = rand::rngs::OsRng;
 | 
			
		||||
        let registration_start_request =
 | 
			
		||||
@ -527,7 +527,7 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
 | 
			
		||||
            }
 | 
			
		||||
            LdapOp::SearchRequest(request) => self.do_search(&request).await,
 | 
			
		||||
            LdapOp::UnbindRequest => {
 | 
			
		||||
                self.dn = "Unauthenticated".to_string();
 | 
			
		||||
                self.dn = UserId::new("unauthenticated");
 | 
			
		||||
                // No need to notify on unbind (per rfc4511)
 | 
			
		||||
                return None;
 | 
			
		||||
            }
 | 
			
		||||
@ -617,10 +617,12 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
 | 
			
		||||
                        ))))
 | 
			
		||||
                    }
 | 
			
		||||
                } else {
 | 
			
		||||
                    Ok(UserRequestFilter::Equality(
 | 
			
		||||
                        map_field(field)?,
 | 
			
		||||
                        value.clone(),
 | 
			
		||||
                    ))
 | 
			
		||||
                    let field = map_field(field)?;
 | 
			
		||||
                    if field == "user_id" {
 | 
			
		||||
                        Ok(UserRequestFilter::UserId(UserId::new(value)))
 | 
			
		||||
                    } else {
 | 
			
		||||
                        Ok(UserRequestFilter::Equality(field, value.clone()))
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            LdapFilter::Present(field) => {
 | 
			
		||||
@ -661,17 +663,17 @@ mod tests {
 | 
			
		||||
        impl BackendHandler for TestBackendHandler {
 | 
			
		||||
            async fn list_users(&self, filters: Option<UserRequestFilter>) -> Result<Vec<User>>;
 | 
			
		||||
            async fn list_groups(&self, filters: Option<GroupRequestFilter>) -> Result<Vec<Group>>;
 | 
			
		||||
            async fn get_user_details(&self, user_id: &str) -> Result<User>;
 | 
			
		||||
            async fn get_user_details(&self, user_id: &UserId) -> Result<User>;
 | 
			
		||||
            async fn get_group_details(&self, group_id: GroupId) -> Result<GroupIdAndName>;
 | 
			
		||||
            async fn get_user_groups(&self, user: &str) -> Result<HashSet<GroupIdAndName>>;
 | 
			
		||||
            async fn get_user_groups(&self, user: &UserId) -> Result<HashSet<GroupIdAndName>>;
 | 
			
		||||
            async fn create_user(&self, request: CreateUserRequest) -> Result<()>;
 | 
			
		||||
            async fn update_user(&self, request: UpdateUserRequest) -> Result<()>;
 | 
			
		||||
            async fn update_group(&self, request: UpdateGroupRequest) -> Result<()>;
 | 
			
		||||
            async fn delete_user(&self, user_id: &str) -> Result<()>;
 | 
			
		||||
            async fn delete_user(&self, user_id: &UserId) -> Result<()>;
 | 
			
		||||
            async fn create_group(&self, group_name: &str) -> Result<GroupId>;
 | 
			
		||||
            async fn delete_group(&self, group_id: GroupId) -> Result<()>;
 | 
			
		||||
            async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> Result<()>;
 | 
			
		||||
            async fn remove_user_from_group(&self, user_id: &str, group_id: GroupId) -> Result<()>;
 | 
			
		||||
            async fn add_user_to_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()>;
 | 
			
		||||
            async fn remove_user_from_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()>;
 | 
			
		||||
        }
 | 
			
		||||
        #[async_trait]
 | 
			
		||||
        impl OpaqueHandler for TestBackendHandler {
 | 
			
		||||
@ -679,7 +681,7 @@ mod tests {
 | 
			
		||||
                &self,
 | 
			
		||||
                request: login::ClientLoginStartRequest
 | 
			
		||||
            ) -> Result<login::ServerLoginStartResponse>;
 | 
			
		||||
            async fn login_finish(&self, request: login::ClientLoginFinishRequest) -> Result<String>;
 | 
			
		||||
            async fn login_finish(&self, request: login::ClientLoginFinishRequest) -> Result<UserId>;
 | 
			
		||||
            async fn registration_start(
 | 
			
		||||
                &self,
 | 
			
		||||
                request: registration::ClientRegistrationStartRequest
 | 
			
		||||
@ -720,12 +722,12 @@ mod tests {
 | 
			
		||||
    ) -> LdapHandler<MockTestBackendHandler> {
 | 
			
		||||
        mock.expect_bind()
 | 
			
		||||
            .with(eq(BindRequest {
 | 
			
		||||
                name: "test".to_string(),
 | 
			
		||||
                name: UserId::new("test"),
 | 
			
		||||
                password: "pass".to_string(),
 | 
			
		||||
            }))
 | 
			
		||||
            .return_once(|_| Ok(()));
 | 
			
		||||
        let mut ldap_handler =
 | 
			
		||||
            LdapHandler::new(mock, "dc=example,dc=com".to_string(), "test".to_string());
 | 
			
		||||
            LdapHandler::new(mock, "dc=example,dc=com".to_string(), UserId::new("test"));
 | 
			
		||||
        let request = LdapBindRequest {
 | 
			
		||||
            dn: "cn=test,ou=people,dc=example,dc=com".to_string(),
 | 
			
		||||
            cred: LdapBindCred::Simple("pass".to_string()),
 | 
			
		||||
@ -742,13 +744,13 @@ mod tests {
 | 
			
		||||
        let mut mock = MockTestBackendHandler::new();
 | 
			
		||||
        mock.expect_bind()
 | 
			
		||||
            .with(eq(crate::domain::handler::BindRequest {
 | 
			
		||||
                name: "bob".to_string(),
 | 
			
		||||
                name: UserId::new("bob"),
 | 
			
		||||
                password: "pass".to_string(),
 | 
			
		||||
            }))
 | 
			
		||||
            .times(1)
 | 
			
		||||
            .return_once(|_| Ok(()));
 | 
			
		||||
        let mut ldap_handler =
 | 
			
		||||
            LdapHandler::new(mock, "dc=example,dc=com".to_string(), "test".to_string());
 | 
			
		||||
            LdapHandler::new(mock, "dc=example,dc=com".to_string(), UserId::new("test"));
 | 
			
		||||
 | 
			
		||||
        let request = LdapOp::BindRequest(LdapBindRequest {
 | 
			
		||||
            dn: "cn=bob,ou=people,dc=example,dc=com".to_string(),
 | 
			
		||||
@ -773,13 +775,13 @@ mod tests {
 | 
			
		||||
        let mut mock = MockTestBackendHandler::new();
 | 
			
		||||
        mock.expect_bind()
 | 
			
		||||
            .with(eq(crate::domain::handler::BindRequest {
 | 
			
		||||
                name: "test".to_string(),
 | 
			
		||||
                name: UserId::new("test"),
 | 
			
		||||
                password: "pass".to_string(),
 | 
			
		||||
            }))
 | 
			
		||||
            .times(1)
 | 
			
		||||
            .return_once(|_| Ok(()));
 | 
			
		||||
        let mut ldap_handler =
 | 
			
		||||
            LdapHandler::new(mock, "dc=example,dc=com".to_string(), "test".to_string());
 | 
			
		||||
            LdapHandler::new(mock, "dc=example,dc=com".to_string(), UserId::new("test"));
 | 
			
		||||
 | 
			
		||||
        let request = LdapBindRequest {
 | 
			
		||||
            dn: "cn=test,ou=people,dc=example,dc=com".to_string(),
 | 
			
		||||
@ -796,13 +798,13 @@ mod tests {
 | 
			
		||||
        let mut mock = MockTestBackendHandler::new();
 | 
			
		||||
        mock.expect_bind()
 | 
			
		||||
            .with(eq(crate::domain::handler::BindRequest {
 | 
			
		||||
                name: "test".to_string(),
 | 
			
		||||
                name: UserId::new("test"),
 | 
			
		||||
                password: "pass".to_string(),
 | 
			
		||||
            }))
 | 
			
		||||
            .times(1)
 | 
			
		||||
            .return_once(|_| Ok(()));
 | 
			
		||||
        let mut ldap_handler =
 | 
			
		||||
            LdapHandler::new(mock, "dc=example,dc=com".to_string(), "admin".to_string());
 | 
			
		||||
            LdapHandler::new(mock, "dc=example,dc=com".to_string(), UserId::new("admin"));
 | 
			
		||||
 | 
			
		||||
        let request = LdapBindRequest {
 | 
			
		||||
            dn: "cn=test,ou=people,dc=example,dc=com".to_string(),
 | 
			
		||||
@ -827,7 +829,7 @@ mod tests {
 | 
			
		||||
    async fn test_bind_invalid_dn() {
 | 
			
		||||
        let mock = MockTestBackendHandler::new();
 | 
			
		||||
        let mut ldap_handler =
 | 
			
		||||
            LdapHandler::new(mock, "dc=example,dc=com".to_string(), "admin".to_string());
 | 
			
		||||
            LdapHandler::new(mock, "dc=example,dc=com".to_string(), UserId::new("admin"));
 | 
			
		||||
 | 
			
		||||
        let request = LdapBindRequest {
 | 
			
		||||
            dn: "cn=bob,dc=example,dc=com".to_string(),
 | 
			
		||||
@ -903,7 +905,7 @@ mod tests {
 | 
			
		||||
        mock.expect_list_users().times(1).return_once(|_| {
 | 
			
		||||
            Ok(vec![
 | 
			
		||||
                User {
 | 
			
		||||
                    user_id: "bob_1".to_string(),
 | 
			
		||||
                    user_id: UserId::new("bob_1"),
 | 
			
		||||
                    email: "bob@bobmail.bob".to_string(),
 | 
			
		||||
                    display_name: "Bôb Böbberson".to_string(),
 | 
			
		||||
                    first_name: "Bôb".to_string(),
 | 
			
		||||
@ -911,7 +913,7 @@ mod tests {
 | 
			
		||||
                    ..Default::default()
 | 
			
		||||
                },
 | 
			
		||||
                User {
 | 
			
		||||
                    user_id: "jim".to_string(),
 | 
			
		||||
                    user_id: UserId::new("jim"),
 | 
			
		||||
                    email: "jim@cricket.jim".to_string(),
 | 
			
		||||
                    display_name: "Jimminy Cricket".to_string(),
 | 
			
		||||
                    first_name: "Jim".to_string(),
 | 
			
		||||
@ -1037,12 +1039,12 @@ mod tests {
 | 
			
		||||
                    Group {
 | 
			
		||||
                        id: GroupId(1),
 | 
			
		||||
                        display_name: "group_1".to_string(),
 | 
			
		||||
                        users: vec!["bob".to_string(), "john".to_string()],
 | 
			
		||||
                        users: vec![UserId::new("bob"), UserId::new("john")],
 | 
			
		||||
                    },
 | 
			
		||||
                    Group {
 | 
			
		||||
                        id: GroupId(3),
 | 
			
		||||
                        display_name: "bestgroup".to_string(),
 | 
			
		||||
                        users: vec!["john".to_string()],
 | 
			
		||||
                        users: vec![UserId::new("john")],
 | 
			
		||||
                    },
 | 
			
		||||
                ])
 | 
			
		||||
            });
 | 
			
		||||
@ -1111,7 +1113,7 @@ mod tests {
 | 
			
		||||
        mock.expect_list_groups()
 | 
			
		||||
            .with(eq(Some(GroupRequestFilter::And(vec![
 | 
			
		||||
                GroupRequestFilter::DisplayName("group_1".to_string()),
 | 
			
		||||
                GroupRequestFilter::Member("bob".to_string()),
 | 
			
		||||
                GroupRequestFilter::Member(UserId::new("bob")),
 | 
			
		||||
                GroupRequestFilter::And(vec![]),
 | 
			
		||||
            ]))))
 | 
			
		||||
            .times(1)
 | 
			
		||||
@ -1250,10 +1252,7 @@ mod tests {
 | 
			
		||||
        mock.expect_list_users()
 | 
			
		||||
            .with(eq(Some(UserRequestFilter::And(vec![
 | 
			
		||||
                UserRequestFilter::Or(vec![
 | 
			
		||||
                    UserRequestFilter::Not(Box::new(UserRequestFilter::Equality(
 | 
			
		||||
                        "user_id".to_string(),
 | 
			
		||||
                        "bob".to_string(),
 | 
			
		||||
                    ))),
 | 
			
		||||
                    UserRequestFilter::Not(Box::new(UserRequestFilter::UserId(UserId::new("bob")))),
 | 
			
		||||
                    UserRequestFilter::And(vec![]),
 | 
			
		||||
                    UserRequestFilter::Not(Box::new(UserRequestFilter::And(vec![]))),
 | 
			
		||||
                    UserRequestFilter::And(vec![]),
 | 
			
		||||
@ -1342,7 +1341,7 @@ mod tests {
 | 
			
		||||
            .times(1)
 | 
			
		||||
            .return_once(|_| {
 | 
			
		||||
                Ok(vec![User {
 | 
			
		||||
                    user_id: "bob_1".to_string(),
 | 
			
		||||
                    user_id: UserId::new("bob_1"),
 | 
			
		||||
                    ..Default::default()
 | 
			
		||||
                }])
 | 
			
		||||
            });
 | 
			
		||||
@ -1378,7 +1377,7 @@ mod tests {
 | 
			
		||||
        let mut mock = MockTestBackendHandler::new();
 | 
			
		||||
        mock.expect_list_users().times(1).return_once(|_| {
 | 
			
		||||
            Ok(vec![User {
 | 
			
		||||
                user_id: "bob_1".to_string(),
 | 
			
		||||
                user_id: UserId::new("bob_1"),
 | 
			
		||||
                email: "bob@bobmail.bob".to_string(),
 | 
			
		||||
                display_name: "Bôb Böbberson".to_string(),
 | 
			
		||||
                first_name: "Bôb".to_string(),
 | 
			
		||||
@ -1393,7 +1392,7 @@ mod tests {
 | 
			
		||||
                Ok(vec![Group {
 | 
			
		||||
                    id: GroupId(1),
 | 
			
		||||
                    display_name: "group_1".to_string(),
 | 
			
		||||
                    users: vec!["bob".to_string(), "john".to_string()],
 | 
			
		||||
                    users: vec![UserId::new("bob"), UserId::new("john")],
 | 
			
		||||
                }])
 | 
			
		||||
            });
 | 
			
		||||
        let mut ldap_handler = setup_bound_handler(mock).await;
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,5 @@
 | 
			
		||||
use super::{jwt_sql_tables::*, tcp_backend_handler::*};
 | 
			
		||||
use crate::domain::{error::*, sql_backend_handler::SqlBackendHandler};
 | 
			
		||||
use crate::domain::{error::*, handler::UserId, sql_backend_handler::SqlBackendHandler};
 | 
			
		||||
use async_trait::async_trait;
 | 
			
		||||
use futures_util::StreamExt;
 | 
			
		||||
use sea_query::{Expr, Iden, Query, SimpleExpr};
 | 
			
		||||
@ -34,7 +34,7 @@ impl TcpBackendHandler for SqlBackendHandler {
 | 
			
		||||
            .map_err(|e| anyhow::anyhow!(e))
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn create_refresh_token(&self, user: &str) -> Result<(String, chrono::Duration)> {
 | 
			
		||||
    async fn create_refresh_token(&self, user: &UserId) -> Result<(String, chrono::Duration)> {
 | 
			
		||||
        use std::collections::hash_map::DefaultHasher;
 | 
			
		||||
        use std::hash::{Hash, Hasher};
 | 
			
		||||
        // TODO: Initialize the rng only once. Maybe Arc<Cell>?
 | 
			
		||||
@ -62,7 +62,7 @@ impl TcpBackendHandler for SqlBackendHandler {
 | 
			
		||||
        Ok((refresh_token, duration))
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn check_token(&self, refresh_token_hash: u64, user: &str) -> Result<bool> {
 | 
			
		||||
    async fn check_token(&self, refresh_token_hash: u64, user: &UserId) -> Result<bool> {
 | 
			
		||||
        let query = Query::select()
 | 
			
		||||
            .expr(SimpleExpr::Value(1.into()))
 | 
			
		||||
            .from(JwtRefreshStorage::Table)
 | 
			
		||||
@ -74,7 +74,7 @@ impl TcpBackendHandler for SqlBackendHandler {
 | 
			
		||||
            .await?
 | 
			
		||||
            .is_some())
 | 
			
		||||
    }
 | 
			
		||||
    async fn blacklist_jwts(&self, user: &str) -> Result<HashSet<u64>> {
 | 
			
		||||
    async fn blacklist_jwts(&self, user: &UserId) -> Result<HashSet<u64>> {
 | 
			
		||||
        use sqlx::Result;
 | 
			
		||||
        let query = Query::select()
 | 
			
		||||
            .column(JwtStorage::JwtHash)
 | 
			
		||||
@ -106,7 +106,7 @@ impl TcpBackendHandler for SqlBackendHandler {
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn start_password_reset(&self, user: &str) -> Result<Option<String>> {
 | 
			
		||||
    async fn start_password_reset(&self, user: &UserId) -> Result<Option<String>> {
 | 
			
		||||
        let query = Query::select()
 | 
			
		||||
            .column(Users::UserId)
 | 
			
		||||
            .from(Users::Table)
 | 
			
		||||
@ -138,7 +138,7 @@ impl TcpBackendHandler for SqlBackendHandler {
 | 
			
		||||
        Ok(Some(token))
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result<String> {
 | 
			
		||||
    async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result<UserId> {
 | 
			
		||||
        let query = Query::select()
 | 
			
		||||
            .column(PasswordResetTokens::UserId)
 | 
			
		||||
            .from(PasswordResetTokens::Table)
 | 
			
		||||
 | 
			
		||||
@ -1,22 +1,22 @@
 | 
			
		||||
use async_trait::async_trait;
 | 
			
		||||
use std::collections::HashSet;
 | 
			
		||||
 | 
			
		||||
use crate::domain::error::Result;
 | 
			
		||||
use crate::domain::{error::Result, handler::UserId};
 | 
			
		||||
 | 
			
		||||
#[async_trait]
 | 
			
		||||
pub trait TcpBackendHandler {
 | 
			
		||||
    async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>>;
 | 
			
		||||
    async fn create_refresh_token(&self, user: &str) -> Result<(String, chrono::Duration)>;
 | 
			
		||||
    async fn check_token(&self, refresh_token_hash: u64, user: &str) -> Result<bool>;
 | 
			
		||||
    async fn blacklist_jwts(&self, user: &str) -> Result<HashSet<u64>>;
 | 
			
		||||
    async fn create_refresh_token(&self, user: &UserId) -> Result<(String, chrono::Duration)>;
 | 
			
		||||
    async fn check_token(&self, refresh_token_hash: u64, user: &UserId) -> Result<bool>;
 | 
			
		||||
    async fn blacklist_jwts(&self, user: &UserId) -> Result<HashSet<u64>>;
 | 
			
		||||
    async fn delete_refresh_token(&self, refresh_token_hash: u64) -> Result<()>;
 | 
			
		||||
 | 
			
		||||
    /// Request a token to reset a user's password.
 | 
			
		||||
    /// If the user doesn't exist, returns `Ok(None)`, otherwise `Ok(Some(token))`.
 | 
			
		||||
    async fn start_password_reset(&self, user: &str) -> Result<Option<String>>;
 | 
			
		||||
    async fn start_password_reset(&self, user: &UserId) -> Result<Option<String>>;
 | 
			
		||||
 | 
			
		||||
    /// Get the user ID associated with a password reset token.
 | 
			
		||||
    async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result<String>;
 | 
			
		||||
    async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result<UserId>;
 | 
			
		||||
 | 
			
		||||
    async fn delete_password_reset_token(&self, token: &str) -> Result<()>;
 | 
			
		||||
}
 | 
			
		||||
@ -37,27 +37,27 @@ mockall::mock! {
 | 
			
		||||
    impl BackendHandler for TestTcpBackendHandler {
 | 
			
		||||
        async fn list_users(&self, filters: Option<UserRequestFilter>) -> Result<Vec<User>>;
 | 
			
		||||
        async fn list_groups(&self, filters: Option<GroupRequestFilter>) -> Result<Vec<Group>>;
 | 
			
		||||
        async fn get_user_details(&self, user_id: &str) -> Result<User>;
 | 
			
		||||
        async fn get_user_details(&self, user_id: &UserId) -> Result<User>;
 | 
			
		||||
        async fn get_group_details(&self, group_id: GroupId) -> Result<GroupIdAndName>;
 | 
			
		||||
        async fn get_user_groups(&self, user: &str) -> Result<HashSet<GroupIdAndName>>;
 | 
			
		||||
        async fn get_user_groups(&self, user: &UserId) -> Result<HashSet<GroupIdAndName>>;
 | 
			
		||||
        async fn create_user(&self, request: CreateUserRequest) -> Result<()>;
 | 
			
		||||
        async fn update_user(&self, request: UpdateUserRequest) -> Result<()>;
 | 
			
		||||
        async fn update_group(&self, request: UpdateGroupRequest) -> Result<()>;
 | 
			
		||||
        async fn delete_user(&self, user_id: &str) -> Result<()>;
 | 
			
		||||
        async fn delete_user(&self, user_id: &UserId) -> Result<()>;
 | 
			
		||||
        async fn create_group(&self, group_name: &str) -> Result<GroupId>;
 | 
			
		||||
        async fn delete_group(&self, group_id: GroupId) -> Result<()>;
 | 
			
		||||
        async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> Result<()>;
 | 
			
		||||
        async fn remove_user_from_group(&self, user_id: &str, group_id: GroupId) -> Result<()>;
 | 
			
		||||
        async fn add_user_to_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()>;
 | 
			
		||||
        async fn remove_user_from_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()>;
 | 
			
		||||
    }
 | 
			
		||||
    #[async_trait]
 | 
			
		||||
    impl TcpBackendHandler for TestTcpBackendHandler {
 | 
			
		||||
        async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>>;
 | 
			
		||||
        async fn create_refresh_token(&self, user: &str) -> Result<(String, chrono::Duration)>;
 | 
			
		||||
        async fn check_token(&self, refresh_token_hash: u64, user: &str) -> Result<bool>;
 | 
			
		||||
        async fn blacklist_jwts(&self, user: &str) -> Result<HashSet<u64>>;
 | 
			
		||||
        async fn create_refresh_token(&self, user: &UserId) -> Result<(String, chrono::Duration)>;
 | 
			
		||||
        async fn check_token(&self, refresh_token_hash: u64, user: &UserId) -> Result<bool>;
 | 
			
		||||
        async fn blacklist_jwts(&self, user: &UserId) -> Result<HashSet<u64>>;
 | 
			
		||||
        async fn delete_refresh_token(&self, refresh_token_hash: u64) -> Result<()>;
 | 
			
		||||
        async fn start_password_reset(&self, user: &str) -> Result<Option<String>>;
 | 
			
		||||
        async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result<String>;
 | 
			
		||||
        async fn start_password_reset(&self, user: &UserId) -> Result<Option<String>>;
 | 
			
		||||
        async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result<UserId>;
 | 
			
		||||
        async fn delete_password_reset_token(&self, token: &str) -> Result<()>;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user