From 480f48f82086b18ec212dd21f394eb6e90e27d0d Mon Sep 17 00:00:00 2001 From: Valentin Tolmer Date: Thu, 16 Sep 2021 09:26:31 +0200 Subject: [PATCH] graphql: Add a method to list groups --- Cargo.lock | 10 ++++ app/queries/get_user_details.graphql | 2 +- app/src/components/user_details.rs | 2 +- schema.graphql | 16 ++--- server/Cargo.toml | 1 + server/src/domain/handler.rs | 10 +++- server/src/domain/sql_backend_handler.rs | 75 ++++++++++++++---------- server/src/domain/sql_tables.rs | 25 ++++++++ server/src/infra/auth_service.rs | 6 +- server/src/infra/graphql/query.rs | 55 +++++++++++++---- server/src/infra/tcp_backend_handler.rs | 2 +- 11 files changed, 147 insertions(+), 57 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 596c283..ec5f6b5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1509,6 +1509,15 @@ dependencies = [ "cfg-if 1.0.0", ] +[[package]] +name = "itertools" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69ddb889f9d0d08a67338271fa9b62996bc788c7796a5c18cf057420aaed5eaf" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "0.4.8" @@ -1705,6 +1714,7 @@ dependencies = [ "futures-util", "hmac 0.10.1", "http", + "itertools", "juniper", "juniper_actix", "jwt", diff --git a/app/queries/get_user_details.graphql b/app/queries/get_user_details.graphql index 57b20ef..150b80e 100644 --- a/app/queries/get_user_details.graphql +++ b/app/queries/get_user_details.graphql @@ -7,7 +7,7 @@ query GetUserDetails($id: String!) { lastName creationDate groups { - id + displayName } } } diff --git a/app/src/components/user_details.rs b/app/src/components/user_details.rs index 58fae4f..f144d41 100644 --- a/app/src/components/user_details.rs +++ b/app/src/components/user_details.rs @@ -219,7 +219,7 @@ impl Component for UserDetails { html! { - {&group.id} + {&group.display_name} } }; diff --git a/schema.graphql b/schema.graphql index 9060d1d..c167e30 100644 --- a/schema.graphql +++ b/schema.graphql @@ -11,7 +11,8 @@ type Mutation { } type Group { - id: String! + id: Int! + displayName: String! "The groups to which this user belongs." users: [User!]! } @@ -30,6 +31,13 @@ input RequestFilter { "DateTime" scalar DateTimeUtc +type Query { + apiVersion: String! + user(userId: String!): User! + users(filters: RequestFilter): [User!]! + groups: [Group!]! +} + "The details required to create a user." input CreateUserInput { id: String! @@ -39,12 +47,6 @@ input CreateUserInput { lastName: String } -type Query { - apiVersion: String! - user(userId: String!): User! - users(filters: RequestFilter): [User!]! -} - type User { id: String! email: String! diff --git a/server/Cargo.toml b/server/Cargo.toml index 1c8997e..e98cb55 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -45,6 +45,7 @@ tracing-subscriber = "*" rand = { version = "0.8", features = ["small_rng", "getrandom"] } juniper_actix = "0.4.0" juniper = "0.15.6" +itertools = "0.10.1" # TODO: update to 0.6 when out. [dependencies.opaque-ke] diff --git a/server/src/domain/handler.rs b/server/src/domain/handler.rs index 19ccb09..fbee794 100644 --- a/server/src/domain/handler.rs +++ b/server/src/domain/handler.rs @@ -31,6 +31,7 @@ impl Default for User { #[derive(PartialEq, Eq, Debug, Serialize, Deserialize)] pub struct Group { + pub id: GroupId, pub display_name: String, pub users: Vec, } @@ -74,9 +75,12 @@ pub trait LoginHandler: Clone + Send { async fn bind(&self, request: BindRequest) -> Result<()>; } -#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct GroupId(pub i32); +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct GroupIdAndName(pub GroupId, pub String); + #[async_trait] pub trait BackendHandler: Clone + Send { async fn list_users(&self, filters: Option) -> Result>; @@ -88,7 +92,7 @@ pub trait BackendHandler: Clone + Send { async fn create_group(&self, group_name: &str) -> Result; async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> Result<()>; async fn remove_user_from_group(&self, user_id: &str, group_id: GroupId) -> Result<()>; - async fn get_user_groups(&self, user: &str) -> Result>; + async fn get_user_groups(&self, user: &str) -> Result>; } #[cfg(test)] @@ -106,7 +110,7 @@ mockall::mock! { async fn update_user(&self, request: UpdateUserRequest) -> Result<()>; async fn delete_user(&self, user_id: &str) -> Result<()>; async fn create_group(&self, group_name: &str) -> Result; - async fn get_user_groups(&self, user: &str) -> Result>; + async fn get_user_groups(&self, user: &str) -> Result>; async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> Result<()>; async fn remove_user_from_group(&self, user_id: &str, group_id: GroupId) -> Result<()>; } diff --git a/server/src/domain/sql_backend_handler.rs b/server/src/domain/sql_backend_handler.rs index 49afa09..26e7d89 100644 --- a/server/src/domain/sql_backend_handler.rs +++ b/server/src/domain/sql_backend_handler.rs @@ -2,7 +2,6 @@ use super::{error::*, handler::*, sql_tables::*}; use crate::infra::configuration::Configuration; use async_trait::async_trait; use futures_util::StreamExt; -use futures_util::TryStreamExt; use sea_query::{Expr, Iden, Order, Query, SimpleExpr}; use sqlx::Row; use std::collections::HashSet; @@ -76,6 +75,7 @@ impl BackendHandler for SqlBackendHandler { async fn list_groups(&self) -> Result> { let query: String = Query::select() + .column((Groups::Table, Groups::GroupId)) .column(Groups::DisplayName) .column(Memberships::UserId) .from(Groups::Table) @@ -88,32 +88,33 @@ impl BackendHandler for SqlBackendHandler { .order_by(Memberships::UserId, Order::Asc) .to_string(DbQueryBuilder {}); - let mut results = sqlx::query(&query).fetch(&self.sql_pool); + // For group_by. + use itertools::Itertools; let mut groups = Vec::new(); - // The rows are ordered by group, user, so we need to group them into vectors. + // The rows are returned sorted by display_name, equivalent to group_id. We group them by + // this key which gives us one element (`rows`) per group. + for ((group_id, display_name), rows) in &sqlx::query(&query) + .fetch_all(&self.sql_pool) + .await? + .into_iter() + .group_by(|row| { + ( + GroupId(row.get::(&*Groups::GroupId.to_string())), + row.get::(&*Groups::DisplayName.to_string()), + ) + }) { - let mut current_group = String::new(); - let mut current_users = Vec::new(); - while let Some(row) = results.try_next().await? { - let display_name = row.get::(&*Groups::DisplayName.to_string()); - if display_name != current_group { - if !current_group.is_empty() { - groups.push(Group { - display_name: current_group, - users: current_users, - }); - current_users = Vec::new(); - } - current_group = display_name.clone(); - } - current_users.push(row.get::(&*Memberships::UserId.to_string())); - } groups.push(Group { - display_name: current_group, - users: current_users, + id: group_id, + display_name, + users: rows + .map(|row| row.get::(&*Memberships::UserId.to_string())) + // If a group has no users, an empty string is returned because of the left + // join. + .filter(|s| !s.is_empty()) + .collect(), }); } - Ok(groups) } @@ -135,13 +136,14 @@ impl BackendHandler for SqlBackendHandler { .await?) } - async fn get_user_groups(&self, user: &str) -> Result> { + async fn get_user_groups(&self, user: &str) -> Result> { if user == self.config.ldap_user_dn { let mut groups = HashSet::new(); - groups.insert("lldap_admin".to_string()); + groups.insert(GroupIdAndName(GroupId(1), "lldap_admin".to_string())); return Ok(groups); } let query: String = Query::select() + .column((Groups::Table, Groups::GroupId)) .column(Groups::DisplayName) .from(Groups::Table) .inner_join( @@ -154,10 +156,15 @@ impl BackendHandler for SqlBackendHandler { sqlx::query(&query) // Extract the group id from the row. - .map(|row: DbRow| row.get::(&*Groups::DisplayName.to_string())) + .map(|row: DbRow| { + GroupIdAndName( + row.get::(&*Groups::GroupId.to_string()), + row.get::(&*Groups::DisplayName.to_string()), + ) + }) .fetch(&self.sql_pool) // Collect the vector of rows, each potentially an error. - .collect::>>() + .collect::>>() .await .into_iter() // Transform it into a single result (the first error if any), and group the group_ids @@ -468,6 +475,7 @@ mod tests { insert_user(&handler, "John", "Pa33w0rd!").await; let group_1 = insert_group(&handler, "Best Group").await; let group_2 = insert_group(&handler, "Worst Group").await; + let group_3 = insert_group(&handler, "Empty Group").await; insert_membership(&handler, group_1, "bob").await; insert_membership(&handler, group_1, "patrick").await; insert_membership(&handler, group_2, "patrick").await; @@ -476,13 +484,20 @@ mod tests { handler.list_groups().await.unwrap(), vec![ Group { + id: group_1, display_name: "Best Group".to_string(), users: vec!["bob".to_string(), "patrick".to_string()] }, Group { + id: group_3, + display_name: "Empty Group".to_string(), + users: vec![] + }, + Group { + id: group_2, display_name: "Worst Group".to_string(), users: vec!["John".to_string(), "patrick".to_string()] - } + }, ] ); } @@ -515,10 +530,10 @@ mod tests { insert_membership(&handler, group_1, "patrick").await; insert_membership(&handler, group_2, "patrick").await; let mut bob_groups = HashSet::new(); - bob_groups.insert("Group1".to_string()); + bob_groups.insert(GroupIdAndName(group_1, "Group1".to_string())); let mut patrick_groups = HashSet::new(); - patrick_groups.insert("Group1".to_string()); - patrick_groups.insert("Group2".to_string()); + patrick_groups.insert(GroupIdAndName(group_1, "Group1".to_string())); + patrick_groups.insert(GroupIdAndName(group_2, "Group2".to_string())); assert_eq!(handler.get_user_groups("bob").await.unwrap(), bob_groups); assert_eq!( handler.get_user_groups("patrick").await.unwrap(), diff --git a/server/src/domain/sql_tables.rs b/server/src/domain/sql_tables.rs index 35f349d..67a5115 100644 --- a/server/src/domain/sql_tables.rs +++ b/server/src/domain/sql_tables.rs @@ -12,6 +12,31 @@ impl From for Value { } } +impl sqlx::Type for GroupId +where + DB: sqlx::Database, + i32: sqlx::Type, +{ + fn type_info() -> ::TypeInfo { + >::type_info() + } + fn compatible(ty: &::TypeInfo) -> bool { + >::compatible(ty) + } +} + +impl<'r, DB> sqlx::Decode<'r, DB> for GroupId +where + DB: sqlx::Database, + i32: sqlx::Decode<'r, DB>, +{ + fn decode( + value: >::ValueRef, + ) -> Result> { + >::decode(value).map(GroupId) + } +} + #[derive(Iden)] pub enum Users { Table, diff --git a/server/src/infra/auth_service.rs b/server/src/infra/auth_service.rs index ed0132e..98311d9 100644 --- a/server/src/infra/auth_service.rs +++ b/server/src/infra/auth_service.rs @@ -1,7 +1,7 @@ use crate::{ domain::{ error::DomainError, - handler::{BackendHandler, BindRequest, LoginHandler}, + handler::{BackendHandler, BindRequest, GroupIdAndName, LoginHandler}, opaque_handler::OpaqueHandler, }, infra::{ @@ -32,12 +32,12 @@ use time::ext::NumericalDuration; type Token = jwt::Token; type SignedToken = Token; -fn create_jwt(key: &Hmac, user: String, groups: HashSet) -> SignedToken { +fn create_jwt(key: &Hmac, user: String, groups: HashSet) -> SignedToken { let claims = JWTClaims { exp: Utc::now() + chrono::Duration::days(1), iat: Utc::now(), user, - groups, + groups: groups.into_iter().map(|g| g.1).collect(), }; let header = jwt::Header { algorithm: jwt::AlgorithmType::Hs512, diff --git a/server/src/infra/graphql/query.rs b/server/src/infra/graphql/query.rs index 2c9a29c..a8dbcb7 100644 --- a/server/src/infra/graphql/query.rs +++ b/server/src/infra/graphql/query.rs @@ -1,10 +1,11 @@ -use crate::domain::handler::BackendHandler; +use crate::domain::handler::{BackendHandler, GroupIdAndName}; use juniper::{graphql_object, FieldResult, GraphQLInputObject}; use serde::{Deserialize, Serialize}; use std::convert::TryInto; type DomainRequestFilter = crate::domain::handler::RequestFilter; type DomainUser = crate::domain::handler::User; +type DomainGroup = crate::domain::handler::Group; use super::api::Context; #[derive(PartialEq, Eq, Debug, GraphQLInputObject)] @@ -113,6 +114,17 @@ impl Query { .await .map(|v| v.into_iter().map(Into::into).collect())?) } + + async fn groups(context: &Context) -> FieldResult>> { + if !context.validation_result.is_admin { + return Err("Unauthorized access to group list".into()); + } + Ok(context + .handler + .list_groups() + .await + .map(|v| v.into_iter().map(Into::into).collect())?) + } } #[derive(PartialEq, Eq, Debug, Serialize, Deserialize)] @@ -179,14 +191,19 @@ impl From for User { #[derive(PartialEq, Eq, Debug, Serialize, Deserialize)] /// Represents a single group. pub struct Group { - group_id: String, + group_id: i32, + display_name: String, + members: Option>, _phantom: std::marker::PhantomData>, } #[graphql_object(context = Context)] impl Group { - fn id(&self) -> String { - self.group_id.clone() + fn id(&self) -> i32 { + self.group_id + } + fn display_name(&self) -> String { + self.display_name.clone() } /// The groups to which this user belongs. async fn users(&self, context: &Context) -> FieldResult>> { @@ -197,10 +214,23 @@ impl Group { } } -impl From for Group { - fn from(group_id: String) -> Self { +impl From for Group { + fn from(group_id_and_name: GroupIdAndName) -> Self { Self { - group_id, + group_id: group_id_and_name.0 .0, + display_name: group_id_and_name.1, + members: None, + _phantom: std::marker::PhantomData, + } + } +} + +impl From for Group { + fn from(group: DomainGroup) -> Self { + Self { + group_id: group.id.0, + display_name: group.display_name, + members: Some(group.users.into_iter().map(Into::into).collect()), _phantom: std::marker::PhantomData, } } @@ -209,7 +239,10 @@ impl From for Group { #[cfg(test)] mod tests { use super::*; - use crate::{domain::handler::MockTestBackendHandler, infra::auth_service::ValidationResults}; + use crate::{ + domain::handler::{GroupId, GroupIdAndName, MockTestBackendHandler}, + infra::auth_service::ValidationResults, + }; use juniper::{ execute, graphql_value, DefaultScalarValue, EmptyMutation, EmptySubscription, GraphQLType, RootNode, Variables, @@ -250,8 +283,8 @@ mod tests { ..Default::default() }) }); - let mut groups = HashSet::::new(); - groups.insert("Bobbersons".to_string()); + let mut groups = HashSet::new(); + groups.insert(GroupIdAndName(GroupId(3), "Bobbersons".to_string())); mock.expect_get_user_groups() .with(eq("bob")) .return_once(|_| Ok(groups)); @@ -270,7 +303,7 @@ mod tests { "user": { "id": "bob", "email": "bob@bobbers.on", - "groups": [{"id": "Bobbersons"}] + "groups": [{"id": 3}] } }), vec![] diff --git a/server/src/infra/tcp_backend_handler.rs b/server/src/infra/tcp_backend_handler.rs index 03f019e..5899a40 100644 --- a/server/src/infra/tcp_backend_handler.rs +++ b/server/src/infra/tcp_backend_handler.rs @@ -29,7 +29,7 @@ mockall::mock! { async fn list_users(&self, filters: Option) -> DomainResult>; async fn list_groups(&self) -> DomainResult>; async fn get_user_details(&self, user_id: &str) -> DomainResult; - async fn get_user_groups(&self, user: &str) -> DomainResult>; + async fn get_user_groups(&self, user: &str) -> DomainResult>; async fn create_user(&self, request: CreateUserRequest) -> DomainResult<()>; async fn update_user(&self, request: UpdateUserRequest) -> DomainResult<()>; async fn delete_user(&self, user_id: &str) -> DomainResult<()>;