From a51965a61a89eeca4c4ef88b377e34b3308c1ef2 Mon Sep 17 00:00:00 2001 From: Valentin Tolmer Date: Thu, 26 Aug 2021 09:52:56 +0200 Subject: [PATCH] Implement basic GraphQL endpoint with auth --- Cargo.lock | 172 +++++++++++++++++++ Cargo.toml | 2 + src/infra/auth_service.rs | 64 +++++-- src/infra/graphql/api.rs | 70 ++++++++ src/infra/graphql/mod.rs | 2 + src/infra/graphql/query.rs | 340 +++++++++++++++++++++++++++++++++++++ src/infra/mod.rs | 1 + src/infra/tcp_api.rs | 3 +- src/infra/tcp_server.rs | 4 +- 9 files changed, 637 insertions(+), 21 deletions(-) create mode 100644 src/infra/graphql/api.rs create mode 100644 src/infra/graphql/mod.rs create mode 100644 src/infra/graphql/query.rs diff --git a/Cargo.lock b/Cargo.lock index a731ffc..378dfec 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -238,6 +238,23 @@ dependencies = [ "url", ] +[[package]] +name = "actix-web-actors" +version = "4.0.0-beta.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7db5c2c78a2606e6634abee4973a4924221cfab66e48f23844256e4fb8ce0f42" +dependencies = [ + "actix", + "actix-codec", + "actix-http", + "actix-web", + "bytes", + "bytestring", + "futures-core", + "pin-project", + "tokio", +] + [[package]] name = "actix-web-codegen" version = "0.5.0-beta.3" @@ -348,6 +365,12 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "23b62fc65de8e4e7f52534fb52b0f3ed04746ae267519eef2a83941e8085068b" +[[package]] +name = "ascii" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eab1c04a571841102f5345a8fc0f6bb3d31c315dec879b5c6e42e40ce7ffa34e" + [[package]] name = "askama_escape" version = "0.10.1" @@ -491,6 +514,23 @@ dependencies = [ "libc", ] +[[package]] +name = "bson" +version = "1.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "903a4f4c7aa97921f1703acac1fd524e9e082b3228edd34dde07758c0c92c672" +dependencies = [ + "base64", + "chrono", + "hex", + "lazy_static", + "linked-hash-map", + "rand 0.7.3", + "serde", + "serde_json", + "uuid", +] + [[package]] name = "build_const" version = "0.2.2" @@ -597,6 +637,19 @@ dependencies = [ "syn", ] +[[package]] +name = "combine" +version = "3.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da3da6baa321ec19e1cc41d31bf599f00c783d0517095cdaf0332e3fe8d20680" +dependencies = [ + "ascii", + "byteorder", + "either", + "memchr", + "unreachable", +] + [[package]] name = "console_error_panic_hook" version = "0.1.6" @@ -838,6 +891,17 @@ dependencies = [ "syn", ] +[[package]] +name = "derive_utils" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "532b4c15dccee12c7044f1fcad956e98410860b22231e44a3b827464797ca7bf" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "difference" version = "2.0.0" @@ -1006,6 +1070,17 @@ version = "0.3.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0402f765d8a89a26043b889b26ce3c4679d268fa6bb22cd7c6aad98340e179d1" +[[package]] +name = "futures-enum" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3422d14de7903a52e9dbc10ae05a7e14445ec61890100e098754e120b2bd7b1e" +dependencies = [ + "derive_utils", + "quote", + "syn", +] + [[package]] name = "futures-executor" version = "0.3.15" @@ -1167,6 +1242,16 @@ dependencies = [ "web-sys", ] +[[package]] +name = "graphql-parser" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1abd4ce5247dfc04a03ccde70f87a048458c9356c7e41d21ad8c407b3dde6f2" +dependencies = [ + "combine", + "thiserror", +] + [[package]] name = "h2" version = "0.3.3" @@ -1306,6 +1391,7 @@ checksum = "824845a0bf897a9042383849b02c1bc219c2383772efcd5c6f9766fa4b81aef3" dependencies = [ "autocfg 1.0.1", "hashbrown", + "serde", ] [[package]] @@ -1347,6 +1433,59 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "juniper" +version = "0.15.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "637ffa8a8d8a05aed3331449e311f145864adcd82442d82e54d0522decb7cecf" +dependencies = [ + "async-trait", + "bson", + "chrono", + "fnv", + "futures", + "futures-enum", + "graphql-parser", + "indexmap", + "juniper_codegen", + "serde", + "smartstring", + "static_assertions", + "url", + "uuid", +] + +[[package]] +name = "juniper_actix" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc44af18ae1f551076171e24eb453c52132a19c219d1f1a1c3068ab363b946b5" +dependencies = [ + "actix", + "actix-http", + "actix-web", + "actix-web-actors", + "anyhow", + "futures", + "http", + "juniper", + "serde", + "serde_json", + "thiserror", +] + +[[package]] +name = "juniper_codegen" +version = "0.15.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a040e09482a45e77dd2dafa0d9d2651d17faf0ac674da0c93eabc3075ee24997" +dependencies = [ + "proc-macro-error", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "jwt" version = "0.13.0" @@ -1435,6 +1574,12 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "linked-hash-map" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fb9b38af92608140b86b693604b9ffcc5824240a484d1ecd4795bacb2fe88f3" + [[package]] name = "lldap" version = "0.1.0" @@ -1460,6 +1605,8 @@ dependencies = [ "futures-util", "hmac 0.10.1", "http", + "juniper", + "juniper_actix", "jwt", "ldap3_server", "lldap_model", @@ -2400,6 +2547,7 @@ version = "1.0.64" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "799e97dc9fdae36a5c8b8f2cae9ce2ee9fdce2058c57a93e6099d919fd982f79" dependencies = [ + "indexmap", "itoa", "ryu", "serde", @@ -2490,6 +2638,15 @@ version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fe0f37c9e8f3c5a4a66ad655a93c74daac4ad00c441533bf5c6e7990bb42604e" +[[package]] +name = "smartstring" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31aa6a31c0c2b21327ce875f7e8952322acfcfd0c27569a6e18a647281352c9b" +dependencies = [ + "static_assertions", +] + [[package]] name = "socket2" version = "0.4.0" @@ -3097,6 +3254,15 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" +[[package]] +name = "unreachable" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "382810877fe448991dfc7f0dd6e3ae5d58088fd0ea5e35189655f84e6814fa56" +dependencies = [ + "void", +] + [[package]] name = "url" version = "2.2.2" @@ -3136,6 +3302,12 @@ version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5fecdca9a5291cc2b8dcf7dc02453fee791a280f3743cb0905f8822ae463b3fe" +[[package]] +name = "void" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a02e4885ed3bc0f2de90ea6dd45ebcbb66dacffe03547fadbb0eeae2770887d" + [[package]] name = "wasi" version = "0.9.0+wasi-snapshot-preview1" diff --git a/Cargo.toml b/Cargo.toml index 96ae800..eaf8af4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,6 +46,8 @@ tracing-actix-web = "0.4.0-beta.7" tracing-log = "*" tracing-subscriber = "*" rand = { version = "0.8", features = ["small_rng", "getrandom"] } +juniper_actix = "0.4.0" +juniper = "0.15.6" # TODO: update to 0.6 when out. [dependencies.opaque-ke] diff --git a/src/infra/auth_service.rs b/src/infra/auth_service.rs index 514eeb0..83af89e 100644 --- a/src/infra/auth_service.rs +++ b/src/infra/auth_service.rs @@ -337,6 +337,49 @@ where } } +pub struct ValidationResults { + pub user: String, + pub is_admin: bool, +} + +impl ValidationResults { + #[cfg(test)] + pub fn admin() -> Self { + Self { + user: "admin".to_string(), + is_admin: true, + } + } + + pub fn can_access(&self, user: &str) -> bool { + self.is_admin || self.user == user + } +} + +pub(crate) fn check_if_token_is_valid( + state: &AppState, + token_str: &str, +) -> Result { + let token: Token<_> = VerifyWithKey::verify_with_key(token_str, &state.jwt_key) + .map_err(|_| ErrorUnauthorized("Invalid JWT"))?; + if token.claims().exp.lt(&Utc::now()) { + return Err(ErrorUnauthorized("Expired JWT")); + } + let jwt_hash = { + let mut s = DefaultHasher::new(); + token_str.hash(&mut s); + s.finish() + }; + if state.jwt_blacklist.read().unwrap().contains(&jwt_hash) { + return Err(ErrorUnauthorized("JWT was logged out")); + } + let is_admin = token.claims().groups.contains("lldap_admin"); + Ok(ValidationResults { + user: token.claims().user.clone(), + is_admin, + }) +} + pub async fn token_validator( req: ServiceRequest, credentials: BearerAuth, @@ -348,24 +391,9 @@ where let state = req .app_data::>>() .expect("Invalid app config"); - let token: Token<_> = VerifyWithKey::verify_with_key(credentials.token(), &state.jwt_key) - .map_err(|_| ErrorUnauthorized("Invalid JWT"))?; - if token.claims().exp.lt(&Utc::now()) { - return Err(ErrorUnauthorized("Expired JWT")); - } - let jwt_hash = { - let mut s = DefaultHasher::new(); - credentials.token().hash(&mut s); - s.finish() - }; - if state.jwt_blacklist.read().unwrap().contains(&jwt_hash) { - return Err(ErrorUnauthorized("JWT was logged out")); - } - let is_admin = token.claims().groups.contains("lldap_admin"); - if is_admin - || (!admin_required && req.match_info().get("user_id") == Some(&token.claims().user)) - { - debug!("Got authorized token for user {}", &token.claims().user); + let ValidationResults { user, is_admin } = check_if_token_is_valid(state, credentials.token())?; + if is_admin || (!admin_required && req.match_info().get("user_id") == Some(&user)) { + debug!("Got authorized token for user {}", &user); Ok(req) } else { Err(ErrorUnauthorized( diff --git a/src/infra/graphql/api.rs b/src/infra/graphql/api.rs new file mode 100644 index 0000000..c3ecc56 --- /dev/null +++ b/src/infra/graphql/api.rs @@ -0,0 +1,70 @@ +use crate::{ + domain::handler::BackendHandler, + infra::{ + auth_service::{check_if_token_is_valid, ValidationResults}, + tcp_server::AppState, + }, +}; +use actix_web::{web, Error, HttpResponse}; +use actix_web_httpauth::extractors::bearer::BearerAuth; +use juniper::{EmptyMutation, EmptySubscription, RootNode}; +use juniper_actix::{graphiql_handler, graphql_handler, playground_handler}; + +use super::query::Query; + +pub struct Context { + pub handler: Box, + pub validation_result: ValidationResults, +} + +impl juniper::Context for Context {} + +type Schema = RootNode< + 'static, + Query, + EmptyMutation>, + EmptySubscription>, +>; + +fn schema() -> Schema { + Schema::new( + Query::::new(), + EmptyMutation::>::new(), + EmptySubscription::>::new(), + ) +} + +async fn graphiql_route() -> Result { + graphiql_handler("/api/graphql", None).await +} +async fn playground_route() -> Result { + playground_handler("/api/graphql", None).await +} + +async fn graphql_route( + req: actix_web::HttpRequest, + mut payload: actix_web::web::Payload, + data: web::Data>, +) -> Result { + use actix_web::FromRequest; + let bearer = BearerAuth::from_request(&req, &mut payload.0).await?; + let validation_result = check_if_token_is_valid(&data, bearer.token())?; + let context = Context:: { + handler: Box::new(data.backend_handler.clone()), + validation_result, + }; + graphql_handler(&schema(), &context, req, payload).await +} + +pub fn configure_endpoint(cfg: &mut web::ServiceConfig) +where + Backend: BackendHandler + Sync + 'static, +{ + cfg.service( + web::resource("/graphql") + .route(web::post().to(graphql_route::)) + .route(web::get().to(graphql_route::)), + ); + cfg.service(web::resource("/graphql/playground").route(web::get().to(playground_route))); + cfg.service(web::resource("/graphql/graphiql").route(web::get().to(graphiql_route))); +} diff --git a/src/infra/graphql/mod.rs b/src/infra/graphql/mod.rs new file mode 100644 index 0000000..1ff2efd --- /dev/null +++ b/src/infra/graphql/mod.rs @@ -0,0 +1,2 @@ +pub mod api; +pub mod query; diff --git a/src/infra/graphql/query.rs b/src/infra/graphql/query.rs new file mode 100644 index 0000000..a2c66ee --- /dev/null +++ b/src/infra/graphql/query.rs @@ -0,0 +1,340 @@ +use crate::domain::handler::BackendHandler; +use juniper::{graphql_object, FieldResult, GraphQLInputObject}; +use lldap_model::{ListUsersRequest, UserDetailsRequest}; +use serde::{Deserialize, Serialize}; +use std::convert::TryInto; + +use super::api::Context; + +#[derive(PartialEq, Eq, Debug, GraphQLInputObject)] +/// A filter for requests, specifying a boolean expression based on field constraints. Only one of +/// the fields can be set at a time. +pub struct RequestFilter { + any: Option>, + all: Option>, + not: Option>, + eq: Option, +} + +impl TryInto for RequestFilter { + type Error = String; + fn try_into(self) -> Result { + let mut field_count = 0; + if self.any.is_some() { + field_count += 1; + } + if self.all.is_some() { + field_count += 1; + } + if self.not.is_some() { + field_count += 1; + } + if self.eq.is_some() { + field_count += 1; + } + if field_count == 0 { + return Err("No field specified in request filter".to_string()); + } + if field_count > 1 { + return Err("Multiple fields specified in request filter".to_string()); + } + if let Some(e) = self.eq { + return Ok(lldap_model::RequestFilter::Equality(e.field, e.value)); + } + if let Some(c) = self.any { + return Ok(lldap_model::RequestFilter::Or( + c.into_iter() + .map(TryInto::try_into) + .collect::, String>>()?, + )); + } + if let Some(c) = self.all { + return Ok(lldap_model::RequestFilter::And( + c.into_iter() + .map(TryInto::try_into) + .collect::, String>>()?, + )); + } + if let Some(c) = self.not { + return Ok(lldap_model::RequestFilter::Not(Box::new((*c).try_into()?))); + } + unreachable!(); + } +} + +#[derive(PartialEq, Eq, Debug, GraphQLInputObject)] +pub struct EqualityConstraint { + field: String, + value: String, +} + +#[derive(PartialEq, Eq, Debug)] +/// The top-level GraphQL query type. +pub struct Query { + _phantom: std::marker::PhantomData>, +} + +impl Query { + pub fn new() -> Self { + Self { + _phantom: std::marker::PhantomData, + } + } +} + +#[graphql_object(context = Context)] +impl Query { + fn api_version() -> &'static str { + "1.0" + } + + async fn user(context: &Context, user_id: String) -> FieldResult> { + if !context.validation_result.can_access(&user_id) { + return Err("Unauthorized access to user data".into()); + } + Ok(context + .handler + .get_user_details(UserDetailsRequest { user_id }) + .await + .map(Into::into)?) + } + + async fn users( + context: &Context, + #[graphql(name = "where")] filters: Option, + ) -> FieldResult>> { + if !context.validation_result.is_admin { + return Err("Unauthorized access to user list".into()); + } + Ok(context + .handler + .list_users(ListUsersRequest { + filters: match filters { + None => None, + Some(f) => Some(f.try_into()?), + }, + }) + .await + .map(|v| v.into_iter().map(Into::into).collect())?) + } +} + +#[derive(PartialEq, Eq, Debug, Serialize, Deserialize)] +/// Represents a single user. +pub struct User { + user: lldap_model::User, + _phantom: std::marker::PhantomData>, +} + +impl Default for User { + fn default() -> Self { + Self { + user: lldap_model::User::default(), + _phantom: std::marker::PhantomData, + } + } +} + +#[graphql_object(context = Context)] +impl User { + fn id(&self) -> &str { + &self.user.user_id + } + + fn email(&self) -> &str { + &self.user.email + } + + /// The groups to which this user belongs. + async fn groups(&self, context: &Context) -> FieldResult>> { + Ok(context + .handler + .get_user_groups(self.user.user_id.clone()) + .await + .map(|set| set.into_iter().map(Into::into).collect())?) + } +} + +impl From for User { + fn from(user: lldap_model::User) -> Self { + Self { + user, + _phantom: std::marker::PhantomData, + } + } +} + +#[derive(PartialEq, Eq, Debug, Serialize, Deserialize)] +/// Represents a single group. +pub struct Group { + group_id: String, + _phantom: std::marker::PhantomData>, +} + +#[graphql_object(context = Context)] +impl Group { + fn id(&self) -> String { + self.group_id.clone() + } + /// The groups to which this user belongs. + async fn users(&self, context: &Context) -> FieldResult>> { + if !context.validation_result.is_admin { + return Err("Unauthorized access to group data".into()); + } + unimplemented!() + } +} + +impl From for Group { + fn from(group_id: String) -> Self { + Self { + group_id, + _phantom: std::marker::PhantomData, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{domain::handler::MockTestBackendHandler, infra::auth_service::ValidationResults}; + use juniper::{ + execute, graphql_value, DefaultScalarValue, EmptyMutation, EmptySubscription, GraphQLType, + RootNode, Variables, + }; + use mockall::predicate::eq; + use std::collections::HashSet; + + fn schema<'q, C, Q>(query_root: Q) -> RootNode<'q, Q, EmptyMutation, EmptySubscription> + where + Q: GraphQLType + 'q, + { + RootNode::new( + query_root, + EmptyMutation::::new(), + EmptySubscription::::new(), + ) + } + + #[tokio::test] + async fn get_user_by_id() { + const QUERY: &str = r#"{ + user(userId: "bob") { + id + email + groups { + id + } + } + }"#; + + let mut mock = MockTestBackendHandler::new(); + mock.expect_get_user_details() + .with(eq(UserDetailsRequest { + user_id: "bob".to_string(), + })) + .return_once(|_| { + Ok(lldap_model::User { + user_id: "bob".to_string(), + email: "bob@bobbers.on".to_string(), + ..Default::default() + }) + }); + let mut groups = HashSet::::new(); + groups.insert("Bobbersons".to_string()); + mock.expect_get_user_groups() + .with(eq("bob".to_string())) + .return_once(|_| Ok(groups)); + + let context = Context:: { + handler: Box::new(mock), + validation_result: ValidationResults::admin(), + }; + + let schema = schema(Query::::new()); + assert_eq!( + execute(QUERY, None, &schema, &Variables::new(), &context).await, + Ok(( + graphql_value!( + { + "user": { + "id": "bob", + "email": "bob@bobbers.on", + "groups": [{"id": "Bobbersons"}] + } + }), + vec![] + )) + ); + } + + #[tokio::test] + async fn list_users() { + const QUERY: &str = r#"{ + users(filters: { + any: [ + {eq: { + field: "id" + value: "bob" + }}, + {eq: { + field: "email" + value: "robert@bobbers.on" + }} + ]}) { + id + email + } + }"#; + + let mut mock = MockTestBackendHandler::new(); + use lldap_model::{RequestFilter, User}; + mock.expect_list_users() + .with(eq(ListUsersRequest { + filters: Some(RequestFilter::Or(vec![ + RequestFilter::Equality("id".to_string(), "bob".to_string()), + RequestFilter::Equality("email".to_string(), "robert@bobbers.on".to_string()), + ])), + })) + .return_once(|_| { + Ok(vec![ + User { + user_id: "bob".to_string(), + email: "bob@bobbers.on".to_string(), + ..Default::default() + }, + User { + user_id: "robert".to_string(), + email: "robert@bobbers.on".to_string(), + ..Default::default() + }, + ]) + }); + + let context = Context:: { + handler: Box::new(mock), + validation_result: ValidationResults::admin(), + }; + + let schema = schema(Query::::new()); + assert_eq!( + execute(QUERY, None, &schema, &Variables::new(), &context).await, + Ok(( + graphql_value!( + { + "users": [ + { + "id": "bob", + "email": "bob@bobbers.on" + }, + { + "id": "robert", + "email": "robert@bobbers.on" + }, + ] + }), + vec![] + )) + ); + } +} diff --git a/src/infra/mod.rs b/src/infra/mod.rs index c9472ec..26d9171 100644 --- a/src/infra/mod.rs +++ b/src/infra/mod.rs @@ -2,6 +2,7 @@ pub mod auth_service; pub mod cli; pub mod configuration; pub mod db_cleaner; +pub mod graphql; pub mod jwt_sql_tables; pub mod ldap_handler; pub mod ldap_server; diff --git a/src/infra/tcp_api.rs b/src/infra/tcp_api.rs index dd343e0..e687c40 100644 --- a/src/infra/tcp_api.rs +++ b/src/infra/tcp_api.rs @@ -62,7 +62,7 @@ where pub fn api_config(cfg: &mut web::ServiceConfig) where - Backend: TcpBackendHandler + BackendHandler + 'static, + Backend: TcpBackendHandler + BackendHandler + Sync + 'static, { let json_config = web::JsonConfig::default() .limit(4096) @@ -77,6 +77,7 @@ where .into() }); cfg.app_data(json_config); + super::graphql::api::configure_endpoint::(cfg); cfg.service( web::resource("/user/{user_id}") .route(web::get().to(user_details_handler::)) diff --git a/src/infra/tcp_server.rs b/src/infra/tcp_server.rs index 0861702..c4735ee 100644 --- a/src/infra/tcp_server.rs +++ b/src/infra/tcp_server.rs @@ -47,7 +47,7 @@ fn http_config( jwt_secret: String, jwt_blacklist: HashSet, ) where - Backend: TcpBackendHandler + BackendHandler + LoginHandler + OpaqueHandler + 'static, + Backend: TcpBackendHandler + BackendHandler + LoginHandler + OpaqueHandler + Sync + 'static, { cfg.app_data(web::Data::new(AppState:: { backend_handler, @@ -84,7 +84,7 @@ pub async fn build_tcp_server( server_builder: ServerBuilder, ) -> Result where - Backend: TcpBackendHandler + BackendHandler + LoginHandler + OpaqueHandler + 'static, + Backend: TcpBackendHandler + BackendHandler + LoginHandler + OpaqueHandler + Sync + 'static, { let jwt_secret = config.jwt_secret.clone(); let jwt_blacklist = backend_handler.get_jwt_blacklist().await?;