mirror of
				https://github.com/nitnelave/lldap.git
				synced 2023-04-12 14:25:13 +00:00 
			
		
		
		
	
							parent
							
								
									82f6292927
								
							
						
					
					
						commit
						539dd46ec8
					
				
							
								
								
									
										50
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										50
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							@ -1998,6 +1998,8 @@ dependencies = [
 | 
			
		||||
 "tokio-util",
 | 
			
		||||
 "tracing",
 | 
			
		||||
 "tracing-actix-web",
 | 
			
		||||
 "tracing-attributes",
 | 
			
		||||
 "tracing-forest",
 | 
			
		||||
 "tracing-log",
 | 
			
		||||
 "tracing-subscriber",
 | 
			
		||||
]
 | 
			
		||||
@ -2092,6 +2094,15 @@ version = "0.1.0"
 | 
			
		||||
source = "registry+https://github.com/rust-lang/crates.io-index"
 | 
			
		||||
checksum = "ffbee8634e0d45d258acb448e7eaab3fce7a0a467395d4d9f228e3c1f01fb2e4"
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "matchers"
 | 
			
		||||
version = "0.1.0"
 | 
			
		||||
source = "registry+https://github.com/rust-lang/crates.io-index"
 | 
			
		||||
checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558"
 | 
			
		||||
