diff --git a/Cargo.toml b/Cargo.toml index ab08ab9..be113f8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,7 +28,7 @@ tracing-actix-web = "0.3.0-beta.2" tracing-log = "*" tracing-subscriber = "*" async-trait = "0.1.48" -sea-query = { version = "0.9.2", features = [ "with-chrono" ] } +sea-query = { version = "0.9.4", features = [ "with-chrono" ] } [dependencies.figment] features = ["toml", "env"] diff --git a/src/domain/handler.rs b/src/domain/handler.rs index 4beb062..c204888 100644 --- a/src/domain/handler.rs +++ b/src/domain/handler.rs @@ -4,6 +4,7 @@ use crate::infra::configuration::Configuration; use anyhow::{bail, Result}; use async_trait::async_trait; use futures_util::StreamExt; +use futures_util::TryStreamExt; use log::*; use sea_query::{Expr, Order, Query, SimpleExpr, SqliteQueryBuilder}; use sqlx::Row; @@ -40,10 +41,17 @@ pub struct User { pub creation_date: chrono::NaiveDateTime, } +#[cfg_attr(test, derive(PartialEq, Eq, Debug))] +pub struct Group { + pub display_name: String, + pub users: Vec, +} + #[async_trait] pub trait BackendHandler: Clone + Send { async fn bind(&self, request: BindRequest) -> Result<()>; async fn list_users(&self, request: ListUsersRequest) -> Result>; + async fn list_groups(&self) -> Result>; } #[derive(Debug, Clone)] @@ -141,6 +149,49 @@ impl BackendHandler for SqlBackendHandler { Ok(results.into_iter().collect::>>()?) } + + async fn list_groups(&self) -> Result> { + let query: String = Query::select() + .column(Groups::DisplayName) + .column(Memberships::UserId) + .from(Groups::Table) + .left_join( + Memberships::Table, + Expr::tbl(Groups::Table, Groups::GroupId) + .equals(Memberships::Table, Memberships::GroupId), + ) + .order_by(Groups::DisplayName, Order::Asc) + .order_by(Memberships::UserId, Order::Asc) + .to_string(SqliteQueryBuilder); + + let mut results = sqlx::query(&query).fetch(&self.sql_pool); + let mut groups = Vec::new(); + // The rows are ordered by group, user, so we need to group them into vectors. + { + let mut current_group = String::new(); + let mut current_users = Vec::new(); + while let Some(row) = results.try_next().await? { + let display_name = row.get::("display_name"); + if display_name != current_group { + if current_group != "" { + groups.push(Group { + display_name: current_group, + users: current_users, + }); + current_users = Vec::new(); + } + current_group = display_name.clone(); + } + current_users.push(row.get::("user_id")); + } + groups.push(Group { + display_name: current_group, + users: current_users, + }); + } + + Ok(groups) + } } #[cfg(test)] @@ -153,6 +204,7 @@ mockall::mock! { impl BackendHandler for TestBackendHandler { async fn bind(&self, request: BindRequest) -> Result<()>; async fn list_users(&self, request: ListUsersRequest) -> Result>; + async fn list_groups(&self) -> Result>; } } @@ -172,7 +224,6 @@ mod tests { } async fn insert_user(sql_pool: &Pool, name: &str, pass: &str) { - /* let query = Query::insert() .into_table(Users::Table) .columns(vec![ @@ -185,27 +236,34 @@ mod tests { Users::Password, ]) .values_panic(vec![ - "bob".into(), + name.into(), "bob@bob".into(), "Bob Böbberson".into(), "Bob".into(), "Böbberson".into(), chrono::NaiveDateTime::from_timestamp(0, 0).into(), - "bob00".into(), + pass.into(), ]) .to_string(SqliteQueryBuilder); - sqlx::query(&query).execute(&sql_pool).await.unwrap(); - */ - sqlx::query( - r#"INSERT INTO users - (user_id, email, display_name, first_name, last_name, creation_date, password) - VALUES (?, "em@ai.l", "Display Name", "Firstname", "Lastname", "1970-01-01 00:00:00", ?)"#, - ) - .bind(name.to_string()) - .bind(pass.to_string()) - .execute(sql_pool) - .await - .unwrap(); + sqlx::query(&query).execute(sql_pool).await.unwrap(); + } + + async fn insert_group(sql_pool: &Pool, id: u32, name: &str) { + let query = Query::insert() + .into_table(Groups::Table) + .columns(vec![Groups::GroupId, Groups::DisplayName]) + .values_panic(vec![id.into(), name.into()]) + .to_string(SqliteQueryBuilder); + sqlx::query(&query).execute(sql_pool).await.unwrap(); + } + + async fn insert_membership(sql_pool: &Pool, group_id: u32, user_id: &str) { + let query = Query::insert() + .into_table(Memberships::Table) + .columns(vec![Memberships::UserId, Memberships::GroupId]) + .values_panic(vec![user_id.into(), group_id.into()]) + .to_string(SqliteQueryBuilder); + sqlx::query(&query).execute(sql_pool).await.unwrap(); } #[tokio::test] @@ -317,4 +375,33 @@ mod tests { assert_eq!(users, vec!["John", "patrick"]); } } + + #[tokio::test] + async fn test_list_groups() { + let sql_pool = get_initialized_db().await; + insert_user(&sql_pool, "bob", "bob00").await; + insert_user(&sql_pool, "patrick", "pass").await; + insert_user(&sql_pool, "John", "Pa33w0rd!").await; + insert_group(&sql_pool, 1, "Best Group").await; + insert_group(&sql_pool, 2, "Worst Group").await; + insert_membership(&sql_pool, 1, "bob").await; + insert_membership(&sql_pool, 1, "patrick").await; + insert_membership(&sql_pool, 2, "patrick").await; + insert_membership(&sql_pool, 2, "John").await; + let config = Configuration::default(); + let handler = SqlBackendHandler::new(config, sql_pool); + assert_eq!( + handler.list_groups().await.unwrap(), + vec![ + Group { + display_name: "Best Group".to_string(), + users: vec!["bob".to_string(), "patrick".to_string()] + }, + Group { + display_name: "Worst Group".to_string(), + users: vec!["John".to_string(), "patrick".to_string()] + } + ] + ); + } } diff --git a/src/domain/sql_tables.rs b/src/domain/sql_tables.rs index 500914c..83a7e2c 100644 --- a/src/domain/sql_tables.rs +++ b/src/domain/sql_tables.rs @@ -1,4 +1,3 @@ -use chrono::NaiveDateTime; use sea_query::*; pub type Pool = sqlx::sqlite::SqlitePool; @@ -90,8 +89,7 @@ pub async fn init_table(pool: &Pool) -> sqlx::Result<()> { .col( ColumnDef::new(Memberships::UserId) .string_len(255) - .not_null() - .primary_key(), + .not_null(), ) .col(ColumnDef::new(Memberships::GroupId).integer().not_null()) .foreign_key( @@ -121,6 +119,7 @@ pub async fn init_table(pool: &Pool) -> sqlx::Result<()> { #[cfg(test)] mod tests { use super::*; + use chrono::NaiveDateTime; use sqlx::{Column, Row}; #[actix_rt::test]