diff --git a/Cargo.lock b/Cargo.lock
index 596c283..ec5f6b5 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1509,6 +1509,15 @@ dependencies = [
"cfg-if 1.0.0",
]
+[[package]]
+name = "itertools"
+version = "0.10.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "69ddb889f9d0d08a67338271fa9b62996bc788c7796a5c18cf057420aaed5eaf"
+dependencies = [
+ "either",
+]
+
[[package]]
name = "itoa"
version = "0.4.8"
@@ -1705,6 +1714,7 @@ dependencies = [
"futures-util",
"hmac 0.10.1",
"http",
+ "itertools",
"juniper",
"juniper_actix",
"jwt",
diff --git a/app/queries/get_user_details.graphql b/app/queries/get_user_details.graphql
index 57b20ef..150b80e 100644
--- a/app/queries/get_user_details.graphql
+++ b/app/queries/get_user_details.graphql
@@ -7,7 +7,7 @@ query GetUserDetails($id: String!) {
lastName
creationDate
groups {
- id
+ displayName
}
}
}
diff --git a/app/src/components/user_details.rs b/app/src/components/user_details.rs
index 58fae4f..f144d41 100644
--- a/app/src/components/user_details.rs
+++ b/app/src/components/user_details.rs
@@ -219,7 +219,7 @@ impl Component for UserDetails {
html! {
|
- {&group.id} |
+ {&group.display_name} |
}
};
diff --git a/schema.graphql b/schema.graphql
index 9060d1d..c167e30 100644
--- a/schema.graphql
+++ b/schema.graphql
@@ -11,7 +11,8 @@ type Mutation {
}
type Group {
- id: String!
+ id: Int!
+ displayName: String!
"The groups to which this user belongs."
users: [User!]!
}
@@ -30,6 +31,13 @@ input RequestFilter {
"DateTime"
scalar DateTimeUtc
+type Query {
+ apiVersion: String!
+ user(userId: String!): User!
+ users(filters: RequestFilter): [User!]!
+ groups: [Group!]!
+}
+
"The details required to create a user."
input CreateUserInput {
id: String!
@@ -39,12 +47,6 @@ input CreateUserInput {
lastName: String
}
-type Query {
- apiVersion: String!
- user(userId: String!): User!
- users(filters: RequestFilter): [User!]!
-}
-
type User {
id: String!
email: String!
diff --git a/server/Cargo.toml b/server/Cargo.toml
index 1c8997e..e98cb55 100644
--- a/server/Cargo.toml
+++ b/server/Cargo.toml
@@ -45,6 +45,7 @@ tracing-subscriber = "*"
rand = { version = "0.8", features = ["small_rng", "getrandom"] }
juniper_actix = "0.4.0"
juniper = "0.15.6"
+itertools = "0.10.1"
# TODO: update to 0.6 when out.
[dependencies.opaque-ke]
diff --git a/server/src/domain/handler.rs b/server/src/domain/handler.rs
index 19ccb09..fbee794 100644
--- a/server/src/domain/handler.rs
+++ b/server/src/domain/handler.rs
@@ -31,6 +31,7 @@ impl Default for User {
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize)]
pub struct Group {
+ pub id: GroupId,
pub display_name: String,
pub users: Vec,
}
@@ -74,9 +75,12 @@ pub trait LoginHandler: Clone + Send {
async fn bind(&self, request: BindRequest) -> Result<()>;
}
-#[derive(Debug, Copy, Clone, PartialEq, Eq)]
+#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct GroupId(pub i32);
+#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
+pub struct GroupIdAndName(pub GroupId, pub String);
+
#[async_trait]
pub trait BackendHandler: Clone + Send {
async fn list_users(&self, filters: Option) -> Result>;
@@ -88,7 +92,7 @@ pub trait BackendHandler: Clone + Send {
async fn create_group(&self, group_name: &str) -> Result;
async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> Result<()>;
async fn remove_user_from_group(&self, user_id: &str, group_id: GroupId) -> Result<()>;
- async fn get_user_groups(&self, user: &str) -> Result>;
+ async fn get_user_groups(&self, user: &str) -> Result>;
}
#[cfg(test)]
@@ -106,7 +110,7 @@ mockall::mock! {
async fn update_user(&self, request: UpdateUserRequest) -> Result<()>;
async fn delete_user(&self, user_id: &str) -> Result<()>;
async fn create_group(&self, group_name: &str) -> Result;
- async fn get_user_groups(&self, user: &str) -> Result>;
+ async fn get_user_groups(&self, user: &str) -> Result>;
async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> Result<()>;
async fn remove_user_from_group(&self, user_id: &str, group_id: GroupId) -> Result<()>;
}
diff --git a/server/src/domain/sql_backend_handler.rs b/server/src/domain/sql_backend_handler.rs
index 49afa09..26e7d89 100644
--- a/server/src/domain/sql_backend_handler.rs
+++ b/server/src/domain/sql_backend_handler.rs
@@ -2,7 +2,6 @@ use super::{error::*, handler::*, sql_tables::*};
use crate::infra::configuration::Configuration;
use async_trait::async_trait;
use futures_util::StreamExt;
-use futures_util::TryStreamExt;
use sea_query::{Expr, Iden, Order, Query, SimpleExpr};
use sqlx::Row;
use std::collections::HashSet;
@@ -76,6 +75,7 @@ impl BackendHandler for SqlBackendHandler {
async fn list_groups(&self) -> Result> {
let query: String = Query::select()
+ .column((Groups::Table, Groups::GroupId))
.column(Groups::DisplayName)
.column(Memberships::UserId)
.from(Groups::Table)
@@ -88,32 +88,33 @@ impl BackendHandler for SqlBackendHandler {
.order_by(Memberships::UserId, Order::Asc)
.to_string(DbQueryBuilder {});
- let mut results = sqlx::query(&query).fetch(&self.sql_pool);
+ // For group_by.
+ use itertools::Itertools;
let mut groups = Vec::new();
- // The rows are ordered by group, user, so we need to group them into vectors.
+ // The rows are returned sorted by display_name, equivalent to group_id. We group them by
+ // this key which gives us one element (`rows`) per group.
+ for ((group_id, display_name), rows) in &sqlx::query(&query)
+ .fetch_all(&self.sql_pool)
+ .await?
+ .into_iter()
+ .group_by(|row| {
+ (
+ GroupId(row.get::(&*Groups::GroupId.to_string())),
+ row.get::(&*Groups::DisplayName.to_string()),
+ )
+ })
{
- let mut current_group = String::new();
- let mut current_users = Vec::new();
- while let Some(row) = results.try_next().await? {
- let display_name = row.get::(&*Groups::DisplayName.to_string());
- if display_name != current_group {
- if !current_group.is_empty() {
- groups.push(Group {
- display_name: current_group,
- users: current_users,
- });
- current_users = Vec::new();
- }
- current_group = display_name.clone();
- }
- current_users.push(row.get::(&*Memberships::UserId.to_string()));
- }
groups.push(Group {
- display_name: current_group,
- users: current_users,
+ id: group_id,
+ display_name,
+ users: rows
+ .map(|row| row.get::(&*Memberships::UserId.to_string()))
+ // If a group has no users, an empty string is returned because of the left
+ // join.
+ .filter(|s| !s.is_empty())
+ .collect(),
});
}
-
Ok(groups)
}
@@ -135,13 +136,14 @@ impl BackendHandler for SqlBackendHandler {
.await?)
}
- async fn get_user_groups(&self, user: &str) -> Result> {
+ async fn get_user_groups(&self, user: &str) -> Result> {
if user == self.config.ldap_user_dn {
let mut groups = HashSet::new();
- groups.insert("lldap_admin".to_string());
+ groups.insert(GroupIdAndName(GroupId(1), "lldap_admin".to_string()));
return Ok(groups);
}
let query: String = Query::select()
+ .column((Groups::Table, Groups::GroupId))
.column(Groups::DisplayName)
.from(Groups::Table)
.inner_join(
@@ -154,10 +156,15 @@ impl BackendHandler for SqlBackendHandler {
sqlx::query(&query)
// Extract the group id from the row.
- .map(|row: DbRow| row.get::(&*Groups::DisplayName.to_string()))
+ .map(|row: DbRow| {
+ GroupIdAndName(
+ row.get::(&*Groups::GroupId.to_string()),
+ row.get::(&*Groups::DisplayName.to_string()),
+ )
+ })
.fetch(&self.sql_pool)
// Collect the vector of rows, each potentially an error.
- .collect::>>()
+ .collect::>>()
.await
.into_iter()
// Transform it into a single result (the first error if any), and group the group_ids
@@ -468,6 +475,7 @@ mod tests {
insert_user(&handler, "John", "Pa33w0rd!").await;
let group_1 = insert_group(&handler, "Best Group").await;
let group_2 = insert_group(&handler, "Worst Group").await;
+ let group_3 = insert_group(&handler, "Empty Group").await;
insert_membership(&handler, group_1, "bob").await;
insert_membership(&handler, group_1, "patrick").await;
insert_membership(&handler, group_2, "patrick").await;
@@ -476,13 +484,20 @@ mod tests {
handler.list_groups().await.unwrap(),
vec![
Group {
+ id: group_1,
display_name: "Best Group".to_string(),
users: vec!["bob".to_string(), "patrick".to_string()]
},
Group {
+ id: group_3,
+ display_name: "Empty Group".to_string(),
+ users: vec![]
+ },
+ Group {
+ id: group_2,
display_name: "Worst Group".to_string(),
users: vec!["John".to_string(), "patrick".to_string()]
- }
+ },
]
);
}
@@ -515,10 +530,10 @@ mod tests {
insert_membership(&handler, group_1, "patrick").await;
insert_membership(&handler, group_2, "patrick").await;
let mut bob_groups = HashSet::new();
- bob_groups.insert("Group1".to_string());
+ bob_groups.insert(GroupIdAndName(group_1, "Group1".to_string()));
let mut patrick_groups = HashSet::new();
- patrick_groups.insert("Group1".to_string());
- patrick_groups.insert("Group2".to_string());
+ patrick_groups.insert(GroupIdAndName(group_1, "Group1".to_string()));
+ patrick_groups.insert(GroupIdAndName(group_2, "Group2".to_string()));
assert_eq!(handler.get_user_groups("bob").await.unwrap(), bob_groups);
assert_eq!(
handler.get_user_groups("patrick").await.unwrap(),
diff --git a/server/src/domain/sql_tables.rs b/server/src/domain/sql_tables.rs
index 35f349d..67a5115 100644
--- a/server/src/domain/sql_tables.rs
+++ b/server/src/domain/sql_tables.rs
@@ -12,6 +12,31 @@ impl From for Value {
}
}
+impl sqlx::Type for GroupId
+where
+ DB: sqlx::Database,
+ i32: sqlx::Type,
+{
+ fn type_info() -> ::TypeInfo {
+ >::type_info()
+ }
+ fn compatible(ty: &::TypeInfo) -> bool {
+ >::compatible(ty)
+ }
+}
+
+impl<'r, DB> sqlx::Decode<'r, DB> for GroupId
+where
+ DB: sqlx::Database,
+ i32: sqlx::Decode<'r, DB>,
+{
+ fn decode(
+ value: >::ValueRef,
+ ) -> Result> {
+ >::decode(value).map(GroupId)
+ }
+}
+
#[derive(Iden)]
pub enum Users {
Table,
diff --git a/server/src/infra/auth_service.rs b/server/src/infra/auth_service.rs
index ed0132e..98311d9 100644
--- a/server/src/infra/auth_service.rs
+++ b/server/src/infra/auth_service.rs
@@ -1,7 +1,7 @@
use crate::{
domain::{
error::DomainError,
- handler::{BackendHandler, BindRequest, LoginHandler},
+ handler::{BackendHandler, BindRequest, GroupIdAndName, LoginHandler},
opaque_handler::OpaqueHandler,
},
infra::{
@@ -32,12 +32,12 @@ use time::ext::NumericalDuration;
type Token = jwt::Token;
type SignedToken = Token;
-fn create_jwt(key: &Hmac, user: String, groups: HashSet) -> SignedToken {
+fn create_jwt(key: &Hmac, user: String, groups: HashSet) -> SignedToken {
let claims = JWTClaims {
exp: Utc::now() + chrono::Duration::days(1),
iat: Utc::now(),
user,
- groups,
+ groups: groups.into_iter().map(|g| g.1).collect(),
};
let header = jwt::Header {
algorithm: jwt::AlgorithmType::Hs512,
diff --git a/server/src/infra/graphql/query.rs b/server/src/infra/graphql/query.rs
index 2c9a29c..a8dbcb7 100644
--- a/server/src/infra/graphql/query.rs
+++ b/server/src/infra/graphql/query.rs
@@ -1,10 +1,11 @@
-use crate::domain::handler::BackendHandler;
+use crate::domain::handler::{BackendHandler, GroupIdAndName};
use juniper::{graphql_object, FieldResult, GraphQLInputObject};
use serde::{Deserialize, Serialize};
use std::convert::TryInto;
type DomainRequestFilter = crate::domain::handler::RequestFilter;
type DomainUser = crate::domain::handler::User;
+type DomainGroup = crate::domain::handler::Group;
use super::api::Context;
#[derive(PartialEq, Eq, Debug, GraphQLInputObject)]
@@ -113,6 +114,17 @@ impl Query {
.await
.map(|v| v.into_iter().map(Into::into).collect())?)
}
+
+ async fn groups(context: &Context) -> FieldResult>> {
+ if !context.validation_result.is_admin {
+ return Err("Unauthorized access to group list".into());
+ }
+ Ok(context
+ .handler
+ .list_groups()
+ .await
+ .map(|v| v.into_iter().map(Into::into).collect())?)
+ }
}
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize)]
@@ -179,14 +191,19 @@ impl From for User {
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize)]
/// Represents a single group.
pub struct Group {
- group_id: String,
+ group_id: i32,
+ display_name: String,
+ members: Option>,
_phantom: std::marker::PhantomData>,
}
#[graphql_object(context = Context)]
impl Group {
- fn id(&self) -> String {
- self.group_id.clone()
+ fn id(&self) -> i32 {
+ self.group_id
+ }
+ fn display_name(&self) -> String {
+ self.display_name.clone()
}
/// The groups to which this user belongs.
async fn users(&self, context: &Context) -> FieldResult>> {
@@ -197,10 +214,23 @@ impl Group {
}
}
-impl From for Group {
- fn from(group_id: String) -> Self {
+impl From for Group {
+ fn from(group_id_and_name: GroupIdAndName) -> Self {
Self {
- group_id,
+ group_id: group_id_and_name.0 .0,
+ display_name: group_id_and_name.1,
+ members: None,
+ _phantom: std::marker::PhantomData,
+ }
+ }
+}
+
+impl From for Group {
+ fn from(group: DomainGroup) -> Self {
+ Self {
+ group_id: group.id.0,
+ display_name: group.display_name,
+ members: Some(group.users.into_iter().map(Into::into).collect()),
_phantom: std::marker::PhantomData,
}
}
@@ -209,7 +239,10 @@ impl From for Group {
#[cfg(test)]
mod tests {
use super::*;
- use crate::{domain::handler::MockTestBackendHandler, infra::auth_service::ValidationResults};
+ use crate::{
+ domain::handler::{GroupId, GroupIdAndName, MockTestBackendHandler},
+ infra::auth_service::ValidationResults,
+ };
use juniper::{
execute, graphql_value, DefaultScalarValue, EmptyMutation, EmptySubscription, GraphQLType,
RootNode, Variables,
@@ -250,8 +283,8 @@ mod tests {
..Default::default()
})
});
- let mut groups = HashSet::::new();
- groups.insert("Bobbersons".to_string());
+ let mut groups = HashSet::new();
+ groups.insert(GroupIdAndName(GroupId(3), "Bobbersons".to_string()));
mock.expect_get_user_groups()
.with(eq("bob"))
.return_once(|_| Ok(groups));
@@ -270,7 +303,7 @@ mod tests {
"user": {
"id": "bob",
"email": "bob@bobbers.on",
- "groups": [{"id": "Bobbersons"}]
+ "groups": [{"id": 3}]
}
}),
vec![]
diff --git a/server/src/infra/tcp_backend_handler.rs b/server/src/infra/tcp_backend_handler.rs
index 03f019e..5899a40 100644
--- a/server/src/infra/tcp_backend_handler.rs
+++ b/server/src/infra/tcp_backend_handler.rs
@@ -29,7 +29,7 @@ mockall::mock! {
async fn list_users(&self, filters: Option) -> DomainResult>;
async fn list_groups(&self) -> DomainResult>;
async fn get_user_details(&self, user_id: &str) -> DomainResult;
- async fn get_user_groups(&self, user: &str) -> DomainResult>;
+ async fn get_user_groups(&self, user: &str) -> DomainResult>;
async fn create_user(&self, request: CreateUserRequest) -> DomainResult<()>;
async fn update_user(&self, request: UpdateUserRequest) -> DomainResult<()>;
async fn delete_user(&self, user_id: &str) -> DomainResult<()>;