Add get_user_groups handler method

This commit is contained in:
Valentin Tolmer 2021-05-12 20:41:51 +02:00
parent 5615ef8e1f
commit ccaa610b3c
2 changed files with 68 additions and 1 deletions

View File

@ -4,16 +4,19 @@ use crate::infra::configuration::Configuration;
use async_trait::async_trait; use async_trait::async_trait;
use futures_util::StreamExt; use futures_util::StreamExt;
use futures_util::TryStreamExt; use futures_util::TryStreamExt;
pub use lldap_model::*;
use log::*; use log::*;
use sea_query::{Expr, Order, Query, SimpleExpr, SqliteQueryBuilder}; use sea_query::{Expr, Order, Query, SimpleExpr, SqliteQueryBuilder};
use sqlx::Row; use sqlx::Row;
use std::collections::HashSet;
pub use lldap_model::*;
#[async_trait] #[async_trait]
pub trait BackendHandler: Clone + Send { pub trait BackendHandler: Clone + Send {
async fn bind(&self, request: BindRequest) -> Result<()>; async fn bind(&self, request: BindRequest) -> Result<()>;
async fn list_users(&self, request: ListUsersRequest) -> Result<Vec<User>>; async fn list_users(&self, request: ListUsersRequest) -> Result<Vec<User>>;
async fn list_groups(&self) -> Result<Vec<Group>>; async fn list_groups(&self) -> Result<Vec<Group>>;
async fn get_user_groups(&self, user: String) -> Result<HashSet<String>>;
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -155,6 +158,33 @@ impl BackendHandler for SqlBackendHandler {
Ok(groups) Ok(groups)
} }
async fn get_user_groups(&self, user: String) -> Result<HashSet<String>> {
let query: String = Query::select()
.column(Groups::DisplayName)
.from(Groups::Table)
.inner_join(
Memberships::Table,
Expr::tbl(Groups::Table, Groups::GroupId)
.equals(Memberships::Table, Memberships::GroupId),
)
.and_where(Expr::col(Memberships::UserId).eq(user))
.to_string(SqliteQueryBuilder);
sqlx::query(&query)
// Extract the group id from the row.
.map(|row: DbRow| row.get::<String, _>("display_name"))
.fetch(&self.sql_pool)
// Collect the vector of rows, each potentially an error.
.collect::<Vec<sqlx::Result<String>>>()
.await
.into_iter()
// Transform it into a single result (the first error if any), and group the group_ids
// into a HashSet.
.collect::<sqlx::Result<HashSet<_>>>()
// Map the sqlx::Error into a domain::Error.
.map_err(Error::DatabaseError)
}
} }
#[cfg(test)] #[cfg(test)]
@ -168,6 +198,7 @@ mockall::mock! {
async fn bind(&self, request: BindRequest) -> Result<()>; async fn bind(&self, request: BindRequest) -> Result<()>;
async fn list_users(&self, request: ListUsersRequest) -> Result<Vec<User>>; async fn list_users(&self, request: ListUsersRequest) -> Result<Vec<User>>;
async fn list_groups(&self) -> Result<Vec<Group>>; async fn list_groups(&self) -> Result<Vec<Group>>;
async fn get_user_groups(&self, user: String) -> Result<HashSet<String>>;
} }
} }
@ -369,4 +400,39 @@ mod tests {
] ]
); );
} }
#[tokio::test]
async fn test_get_user_groups() {
let sql_pool = get_initialized_db().await;
insert_user(&sql_pool, "bob", "bob00").await;
insert_user(&sql_pool, "patrick", "pass").await;
insert_user(&sql_pool, "John", "Pa33w0rd!").await;
insert_group(&sql_pool, 1, "Group1").await;
insert_group(&sql_pool, 2, "Group2").await;
insert_membership(&sql_pool, 1, "bob").await;
insert_membership(&sql_pool, 1, "patrick").await;
insert_membership(&sql_pool, 2, "patrick").await;
let config = Configuration::default();
let handler = SqlBackendHandler::new(config, sql_pool);
let mut bob_groups = HashSet::new();
bob_groups.insert("Group1".to_string());
let mut patrick_groups = HashSet::new();
patrick_groups.insert("Group1".to_string());
patrick_groups.insert("Group2".to_string());
assert_eq!(
handler.get_user_groups("bob".to_string()).await.unwrap(),
bob_groups
);
assert_eq!(
handler
.get_user_groups("patrick".to_string())
.await
.unwrap(),
patrick_groups
);
assert_eq!(
handler.get_user_groups("John".to_string()).await.unwrap(),
HashSet::new()
);
}
} }

View File

@ -2,6 +2,7 @@ use sea_query::*;
pub type Pool = sqlx::sqlite::SqlitePool; pub type Pool = sqlx::sqlite::SqlitePool;
pub type PoolOptions = sqlx::sqlite::SqlitePoolOptions; pub type PoolOptions = sqlx::sqlite::SqlitePoolOptions;
pub type DbRow = sqlx::sqlite::SqliteRow;
#[derive(Iden)] #[derive(Iden)]
pub enum Users { pub enum Users {