dependencies = [
 | 
			
		||||
 "regex-automata",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "matches"
 | 
			
		||||
version = "0.1.9"
 | 
			
		||||
@ -2838,6 +2849,15 @@ dependencies = [
 | 
			
		||||
 "regex-syntax",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "regex-automata"
 | 
			
		||||
version = "0.1.10"
 | 
			
		||||
source = "registry+https://github.com/rust-lang/crates.io-index"
 | 
			
		||||
checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132"
 | 
			
		||||
dependencies = [
 | 
			
		||||
 "regex-syntax",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "regex-syntax"
 | 
			
		||||
version = "0.6.25"
 | 
			
		||||
@ -3546,18 +3566,18 @@ checksum = "b1141d4d61095b28419e22cb0bbf02755f5e54e0526f97f1e3d1d160e60885fb"
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "thiserror"
 | 
			
		||||
version = "1.0.30"
 | 
			
		||||
version = "1.0.31"
 | 
			
		||||
source = "registry+https://github.com/rust-lang/crates.io-index"
 | 
			
		||||
checksum = "854babe52e4df1653706b98fcfc05843010039b406875930a70e4d9644e5c417"
 | 
			
		||||
checksum = "bd829fe32373d27f76265620b5309d0340cb8550f523c1dda251d6298069069a"
 | 
			
		||||
dependencies = [
 | 
			
		||||
 "thiserror-impl",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "thiserror-impl"
 | 
			
		||||
version = "1.0.30"
 | 
			
		||||
version = "1.0.31"
 | 
			
		||||
source = "registry+https://github.com/rust-lang/crates.io-index"
 | 
			
		||||
checksum = "aa32fd3f627f367fe16f893e2597ae3c05020f8bba2666a4e6ea73d377e5714b"
 | 
			
		||||
checksum = "0396bc89e626244658bef819e22d0cc459e795a5ebe878e6ec336d1674a8d79a"
 | 
			
		||||
dependencies = [
 | 
			
		||||
 "proc-macro2",
 | 
			
		||||
 "quote",
 | 
			
		||||
@ -3744,9 +3764,9 @@ dependencies = [
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "tracing-attributes"
 | 
			
		||||
version = "0.1.15"
 | 
			
		||||
version = "0.1.21"
 | 
			
		||||
source = "registry+https://github.com/rust-lang/crates.io-index"
 | 
			
		||||
checksum = "c42e6fa53307c8a17e4ccd4dc81cf5ec38db9209f59b222210375b54ee40d1e2"
 | 
			
		||||
checksum = "cc6b8ad3567499f98a1db7a752b07a7c8c7c7c34c332ec00effb2b0027974b7c"
 | 
			
		||||
dependencies = [
 | 
			
		||||
 "proc-macro2",
 | 
			
		||||
 "quote",
 | 
			
		||||
@ -3762,6 +3782,20 @@ dependencies = [
 | 
			
		||||
 "lazy_static",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "tracing-forest"
 | 
			
		||||
version = "0.1.4"
 | 
			
		||||
source = "registry+https://github.com/rust-lang/crates.io-index"
 | 
			
		||||
checksum = "5db74d83f3fcda3ca1355dd91294098df02cc03d54e6cce81e40a18671c3fd7a"
 | 
			
		||||
dependencies = [
 | 
			
		||||
 "chrono",
 | 
			
		||||
 "smallvec",
 | 
			
		||||
 "thiserror",
 | 
			
		||||
 "tokio",
 | 
			
		||||
 "tracing",
 | 
			
		||||
 "tracing-subscriber",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "tracing-futures"
 | 
			
		||||
version = "0.2.5"
 | 
			
		||||
@ -3790,9 +3824,13 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
 | 
			
		||||
checksum = "80a4ddde70311d8da398062ecf6fc2c309337de6b0f77d6c27aff8d53f6fca52"
 | 
			
		||||
dependencies = [
 | 
			
		||||
 "ansi_term",
 | 
			
		||||
 "lazy_static",
 | 
			
		||||
 "matchers",
 | 
			
		||||
 "regex",
 | 
			
		||||
 "sharded-slab",
 | 
			
		||||
 "smallvec",
 | 
			
		||||
 "thread_local",
 | 
			
		||||
 "tracing",
 | 
			
		||||
 "tracing-core",
 | 
			
		||||
 "tracing-log",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
@ -41,10 +41,9 @@ tokio = { version = "1.13.1", features = ["full"] }
 | 
			
		||||
tokio-native-tls = "0.3"
 | 
			
		||||
tokio-util = "0.6.3"
 | 
			
		||||
tokio-stream = "*"
 | 
			
		||||
tracing = "*"
 | 
			
		||||
tracing-actix-web = "0.4.0-beta.7"
 | 
			
		||||
tracing-attributes = "^0.1.21"
 | 
			
		||||
tracing-log = "*"
 | 
			
		||||
tracing-subscriber = "0.3"
 | 
			
		||||
rand = { version = "0.8", features = ["small_rng", "getrandom"] }
 | 
			
		||||
juniper_actix = "0.4.0"
 | 
			
		||||
juniper = "0.15.6"
 | 
			
		||||
@ -53,6 +52,10 @@ itertools = "0.10.1"
 | 
			
		||||
[dependencies.opaque-ke]
 | 
			
		||||
version = "0.6"
 | 
			
		||||
 | 
			
		||||
[dependencies.tracing-subscriber]
 | 
			
		||||
version = "0.3"
 | 
			
		||||
features = ["env-filter", "tracing-log"]
 | 
			
		||||
 | 
			
		||||
[dependencies.lettre]
 | 
			
		||||
version = "0.10.0-rc.3"
 | 
			
		||||
features = [
 | 
			
		||||
@ -95,5 +98,12 @@ version = "*"
 | 
			
		||||
features = ["vendored"]
 | 
			
		||||
version = "*"
 | 
			
		||||
 | 
			
		||||
[dependencies.tracing-forest]
 | 
			
		||||
features = ["smallvec", "chrono", "tokio"]
 | 
			
		||||
version = "^0.1.4"
 | 
			
		||||
 | 
			
		||||
[dependencies.tracing]
 | 
			
		||||
version = "*"
 | 
			
		||||
 | 
			
		||||
[dev-dependencies]
 | 
			
		||||
mockall = "0.9.1"
 | 
			
		||||
 | 
			
		||||
@ -6,6 +6,7 @@ use sea_query::{Alias, Cond, Expr, Iden, Order, Query, SimpleExpr};
 | 
			
		||||
use sea_query_binder::SqlxBinder;
 | 
			
		||||
use sqlx::{query_as_with, query_with, FromRow, Row};
 | 
			
		||||
use std::collections::HashSet;
 | 
			
		||||
use tracing::{debug, instrument};
 | 
			
		||||
 | 
			
		||||
#[derive(Debug, Clone)]
 | 
			
		||||
pub struct SqlBackendHandler {
 | 
			
		||||
@ -110,11 +111,13 @@ fn get_group_filter_expr(filter: GroupRequestFilter) -> SimpleExpr {
 | 
			
		||||
 | 
			
		||||
#[async_trait]
 | 
			
		||||
impl BackendHandler for SqlBackendHandler {
 | 
			
		||||
    #[instrument(skip_all, level = "debug", ret, err)]
 | 
			
		||||
    async fn list_users(
 | 
			
		||||
        &self,
 | 
			
		||||
        filters: Option<UserRequestFilter>,
 | 
			
		||||
        get_groups: bool,
 | 
			
		||||
    ) -> Result<Vec<UserAndGroups>> {
 | 
			
		||||
        debug!(?filters, get_groups);
 | 
			
		||||
        let (query, values) = {
 | 
			
		||||
            let mut query_builder = Query::select()
 | 
			
		||||
                .column((Users::Table, Users::UserId))
 | 
			
		||||
@ -167,7 +170,8 @@ impl BackendHandler for SqlBackendHandler {
 | 
			
		||||
 | 
			
		||||
            query_builder.build_sqlx(DbQueryBuilder {})
 | 
			
		||||
        };
 | 
			
		||||
        log::error!("query: {}", &query);
 | 
			
		||||
 | 
			
		||||
        debug!(%query);
 | 
			
		||||
 | 
			
		||||
        // For group_by.
 | 
			
		||||
        use itertools::Itertools;
 | 
			
		||||
@ -199,11 +203,12 @@ impl BackendHandler for SqlBackendHandler {
 | 
			
		||||
                },
 | 
			
		||||
            });
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        Ok(users)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[instrument(skip_all, level = "debug", ret, err)]
 | 
			
		||||
    async fn list_groups(&self, filters: Option<GroupRequestFilter>) -> Result<Vec<Group>> {
 | 
			
		||||
        debug!(?filters);
 | 
			
		||||
        let (query, values) = {
 | 
			
		||||
            let mut query_builder = Query::select()
 | 
			
		||||
                .column((Groups::Table, Groups::GroupId))
 | 
			
		||||
@ -233,6 +238,7 @@ impl BackendHandler for SqlBackendHandler {
 | 
			
		||||
 | 
			
		||||
            query_builder.build_sqlx(DbQueryBuilder {})
 | 
			
		||||
        };
 | 
			
		||||
        debug!(%query);
 | 
			
		||||
 | 
			
		||||
        // For group_by.
 | 
			
		||||
        use itertools::Itertools;
 | 
			
		||||
@ -264,7 +270,9 @@ impl BackendHandler for SqlBackendHandler {
 | 
			
		||||
        Ok(groups)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[instrument(skip_all, level = "debug", ret, err)]
 | 
			
		||||
    async fn get_user_details(&self, user_id: &UserId) -> Result<User> {
 | 
			
		||||
        debug!(?user_id);
 | 
			
		||||
        let (query, values) = Query::select()
 | 
			
		||||
            .column(Users::UserId)
 | 
			
		||||
            .column(Users::Email)
 | 
			
		||||
@ -276,19 +284,23 @@ impl BackendHandler for SqlBackendHandler {
 | 
			
		||||
            .from(Users::Table)
 | 
			
		||||
            .cond_where(Expr::col(Users::UserId).eq(user_id))
 | 
			
		||||
            .build_sqlx(DbQueryBuilder {});
 | 
			
		||||
        debug!(%query);
 | 
			
		||||
 | 
			
		||||
        Ok(query_as_with::<_, User, _>(query.as_str(), values)
 | 
			
		||||
            .fetch_one(&self.sql_pool)
 | 
			
		||||
            .await?)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[instrument(skip_all, level = "debug", ret, err)]
 | 
			
		||||
    async fn get_group_details(&self, group_id: GroupId) -> Result<GroupIdAndName> {
 | 
			
		||||
        debug!(?group_id);
 | 
			
		||||
        let (query, values) = Query::select()
 | 
			
		||||
            .column(Groups::GroupId)
 | 
			
		||||
            .column(Groups::DisplayName)
 | 
			
		||||
            .from(Groups::Table)
 | 
			
		||||
            .cond_where(Expr::col(Groups::GroupId).eq(group_id))
 | 
			
		||||
            .build_sqlx(DbQueryBuilder {});
 | 
			
		||||
        debug!(%query);
 | 
			
		||||
 | 
			
		||||
        Ok(
 | 
			
		||||
            query_as_with::<_, GroupIdAndName, _>(query.as_str(), values)
 | 
			
		||||
@ -297,12 +309,9 @@ impl BackendHandler for SqlBackendHandler {
 | 
			
		||||
        )
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[instrument(skip_all, level = "debug", ret, err)]
 | 
			
		||||
    async fn get_user_groups(&self, user_id: &UserId) -> Result<HashSet<GroupIdAndName>> {
 | 
			
		||||
        if *user_id == self.config.ldap_user_dn {
 | 
			
		||||
            let mut groups = HashSet::new();
 | 
			
		||||
            groups.insert(GroupIdAndName(GroupId(1), "lldap_admin".to_string()));
 | 
			
		||||
            return Ok(groups);
 | 
			
		||||
        }
 | 
			
		||||
        debug!(?user_id);
 | 
			
		||||
        let (query, values) = Query::select()
 | 
			
		||||
            .column((Groups::Table, Groups::GroupId))
 | 
			
		||||
            .column(Groups::DisplayName)
 | 
			
		||||
@ -314,6 +323,7 @@ impl BackendHandler for SqlBackendHandler {
 | 
			
		||||
            )
 | 
			
		||||
            .cond_where(Expr::col(Memberships::UserId).eq(user_id))
 | 
			
		||||
            .build_sqlx(DbQueryBuilder {});
 | 
			
		||||
        debug!(%query);
 | 
			
		||||
 | 
			
		||||
        query_with(query.as_str(), values)
 | 
			
		||||
            // Extract the group id from the row.
 | 
			
		||||
@ -335,7 +345,9 @@ impl BackendHandler for SqlBackendHandler {
 | 
			
		||||
            .map_err(DomainError::DatabaseError)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[instrument(skip_all, level = "debug", err)]
 | 
			
		||||
    async fn create_user(&self, request: CreateUserRequest) -> Result<()> {
 | 
			
		||||
        debug!(user_id = ?request.user_id);
 | 
			
		||||
        let columns = vec![
 | 
			
		||||
            Users::UserId,
 | 
			
		||||
            Users::Email,
 | 
			
		||||
@ -356,13 +368,16 @@ impl BackendHandler for SqlBackendHandler {
 | 
			
		||||
                chrono::Utc::now().naive_utc().into(),
 | 
			
		||||
            ])
 | 
			
		||||
            .build_sqlx(DbQueryBuilder {});
 | 
			
		||||
        debug!(%query);
 | 
			
		||||
        query_with(query.as_str(), values)
 | 
			
		||||
            .execute(&self.sql_pool)
 | 
			
		||||
            .await?;
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[instrument(skip_all, level = "debug", err)]
 | 
			
		||||
    async fn update_user(&self, request: UpdateUserRequest) -> Result<()> {
 | 
			
		||||
        debug!(user_id = ?request.user_id);
 | 
			
		||||
        let mut values = Vec::new();
 | 
			
		||||
        if let Some(email) = request.email {
 | 
			
		||||
            values.push((Users::Email, email.into()));
 | 
			
		||||
@ -384,13 +399,16 @@ impl BackendHandler for SqlBackendHandler {
 | 
			
		||||
            .values(values)
 | 
			
		||||
            .cond_where(Expr::col(Users::UserId).eq(request.user_id))
 | 
			
		||||
            .build_sqlx(DbQueryBuilder {});
 | 
			
		||||
        debug!(%query);
 | 
			
		||||
        query_with(query.as_str(), values)
 | 
			
		||||
            .execute(&self.sql_pool)
 | 
			
		||||
            .await?;
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[instrument(skip_all, level = "debug", err)]
 | 
			
		||||
    async fn update_group(&self, request: UpdateGroupRequest) -> Result<()> {
 | 
			
		||||
        debug!(?request.group_id);
 | 
			
		||||
        let mut values = Vec::new();
 | 
			
		||||
        if let Some(display_name) = request.display_name {
 | 
			
		||||
            values.push((Groups::DisplayName, display_name.into()));
 | 
			
		||||
@ -403,29 +421,36 @@ impl BackendHandler for SqlBackendHandler {
 | 
			
		||||
            .values(values)
 | 
			
		||||
            .cond_where(Expr::col(Groups::GroupId).eq(request.group_id))
 | 
			
		||||
            .build_sqlx(DbQueryBuilder {});
 | 
			
		||||
        debug!(%query);
 | 
			
		||||
        query_with(query.as_str(), values)
 | 
			
		||||
            .execute(&self.sql_pool)
 | 
			
		||||
            .await?;
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[instrument(skip_all, level = "debug", err)]
 | 
			
		||||
    async fn delete_user(&self, user_id: &UserId) -> Result<()> {
 | 
			
		||||
        let (delete_query, values) = Query::delete()
 | 
			
		||||
        debug!(?user_id);
 | 
			
		||||
        let (query, values) = Query::delete()
 | 
			
		||||
            .from_table(Users::Table)
 | 
			
		||||
            .cond_where(Expr::col(Users::UserId).eq(user_id))
 | 
			
		||||
            .build_sqlx(DbQueryBuilder {});
 | 
			
		||||
        query_with(delete_query.as_str(), values)
 | 
			
		||||
        debug!(%query);
 | 
			
		||||
        query_with(query.as_str(), values)
 | 
			
		||||
            .execute(&self.sql_pool)
 | 
			
		||||
            .await?;
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[instrument(skip_all, level = "debug", ret, err)]
 | 
			
		||||
    async fn create_group(&self, group_name: &str) -> Result<GroupId> {
 | 
			
		||||
        debug!(?group_name);
 | 
			
		||||
        let (query, values) = Query::insert()
 | 
			
		||||
            .into_table(Groups::Table)
 | 
			
		||||
            .columns(vec![Groups::DisplayName])
 | 
			
		||||
            .values_panic(vec![group_name.into()])
 | 
			
		||||
            .build_sqlx(DbQueryBuilder {});
 | 
			
		||||
        debug!(%query);
 | 
			
		||||
        query_with(query.as_str(), values)
 | 
			
		||||
            .execute(&self.sql_pool)
 | 
			
		||||
            .await?;
 | 
			
		||||
@ -434,36 +459,45 @@ impl BackendHandler for SqlBackendHandler {
 | 
			
		||||
            .from(Groups::Table)
 | 
			
		||||
            .cond_where(Expr::col(Groups::DisplayName).eq(group_name))
 | 
			
		||||
            .build_sqlx(DbQueryBuilder {});
 | 
			
		||||
        debug!(%query);
 | 
			
		||||
        let row = query_with(query.as_str(), values)
 | 
			
		||||
            .fetch_one(&self.sql_pool)
 | 
			
		||||
            .await?;
 | 
			
		||||
        Ok(GroupId(row.get::<i32, _>(&*Groups::GroupId.to_string())))
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[instrument(skip_all, level = "debug", err)]
 | 
			
		||||
    async fn delete_group(&self, group_id: GroupId) -> Result<()> {
 | 
			
		||||
        let (delete_query, values) = Query::delete()
 | 
			
		||||
        debug!(?group_id);
 | 
			
		||||
        let (query, values) = Query::delete()
 | 
			
		||||
            .from_table(Groups::Table)
 | 
			
		||||
            .cond_where(Expr::col(Groups::GroupId).eq(group_id))
 | 
			
		||||
            .build_sqlx(DbQueryBuilder {});
 | 
			
		||||
        query_with(delete_query.as_str(), values)
 | 
			
		||||
            .execute(&self.sql_pool)
 | 
			
		||||
            .await?;
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn add_user_to_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()> {
 | 
			
		||||
        let (query, values) = Query::insert()
 | 
			
		||||
            .into_table(Memberships::Table)
 | 
			
		||||
            .columns(vec![Memberships::UserId, Memberships::GroupId])
 | 
			
		||||
            .values_panic(vec![user_id.into(), group_id.into()])
 | 
			
		||||
            .build_sqlx(DbQueryBuilder {});
 | 
			
		||||
        debug!(%query);
 | 
			
		||||
        query_with(query.as_str(), values)
 | 
			
		||||
            .execute(&self.sql_pool)
 | 
			
		||||
            .await?;
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[instrument(skip_all, level = "debug", err)]
 | 
			
		||||
    async fn add_user_to_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()> {
 | 
			
		||||
        debug!(?user_id, ?group_id);
 | 
			
		||||
        let (query, values) = Query::insert()
 | 
			
		||||
            .into_table(Memberships::Table)
 | 
			
		||||
            .columns(vec![Memberships::UserId, Memberships::GroupId])
 | 
			
		||||
            .values_panic(vec![user_id.into(), group_id.into()])
 | 
			
		||||
            .build_sqlx(DbQueryBuilder {});
 | 
			
		||||
        debug!(%query);
 | 
			
		||||
        query_with(query.as_str(), values)
 | 
			
		||||
            .execute(&self.sql_pool)
 | 
			
		||||
            .await?;
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[instrument(skip_all, level = "debug", err)]
 | 
			
		||||
    async fn remove_user_from_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()> {
 | 
			
		||||
        debug!(?user_id, ?group_id);
 | 
			
		||||
        let (query, values) = Query::delete()
 | 
			
		||||
            .from_table(Memberships::Table)
 | 
			
		||||
            .cond_where(
 | 
			
		||||
@ -472,6 +506,7 @@ impl BackendHandler for SqlBackendHandler {
 | 
			
		||||
                    .add(Expr::col(Memberships::UserId).eq(user_id)),
 | 
			
		||||
            )
 | 
			
		||||
            .build_sqlx(DbQueryBuilder {});
 | 
			
		||||
        debug!(%query);
 | 
			
		||||
        query_with(query.as_str(), values)
 | 
			
		||||
            .execute(&self.sql_pool)
 | 
			
		||||
            .await?;
 | 
			
		||||
 | 
			
		||||
@ -7,14 +7,15 @@ use super::{
 | 
			
		||||
};
 | 
			
		||||
use async_trait::async_trait;
 | 
			
		||||
use lldap_auth::opaque;
 | 
			
		||||
use log::*;
 | 
			
		||||
use sea_query::{Expr, Iden, Query};
 | 
			
		||||
use sea_query_binder::SqlxBinder;
 | 
			
		||||
use secstr::SecUtf8;
 | 
			
		||||
use sqlx::Row;
 | 
			
		||||
use tracing::{debug, instrument};
 | 
			
		||||
 | 
			
		||||
type SqlOpaqueHandler = SqlBackendHandler;
 | 
			
		||||
 | 
			
		||||
#[instrument(skip_all, level = "debug", err)]
 | 
			
		||||
fn passwords_match(
 | 
			
		||||
    password_file_bytes: &[u8],
 | 
			
		||||
    clear_password: &str,
 | 
			
		||||
@ -48,6 +49,7 @@ impl SqlBackendHandler {
 | 
			
		||||
        )?)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[instrument(skip_all, level = "debug", err)]
 | 
			
		||||
    async fn get_password_file_for_user(
 | 
			
		||||
        &self,
 | 
			
		||||
        username: &str,
 | 
			
		||||
@ -86,6 +88,7 @@ impl SqlBackendHandler {
 | 
			
		||||
 | 
			
		||||
#[async_trait]
 | 
			
		||||
impl LoginHandler for SqlBackendHandler {
 | 
			
		||||
    #[instrument(skip_all, level = "debug", err)]
 | 
			
		||||
    async fn bind(&self, request: BindRequest) -> Result<()> {
 | 
			
		||||
        if request.name == self.config.ldap_user_dn {
 | 
			
		||||
            if SecUtf8::from(request.password) == self.config.ldap_user_pass {
 | 
			
		||||
@ -135,6 +138,7 @@ impl LoginHandler for SqlBackendHandler {
 | 
			
		||||
 | 
			
		||||
#[async_trait]
 | 
			
		||||
impl OpaqueHandler for SqlOpaqueHandler {
 | 
			
		||||
    #[instrument(skip_all, level = "debug", err)]
 | 
			
		||||
    async fn login_start(
 | 
			
		||||
        &self,
 | 
			
		||||
        request: login::ClientLoginStartRequest,
 | 
			
		||||
@ -163,6 +167,7 @@ impl OpaqueHandler for SqlOpaqueHandler {
 | 
			
		||||
        })
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[instrument(skip_all, level = "debug", err)]
 | 
			
		||||
    async fn login_finish(&self, request: login::ClientLoginFinishRequest) -> Result<UserId> {
 | 
			
		||||
        let secret_key = self.get_orion_secret_key()?;
 | 
			
		||||
        let login::ServerData {
 | 
			
		||||
@ -181,6 +186,7 @@ impl OpaqueHandler for SqlOpaqueHandler {
 | 
			
		||||
        Ok(UserId::new(&username))
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[instrument(skip_all, level = "debug", err)]
 | 
			
		||||
    async fn registration_start(
 | 
			
		||||
        &self,
 | 
			
		||||
        request: registration::ClientRegistrationStartRequest,
 | 
			
		||||
@ -202,6 +208,7 @@ impl OpaqueHandler for SqlOpaqueHandler {
 | 
			
		||||
        })
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[instrument(skip_all, level = "debug", err)]
 | 
			
		||||
    async fn registration_finish(
 | 
			
		||||
        &self,
 | 
			
		||||
        request: registration::ClientRegistrationFinishRequest,
 | 
			
		||||
@ -230,6 +237,7 @@ impl OpaqueHandler for SqlOpaqueHandler {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/// Convenience function to set a user's password.
 | 
			
		||||
#[instrument(skip_all, level = "debug", err)]
 | 
			
		||||
pub(crate) async fn register_password(
 | 
			
		||||
    opaque_handler: &SqlOpaqueHandler,
 | 
			
		||||
    username: &UserId,
 | 
			
		||||
 | 
			
		||||
@ -16,9 +16,9 @@ use futures::future::{ok, Ready};
 | 
			
		||||
use futures_util::{FutureExt, TryFutureExt};
 | 
			
		||||
use hmac::Hmac;
 | 
			
		||||
use jwt::{SignWithKey, VerifyWithKey};
 | 
			
		||||
use log::*;
 | 
			
		||||
use sha2::Sha512;
 | 
			
		||||
use time::ext::NumericalDuration;
 | 
			
		||||
use tracing::{debug, instrument, warn};
 | 
			
		||||
 | 
			
		||||
use lldap_auth::{login, opaque, password_reset, registration, JWTClaims};
 | 
			
		||||
 | 
			
		||||
@ -76,6 +76,7 @@ fn get_refresh_token(request: HttpRequest) -> std::result::Result<(u64, UserId),
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[instrument(skip_all, level = "debug")]
 | 
			
		||||
async fn get_refresh<Backend>(
 | 
			
		||||
    data: web::Data<AppState<Backend>>,
 | 
			
		||||
    request: HttpRequest,
 | 
			
		||||
@ -95,15 +96,10 @@ where
 | 
			
		||||
        .await;
 | 
			
		||||
    // Async closures are not supported yet.
 | 
			
		||||
    match res_found {
 | 
			
		||||
        Ok(found) => {
 | 
			
		||||
            if found {
 | 
			
		||||
                backend_handler.get_user_groups(&user).await
 | 
			
		||||
            } else {
 | 
			
		||||
                Err(DomainError::AuthenticationError(
 | 
			
		||||
                    "Invalid refresh token".to_string(),
 | 
			
		||||
                ))
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        Ok(true) => backend_handler.get_user_groups(&user).await,
 | 
			
		||||
        Ok(false) => Err(DomainError::AuthenticationError(
 | 
			
		||||
            "Invalid refresh token".to_string(),
 | 
			
		||||
        )),
 | 
			
		||||
        Err(e) => Err(e),
 | 
			
		||||
    }
 | 
			
		||||
    .map(|groups| create_jwt(jwt_key, user.to_string(), groups))
 | 
			
		||||
@ -125,6 +121,7 @@ where
 | 
			
		||||
    .unwrap_or_else(error_to_http_response)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[instrument(skip_all, level = "debug")]
 | 
			
		||||
async fn get_password_reset_step1<Backend>(
 | 
			
		||||
    data: web::Data<AppState<Backend>>,
 | 
			
		||||
    request: HttpRequest,
 | 
			
		||||
@ -161,6 +158,7 @@ where
 | 
			
		||||
    HttpResponse::Ok().finish()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[instrument(skip_all, level = "debug")]
 | 
			
		||||
async fn get_password_reset_step2<Backend>(
 | 
			
		||||
    data: web::Data<AppState<Backend>>,
 | 
			
		||||
    request: HttpRequest,
 | 
			
		||||
@ -202,6 +200,7 @@ where
 | 
			
		||||
        })
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[instrument(skip_all, level = "debug")]
 | 
			
		||||
async fn get_logout<Backend>(
 | 
			
		||||
    data: web::Data<AppState<Backend>>,
 | 
			
		||||
    request: HttpRequest,
 | 
			
		||||
@ -261,6 +260,7 @@ pub(crate) fn error_to_api_response<T>(error: DomainError) -> ApiResult<T> {
 | 
			
		||||
 | 
			
		||||
pub type ApiResult<M> = actix_web::Either<web::Json<M>, HttpResponse>;
 | 
			
		||||
 | 
			
		||||
#[instrument(skip_all, level = "debug")]
 | 
			
		||||
async fn opaque_login_start<Backend>(
 | 
			
		||||
    data: web::Data<AppState<Backend>>,
 | 
			
		||||
    request: web::Json<login::ClientLoginStartRequest>,
 | 
			
		||||
@ -275,6 +275,7 @@ where
 | 
			
		||||
        .unwrap_or_else(error_to_api_response)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[instrument(skip_all, level = "debug")]
 | 
			
		||||
async fn get_login_successful_response<Backend>(
 | 
			
		||||
    data: &web::Data<AppState<Backend>>,
 | 
			
		||||
    name: &UserId,
 | 
			
		||||
@ -317,6 +318,7 @@ where
 | 
			
		||||
        .unwrap_or_else(error_to_http_response)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[instrument(skip_all, level = "debug")]
 | 
			
		||||
async fn opaque_login_finish<Backend>(
 | 
			
		||||
    data: web::Data<AppState<Backend>>,
 | 
			
		||||
    request: web::Json<login::ClientLoginFinishRequest>,
 | 
			
		||||
@ -335,6 +337,7 @@ where
 | 
			
		||||
    get_login_successful_response(&data, &name).await
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[instrument(skip_all, level = "debug")]
 | 
			
		||||
async fn simple_login<Backend>(
 | 
			
		||||
    data: web::Data<AppState<Backend>>,
 | 
			
		||||
    request: web::Json<login::ClientSimpleLoginRequest>,
 | 
			
		||||
@ -387,6 +390,7 @@ where
 | 
			
		||||
    get_login_successful_response(&data, &name).await
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[instrument(skip_all, level = "debug")]
 | 
			
		||||
async fn post_authorize<Backend>(
 | 
			
		||||
    data: web::Data<AppState<Backend>>,
 | 
			
		||||
    request: web::Json<BindRequest>,
 | 
			
		||||
@ -395,12 +399,14 @@ where
 | 
			
		||||
    Backend: TcpBackendHandler + BackendHandler + LoginHandler + 'static,
 | 
			
		||||
{
 | 
			
		||||
    let name = request.name.clone();
 | 
			
		||||
    debug!(%name);
 | 
			
		||||
    if let Err(e) = data.backend_handler.bind(request.into_inner()).await {
 | 
			
		||||
        return error_to_http_response(e);
 | 
			
		||||
    }
 | 
			
		||||
    get_login_successful_response(&data, &name).await
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[instrument(skip_all, level = "debug")]
 | 
			
		||||
async fn opaque_register_start<Backend>(
 | 
			
		||||
    request: actix_web::HttpRequest,
 | 
			
		||||
    mut payload: actix_web::web::Payload,
 | 
			
		||||
@ -450,6 +456,7 @@ where
 | 
			
		||||
        .unwrap_or_else(error_to_api_response)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[instrument(skip_all, level = "debug")]
 | 
			
		||||
async fn opaque_register_finish<Backend>(
 | 
			
		||||
    data: web::Data<AppState<Backend>>,
 | 
			
		||||
    request: web::Json<registration::ClientRegistrationFinishRequest>,
 | 
			
		||||
@ -530,6 +537,7 @@ pub enum Permission {
 | 
			
		||||
    Regular,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
pub struct ValidationResults {
 | 
			
		||||
    pub user: String,
 | 
			
		||||
    pub permission: Permission,
 | 
			
		||||
@ -567,6 +575,7 @@ impl ValidationResults {
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[instrument(skip_all, level = "debug", err, ret)]
 | 
			
		||||
pub(crate) fn check_if_token_is_valid<Backend>(
 | 
			
		||||
    state: &AppState<Backend>,
 | 
			
		||||
    token_str: &str,
 | 
			
		||||
 | 
			
		||||
@ -7,6 +7,7 @@ use chrono::Local;
 | 
			
		||||
use cron::Schedule;
 | 
			
		||||
use sea_query::{Expr, Query};
 | 
			
		||||
use std::{str::FromStr, time::Duration};
 | 
			
		||||
use tracing::{debug, error, info, instrument};
 | 
			
		||||
 | 
			
		||||
// Define actor
 | 
			
		||||
pub struct Scheduler {
 | 
			
		||||
@ -19,7 +20,7 @@ impl Actor for Scheduler {
 | 
			
		||||
    type Context = Context<Self>;
 | 
			
		||||
 | 
			
		||||
    fn started(&mut self, context: &mut Context<Self>) {
 | 
			
		||||
        log::info!("DB Cleanup Cron started");
 | 
			
		||||
        info!("DB Cleanup Cron started");
 | 
			
		||||
 | 
			
		||||
        context.run_later(self.duration_until_next(), move |this, ctx| {
 | 
			
		||||
            this.schedule_task(ctx)
 | 
			
		||||
@ -27,7 +28,7 @@ impl Actor for Scheduler {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn stopped(&mut self, _ctx: &mut Context<Self>) {
 | 
			
		||||
        log::info!("DB Cleanup stopped");
 | 
			
		||||
        info!("DB Cleanup stopped");
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -38,7 +39,6 @@ impl Scheduler {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn schedule_task(&self, ctx: &mut Context<Self>) {
 | 
			
		||||
        log::info!("Cleaning DB");
 | 
			
		||||
        let future = actix::fut::wrap_future::<_, Self>(Self::cleanup_db(self.sql_pool.clone()));
 | 
			
		||||
        ctx.spawn(future);
 | 
			
		||||
 | 
			
		||||
@ -47,17 +47,16 @@ impl Scheduler {
 | 
			
		||||
        });
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[instrument(skip_all)]
 | 
			
		||||
    async fn cleanup_db(sql_pool: Pool) {
 | 
			
		||||
        if let Err(e) = sqlx::query(
 | 
			
		||||
            &Query::delete()
 | 
			
		||||
                .from_table(JwtRefreshStorage::Table)
 | 
			
		||||
                .and_where(Expr::col(JwtRefreshStorage::ExpiryDate).lt(Local::now().naive_utc()))
 | 
			
		||||
                .to_string(DbQueryBuilder {}),
 | 
			
		||||
        )
 | 
			
		||||
        .execute(&sql_pool)
 | 
			
		||||
        .await
 | 
			
		||||
        {
 | 
			
		||||
            log::error!("DB error while cleaning up JWT refresh tokens: {}", e);
 | 
			
		||||
        info!("Cleaning DB");
 | 
			
		||||
        let query = Query::delete()
 | 
			
		||||
            .from_table(JwtRefreshStorage::Table)
 | 
			
		||||
            .and_where(Expr::col(JwtRefreshStorage::ExpiryDate).lt(Local::now().naive_utc()))
 | 
			
		||||
            .to_string(DbQueryBuilder {});
 | 
			
		||||
        debug!(%query);
 | 
			
		||||
        if let Err(e) = sqlx::query(&query).execute(&sql_pool).await {
 | 
			
		||||
            error!("DB error while cleaning up JWT refresh tokens: {}", e);
 | 
			
		||||
        };
 | 
			
		||||
        if let Err(e) = sqlx::query(
 | 
			
		||||
            &Query::delete()
 | 
			
		||||
@ -68,9 +67,9 @@ impl Scheduler {
 | 
			
		||||
        .execute(&sql_pool)
 | 
			
		||||
        .await
 | 
			
		||||
        {
 | 
			
		||||
            log::error!("DB error while cleaning up JWT storage: {}", e);
 | 
			
		||||
            error!("DB error while cleaning up JWT storage: {}", e);
 | 
			
		||||
        };
 | 
			
		||||
        log::info!("DB cleaned!");
 | 
			
		||||
        info!("DB cleaned!");
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn duration_until_next(&self) -> Duration {
 | 
			
		||||
 | 
			
		||||
@ -2,6 +2,7 @@ use crate::domain::handler::{
 | 
			
		||||
    BackendHandler, CreateUserRequest, GroupId, UpdateGroupRequest, UpdateUserRequest, UserId,
 | 
			
		||||
};
 | 
			
		||||
use juniper::{graphql_object, FieldResult, GraphQLInputObject, GraphQLObject};
 | 
			
		||||
use tracing::{debug, debug_span, Instrument};
 | 
			
		||||
 | 
			
		||||
use super::api::Context;
 | 
			
		||||
 | 
			
		||||
@ -63,7 +64,12 @@ impl<Handler: BackendHandler + Sync> Mutation<Handler> {
 | 
			
		||||
        context: &Context<Handler>,
 | 
			
		||||
        user: CreateUserInput,
 | 
			
		||||
    ) -> FieldResult<super::query::User<Handler>> {
 | 
			
		||||
        let span = debug_span!("[GraphQL mutation] create_user");
 | 
			
		||||
        span.in_scope(|| {
 | 
			
		||||
            debug!(?user.id);
 | 
			
		||||
        });
 | 
			
		||||
        if !context.validation_result.is_admin() {
 | 
			
		||||
            span.in_scope(|| debug!("Unauthorized"));
 | 
			
		||||
            return Err("Unauthorized user creation".into());
 | 
			
		||||
        }
 | 
			
		||||
        let user_id = UserId::new(&user.id);
 | 
			
		||||
@ -76,10 +82,12 @@ impl<Handler: BackendHandler + Sync> Mutation<Handler> {
 | 
			
		||||
                first_name: user.first_name,
 | 
			
		||||
                last_name: user.last_name,
 | 
			
		||||
            })
 | 
			
		||||
            .instrument(span.clone())
 | 
			
		||||
            .await?;
 | 
			
		||||
        Ok(context
 | 
			
		||||
            .handler
 | 
			
		||||
            .get_user_details(&user_id)
 | 
			
		||||
            .instrument(span)
 | 
			
		||||
            .await
 | 
			
		||||
            .map(Into::into)?)
 | 
			
		||||
    }
 | 
			
		||||
@ -88,13 +96,19 @@ impl<Handler: BackendHandler + Sync> Mutation<Handler> {
 | 
			
		||||
        context: &Context<Handler>,
 | 
			
		||||
        name: String,
 | 
			
		||||
    ) -> FieldResult<super::query::Group<Handler>> {
 | 
			
		||||
        let span = debug_span!("[GraphQL mutation] create_group");
 | 
			
		||||
        span.in_scope(|| {
 | 
			
		||||
            debug!(?name);
 | 
			
		||||
        });
 | 
			
		||||
        if !context.validation_result.is_admin() {
 | 
			
		||||
            span.in_scope(|| debug!("Unauthorized"));
 | 
			
		||||
            return Err("Unauthorized group creation".into());
 | 
			
		||||
        }
 | 
			
		||||
        let group_id = context.handler.create_group(&name).await?;
 | 
			
		||||
        Ok(context
 | 
			
		||||
            .handler
 | 
			
		||||
            .get_group_details(group_id)
 | 
			
		||||
            .instrument(span)
 | 
			
		||||
            .await
 | 
			
		||||
            .map(Into::into)?)
 | 
			
		||||
    }
 | 
			
		||||
@ -103,7 +117,12 @@ impl<Handler: BackendHandler + Sync> Mutation<Handler> {
 | 
			
		||||
        context: &Context<Handler>,
 | 
			
		||||
        user: UpdateUserInput,
 | 
			
		||||
    ) -> FieldResult<Success> {
 | 
			
		||||
        let span = debug_span!("[GraphQL mutation] update_user");
 | 
			
		||||
        span.in_scope(|| {
 | 
			
		||||
            debug!(?user.id);
 | 
			
		||||
        });
 | 
			
		||||
        if !context.validation_result.can_write(&user.id) {
 | 
			
		||||
            span.in_scope(|| debug!("Unauthorized"));
 | 
			
		||||
            return Err("Unauthorized user update".into());
 | 
			
		||||
        }
 | 
			
		||||
        context
 | 
			
		||||
@ -115,6 +134,7 @@ impl<Handler: BackendHandler + Sync> Mutation<Handler> {
 | 
			
		||||
                first_name: user.first_name,
 | 
			
		||||
                last_name: user.last_name,
 | 
			
		||||
            })
 | 
			
		||||
            .instrument(span)
 | 
			
		||||
            .await?;
 | 
			
		||||
        Ok(Success::new())
 | 
			
		||||
    }
 | 
			
		||||
@ -123,10 +143,16 @@ impl<Handler: BackendHandler + Sync> Mutation<Handler> {
 | 
			
		||||
        context: &Context<Handler>,
 | 
			
		||||
        group: UpdateGroupInput,
 | 
			
		||||
    ) -> FieldResult<Success> {
 | 
			
		||||
        let span = debug_span!("[GraphQL mutation] update_group");
 | 
			
		||||
        span.in_scope(|| {
 | 
			
		||||
            debug!(?group.id);
 | 
			
		||||
        });
 | 
			
		||||
        if !context.validation_result.is_admin() {
 | 
			
		||||
            span.in_scope(|| debug!("Unauthorized"));
 | 
			
		||||
            return Err("Unauthorized group update".into());
 | 
			
		||||
        }
 | 
			
		||||
        if group.id == 1 {
 | 
			
		||||
            span.in_scope(|| debug!("Cannot change admin group details"));
 | 
			
		||||
            return Err("Cannot change admin group details".into());
 | 
			
		||||
        }
 | 
			
		||||
        context
 | 
			
		||||
@ -135,6 +161,7 @@ impl<Handler: BackendHandler + Sync> Mutation<Handler> {
 | 
			
		||||
                group_id: GroupId(group.id),
 | 
			
		||||
                display_name: group.display_name,
 | 
			
		||||
            })
 | 
			
		||||
            .instrument(span)
 | 
			
		||||
            .await?;
 | 
			
		||||
        Ok(Success::new())
 | 
			
		||||
    }
 | 
			
		||||
@ -144,12 +171,18 @@ impl<Handler: BackendHandler + Sync> Mutation<Handler> {
 | 
			
		||||
        user_id: String,
 | 
			
		||||
        group_id: i32,
 | 
			
		||||
    ) -> FieldResult<Success> {
 | 
			
		||||
        let span = debug_span!("[GraphQL mutation] add_user_to_group");
 | 
			
		||||
        span.in_scope(|| {
 | 
			
		||||
            debug!(?user_id, ?group_id);
 | 
			
		||||
        });
 | 
			
		||||
        if !context.validation_result.is_admin() {
 | 
			
		||||
            span.in_scope(|| debug!("Unauthorized"));
 | 
			
		||||
            return Err("Unauthorized group membership modification".into());
 | 
			
		||||
        }
 | 
			
		||||
        context
 | 
			
		||||
            .handler
 | 
			
		||||
            .add_user_to_group(&UserId::new(&user_id), GroupId(group_id))
 | 
			
		||||
            .instrument(span)
 | 
			
		||||
            .await?;
 | 
			
		||||
        Ok(Success::new())
 | 
			
		||||
    }
 | 
			
		||||
@ -159,38 +192,65 @@ impl<Handler: BackendHandler + Sync> Mutation<Handler> {
 | 
			
		||||
        user_id: String,
 | 
			
		||||
        group_id: i32,
 | 
			
		||||
    ) -> FieldResult<Success> {
 | 
			
		||||
        let span = debug_span!("[GraphQL mutation] remove_user_from_group");
 | 
			
		||||
        span.in_scope(|| {
 | 
			
		||||
            debug!(?user_id, ?group_id);
 | 
			
		||||
        });
 | 
			
		||||
        if !context.validation_result.is_admin() {
 | 
			
		||||
            span.in_scope(|| debug!("Unauthorized"));
 | 
			
		||||
            return Err("Unauthorized group membership modification".into());
 | 
			
		||||
        }
 | 
			
		||||
        if context.validation_result.user == user_id && group_id == 1 {
 | 
			
		||||
            span.in_scope(|| debug!("Cannot remove admin rights for current user"));
 | 
			
		||||
            return Err("Cannot remove admin rights for current user".into());
 | 
			
		||||
        }
 | 
			
		||||
        context
 | 
			
		||||
            .handler
 | 
			
		||||
            .remove_user_from_group(&UserId::new(&user_id), GroupId(group_id))
 | 
			
		||||
            .instrument(span)
 | 
			
		||||
            .await?;
 | 
			
		||||
        Ok(Success::new())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn delete_user(context: &Context<Handler>, user_id: String) -> FieldResult<Success> {
 | 
			
		||||
        let span = debug_span!("[GraphQL mutation] delete_user");
 | 
			
		||||
        span.in_scope(|| {
 | 
			
		||||
            debug!(?user_id);
 | 
			
		||||
        });
 | 
			
		||||
        if !context.validation_result.is_admin() {
 | 
			
		||||
            span.in_scope(|| debug!("Unauthorized"));
 | 
			
		||||
            return Err("Unauthorized user deletion".into());
 | 
			
		||||
        }
 | 
			
		||||
        if context.validation_result.user == user_id {
 | 
			
		||||
            span.in_scope(|| debug!("Cannot delete current user"));
 | 
			
		||||
            return Err("Cannot delete current user".into());
 | 
			
		||||
        }
 | 
			
		||||
        context.handler.delete_user(&UserId::new(&user_id)).await?;
 | 
			
		||||
        context
 | 
			
		||||
            .handler
 | 
			
		||||
            .delete_user(&UserId::new(&user_id))
 | 
			
		||||
            .instrument(span)
 | 
			
		||||
            .await?;
 | 
			
		||||
        Ok(Success::new())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn delete_group(context: &Context<Handler>, group_id: i32) -> FieldResult<Success> {
 | 
			
		||||
        let span = debug_span!("[GraphQL mutation] delete_group");
 | 
			
		||||
        span.in_scope(|| {
 | 
			
		||||
            debug!(?group_id);
 | 
			
		||||
        });
 | 
			
		||||
        if !context.validation_result.is_admin() {
 | 
			
		||||
            span.in_scope(|| debug!("Unauthorized"));
 | 
			
		||||
            return Err("Unauthorized group deletion".into());
 | 
			
		||||
        }
 | 
			
		||||
        if group_id == 1 {
 | 
			
		||||
            span.in_scope(|| debug!("Cannot delete admin group"));
 | 
			
		||||
            return Err("Cannot delete admin group".into());
 | 
			
		||||
        }
 | 
			
		||||
        context.handler.delete_group(GroupId(group_id)).await?;
 | 
			
		||||
        context
 | 
			
		||||
            .handler
 | 
			
		||||
            .delete_group(GroupId(group_id))
 | 
			
		||||
            .instrument(span)
 | 
			
		||||
            .await?;
 | 
			
		||||
        Ok(Success::new())
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,7 @@
 | 
			
		||||
use crate::domain::handler::{BackendHandler, GroupId, GroupIdAndName, UserId};
 | 
			
		||||
use juniper::{graphql_object, FieldResult, GraphQLInputObject};
 | 
			
		||||
use serde::{Deserialize, Serialize};
 | 
			
		||||
use tracing::{debug, debug_span, Instrument};
 | 
			
		||||
 | 
			
		||||
type DomainRequestFilter = crate::domain::handler::UserRequestFilter;
 | 
			
		||||
type DomainUser = crate::domain::handler::User;
 | 
			
		||||
@ -108,12 +109,18 @@ impl<Handler: BackendHandler + Sync> Query<Handler> {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn user(context: &Context<Handler>, user_id: String) -> FieldResult<User<Handler>> {
 | 
			
		||||
        let span = debug_span!("[GraphQL query] user");
 | 
			
		||||
        span.in_scope(|| {
 | 
			
		||||
            debug!(?user_id);
 | 
			
		||||
        });
 | 
			
		||||
        if !context.validation_result.can_read(&user_id) {
 | 
			
		||||
            span.in_scope(|| debug!("Unauthorized"));
 | 
			
		||||
            return Err("Unauthorized access to user data".into());
 | 
			
		||||
        }
 | 
			
		||||
        Ok(context
 | 
			
		||||
            .handler
 | 
			
		||||
            .get_user_details(&UserId::new(&user_id))
 | 
			
		||||
            .instrument(span)
 | 
			
		||||
            .await
 | 
			
		||||
            .map(Into::into)?)
 | 
			
		||||
    }
 | 
			
		||||
@ -122,34 +129,49 @@ impl<Handler: BackendHandler + Sync> Query<Handler> {
 | 
			
		||||
        context: &Context<Handler>,
 | 
			
		||||
        #[graphql(name = "where")] filters: Option<RequestFilter>,
 | 
			
		||||
    ) -> FieldResult<Vec<User<Handler>>> {
 | 
			
		||||
        let span = debug_span!("[GraphQL query] users");
 | 
			
		||||
        span.in_scope(|| {
 | 
			
		||||
            debug!(?filters);
 | 
			
		||||
        });
 | 
			
		||||
        if !context.validation_result.is_admin_or_readonly() {
 | 
			
		||||
            span.in_scope(|| debug!("Unauthorized"));
 | 
			
		||||
            return Err("Unauthorized access to user list".into());
 | 
			
		||||
        }
 | 
			
		||||
        Ok(context
 | 
			
		||||
            .handler
 | 
			
		||||
            .list_users(filters.map(TryInto::try_into).transpose()?, false)
 | 
			
		||||
            .instrument(span)
 | 
			
		||||
            .await
 | 
			
		||||
            .map(|v| v.into_iter().map(Into::into).collect())?)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn groups(context: &Context<Handler>) -> FieldResult<Vec<Group<Handler>>> {
 | 
			
		||||
        let span = debug_span!("[GraphQL query] groups");
 | 
			
		||||
        if !context.validation_result.is_admin_or_readonly() {
 | 
			
		||||
            span.in_scope(|| debug!("Unauthorized"));
 | 
			
		||||
            return Err("Unauthorized access to group list".into());
 | 
			
		||||
        }
 | 
			
		||||
        Ok(context
 | 
			
		||||
            .handler
 | 
			
		||||
            .list_groups(None)
 | 
			
		||||
            .instrument(span)
 | 
			
		||||
            .await
 | 
			
		||||
            .map(|v| v.into_iter().map(Into::into).collect())?)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn group(context: &Context<Handler>, group_id: i32) -> FieldResult<Group<Handler>> {
 | 
			
		||||
        let span = debug_span!("[GraphQL query] group");
 | 
			
		||||
        span.in_scope(|| {
 | 
			
		||||
            debug!(?group_id);
 | 
			
		||||
        });
 | 
			
		||||
        if !context.validation_result.is_admin_or_readonly() {
 | 
			
		||||
            span.in_scope(|| debug!("Unauthorized"));
 | 
			
		||||
            return Err("Unauthorized access to group data".into());
 | 
			
		||||
        }
 | 
			
		||||
        Ok(context
 | 
			
		||||
            .handler
 | 
			
		||||
            .get_group_details(GroupId(group_id))
 | 
			
		||||
            .instrument(span)
 | 
			
		||||
            .await
 | 
			
		||||
            .map(Into::into)?)
 | 
			
		||||
    }
 | 
			
		||||
@ -199,9 +221,14 @@ impl<Handler: BackendHandler + Sync> User<Handler> {
 | 
			
		||||
 | 
			
		||||
    /// The groups to which this user belongs.
 | 
			
		||||
    async fn groups(&self, context: &Context<Handler>) -> FieldResult<Vec<Group<Handler>>> {
 | 
			
		||||
        let span = debug_span!("[GraphQL query] user::groups");
 | 
			
		||||
        span.in_scope(|| {
 | 
			
		||||
            debug!(user_id = ?self.user.user_id);
 | 
			
		||||
        });
 | 
			
		||||
        Ok(context
 | 
			
		||||
            .handler
 | 
			
		||||
            .get_user_groups(&self.user.user_id)
 | 
			
		||||
            .instrument(span)
 | 
			
		||||
            .await
 | 
			
		||||
            .map(|set| set.into_iter().map(Into::into).collect())?)
 | 
			
		||||
    }
 | 
			
		||||
@ -244,7 +271,12 @@ impl<Handler: BackendHandler + Sync> Group<Handler> {
 | 
			
		||||
    }
 | 
			
		||||
    /// The groups to which this user belongs.
 | 
			
		||||
    async fn users(&self, context: &Context<Handler>) -> FieldResult<Vec<User<Handler>>> {
 | 
			
		||||
        let span = debug_span!("[GraphQL query] group::users");
 | 
			
		||||
        span.in_scope(|| {
 | 
			
		||||
            debug!(name = %self.display_name);
 | 
			
		||||
        });
 | 
			
		||||
        if !context.validation_result.is_admin_or_readonly() {
 | 
			
		||||
            span.in_scope(|| debug!("Unauthorized"));
 | 
			
		||||
            return Err("Unauthorized access to group data".into());
 | 
			
		||||
        }
 | 
			
		||||
        Ok(context
 | 
			
		||||
@ -253,6 +285,7 @@ impl<Handler: BackendHandler + Sync> Group<Handler> {
 | 
			
		||||
                Some(DomainRequestFilter::MemberOfId(GroupId(self.group_id))),
 | 
			
		||||
                false,
 | 
			
		||||
            )
 | 
			
		||||
            .instrument(span)
 | 
			
		||||
            .await
 | 
			
		||||
            .map(|v| v.into_iter().map(Into::into).collect())?)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -15,7 +15,7 @@ use ldap3_server::proto::{
 | 
			
		||||
    LdapFilter, LdapOp, LdapPartialAttribute, LdapPasswordModifyRequest, LdapResult,
 | 
			
		||||
    LdapResultCode, LdapSearchRequest, LdapSearchResultEntry, LdapSearchScope,
 | 
			
		||||
};
 | 
			
		||||
use log::{debug, warn};
 | 
			
		||||
use tracing::{debug, instrument, warn};
 | 
			
		||||
 | 
			
		||||
#[derive(Debug, PartialEq, Eq, Clone)]
 | 
			
		||||
struct LdapDn(String);
 | 
			
		||||
@ -198,22 +198,26 @@ fn get_user_attribute(
 | 
			
		||||
    }))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn expand_attribute_wildcards(attributes: &[String], all_attribute_keys: &[&str]) -> Vec<String> {
 | 
			
		||||
    let mut attributes_out = attributes.to_owned();
 | 
			
		||||
#[instrument(skip_all, level = "debug")]
 | 
			
		||||
fn expand_attribute_wildcards(
 | 
			
		||||
    ldap_attributes: &[String],
 | 
			
		||||
    all_attribute_keys: &[&str],
 | 
			
		||||
) -> Vec<String> {
 | 
			
		||||
    let mut attributes_out = ldap_attributes.to_owned();
 | 
			
		||||
 | 
			
		||||
    if attributes_out.iter().any(|x| x == "*") || attributes_out.is_empty() {
 | 
			
		||||
        debug!(r#"Expanding * / empty attrs:"#);
 | 
			
		||||
        // Remove occurrences of '*'
 | 
			
		||||
        attributes_out.retain(|x| x != "*");
 | 
			
		||||
        // Splice in all non-operational attributes
 | 
			
		||||
        attributes_out.extend(all_attribute_keys.iter().map(|s| s.to_string()));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    debug!(r#"Expanded: "{:?}""#, &attributes_out);
 | 
			
		||||
 | 
			
		||||
    // Deduplicate, preserving order
 | 
			
		||||
    attributes_out.into_iter().unique().collect_vec()
 | 
			
		||||
    let resolved_attributes = attributes_out.into_iter().unique().collect_vec();
 | 
			
		||||
    debug!(?ldap_attributes, ?resolved_attributes);
 | 
			
		||||
    resolved_attributes
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const ALL_USER_ATTRIBUTE_KEYS: &[&str] = &[
 | 
			
		||||
    "objectclass",
 | 
			
		||||
    "dn",
 | 
			
		||||
@ -470,8 +474,9 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[instrument(skip_all, level = "debug")]
 | 
			
		||||
    pub async fn do_bind(&mut self, request: &LdapBindRequest) -> (LdapResultCode, String) {
 | 
			
		||||
        debug!(r#"Received bind request for "{}""#, &request.dn);
 | 
			
		||||
        debug!("DN: {}", &request.dn);
 | 
			
		||||
        let user_id = match get_user_id_from_distinguished_name(
 | 
			
		||||
            &request.dn.to_ascii_lowercase(),
 | 
			
		||||
            &self.base_dn,
 | 
			
		||||
@ -507,6 +512,7 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
 | 
			
		||||
                        Permission::Regular
 | 
			
		||||
                    },
 | 
			
		||||
                ));
 | 
			
		||||
                debug!("Success!");
 | 
			
		||||
                (LdapResultCode::Success, "".to_string())
 | 
			
		||||
            }
 | 
			
		||||
            Err(_) => (LdapResultCode::InvalidCredentials, "".to_string()),
 | 
			
		||||
@ -597,8 +603,8 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn do_search(&mut self, request: &LdapSearchRequest) -> Vec<LdapOp> {
 | 
			
		||||
        let user_filter = match &self.user_info {
 | 
			
		||||
    pub async fn do_search_or_dse(&mut self, request: &LdapSearchRequest) -> Vec<LdapOp> {
 | 
			
		||||
        let user_filter = match self.user_info.clone() {
 | 
			
		||||
            Some((_, Permission::Admin)) | Some((_, Permission::Readonly)) => None,
 | 
			
		||||
            Some((user_id, Permission::Regular)) => Some(user_id),
 | 
			
		||||
            None => {
 | 
			
		||||
@ -612,10 +618,19 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
 | 
			
		||||
            && request.scope == LdapSearchScope::Base
 | 
			
		||||
            && request.filter == LdapFilter::Present("objectClass".to_string())
 | 
			
		||||
        {
 | 
			
		||||
            debug!("Received rootDSE request");
 | 
			
		||||
            debug!("rootDSE request");
 | 
			
		||||
            return vec![root_dse_response(&self.base_dn_str), make_search_success()];
 | 
			
		||||
        }
 | 
			
		||||
        debug!("Received search request: {:?}", &request);
 | 
			
		||||
        self.do_search(request, user_filter).await
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[instrument(skip_all, level = "debug")]
 | 
			
		||||
    pub async fn do_search(
 | 
			
		||||
        &mut self,
 | 
			
		||||
        request: &LdapSearchRequest,
 | 
			
		||||
        user_filter: Option<UserId>,
 | 
			
		||||
    ) -> Vec<LdapOp> {
 | 
			
		||||
        let user_filter = user_filter.as_ref();
 | 
			
		||||
        let dn_parts = match parse_distinguished_name(&request.base.to_ascii_lowercase()) {
 | 
			
		||||
            Ok(dn) => dn,
 | 
			
		||||
            Err(_) => {
 | 
			
		||||
@ -626,6 +641,7 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
 | 
			
		||||
            }
 | 
			
		||||
        };
 | 
			
		||||
        let scope = get_search_scope(&self.base_dn, &dn_parts);
 | 
			
		||||
        debug!(?request.base, ?scope);
 | 
			
		||||
        let get_user_list = || async {
 | 
			
		||||
            self.get_user_list(&request.filter, &request.attrs, &request.base, &user_filter)
 | 
			
		||||
                .await
 | 
			
		||||
@ -676,14 +692,16 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
 | 
			
		||||
        results
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[instrument(skip_all, level = "debug")]
 | 
			
		||||
    async fn get_user_list(
 | 
			
		||||
        &self,
 | 
			
		||||
        filter: &LdapFilter,
 | 
			
		||||
        ldap_filter: &LdapFilter,
 | 
			
		||||
        attributes: &[String],
 | 
			
		||||
        base: &str,
 | 
			
		||||
        user_filter: &Option<&UserId>,
 | 
			
		||||
    ) -> Vec<LdapOp> {
 | 
			
		||||
        let filters = match self.convert_user_filter(filter) {
 | 
			
		||||
        debug!(?ldap_filter);
 | 
			
		||||
        let filters = match self.convert_user_filter(ldap_filter) {
 | 
			
		||||
            Ok(f) => f,
 | 
			
		||||
            Err(e) => {
 | 
			
		||||
                return vec![make_search_error(
 | 
			
		||||
@ -692,19 +710,20 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
 | 
			
		||||
                )]
 | 
			
		||||
            }
 | 
			
		||||
        };
 | 
			
		||||
        let filters = match user_filter {
 | 
			
		||||
        let parsed_filters = match user_filter {
 | 
			
		||||
            None => filters,
 | 
			
		||||
            Some(u) => {
 | 
			
		||||
                UserRequestFilter::And(vec![filters, UserRequestFilter::UserId((*u).clone())])
 | 
			
		||||
            }
 | 
			
		||||
        };
 | 
			
		||||
        debug!(?parsed_filters);
 | 
			
		||||
        let expanded_attributes = expand_attribute_wildcards(attributes, ALL_USER_ATTRIBUTE_KEYS);
 | 
			
		||||
        let need_groups = expanded_attributes
 | 
			
		||||
            .iter()
 | 
			
		||||
            .any(|s| s.to_ascii_lowercase() == "memberof");
 | 
			
		||||
        let users = match self
 | 
			
		||||
            .backend_handler
 | 
			
		||||
            .list_users(Some(filters), need_groups)
 | 
			
		||||
            .list_users(Some(parsed_filters), need_groups)
 | 
			
		||||
            .await
 | 
			
		||||
        {
 | 
			
		||||
            Ok(users) => users,
 | 
			
		||||
@ -737,14 +756,16 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
 | 
			
		||||
            })
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[instrument(skip_all, level = "debug")]
 | 
			
		||||
    async fn get_groups_list(
 | 
			
		||||
        &self,
 | 
			
		||||
        filter: &LdapFilter,
 | 
			
		||||
        ldap_filter: &LdapFilter,
 | 
			
		||||
        attributes: &[String],
 | 
			
		||||
        base: &str,
 | 
			
		||||
        user_filter: &Option<&UserId>,
 | 
			
		||||
    ) -> Vec<LdapOp> {
 | 
			
		||||
        let filter = match self.convert_group_filter(filter) {
 | 
			
		||||
        debug!(?ldap_filter);
 | 
			
		||||
        let filter = match self.convert_group_filter(ldap_filter) {
 | 
			
		||||
            Ok(f) => f,
 | 
			
		||||
            Err(e) => {
 | 
			
		||||
                return vec![make_search_error(
 | 
			
		||||
@ -753,14 +774,14 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
 | 
			
		||||
                )]
 | 
			
		||||
            }
 | 
			
		||||
        };
 | 
			
		||||
        let filter = match user_filter {
 | 
			
		||||
        let parsed_filters = match user_filter {
 | 
			
		||||
            None => filter,
 | 
			
		||||
            Some(u) => {
 | 
			
		||||
                GroupRequestFilter::And(vec![filter, GroupRequestFilter::Member((*u).clone())])
 | 
			
		||||
            }
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        let groups = match self.backend_handler.list_groups(Some(filter)).await {
 | 
			
		||||
        debug!(?parsed_filters);
 | 
			
		||||
        let groups = match self.backend_handler.list_groups(Some(parsed_filters)).await {
 | 
			
		||||
            Ok(groups) => groups,
 | 
			
		||||
            Err(e) => {
 | 
			
		||||
                return vec![make_search_error(
 | 
			
		||||
@ -805,7 +826,7 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
 | 
			
		||||
                    saslcreds: None,
 | 
			
		||||
                })]
 | 
			
		||||
            }
 | 
			
		||||
            LdapOp::SearchRequest(request) => self.do_search(&request).await,
 | 
			
		||||
            LdapOp::SearchRequest(request) => self.do_search_or_dse(&request).await,
 | 
			
		||||
            LdapOp::UnbindRequest => {
 | 
			
		||||
                self.user_info = None;
 | 
			
		||||
                // No need to notify on unbind (per rfc4511)
 | 
			
		||||
@ -1176,7 +1197,7 @@ mod tests {
 | 
			
		||||
        let request =
 | 
			
		||||
            make_user_search_request::<String>(LdapFilter::And(vec![]), vec!["1.1".to_string()]);
 | 
			
		||||
        assert_eq!(
 | 
			
		||||
            ldap_handler.do_search(&request).await,
 | 
			
		||||
            ldap_handler.do_search_or_dse(&request).await,
 | 
			
		||||
            vec![
 | 
			
		||||
                LdapOp::SearchResultEntry(LdapSearchResultEntry {
 | 
			
		||||
                    dn: "uid=test,ou=people,dc=example,dc=com".to_string(),
 | 
			
		||||
@ -1199,7 +1220,7 @@ mod tests {
 | 
			
		||||
        let request =
 | 
			
		||||
            make_user_search_request::<String>(LdapFilter::And(vec![]), vec!["1.1".to_string()]);
 | 
			
		||||
        assert_eq!(
 | 
			
		||||
            ldap_handler.do_search(&request).await,
 | 
			
		||||
            ldap_handler.do_search_or_dse(&request).await,
 | 
			
		||||
            vec![make_search_success()],
 | 
			
		||||
        );
 | 
			
		||||
    }
 | 
			
		||||
@ -1226,7 +1247,7 @@ mod tests {
 | 
			
		||||
            vec!["memberOf".to_string()],
 | 
			
		||||
        );
 | 
			
		||||
        assert_eq!(
 | 
			
		||||
            ldap_handler.do_search(&request).await,
 | 
			
		||||
            ldap_handler.do_search_or_dse(&request).await,
 | 
			
		||||
            vec![
 | 
			
		||||
                LdapOp::SearchResultEntry(LdapSearchResultEntry {
 | 
			
		||||
                    dn: "uid=bob,ou=people,dc=example,dc=com".to_string(),
 | 
			
		||||
@ -1266,7 +1287,7 @@ mod tests {
 | 
			
		||||
            attrs: vec!["1.1".to_string()],
 | 
			
		||||
        };
 | 
			
		||||
        assert_eq!(
 | 
			
		||||
            ldap_handler.do_search(&request).await,
 | 
			
		||||
            ldap_handler.do_search_or_dse(&request).await,
 | 
			
		||||
            vec![make_search_success()],
 | 
			
		||||
        );
 | 
			
		||||
    }
 | 
			
		||||
@ -1397,7 +1418,7 @@ mod tests {
 | 
			
		||||
            ],
 | 
			
		||||
        );
 | 
			
		||||
        assert_eq!(
 | 
			
		||||
            ldap_handler.do_search(&request).await,
 | 
			
		||||
            ldap_handler.do_search_or_dse(&request).await,
 | 
			
		||||
            vec![
 | 
			
		||||
                LdapOp::SearchResultEntry(LdapSearchResultEntry {
 | 
			
		||||
                    dn: "uid=bob_1,ou=people,dc=example,dc=com".to_string(),
 | 
			
		||||
@ -1515,7 +1536,7 @@ mod tests {
 | 
			
		||||
            vec!["objectClass", "dn", "cn", "uniqueMember"],
 | 
			
		||||
        );
 | 
			
		||||
        assert_eq!(
 | 
			
		||||
            ldap_handler.do_search(&request).await,
 | 
			
		||||
            ldap_handler.do_search_or_dse(&request).await,
 | 
			
		||||
            vec![
 | 
			
		||||
                LdapOp::SearchResultEntry(LdapSearchResultEntry {
 | 
			
		||||
                    dn: "cn=group_1,ou=groups,dc=example,dc=com".to_string(),
 | 
			
		||||
@ -1612,7 +1633,7 @@ mod tests {
 | 
			
		||||
            vec!["1.1"],
 | 
			
		||||
        );
 | 
			
		||||
        assert_eq!(
 | 
			
		||||
            ldap_handler.do_search(&request).await,
 | 
			
		||||
            ldap_handler.do_search_or_dse(&request).await,
 | 
			
		||||
            vec![
 | 
			
		||||
                LdapOp::SearchResultEntry(LdapSearchResultEntry {
 | 
			
		||||
                    dn: "cn=group_1,ou=groups,dc=example,dc=com".to_string(),
 | 
			
		||||
@ -1650,7 +1671,7 @@ mod tests {
 | 
			
		||||
            vec!["cn"],
 | 
			
		||||
        );
 | 
			
		||||
        assert_eq!(
 | 
			
		||||
            ldap_handler.do_search(&request).await,
 | 
			
		||||
            ldap_handler.do_search_or_dse(&request).await,
 | 
			
		||||
            vec![
 | 
			
		||||
                LdapOp::SearchResultEntry(LdapSearchResultEntry {
 | 
			
		||||
                    dn: "cn=group_1,ou=groups,dc=example,dc=com".to_string(),
 | 
			
		||||
@ -1687,7 +1708,7 @@ mod tests {
 | 
			
		||||
            attrs: vec!["1.1".to_string()],
 | 
			
		||||
        };
 | 
			
		||||
        assert_eq!(
 | 
			
		||||
            ldap_handler.do_search(&request).await,
 | 
			
		||||
            ldap_handler.do_search_or_dse(&request).await,
 | 
			
		||||
            vec![make_search_success()],
 | 
			
		||||
        );
 | 
			
		||||
    }
 | 
			
		||||
@ -1717,7 +1738,7 @@ mod tests {
 | 
			
		||||
            vec!["cn"],
 | 
			
		||||
        );
 | 
			
		||||
        assert_eq!(
 | 
			
		||||
            ldap_handler.do_search(&request).await,
 | 
			
		||||
            ldap_handler.do_search_or_dse(&request).await,
 | 
			
		||||
            vec![make_search_error(
 | 
			
		||||
                LdapResultCode::Other,
 | 
			
		||||
                r#"Error while listing groups "ou=groups,dc=example,dc=com": Internal error: `Error getting groups`"#.to_string()
 | 
			
		||||
@ -1737,7 +1758,7 @@ mod tests {
 | 
			
		||||
            vec!["cn"],
 | 
			
		||||
        );
 | 
			
		||||
        assert_eq!(
 | 
			
		||||
            ldap_handler.do_search(&request).await,
 | 
			
		||||
            ldap_handler.do_search_or_dse(&request).await,
 | 
			
		||||
            vec![make_search_error(
 | 
			
		||||
                LdapResultCode::UnwillingToPerform,
 | 
			
		||||
                r#"Unsupported group filter: Unsupported group filter: Substring("whatever", LdapSubstringFilter { initial: None, any: [], final_: None })"#
 | 
			
		||||
@ -1785,7 +1806,7 @@ mod tests {
 | 
			
		||||
            vec!["objectClass"],
 | 
			
		||||
        );
 | 
			
		||||
        assert_eq!(
 | 
			
		||||
            ldap_handler.do_search(&request).await,
 | 
			
		||||
            ldap_handler.do_search_or_dse(&request).await,
 | 
			
		||||
            vec![make_search_success()]
 | 
			
		||||
        );
 | 
			
		||||
    }
 | 
			
		||||
@ -1809,7 +1830,7 @@ mod tests {
 | 
			
		||||
            vec!["objectClass"],
 | 
			
		||||
        );
 | 
			
		||||
        assert_eq!(
 | 
			
		||||
            ldap_handler.do_search(&request).await,
 | 
			
		||||
            ldap_handler.do_search_or_dse(&request).await,
 | 
			
		||||
            vec![make_search_success()]
 | 
			
		||||
        );
 | 
			
		||||
        let request = make_user_search_request(
 | 
			
		||||
@ -1817,7 +1838,7 @@ mod tests {
 | 
			
		||||
            vec!["objectClass"],
 | 
			
		||||
        );
 | 
			
		||||
        assert_eq!(
 | 
			
		||||
            ldap_handler.do_search(&request).await,
 | 
			
		||||
            ldap_handler.do_search_or_dse(&request).await,
 | 
			
		||||
            vec![make_search_error(
 | 
			
		||||
                LdapResultCode::UnwillingToPerform,
 | 
			
		||||
                "Unsupported user filter: while parsing a group ID: Missing DN value".to_string()
 | 
			
		||||
@ -1831,7 +1852,7 @@ mod tests {
 | 
			
		||||
            vec!["objectClass"],
 | 
			
		||||
        );
 | 
			
		||||
        assert_eq!(
 | 
			
		||||
            ldap_handler.do_search(&request).await,
 | 
			
		||||
            ldap_handler.do_search_or_dse(&request).await,
 | 
			
		||||
            vec![make_search_error(
 | 
			
		||||
                LdapResultCode::UnwillingToPerform,
 | 
			
		||||
                "Unsupported user filter: Unexpected group DN format. Got \"cn=mygroup,dc=example,dc=com\", expected: \"cn=groupname,ou=groups,dc=example,dc=com\"".to_string()
 | 
			
		||||
@ -1869,7 +1890,7 @@ mod tests {
 | 
			
		||||
            vec!["objectclass"],
 | 
			
		||||
        );
 | 
			
		||||
        assert_eq!(
 | 
			
		||||
            ldap_handler.do_search(&request).await,
 | 
			
		||||
            ldap_handler.do_search_or_dse(&request).await,
 | 
			
		||||
            vec![
 | 
			
		||||
                LdapOp::SearchResultEntry(LdapSearchResultEntry {
 | 
			
		||||
                    dn: "uid=bob_1,ou=people,dc=example,dc=com".to_string(),
 | 
			
		||||
@ -1921,7 +1942,7 @@ mod tests {
 | 
			
		||||
            vec!["objectClass", "dn", "cn"],
 | 
			
		||||
        );
 | 
			
		||||
        assert_eq!(
 | 
			
		||||
            ldap_handler.do_search(&request).await,
 | 
			
		||||
            ldap_handler.do_search_or_dse(&request).await,
 | 
			
		||||
            vec![
 | 
			
		||||
                LdapOp::SearchResultEntry(LdapSearchResultEntry {
 | 
			
		||||
                    dn: "uid=bob_1,ou=people,dc=example,dc=com".to_string(),
 | 
			
		||||
@ -2086,12 +2107,18 @@ mod tests {
 | 
			
		||||
            make_search_success(),
 | 
			
		||||
        ];
 | 
			
		||||
 | 
			
		||||
        assert_eq!(ldap_handler.do_search(&request).await, expected_result);
 | 
			
		||||
        assert_eq!(
 | 
			
		||||
            ldap_handler.do_search_or_dse(&request).await,
 | 
			
		||||
            expected_result
 | 
			
		||||
        );
 | 
			
		||||
 | 
			
		||||
        let request2 =
 | 
			
		||||
            make_search_request("dc=example,dc=com", LdapFilter::And(vec![]), vec!["*", "*"]);
 | 
			
		||||
 | 
			
		||||
        assert_eq!(ldap_handler.do_search(&request2).await, expected_result);
 | 
			
		||||
        assert_eq!(
 | 
			
		||||
            ldap_handler.do_search_or_dse(&request2).await,
 | 
			
		||||
            expected_result
 | 
			
		||||
        );
 | 
			
		||||
 | 
			
		||||
        let request3 = make_search_request(
 | 
			
		||||
            "dc=example,dc=com",
 | 
			
		||||
@ -2099,12 +2126,18 @@ mod tests {
 | 
			
		||||
            vec!["*", "+", "+"],
 | 
			
		||||
        );
 | 
			
		||||
 | 
			
		||||
        assert_eq!(ldap_handler.do_search(&request3).await, expected_result);
 | 
			
		||||
        assert_eq!(
 | 
			
		||||
            ldap_handler.do_search_or_dse(&request3).await,
 | 
			
		||||
            expected_result
 | 
			
		||||
        );
 | 
			
		||||
 | 
			
		||||
        let request4 =
 | 
			
		||||
            make_search_request("dc=example,dc=com", LdapFilter::And(vec![]), vec![""; 0]);
 | 
			
		||||
 | 
			
		||||
        assert_eq!(ldap_handler.do_search(&request4).await, expected_result);
 | 
			
		||||
        assert_eq!(
 | 
			
		||||
            ldap_handler.do_search_or_dse(&request4).await,
 | 
			
		||||
            expected_result
 | 
			
		||||
        );
 | 
			
		||||
 | 
			
		||||
        let request5 = make_search_request(
 | 
			
		||||
            "dc=example,dc=com",
 | 
			
		||||
@ -2112,7 +2145,10 @@ mod tests {
 | 
			
		||||
            vec!["objectclass", "dn", "uid", "*"],
 | 
			
		||||
        );
 | 
			
		||||
 | 
			
		||||
        assert_eq!(ldap_handler.do_search(&request5).await, expected_result);
 | 
			
		||||
        assert_eq!(
 | 
			
		||||
            ldap_handler.do_search_or_dse(&request5).await,
 | 
			
		||||
            expected_result
 | 
			
		||||
        );
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[tokio::test]
 | 
			
		||||
@ -2124,7 +2160,7 @@ mod tests {
 | 
			
		||||
            vec!["objectClass"],
 | 
			
		||||
        );
 | 
			
		||||
        assert_eq!(
 | 
			
		||||
            ldap_handler.do_search(&request).await,
 | 
			
		||||
            ldap_handler.do_search_or_dse(&request).await,
 | 
			
		||||
            vec![make_search_success()]
 | 
			
		||||
        );
 | 
			
		||||
    }
 | 
			
		||||
@ -2140,7 +2176,7 @@ mod tests {
 | 
			
		||||
            vec!["objectClass"],
 | 
			
		||||
        );
 | 
			
		||||
        assert_eq!(
 | 
			
		||||
            ldap_handler.do_search(&request).await,
 | 
			
		||||
            ldap_handler.do_search_or_dse(&request).await,
 | 
			
		||||
            vec![make_search_error(
 | 
			
		||||
                LdapResultCode::UnwillingToPerform,
 | 
			
		||||
                "Unsupported user filter: Unsupported user filter: Substring(\"uid\", LdapSubstringFilter { initial: None, any: [], final_: None })".to_string()
 | 
			
		||||
@ -2272,7 +2308,7 @@ mod tests {
 | 
			
		||||
            attrs: vec!["supportedExtension".to_string()],
 | 
			
		||||
        };
 | 
			
		||||
        assert_eq!(
 | 
			
		||||
            ldap_handler.do_search(&request).await,
 | 
			
		||||
            ldap_handler.do_search_or_dse(&request).await,
 | 
			
		||||
            vec![
 | 
			
		||||
                root_dse_response("dc=example,dc=com"),
 | 
			
		||||
                make_search_success()
 | 
			
		||||
 | 
			
		||||
@ -10,12 +10,13 @@ use actix_server::ServerBuilder;
 | 
			
		||||
use actix_service::{fn_service, ServiceFactoryExt};
 | 
			
		||||
use anyhow::{Context, Result};
 | 
			
		||||
use ldap3_server::{proto::LdapMsg, LdapCodec};
 | 
			
		||||
use log::*;
 | 
			
		||||
use native_tls::{Identity, TlsAcceptor};
 | 
			
		||||
use tokio_native_tls::TlsAcceptor as NativeTlsAcceptor;
 | 
			
		||||
use tokio_util::codec::{FramedRead, FramedWrite};
 | 
			
		||||
use tracing::{debug, error, info, instrument};
 | 
			
		||||
 | 
			
		||||
async fn handle_incoming_message<Backend, Writer>(
 | 
			
		||||
#[instrument(skip_all, level = "info", name = "LDAP request")]
 | 
			
		||||
async fn handle_ldap_message<Backend, Writer>(
 | 
			
		||||
    msg: Result<LdapMsg, std::io::Error>,
 | 
			
		||||
    resp: &mut Writer,
 | 
			
		||||
    session: &mut LdapHandler<Backend>,
 | 
			
		||||
@ -27,18 +28,18 @@ where
 | 
			
		||||
{
 | 
			
		||||
    use futures_util::SinkExt;
 | 
			
		||||
    let msg = msg.context("while receiving LDAP op")?;
 | 
			
		||||
    debug!("Received LDAP message: {:?}", &msg);
 | 
			
		||||
    debug!(?msg);
 | 
			
		||||
    match session.handle_ldap_message(msg.op).await {
 | 
			
		||||
        None => return Ok(false),
 | 
			
		||||
        Some(result) => {
 | 
			
		||||
            if result.is_empty() {
 | 
			
		||||
                debug!("No response");
 | 
			
		||||
            }
 | 
			
		||||
            for result_op in result.into_iter() {
 | 
			
		||||
                debug!("Replying with LDAP op: {:?}", &result_op);
 | 
			
		||||
            for response in result.into_iter() {
 | 
			
		||||
                debug!(?response);
 | 
			
		||||
                resp.send(LdapMsg {
 | 
			
		||||
                    msgid: msg.msgid,
 | 
			
		||||
                    op: result_op,
 | 
			
		||||
                    op: response,
 | 
			
		||||
                    ctrl: vec![],
 | 
			
		||||
                })
 | 
			
		||||
                .await
 | 
			
		||||
@ -66,6 +67,7 @@ fn get_file_as_byte_vec(filename: &str) -> Result<Vec<u8>> {
 | 
			
		||||
    .context(format!("while reading file {}", filename))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[instrument(skip_all, level = "info", name = "LDAP session")]
 | 
			
		||||
async fn handle_ldap_stream<Stream, Backend>(
 | 
			
		||||
    stream: Stream,
 | 
			
		||||
    backend_handler: Backend,
 | 
			
		||||
@ -91,7 +93,7 @@ where
 | 
			
		||||
    );
 | 
			
		||||
 | 
			
		||||
    while let Some(msg) = requests.next().await {
 | 
			
		||||
        if !handle_incoming_message(msg, &mut resp, &mut session)
 | 
			
		||||
        if !handle_ldap_message(msg, &mut resp, &mut session)
 | 
			
		||||
            .await
 | 
			
		||||
            .context("while handling incoming messages")?
 | 
			
		||||
        {
 | 
			
		||||
@ -145,6 +147,7 @@ where
 | 
			
		||||
        .map_err(|err: anyhow::Error| error!("[LDAP] Service Error: {:#}", err))
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    info!("Starting the LDAP server on port {}", config.ldap_port);
 | 
			
		||||
    let server_builder = server_builder
 | 
			
		||||
        .bind("ldap", ("0.0.0.0", config.ldap_port), binder)
 | 
			
		||||
        .with_context(|| format!("while binding to the port {}", config.ldap_port));
 | 
			
		||||
@ -176,6 +179,10 @@ where
 | 
			
		||||
            .map_err(|err: anyhow::Error| error!("[LDAPS] Service Error: {:#}", err))
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        info!(
 | 
			
		||||
            "Starting the LDAPS server on port {}",
 | 
			
		||||
            config.ldaps_options.port
 | 
			
		||||
        );
 | 
			
		||||
        server_builder.and_then(|s| {
 | 
			
		||||
            s.bind("ldaps", ("0.0.0.0", config.ldaps_options.port), tls_binder)
 | 
			
		||||
                .with_context(|| format!("while binding to the port {}", config.ldaps_options.port))
 | 
			
		||||
 | 
			
		||||
@ -1,30 +1,50 @@
 | 
			
		||||
use crate::infra::configuration::Configuration;
 | 
			
		||||
use tracing_subscriber::prelude::*;
 | 
			
		||||
use actix_web::{
 | 
			
		||||
    dev::{ServiceRequest, ServiceResponse},
 | 
			
		||||
    Error,
 | 
			
		||||
};
 | 
			
		||||
use tracing::{error, info, Span};
 | 
			
		||||
use tracing_actix_web::{root_span, RootSpanBuilder};
 | 
			
		||||
use tracing_subscriber::{filter::EnvFilter, layer::SubscriberExt, util::SubscriberInitExt};
 | 
			
		||||
 | 
			
		||||
/// We will define a custom root span builder to capture additional fields, specific
 | 
			
		||||
/// to our application, on top of the ones provided by `DefaultRootSpanBuilder` out of the box.
 | 
			
		||||
pub struct CustomRootSpanBuilder;
 | 
			
		||||
 | 
			
		||||
impl RootSpanBuilder for CustomRootSpanBuilder {
 | 
			
		||||
    fn on_request_start(request: &ServiceRequest) -> Span {
 | 
			
		||||
        let span = root_span!(request);
 | 
			
		||||
        span.in_scope(|| {
 | 
			
		||||
            info!(uri = %request.uri());
 | 
			
		||||
        });
 | 
			
		||||
        span
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn on_request_end<B>(_: Span, outcome: &Result<ServiceResponse<B>, Error>) {
 | 
			
		||||
        match &outcome {
 | 
			
		||||
            Ok(response) => {
 | 
			
		||||
                if let Some(error) = response.response().error() {
 | 
			
		||||
                    error!(?error);
 | 
			
		||||
                } else {
 | 
			
		||||
                    info!(status_code = &response.response().status().as_u16());
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            Err(error) => error!(?error),
 | 
			
		||||
        };
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub fn init(config: &Configuration) -> anyhow::Result<()> {
 | 
			
		||||
    let max_log_level = log_level_from_config(config);
 | 
			
		||||
    let sqlx_max_log_level = sqlx_log_level_from_config(config);
 | 
			
		||||
    let filter = tracing_subscriber::filter::Targets::new()
 | 
			
		||||
        .with_target("lldap", max_log_level)
 | 
			
		||||
        .with_target("sqlx", sqlx_max_log_level);
 | 
			
		||||
    let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| {
 | 
			
		||||
        EnvFilter::new(if config.verbose {
 | 
			
		||||
            "sqlx=warn,debug"
 | 
			
		||||
        } else {
 | 
			
		||||
            "sqlx=warn,info"
 | 
			
		||||
        })
 | 
			
		||||
    });
 | 
			
		||||
    tracing_subscriber::registry()
 | 
			
		||||
        .with(tracing_subscriber::fmt::layer().with_filter(filter))
 | 
			
		||||
        .with(env_filter)
 | 
			
		||||
        .with(tracing_forest::ForestLayer::default())
 | 
			
		||||
        .init();
 | 
			
		||||
    Ok(())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn log_level_from_config(config: &Configuration) -> tracing::Level {
 | 
			
		||||
    if config.verbose {
 | 
			
		||||
        tracing::Level::DEBUG
 | 
			
		||||
    } else {
 | 
			
		||||
        tracing::Level::INFO
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn sqlx_log_level_from_config(config: &Configuration) -> tracing::Level {
 | 
			
		||||
    if config.verbose {
 | 
			
		||||
        tracing::Level::INFO
 | 
			
		||||
    } else {
 | 
			
		||||
        tracing::Level::WARN
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -4,7 +4,7 @@ use lettre::{
 | 
			
		||||
    message::Mailbox, transport::smtp::authentication::Credentials, Message, SmtpTransport,
 | 
			
		||||
    Transport,
 | 
			
		||||
};
 | 
			
		||||
use log::debug;
 | 
			
		||||
use tracing::debug;
 | 
			
		||||
 | 
			
		||||
fn send_email(to: Mailbox, subject: &str, body: String, options: &MailOptions) -> Result<()> {
 | 
			
		||||
    let from = options
 | 
			
		||||
 | 
			
		||||
@ -6,6 +6,7 @@ use sea_query::{Expr, Iden, Query, SimpleExpr};
 | 
			
		||||
use sea_query_binder::SqlxBinder;
 | 
			
		||||
use sqlx::{query_as_with, query_with, Row};
 | 
			
		||||
use std::collections::HashSet;
 | 
			
		||||
use tracing::{debug, instrument};
 | 
			
		||||
 | 
			
		||||
fn gen_random_string(len: usize) -> String {
 | 
			
		||||
    use rand::{distributions::Alphanumeric, rngs::SmallRng, Rng, SeedableRng};
 | 
			
		||||
@ -19,12 +20,14 @@ fn gen_random_string(len: usize) -> String {
 | 
			
		||||
 | 
			
		||||
#[async_trait]
 | 
			
		||||
impl TcpBackendHandler for SqlBackendHandler {
 | 
			
		||||
    #[instrument(skip_all, level = "debug")]
 | 
			
		||||
    async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>> {
 | 
			
		||||
        let (query, values) = Query::select()
 | 
			
		||||
            .column(JwtStorage::JwtHash)
 | 
			
		||||
            .from(JwtStorage::Table)
 | 
			
		||||
            .build_sqlx(DbQueryBuilder {});
 | 
			
		||||
 | 
			
		||||
        debug!(%query);
 | 
			
		||||
        query_with(&query, values)
 | 
			
		||||
            .map(|row: DbRow| row.get::<i64, _>(&*JwtStorage::JwtHash.to_string()) as u64)
 | 
			
		||||
            .fetch(&self.sql_pool)
 | 
			
		||||
@ -35,7 +38,9 @@ impl TcpBackendHandler for SqlBackendHandler {
 | 
			
		||||
            .map_err(|e| anyhow::anyhow!(e))
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[instrument(skip_all, level = "debug")]
 | 
			
		||||
    async fn create_refresh_token(&self, user: &UserId) -> Result<(String, chrono::Duration)> {
 | 
			
		||||
        debug!(?user);
 | 
			
		||||
        use std::collections::hash_map::DefaultHasher;
 | 
			
		||||
        use std::hash::{Hash, Hasher};
 | 
			
		||||
        // TODO: Initialize the rng only once. Maybe Arc<Cell>?
 | 
			
		||||
@ -59,23 +64,30 @@ impl TcpBackendHandler for SqlBackendHandler {
 | 
			
		||||
                (chrono::Utc::now() + duration).naive_utc().into(),
 | 
			
		||||
            ])
 | 
			
		||||
            .build_sqlx(DbQueryBuilder {});
 | 
			
		||||
        debug!(%query);
 | 
			
		||||
        query_with(&query, values).execute(&self.sql_pool).await?;
 | 
			
		||||
        Ok((refresh_token, duration))
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[instrument(skip_all, level = "debug")]
 | 
			
		||||
    async fn check_token(&self, refresh_token_hash: u64, user: &UserId) -> Result<bool> {
 | 
			
		||||
        debug!(?user);
 | 
			
		||||
        let (query, values) = Query::select()
 | 
			
		||||
            .expr(SimpleExpr::Value(1.into()))
 | 
			
		||||
            .from(JwtRefreshStorage::Table)
 | 
			
		||||
            .and_where(Expr::col(JwtRefreshStorage::RefreshTokenHash).eq(refresh_token_hash as i64))
 | 
			
		||||
            .and_where(Expr::col(JwtRefreshStorage::UserId).eq(user))
 | 
			
		||||
            .build_sqlx(DbQueryBuilder {});
 | 
			
		||||
        debug!(%query);
 | 
			
		||||
        Ok(query_with(&query, values)
 | 
			
		||||
            .fetch_optional(&self.sql_pool)
 | 
			
		||||
            .await?
 | 
			
		||||
            .is_some())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[instrument(skip_all, level = "debug")]
 | 
			
		||||
    async fn blacklist_jwts(&self, user: &UserId) -> Result<HashSet<u64>> {
 | 
			
		||||
        debug!(?user);
 | 
			
		||||
        use sqlx::Result;
 | 
			
		||||
        let (query, values) = Query::select()
 | 
			
		||||
            .column(JwtStorage::JwtHash)
 | 
			
		||||
@ -95,31 +107,39 @@ impl TcpBackendHandler for SqlBackendHandler {
 | 
			
		||||
            .values(vec![(JwtStorage::Blacklisted, true.into())])
 | 
			
		||||
            .and_where(Expr::col(JwtStorage::UserId).eq(user))
 | 
			
		||||
            .build_sqlx(DbQueryBuilder {});
 | 
			
		||||
        debug!(%query);
 | 
			
		||||
        query_with(&query, values).execute(&self.sql_pool).await?;
 | 
			
		||||
        Ok(result?)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[instrument(skip_all, level = "debug")]
 | 
			
		||||
    async fn delete_refresh_token(&self, refresh_token_hash: u64) -> Result<()> {
 | 
			
		||||
        let (query, values) = Query::delete()
 | 
			
		||||
            .from_table(JwtRefreshStorage::Table)
 | 
			
		||||
            .and_where(Expr::col(JwtRefreshStorage::RefreshTokenHash).eq(refresh_token_hash))
 | 
			
		||||
            .and_where(Expr::col(JwtRefreshStorage::RefreshTokenHash).eq(refresh_token_hash as i64))
 | 
			
		||||
            .build_sqlx(DbQueryBuilder {});
 | 
			
		||||
        debug!(%query);
 | 
			
		||||
        query_with(&query, values).execute(&self.sql_pool).await?;
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[instrument(skip_all, level = "debug")]
 | 
			
		||||
    async fn start_password_reset(&self, user: &UserId) -> Result<Option<String>> {
 | 
			
		||||
        debug!(?user);
 | 
			
		||||
        let (query, values) = Query::select()
 | 
			
		||||
            .column(Users::UserId)
 | 
			
		||||
            .from(Users::Table)
 | 
			
		||||
            .and_where(Expr::col(Users::UserId).eq(user))
 | 
			
		||||
            .build_sqlx(DbQueryBuilder {});
 | 
			
		||||
 | 
			
		||||
        debug!(%query);
 | 
			
		||||
        // Check that the user exists.
 | 
			
		||||
        if query_with(&query, values)
 | 
			
		||||
            .fetch_one(&self.sql_pool)
 | 
			
		||||
            .await
 | 
			
		||||
            .is_err()
 | 
			
		||||
        {
 | 
			
		||||
            debug!("User not found");
 | 
			
		||||
            return Ok(None);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
@ -139,10 +159,12 @@ impl TcpBackendHandler for SqlBackendHandler {
 | 
			
		||||
                (chrono::Utc::now() + duration).naive_utc().into(),
 | 
			
		||||
            ])
 | 
			
		||||
            .build_sqlx(DbQueryBuilder {});
 | 
			
		||||
        debug!(%query);
 | 
			
		||||
        query_with(&query, values).execute(&self.sql_pool).await?;
 | 
			
		||||
        Ok(Some(token))
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[instrument(skip_all, level = "debug", ret)]
 | 
			
		||||
    async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result<UserId> {
 | 
			
		||||
        let (query, values) = Query::select()
 | 
			
		||||
            .column(PasswordResetTokens::UserId)
 | 
			
		||||
@ -152,6 +174,7 @@ impl TcpBackendHandler for SqlBackendHandler {
 | 
			
		||||
                Expr::col(PasswordResetTokens::ExpiryDate).gt(chrono::Utc::now().naive_utc()),
 | 
			
		||||
            )
 | 
			
		||||
            .build_sqlx(DbQueryBuilder {});
 | 
			
		||||
        debug!(%query);
 | 
			
		||||
 | 
			
		||||
        let (user_id,) = query_as_with(&query, values)
 | 
			
		||||
            .fetch_one(&self.sql_pool)
 | 
			
		||||
@ -159,11 +182,13 @@ impl TcpBackendHandler for SqlBackendHandler {
 | 
			
		||||
        Ok(user_id)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[instrument(skip_all, level = "debug")]
 | 
			
		||||
    async fn delete_password_reset_token(&self, token: &str) -> Result<()> {
 | 
			
		||||
        let (query, values) = Query::delete()
 | 
			
		||||
            .from_table(PasswordResetTokens::Table)
 | 
			
		||||
            .and_where(Expr::col(PasswordResetTokens::Token).eq(token))
 | 
			
		||||
            .build_sqlx(DbQueryBuilder {});
 | 
			
		||||
        debug!(%query);
 | 
			
		||||
        query_with(&query, values).execute(&self.sql_pool).await?;
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -7,6 +7,7 @@ use crate::{
 | 
			
		||||
    infra::{
 | 
			
		||||
        auth_service,
 | 
			
		||||
        configuration::{Configuration, MailOptions},
 | 
			
		||||
        logging::CustomRootSpanBuilder,
 | 
			
		||||
        tcp_backend_handler::*,
 | 
			
		||||
    },
 | 
			
		||||
};
 | 
			
		||||
@ -21,6 +22,7 @@ use sha2::Sha512;
 | 
			
		||||
use std::collections::HashSet;
 | 
			
		||||
use std::path::PathBuf;
 | 
			
		||||
use std::sync::RwLock;
 | 
			
		||||
use tracing::info;
 | 
			
		||||
 | 
			
		||||
async fn index() -> actix_web::Result<NamedFile> {
 | 
			
		||||
    let mut path = PathBuf::new();
 | 
			
		||||
@ -105,6 +107,7 @@ where
 | 
			
		||||
        .context("while getting the jwt blacklist")?;
 | 
			
		||||
    let server_url = config.http_url.clone();
 | 
			
		||||
    let mail_options = config.smtp_options.clone();
 | 
			
		||||
    info!("Starting the API/web server on port {}", config.http_port);
 | 
			
		||||
    server_builder
 | 
			
		||||
        .bind("http", ("0.0.0.0", config.http_port), move || {
 | 
			
		||||
            let backend_handler = backend_handler.clone();
 | 
			
		||||
@ -114,16 +117,18 @@ where
 | 
			
		||||
            let mail_options = mail_options.clone();
 | 
			
		||||
            HttpServiceBuilder::new()
 | 
			
		||||
                .finish(map_config(
 | 
			
		||||
                    App::new().configure(move |cfg| {
 | 
			
		||||
                        http_config(
 | 
			
		||||
                            cfg,
 | 
			
		||||
                            backend_handler,
 | 
			
		||||
                            jwt_secret,
 | 
			
		||||
                            jwt_blacklist,
 | 
			
		||||
                            server_url,
 | 
			
		||||
                            mail_options,
 | 
			
		||||
                        )
 | 
			
		||||
                    }),
 | 
			
		||||
                    App::new()
 | 
			
		||||
                        .wrap(tracing_actix_web::TracingLogger::<CustomRootSpanBuilder>::new())
 | 
			
		||||
                        .configure(move |cfg| {
 | 
			
		||||
                            http_config(
 | 
			
		||||
                                cfg,
 | 
			
		||||
                                backend_handler,
 | 
			
		||||
                                jwt_secret,
 | 
			
		||||
                                jwt_blacklist,
 | 
			
		||||
                                server_url,
 | 
			
		||||
                                mail_options,
 | 
			
		||||
                            )
 | 
			
		||||
                        }),
 | 
			
		||||
                    |_| AppConfig::default(),
 | 
			
		||||
                ))
 | 
			
		||||
                .tcp()
 | 
			
		||||
 | 
			
		||||
@ -12,9 +12,10 @@ use crate::{
 | 
			
		||||
    infra::{cli::*, configuration::Configuration, db_cleaner::Scheduler, mail},
 | 
			
		||||
};
 | 
			
		||||
use actix::Actor;
 | 
			
		||||
use actix_server::ServerBuilder;
 | 
			
		||||
use anyhow::{anyhow, Context, Result};
 | 
			
		||||
use futures_util::TryFutureExt;
 | 
			
		||||
use log::*;
 | 
			
		||||
use tracing::*;
 | 
			
		||||
 | 
			
		||||
mod domain;
 | 
			
		||||
mod infra;
 | 
			
		||||
@ -45,7 +46,10 @@ async fn create_admin_user(handler: &SqlBackendHandler, config: &Configuration)
 | 
			
		||||
        .context("Error adding admin user to group")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn run_server(config: Configuration) -> Result<()> {
 | 
			
		||||
#[instrument(skip_all)]
 | 
			
		||||
async fn set_up_server(config: Configuration) -> Result<ServerBuilder> {
 | 
			
		||||
    info!("Starting LLDAP....");
 | 
			
		||||
 | 
			
		||||
    let sql_pool = PoolOptions::new()
 | 
			
		||||
        .max_connections(5)
 | 
			
		||||
        .connect(&config.database_url)
 | 
			
		||||
@ -89,7 +93,12 @@ async fn run_server(config: Configuration) -> Result<()> {
 | 
			
		||||
    // Run every hour.
 | 
			
		||||
    let scheduler = Scheduler::new("0 0 * * * * *", sql_pool);
 | 
			
		||||
    scheduler.start();
 | 
			
		||||
    server_builder
 | 
			
		||||
    Ok(server_builder)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn run_server(config: Configuration) -> Result<()> {
 | 
			
		||||
    set_up_server(config)
 | 
			
		||||
        .await?
 | 
			
		||||
        .workers(1)
 | 
			
		||||
        .run()
 | 
			
		||||
        .await
 | 
			
		||||
@ -103,8 +112,6 @@ fn run_server_command(opts: RunOpts) -> Result<()> {
 | 
			
		||||
    let config = infra::configuration::init(opts)?;
 | 
			
		||||
    infra::logging::init(&config)?;
 | 
			
		||||
 | 
			
		||||
    info!("Starting LLDAP....");
 | 
			
		||||
 | 
			
		||||
    actix::run(
 | 
			
		||||
        run_server(config).unwrap_or_else(|e| error!("Could not bring up the servers: {:#}", e)),
 | 
			
		||||
    )?;
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user