From 28a941924eedfd0b9407ddd8d7dd371c6cb1971a Mon Sep 17 00:00:00 2001 From: Valentin Tolmer Date: Thu, 20 May 2021 19:18:15 +0200 Subject: [PATCH] Split big files into little ones --- src/domain/handler.rs | 496 +----------------------------- src/domain/mod.rs | 1 + src/domain/sql_backend_handler.rs | 420 +++++++++++++++++++++++++ src/infra/auth_service.rs | 237 ++++++++++++++ src/infra/mod.rs | 4 + src/infra/sql_backend_handler.rs | 81 +++++ src/infra/tcp_api.rs | 97 ++++++ src/infra/tcp_backend_handler.rs | 35 +++ src/infra/tcp_server.rs | 314 ++----------------- src/main.rs | 8 +- 10 files changed, 904 insertions(+), 789 deletions(-) create mode 100644 src/domain/sql_backend_handler.rs create mode 100644 src/infra/auth_service.rs create mode 100644 src/infra/sql_backend_handler.rs create mode 100644 src/infra/tcp_api.rs create mode 100644 src/infra/tcp_backend_handler.rs diff --git a/src/domain/handler.rs b/src/domain/handler.rs index b36d882..1f03611 100644 --- a/src/domain/handler.rs +++ b/src/domain/handler.rs @@ -1,14 +1,5 @@ -use super::sql_tables::*; -use crate::domain::{error::*, sql_tables::Pool}; -use crate::infra::configuration::Configuration; -use crate::infra::jwt_sql_tables::*; +use super::error::*; use async_trait::async_trait; -use futures_util::StreamExt; -use futures_util::TryStreamExt; -use log::*; -use sea_query::Iden; -use sea_query::{Expr, Order, Query, SimpleExpr}; -use sqlx::Row; use std::collections::HashSet; pub use lldap_model::*; @@ -21,256 +12,6 @@ pub trait BackendHandler: Clone + Send { async fn get_user_groups(&self, user: String) -> Result>; } -#[derive(Debug, Clone)] -pub struct SqlBackendHandler { - config: Configuration, - sql_pool: Pool, -} - -impl SqlBackendHandler { - pub fn new(config: Configuration, sql_pool: Pool) -> Self { - SqlBackendHandler { config, sql_pool } - } -} - -fn passwords_match(encrypted_password: &str, clear_password: &str) -> bool { - encrypted_password == clear_password -} - -fn get_filter_expr(filter: RequestFilter) -> SimpleExpr { - use RequestFilter::*; - fn get_repeated_filter( - fs: Vec, - field: &dyn Fn(SimpleExpr, SimpleExpr) -> SimpleExpr, - ) -> SimpleExpr { - let mut it = fs.into_iter(); - let first_expr = match it.next() { - None => return Expr::value(true), - Some(f) => get_filter_expr(f), - }; - it.fold(first_expr, |e, f| field(e, get_filter_expr(f))) - } - match filter { - And(fs) => get_repeated_filter(fs, &SimpleExpr::and), - Or(fs) => get_repeated_filter(fs, &SimpleExpr::or), - Not(f) => Expr::not(Expr::expr(get_filter_expr(*f))), - Equality(s1, s2) => Expr::expr(Expr::cust(&s1)).eq(s2), - } -} - -#[async_trait] -impl BackendHandler for SqlBackendHandler { - async fn bind(&self, request: BindRequest) -> Result<()> { - if request.name == self.config.ldap_user_dn { - if request.password == self.config.ldap_user_pass { - return Ok(()); - } else { - debug!(r#"Invalid password for LDAP bind user"#); - return Err(Error::AuthenticationError(request.name)); - } - } - let query = Query::select() - .column(Users::Password) - .from(Users::Table) - .and_where(Expr::col(Users::UserId).eq(request.name.as_str())) - .to_string(DbQueryBuilder {}); - if let Ok(row) = sqlx::query(&query).fetch_one(&self.sql_pool).await { - if passwords_match( - &request.password, - &row.get::(&*Users::Password.to_string()), - ) { - return Ok(()); - } else { - debug!(r#"Invalid password for "{}""#, request.name); - } - } else { - debug!(r#"No user found for "{}""#, request.name); - } - Err(Error::AuthenticationError(request.name)) - } - - async fn list_users(&self, request: ListUsersRequest) -> Result> { - let query = { - let mut query_builder = Query::select() - .column(Users::UserId) - .column(Users::Email) - .column(Users::DisplayName) - .column(Users::FirstName) - .column(Users::LastName) - .column(Users::Avatar) - .column(Users::CreationDate) - .from(Users::Table) - .order_by(Users::UserId, Order::Asc) - .to_owned(); - if let Some(filter) = request.filters { - if filter != RequestFilter::And(Vec::new()) - && filter != RequestFilter::Or(Vec::new()) - { - query_builder.and_where(get_filter_expr(filter)); - } - } - - query_builder.to_string(DbQueryBuilder {}) - }; - - let results = sqlx::query_as::<_, User>(&query) - .fetch(&self.sql_pool) - .collect::>>() - .await; - - Ok(results.into_iter().collect::>>()?) - } - - async fn list_groups(&self) -> Result> { - let query: String = Query::select() - .column(Groups::DisplayName) - .column(Memberships::UserId) - .from(Groups::Table) - .left_join( - Memberships::Table, - Expr::tbl(Groups::Table, Groups::GroupId) - .equals(Memberships::Table, Memberships::GroupId), - ) - .order_by(Groups::DisplayName, Order::Asc) - .order_by(Memberships::UserId, Order::Asc) - .to_string(DbQueryBuilder {}); - - let mut results = sqlx::query(&query).fetch(&self.sql_pool); - let mut groups = Vec::new(); - // The rows are ordered by group, user, so we need to group them into vectors. - { - let mut current_group = String::new(); - let mut current_users = Vec::new(); - while let Some(row) = results.try_next().await? { - let display_name = row.get::(&*Groups::DisplayName.to_string()); - if display_name != current_group { - if !current_group.is_empty() { - groups.push(Group { - display_name: current_group, - users: current_users, - }); - current_users = Vec::new(); - } - current_group = display_name.clone(); - } - current_users.push(row.get::(&*Memberships::UserId.to_string())); - } - groups.push(Group { - display_name: current_group, - users: current_users, - }); - } - - Ok(groups) - } - - async fn get_user_groups(&self, user: String) -> Result> { - if user == self.config.ldap_user_dn { - let mut groups = HashSet::new(); - groups.insert("lldap_admin".to_string()); - return Ok(groups); - } - let query: String = Query::select() - .column(Groups::DisplayName) - .from(Groups::Table) - .inner_join( - Memberships::Table, - Expr::tbl(Groups::Table, Groups::GroupId) - .equals(Memberships::Table, Memberships::GroupId), - ) - .and_where(Expr::col(Memberships::UserId).eq(user)) - .to_string(DbQueryBuilder {}); - - sqlx::query(&query) - // Extract the group id from the row. - .map(|row: DbRow| row.get::(&*Groups::DisplayName.to_string())) - .fetch(&self.sql_pool) - // Collect the vector of rows, each potentially an error. - .collect::>>() - .await - .into_iter() - // Transform it into a single result (the first error if any), and group the group_ids - // into a HashSet. - .collect::>>() - // Map the sqlx::Error into a domain::Error. - .map_err(Error::DatabaseError) - } -} - -#[async_trait] -impl crate::infra::tcp_server::TcpBackendHandler for SqlBackendHandler { - async fn get_jwt_blacklist(&self) -> anyhow::Result> { - use sqlx::Result; - let query = Query::select() - .column(JwtBlacklist::JwtHash) - .from(JwtBlacklist::Table) - .to_string(DbQueryBuilder {}); - - sqlx::query(&query) - .map(|row: DbRow| row.get::(&*JwtBlacklist::JwtHash.to_string()) as u64) - .fetch(&self.sql_pool) - .collect::>>() - .await - .into_iter() - .collect::>>() - .map_err(|e| anyhow::anyhow!(e)) - } - - async fn create_refresh_token(&self, user: &str) -> Result<(String, chrono::Duration)> { - use rand::{distributions::Alphanumeric, rngs::SmallRng, Rng, SeedableRng}; - use std::collections::hash_map::DefaultHasher; - use std::hash::{Hash, Hasher}; - // TODO: Initialize the rng only once. Maybe Arc? - let mut rng = SmallRng::from_entropy(); - let refresh_token: String = std::iter::repeat(()) - .map(|()| rng.sample(Alphanumeric)) - .map(char::from) - .take(100) - .collect(); - let refresh_token_hash = { - let mut s = DefaultHasher::new(); - refresh_token.hash(&mut s); - s.finish() - }; - let duration = chrono::Duration::days(30); - let query = Query::insert() - .into_table(JwtRefreshStorage::Table) - .columns(vec![ - JwtRefreshStorage::RefreshTokenHash, - JwtRefreshStorage::UserId, - JwtRefreshStorage::ExpiryDate, - ]) - .values_panic(vec![ - (refresh_token_hash as i64).into(), - user.into(), - (chrono::Utc::now() + duration).naive_utc().into(), - ]) - .to_string(DbQueryBuilder {}); - sqlx::query(&query).execute(&self.sql_pool).await?; - Ok((refresh_token, duration)) - } - - async fn check_token(&self, token: &str, user: &str) -> Result { - use std::collections::hash_map::DefaultHasher; - use std::hash::{Hash, Hasher}; - let refresh_token_hash = { - let mut s = DefaultHasher::new(); - token.hash(&mut s); - s.finish() - }; - let query = Query::select() - .expr(SimpleExpr::Value(1.into())) - .from(JwtRefreshStorage::Table) - .and_where(Expr::col(JwtRefreshStorage::RefreshTokenHash).eq(refresh_token_hash as i64)) - .and_where(Expr::col(JwtRefreshStorage::UserId).eq(user)) - .to_string(DbQueryBuilder {}); - Ok(sqlx::query(&query) - .fetch_optional(&self.sql_pool) - .await? - .is_some()) - } -} - #[cfg(test)] mockall::mock! { pub TestBackendHandler{} @@ -285,238 +26,3 @@ mockall::mock! { async fn get_user_groups(&self, user: String) -> Result>; } } - -#[cfg(test)] -mod tests { - use super::*; - use crate::domain::sql_tables::init_table; - - async fn get_in_memory_db() -> Pool { - PoolOptions::new().connect("sqlite::memory:").await.unwrap() - } - - async fn get_initialized_db() -> Pool { - let sql_pool = get_in_memory_db().await; - init_table(&sql_pool).await.unwrap(); - sql_pool - } - - async fn insert_user(sql_pool: &Pool, name: &str, pass: &str) { - let query = Query::insert() - .into_table(Users::Table) - .columns(vec![ - Users::UserId, - Users::Email, - Users::DisplayName, - Users::FirstName, - Users::LastName, - Users::CreationDate, - Users::Password, - ]) - .values_panic(vec![ - name.into(), - "bob@bob".into(), - "Bob Böbberson".into(), - "Bob".into(), - "Böbberson".into(), - chrono::NaiveDateTime::from_timestamp(0, 0).into(), - pass.into(), - ]) - .to_string(DbQueryBuilder {}); - sqlx::query(&query).execute(sql_pool).await.unwrap(); - } - - async fn insert_group(sql_pool: &Pool, id: u32, name: &str) { - let query = Query::insert() - .into_table(Groups::Table) - .columns(vec![Groups::GroupId, Groups::DisplayName]) - .values_panic(vec![id.into(), name.into()]) - .to_string(DbQueryBuilder {}); - sqlx::query(&query).execute(sql_pool).await.unwrap(); - } - - async fn insert_membership(sql_pool: &Pool, group_id: u32, user_id: &str) { - let query = Query::insert() - .into_table(Memberships::Table) - .columns(vec![Memberships::UserId, Memberships::GroupId]) - .values_panic(vec![user_id.into(), group_id.into()]) - .to_string(DbQueryBuilder {}); - sqlx::query(&query).execute(sql_pool).await.unwrap(); - } - - #[tokio::test] - async fn test_bind_admin() { - let sql_pool = get_in_memory_db().await; - let config = Configuration { - ldap_user_dn: "admin".to_string(), - ldap_user_pass: "test".to_string(), - ..Default::default() - }; - let handler = SqlBackendHandler::new(config, sql_pool); - handler - .bind(BindRequest { - name: "admin".to_string(), - password: "test".to_string(), - }) - .await - .unwrap(); - } - - #[tokio::test] - async fn test_bind_user() { - let sql_pool = get_initialized_db().await; - insert_user(&sql_pool, "bob", "bob00").await; - let config = Configuration::default(); - let handler = SqlBackendHandler::new(config, sql_pool); - handler - .bind(BindRequest { - name: "bob".to_string(), - password: "bob00".to_string(), - }) - .await - .unwrap(); - handler - .bind(BindRequest { - name: "andrew".to_string(), - password: "bob00".to_string(), - }) - .await - .unwrap_err(); - handler - .bind(BindRequest { - name: "bob".to_string(), - password: "wrong_password".to_string(), - }) - .await - .unwrap_err(); - } - - #[tokio::test] - async fn test_list_users() { - let sql_pool = get_initialized_db().await; - insert_user(&sql_pool, "bob", "bob00").await; - insert_user(&sql_pool, "patrick", "pass").await; - insert_user(&sql_pool, "John", "Pa33w0rd!").await; - let config = Configuration::default(); - let handler = SqlBackendHandler::new(config, sql_pool); - { - let users = handler - .list_users(ListUsersRequest { filters: None }) - .await - .unwrap() - .into_iter() - .map(|u| u.user_id) - .collect::>(); - assert_eq!(users, vec!["John", "bob", "patrick"]); - } - { - let users = handler - .list_users(ListUsersRequest { - filters: Some(RequestFilter::Equality( - "user_id".to_string(), - "bob".to_string(), - )), - }) - .await - .unwrap() - .into_iter() - .map(|u| u.user_id) - .collect::>(); - assert_eq!(users, vec!["bob"]); - } - { - let users = handler - .list_users(ListUsersRequest { - filters: Some(RequestFilter::Or(vec![ - RequestFilter::Equality("user_id".to_string(), "bob".to_string()), - RequestFilter::Equality("user_id".to_string(), "John".to_string()), - ])), - }) - .await - .unwrap() - .into_iter() - .map(|u| u.user_id) - .collect::>(); - assert_eq!(users, vec!["John", "bob"]); - } - { - let users = handler - .list_users(ListUsersRequest { - filters: Some(RequestFilter::Not(Box::new(RequestFilter::Equality( - "user_id".to_string(), - "bob".to_string(), - )))), - }) - .await - .unwrap() - .into_iter() - .map(|u| u.user_id) - .collect::>(); - assert_eq!(users, vec!["John", "patrick"]); - } - } - - #[tokio::test] - async fn test_list_groups() { - let sql_pool = get_initialized_db().await; - insert_user(&sql_pool, "bob", "bob00").await; - insert_user(&sql_pool, "patrick", "pass").await; - insert_user(&sql_pool, "John", "Pa33w0rd!").await; - insert_group(&sql_pool, 1, "Best Group").await; - insert_group(&sql_pool, 2, "Worst Group").await; - insert_membership(&sql_pool, 1, "bob").await; - insert_membership(&sql_pool, 1, "patrick").await; - insert_membership(&sql_pool, 2, "patrick").await; - insert_membership(&sql_pool, 2, "John").await; - let config = Configuration::default(); - let handler = SqlBackendHandler::new(config, sql_pool); - assert_eq!( - handler.list_groups().await.unwrap(), - vec![ - Group { - display_name: "Best Group".to_string(), - users: vec!["bob".to_string(), "patrick".to_string()] - }, - Group { - display_name: "Worst Group".to_string(), - users: vec!["John".to_string(), "patrick".to_string()] - } - ] - ); - } - - #[tokio::test] - async fn test_get_user_groups() { - let sql_pool = get_initialized_db().await; - insert_user(&sql_pool, "bob", "bob00").await; - insert_user(&sql_pool, "patrick", "pass").await; - insert_user(&sql_pool, "John", "Pa33w0rd!").await; - insert_group(&sql_pool, 1, "Group1").await; - insert_group(&sql_pool, 2, "Group2").await; - insert_membership(&sql_pool, 1, "bob").await; - insert_membership(&sql_pool, 1, "patrick").await; - insert_membership(&sql_pool, 2, "patrick").await; - let config = Configuration::default(); - let handler = SqlBackendHandler::new(config, sql_pool); - let mut bob_groups = HashSet::new(); - bob_groups.insert("Group1".to_string()); - let mut patrick_groups = HashSet::new(); - patrick_groups.insert("Group1".to_string()); - patrick_groups.insert("Group2".to_string()); - assert_eq!( - handler.get_user_groups("bob".to_string()).await.unwrap(), - bob_groups - ); - assert_eq!( - handler - .get_user_groups("patrick".to_string()) - .await - .unwrap(), - patrick_groups - ); - assert_eq!( - handler.get_user_groups("John".to_string()).await.unwrap(), - HashSet::new() - ); - } -} diff --git a/src/domain/mod.rs b/src/domain/mod.rs index 00320b6..ede7b41 100644 --- a/src/domain/mod.rs +++ b/src/domain/mod.rs @@ -1,3 +1,4 @@ pub mod error; pub mod handler; +pub mod sql_backend_handler; pub mod sql_tables; diff --git a/src/domain/sql_backend_handler.rs b/src/domain/sql_backend_handler.rs new file mode 100644 index 0000000..c4beb5e --- /dev/null +++ b/src/domain/sql_backend_handler.rs @@ -0,0 +1,420 @@ +use super::{error::*, handler::*, sql_tables::*}; +use crate::infra::configuration::Configuration; +use async_trait::async_trait; +use futures_util::StreamExt; +use futures_util::TryStreamExt; +use log::*; +use sea_query::{Expr, Iden, Order, Query, SimpleExpr}; +use sqlx::Row; +use std::collections::HashSet; + +#[derive(Debug, Clone)] +pub struct SqlBackendHandler { + pub(crate) config: Configuration, + pub(crate) sql_pool: Pool, +} + +impl SqlBackendHandler { + pub fn new(config: Configuration, sql_pool: Pool) -> Self { + SqlBackendHandler { config, sql_pool } + } +} + +fn passwords_match(encrypted_password: &str, clear_password: &str) -> bool { + encrypted_password == clear_password +} + +fn get_filter_expr(filter: RequestFilter) -> SimpleExpr { + use RequestFilter::*; + fn get_repeated_filter( + fs: Vec, + field: &dyn Fn(SimpleExpr, SimpleExpr) -> SimpleExpr, + ) -> SimpleExpr { + let mut it = fs.into_iter(); + let first_expr = match it.next() { + None => return Expr::value(true), + Some(f) => get_filter_expr(f), + }; + it.fold(first_expr, |e, f| field(e, get_filter_expr(f))) + } + match filter { + And(fs) => get_repeated_filter(fs, &SimpleExpr::and), + Or(fs) => get_repeated_filter(fs, &SimpleExpr::or), + Not(f) => Expr::not(Expr::expr(get_filter_expr(*f))), + Equality(s1, s2) => Expr::expr(Expr::cust(&s1)).eq(s2), + } +} + +#[async_trait] +impl BackendHandler for SqlBackendHandler { + async fn bind(&self, request: BindRequest) -> Result<()> { + if request.name == self.config.ldap_user_dn { + if request.password == self.config.ldap_user_pass { + return Ok(()); + } else { + debug!(r#"Invalid password for LDAP bind user"#); + return Err(Error::AuthenticationError(request.name)); + } + } + let query = Query::select() + .column(Users::Password) + .from(Users::Table) + .and_where(Expr::col(Users::UserId).eq(request.name.as_str())) + .to_string(DbQueryBuilder {}); + if let Ok(row) = sqlx::query(&query).fetch_one(&self.sql_pool).await { + if passwords_match( + &request.password, + &row.get::(&*Users::Password.to_string()), + ) { + return Ok(()); + } else { + debug!(r#"Invalid password for "{}""#, request.name); + } + } else { + debug!(r#"No user found for "{}""#, request.name); + } + Err(Error::AuthenticationError(request.name)) + } + + async fn list_users(&self, request: ListUsersRequest) -> Result> { + let query = { + let mut query_builder = Query::select() + .column(Users::UserId) + .column(Users::Email) + .column(Users::DisplayName) + .column(Users::FirstName) + .column(Users::LastName) + .column(Users::Avatar) + .column(Users::CreationDate) + .from(Users::Table) + .order_by(Users::UserId, Order::Asc) + .to_owned(); + if let Some(filter) = request.filters { + if filter != RequestFilter::And(Vec::new()) + && filter != RequestFilter::Or(Vec::new()) + { + query_builder.and_where(get_filter_expr(filter)); + } + } + + query_builder.to_string(DbQueryBuilder {}) + }; + + let results = sqlx::query_as::<_, User>(&query) + .fetch(&self.sql_pool) + .collect::>>() + .await; + + Ok(results.into_iter().collect::>>()?) + } + + async fn list_groups(&self) -> Result> { + let query: String = Query::select() + .column(Groups::DisplayName) + .column(Memberships::UserId) + .from(Groups::Table) + .left_join( + Memberships::Table, + Expr::tbl(Groups::Table, Groups::GroupId) + .equals(Memberships::Table, Memberships::GroupId), + ) + .order_by(Groups::DisplayName, Order::Asc) + .order_by(Memberships::UserId, Order::Asc) + .to_string(DbQueryBuilder {}); + + let mut results = sqlx::query(&query).fetch(&self.sql_pool); + let mut groups = Vec::new(); + // The rows are ordered by group, user, so we need to group them into vectors. + { + let mut current_group = String::new(); + let mut current_users = Vec::new(); + while let Some(row) = results.try_next().await? { + let display_name = row.get::(&*Groups::DisplayName.to_string()); + if display_name != current_group { + if !current_group.is_empty() { + groups.push(Group { + display_name: current_group, + users: current_users, + }); + current_users = Vec::new(); + } + current_group = display_name.clone(); + } + current_users.push(row.get::(&*Memberships::UserId.to_string())); + } + groups.push(Group { + display_name: current_group, + users: current_users, + }); + } + + Ok(groups) + } + + async fn get_user_groups(&self, user: String) -> Result> { + if user == self.config.ldap_user_dn { + let mut groups = HashSet::new(); + groups.insert("lldap_admin".to_string()); + return Ok(groups); + } + let query: String = Query::select() + .column(Groups::DisplayName) + .from(Groups::Table) + .inner_join( + Memberships::Table, + Expr::tbl(Groups::Table, Groups::GroupId) + .equals(Memberships::Table, Memberships::GroupId), + ) + .and_where(Expr::col(Memberships::UserId).eq(user)) + .to_string(DbQueryBuilder {}); + + sqlx::query(&query) + // Extract the group id from the row. + .map(|row: DbRow| row.get::(&*Groups::DisplayName.to_string())) + .fetch(&self.sql_pool) + // Collect the vector of rows, each potentially an error. + .collect::>>() + .await + .into_iter() + // Transform it into a single result (the first error if any), and group the group_ids + // into a HashSet. + .collect::>>() + // Map the sqlx::Error into a domain::Error. + .map_err(Error::DatabaseError) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::domain::sql_tables::init_table; + + async fn get_in_memory_db() -> Pool { + PoolOptions::new().connect("sqlite::memory:").await.unwrap() + } + + async fn get_initialized_db() -> Pool { + let sql_pool = get_in_memory_db().await; + init_table(&sql_pool).await.unwrap(); + sql_pool + } + + async fn insert_user(sql_pool: &Pool, name: &str, pass: &str) { + let query = Query::insert() + .into_table(Users::Table) + .columns(vec![ + Users::UserId, + Users::Email, + Users::DisplayName, + Users::FirstName, + Users::LastName, + Users::CreationDate, + Users::Password, + ]) + .values_panic(vec![ + name.into(), + "bob@bob".into(), + "Bob Böbberson".into(), + "Bob".into(), + "Böbberson".into(), + chrono::NaiveDateTime::from_timestamp(0, 0).into(), + pass.into(), + ]) + .to_string(DbQueryBuilder {}); + sqlx::query(&query).execute(sql_pool).await.unwrap(); + } + + async fn insert_group(sql_pool: &Pool, id: u32, name: &str) { + let query = Query::insert() + .into_table(Groups::Table) + .columns(vec![Groups::GroupId, Groups::DisplayName]) + .values_panic(vec![id.into(), name.into()]) + .to_string(DbQueryBuilder {}); + sqlx::query(&query).execute(sql_pool).await.unwrap(); + } + + async fn insert_membership(sql_pool: &Pool, group_id: u32, user_id: &str) { + let query = Query::insert() + .into_table(Memberships::Table) + .columns(vec![Memberships::UserId, Memberships::GroupId]) + .values_panic(vec![user_id.into(), group_id.into()]) + .to_string(DbQueryBuilder {}); + sqlx::query(&query).execute(sql_pool).await.unwrap(); + } + + #[tokio::test] + async fn test_bind_admin() { + let sql_pool = get_in_memory_db().await; + let config = Configuration { + ldap_user_dn: "admin".to_string(), + ldap_user_pass: "test".to_string(), + ..Default::default() + }; + let handler = SqlBackendHandler::new(config, sql_pool); + handler + .bind(BindRequest { + name: "admin".to_string(), + password: "test".to_string(), + }) + .await + .unwrap(); + } + + #[tokio::test] + async fn test_bind_user() { + let sql_pool = get_initialized_db().await; + insert_user(&sql_pool, "bob", "bob00").await; + let config = Configuration::default(); + let handler = SqlBackendHandler::new(config, sql_pool); + handler + .bind(BindRequest { + name: "bob".to_string(), + password: "bob00".to_string(), + }) + .await + .unwrap(); + handler + .bind(BindRequest { + name: "andrew".to_string(), + password: "bob00".to_string(), + }) + .await + .unwrap_err(); + handler + .bind(BindRequest { + name: "bob".to_string(), + password: "wrong_password".to_string(), + }) + .await + .unwrap_err(); + } + + #[tokio::test] + async fn test_list_users() { + let sql_pool = get_initialized_db().await; + insert_user(&sql_pool, "bob", "bob00").await; + insert_user(&sql_pool, "patrick", "pass").await; + insert_user(&sql_pool, "John", "Pa33w0rd!").await; + let config = Configuration::default(); + let handler = SqlBackendHandler::new(config, sql_pool); + { + let users = handler + .list_users(ListUsersRequest { filters: None }) + .await + .unwrap() + .into_iter() + .map(|u| u.user_id) + .collect::>(); + assert_eq!(users, vec!["John", "bob", "patrick"]); + } + { + let users = handler + .list_users(ListUsersRequest { + filters: Some(RequestFilter::Equality( + "user_id".to_string(), + "bob".to_string(), + )), + }) + .await + .unwrap() + .into_iter() + .map(|u| u.user_id) + .collect::>(); + assert_eq!(users, vec!["bob"]); + } + { + let users = handler + .list_users(ListUsersRequest { + filters: Some(RequestFilter::Or(vec![ + RequestFilter::Equality("user_id".to_string(), "bob".to_string()), + RequestFilter::Equality("user_id".to_string(), "John".to_string()), + ])), + }) + .await + .unwrap() + .into_iter() + .map(|u| u.user_id) + .collect::>(); + assert_eq!(users, vec!["John", "bob"]); + } + { + let users = handler + .list_users(ListUsersRequest { + filters: Some(RequestFilter::Not(Box::new(RequestFilter::Equality( + "user_id".to_string(), + "bob".to_string(), + )))), + }) + .await + .unwrap() + .into_iter() + .map(|u| u.user_id) + .collect::>(); + assert_eq!(users, vec!["John", "patrick"]); + } + } + + #[tokio::test] + async fn test_list_groups() { + let sql_pool = get_initialized_db().await; + insert_user(&sql_pool, "bob", "bob00").await; + insert_user(&sql_pool, "patrick", "pass").await; + insert_user(&sql_pool, "John", "Pa33w0rd!").await; + insert_group(&sql_pool, 1, "Best Group").await; + insert_group(&sql_pool, 2, "Worst Group").await; + insert_membership(&sql_pool, 1, "bob").await; + insert_membership(&sql_pool, 1, "patrick").await; + insert_membership(&sql_pool, 2, "patrick").await; + insert_membership(&sql_pool, 2, "John").await; + let config = Configuration::default(); + let handler = SqlBackendHandler::new(config, sql_pool); + assert_eq!( + handler.list_groups().await.unwrap(), + vec![ + Group { + display_name: "Best Group".to_string(), + users: vec!["bob".to_string(), "patrick".to_string()] + }, + Group { + display_name: "Worst Group".to_string(), + users: vec!["John".to_string(), "patrick".to_string()] + } + ] + ); + } + + #[tokio::test] + async fn test_get_user_groups() { + let sql_pool = get_initialized_db().await; + insert_user(&sql_pool, "bob", "bob00").await; + insert_user(&sql_pool, "patrick", "pass").await; + insert_user(&sql_pool, "John", "Pa33w0rd!").await; + insert_group(&sql_pool, 1, "Group1").await; + insert_group(&sql_pool, 2, "Group2").await; + insert_membership(&sql_pool, 1, "bob").await; + insert_membership(&sql_pool, 1, "patrick").await; + insert_membership(&sql_pool, 2, "patrick").await; + let config = Configuration::default(); + let handler = SqlBackendHandler::new(config, sql_pool); + let mut bob_groups = HashSet::new(); + bob_groups.insert("Group1".to_string()); + let mut patrick_groups = HashSet::new(); + patrick_groups.insert("Group1".to_string()); + patrick_groups.insert("Group2".to_string()); + assert_eq!( + handler.get_user_groups("bob".to_string()).await.unwrap(), + bob_groups + ); + assert_eq!( + handler + .get_user_groups("patrick".to_string()) + .await + .unwrap(), + patrick_groups + ); + assert_eq!( + handler.get_user_groups("John".to_string()).await.unwrap(), + HashSet::new() + ); + } +} diff --git a/src/infra/auth_service.rs b/src/infra/auth_service.rs new file mode 100644 index 0000000..b69e84e --- /dev/null +++ b/src/infra/auth_service.rs @@ -0,0 +1,237 @@ +use crate::{domain::handler::*, infra::{tcp_backend_handler::*, tcp_server::{AppState, error_to_http_response}}}; +use hmac::Hmac; +use jwt::{SignWithKey, VerifyWithKey}; +use log::*; +use std::collections::{hash_map::DefaultHasher, HashSet}; +use std::hash::{Hash, Hasher}; +use time::ext::NumericalDuration; +use actix_web_httpauth::extractors::bearer::BearerAuth; +use anyhow::Result; +use std::task::{Context, Poll}; +use std::pin::Pin; +use actix_web::{ + cookie::{Cookie, SameSite}, + dev::{ServiceRequest, Service, Transform, ServiceResponse}, + error::{ErrorBadRequest, ErrorUnauthorized}, + web, HttpRequest, HttpResponse +}; +use futures_util::{FutureExt, TryFutureExt}; +use futures::future::{ok, Ready}; +use chrono::prelude::*; +use sha2::Sha512; + +type Token = jwt::Token; +type SignedToken = Token; + +fn create_jwt(key: &Hmac, user: String, groups: HashSet) -> SignedToken { + let claims = JWTClaims { + exp: Utc::now() + chrono::Duration::days(1), + iat: Utc::now(), + user, + groups, + }; + let header = jwt::Header { + algorithm: jwt::AlgorithmType::Hs512, + ..Default::default() + }; + jwt::Token::new(header, claims).sign_with_key(key).unwrap() +} + +async fn get_refresh( + data: web::Data>, + request: HttpRequest, +) -> HttpResponse +where + Backend: TcpBackendHandler + BackendHandler + 'static, +{ + let backend_handler = &data.backend_handler; + let jwt_key = &data.jwt_key; + let (refresh_token, user) = match request.cookie("refresh_token") { + None => { + return HttpResponse::Unauthorized().body("Missing refresh token") + } + Some(t) => match t.value().split_once("+") { + None => { + return HttpResponse::Unauthorized().body("Invalid refresh token") + } + Some((t, u)) => (t.to_string(), u.to_string()), + }, + }; + let res_found = data.backend_handler.check_token(&refresh_token, &user).await; + // Async closures are not supported yet. + match res_found { + Ok(found) => { + if found { + backend_handler.get_user_groups(user.to_string()).await + } else { + Err(DomainError::AuthenticationError( + "Invalid refresh token".to_string(), + )) + } + } + Err(e) => Err(e), + } + .map(|groups| create_jwt(jwt_key, user.to_string(), groups)) + .map(|token| { + HttpResponse::Ok() + .cookie( + Cookie::build("token", token.as_str()) + .max_age(1.days()) + .path("/api") + .http_only(true) + .same_site(SameSite::Strict) + .finish(), + ) + .body(token.as_str().to_owned()) + }) + .unwrap_or_else(error_to_http_response) +} + +async fn post_authorize( + data: web::Data>, + request: web::Json, +) -> HttpResponse +where + Backend: TcpBackendHandler + BackendHandler + 'static, +{ + let req: BindRequest = request.clone(); + data.backend_handler + .bind(req) + // If the authentication was successful, we need to fetch the groups to create the JWT + // token. + .and_then(|_| data.backend_handler.get_user_groups(request.name.clone())) + .and_then(|g| async { + Ok(( + g, + data.backend_handler + .create_refresh_token(&request.name) + .await?, + )) + }) + .await + .map(|(groups, (refresh_token, max_age))| { + let token = create_jwt(&data.jwt_key, request.name.clone(), groups); + HttpResponse::Ok() + .cookie( + Cookie::build("token", token.as_str()) + .max_age(1.days()) + .path("/api") + .http_only(true) + .same_site(SameSite::Strict) + .finish(), + ) + .cookie( + Cookie::build("refresh_token", refresh_token + "+" + &request.name) + .max_age(max_age.num_days().days()) + .path("/api/authorize/refresh") + .http_only(true) + .same_site(SameSite::Strict) + .finish(), + ) + .body(token.as_str().to_owned()) + }) + .unwrap_or_else(error_to_http_response) +} + +pub struct CookieToHeaderTranslatorFactory; + +impl Transform for CookieToHeaderTranslatorFactory +where + S: Service, Error=actix_web::Error>, + S::Future: 'static, + B: 'static, +{ + type Response = ServiceResponse; + type Error = actix_web::Error; + type InitError = (); + type Transform = CookieToHeaderTranslator; + type Future = Ready>; + + fn new_transform(&self, service: S) -> Self::Future { + ok(CookieToHeaderTranslator { service }) + } +} + +pub struct CookieToHeaderTranslator { + service: S, +} + +impl Service for CookieToHeaderTranslator +where + S: Service, Error = actix_web::Error>, + S::Future: 'static, + B: 'static, +{ + type Response = ServiceResponse; + type Error = actix_web::Error; + #[allow(clippy::type_complexity)] + type Future = Pin>>>; + + fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { + self.service.poll_ready(cx) + } + + fn call(&self, mut req: ServiceRequest) -> Self::Future { + if let Some(token_cookie) = req.cookie("token") { + if let Ok(header_value) = actix_http::header::HeaderValue::from_str(&format!( + "Bearer {}", + token_cookie.value() + )) { + req.headers_mut() + .insert(actix_http::header::AUTHORIZATION, header_value); + } else { + return async move { + Ok(req.error_response(ErrorBadRequest("Invalid token cookie"))) + } + .boxed_local(); + } + }; + + Box::pin(self.service.call(req)) + } +} + +pub async fn token_validator( + req: ServiceRequest, + credentials: BearerAuth, +) -> Result +where + Backend: TcpBackendHandler + BackendHandler + 'static, +{ + let state = req + .app_data::>>() + .expect("Invalid app config"); + let token: Token<_> = VerifyWithKey::verify_with_key(credentials.token(), &state.jwt_key) + .map_err(|_| ErrorUnauthorized("Invalid JWT"))?; + if token.claims().exp.lt(&Utc::now()) { + return Err(ErrorUnauthorized("Expired JWT")); + } + let jwt_hash = { + let mut s = DefaultHasher::new(); + credentials.token().hash(&mut s); + s.finish() + }; + if state.jwt_blacklist.contains(&jwt_hash) { + return Err(ErrorUnauthorized("JWT was logged out")); + } + let groups = &token.claims().groups; + if groups.contains("lldap_admin") { + debug!("Got authorized token for user {}", &token.claims().user); + Ok(req) + } else { + Err(ErrorUnauthorized( + "JWT error: User is not in group lldap_admin", + )) + } +} + + +pub fn configure_server( + cfg: &mut web::ServiceConfig, +) where + Backend: TcpBackendHandler + BackendHandler + 'static, +{ + cfg + .service(web::resource("").route(web::post().to(post_authorize::))) + .service(web::resource("/refresh").route(web::get().to(get_refresh::))); +} diff --git a/src/infra/mod.rs b/src/infra/mod.rs index 2b65f96..2660e94 100644 --- a/src/infra/mod.rs +++ b/src/infra/mod.rs @@ -1,7 +1,11 @@ +pub mod auth_service; pub mod cli; pub mod configuration; pub mod jwt_sql_tables; pub mod ldap_handler; pub mod ldap_server; pub mod logging; +pub mod sql_backend_handler; +pub mod tcp_api; +pub mod tcp_backend_handler; pub mod tcp_server; diff --git a/src/infra/sql_backend_handler.rs b/src/infra/sql_backend_handler.rs new file mode 100644 index 0000000..0f22ded --- /dev/null +++ b/src/infra/sql_backend_handler.rs @@ -0,0 +1,81 @@ +use super::{jwt_sql_tables::*, tcp_backend_handler::*}; +use crate::domain::{error::*, sql_backend_handler::SqlBackendHandler}; +use async_trait::async_trait; +use futures_util::StreamExt; +use sea_query::{Expr, Iden, Query, SimpleExpr}; +use sqlx::Row; +use std::collections::HashSet; + +#[async_trait] +impl TcpBackendHandler for SqlBackendHandler { + async fn get_jwt_blacklist(&self) -> anyhow::Result> { + use sqlx::Result; + let query = Query::select() + .column(JwtBlacklist::JwtHash) + .from(JwtBlacklist::Table) + .to_string(DbQueryBuilder {}); + + sqlx::query(&query) + .map(|row: DbRow| row.get::(&*JwtBlacklist::JwtHash.to_string()) as u64) + .fetch(&self.sql_pool) + .collect::>>() + .await + .into_iter() + .collect::>>() + .map_err(|e| anyhow::anyhow!(e)) + } + + async fn create_refresh_token(&self, user: &str) -> Result<(String, chrono::Duration)> { + use rand::{distributions::Alphanumeric, rngs::SmallRng, Rng, SeedableRng}; + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + // TODO: Initialize the rng only once. Maybe Arc? + let mut rng = SmallRng::from_entropy(); + let refresh_token: String = std::iter::repeat(()) + .map(|()| rng.sample(Alphanumeric)) + .map(char::from) + .take(100) + .collect(); + let refresh_token_hash = { + let mut s = DefaultHasher::new(); + refresh_token.hash(&mut s); + s.finish() + }; + let duration = chrono::Duration::days(30); + let query = Query::insert() + .into_table(JwtRefreshStorage::Table) + .columns(vec![ + JwtRefreshStorage::RefreshTokenHash, + JwtRefreshStorage::UserId, + JwtRefreshStorage::ExpiryDate, + ]) + .values_panic(vec![ + (refresh_token_hash as i64).into(), + user.into(), + (chrono::Utc::now() + duration).naive_utc().into(), + ]) + .to_string(DbQueryBuilder {}); + sqlx::query(&query).execute(&self.sql_pool).await?; + Ok((refresh_token, duration)) + } + + async fn check_token(&self, token: &str, user: &str) -> Result { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + let refresh_token_hash = { + let mut s = DefaultHasher::new(); + token.hash(&mut s); + s.finish() + }; + let query = Query::select() + .expr(SimpleExpr::Value(1.into())) + .from(JwtRefreshStorage::Table) + .and_where(Expr::col(JwtRefreshStorage::RefreshTokenHash).eq(refresh_token_hash as i64)) + .and_where(Expr::col(JwtRefreshStorage::UserId).eq(user)) + .to_string(DbQueryBuilder {}); + Ok(sqlx::query(&query) + .fetch_optional(&self.sql_pool) + .await? + .is_some()) + } +} diff --git a/src/infra/tcp_api.rs b/src/infra/tcp_api.rs new file mode 100644 index 0000000..546e480 --- /dev/null +++ b/src/infra/tcp_api.rs @@ -0,0 +1,97 @@ +use crate::{ + domain::handler::*, + infra::{tcp_server::{AppState, error_to_http_response}, tcp_backend_handler::*}, +}; +use actix_web::{web, HttpResponse}; + + +fn error_to_api_response(error: DomainError) -> ApiResult { + ApiResult::Right(error_to_http_response(error)) +} + +type ApiResult = actix_web::Either, HttpResponse>; + +async fn user_list_handler( + data: web::Data>, + info: web::Json, +) -> ApiResult> +where + Backend: TcpBackendHandler + BackendHandler + 'static, +{ + let req: ListUsersRequest = info.clone(); + data.backend_handler + .list_users(req) + .await + .map(|res| ApiResult::Left(web::Json(res))) + .unwrap_or_else(error_to_api_response) +} + +pub fn api_config(cfg: &mut web::ServiceConfig) +where + Backend: TcpBackendHandler + BackendHandler + 'static, +{ + let json_config = web::JsonConfig::default() + .limit(4096) + .error_handler(|err, _req| { + // create custom error response + log::error!("API error: {}", err); + let msg = err.to_string(); + actix_web::error::InternalError::from_response( + err, + HttpResponse::BadRequest().body(msg).into(), + ) + .into() + }); + cfg.service( + web::resource("/users") + .app_data(json_config) + .route(web::post().to(user_list_handler::)), + ); +} + +#[cfg(test)] +mod tests { + use super::*; + use hmac::{Hmac, NewMac}; + use std::collections::HashSet; + + fn get_data(handler: MockTestTcpBackendHandler) -> web::Data> { + let app_state = AppState:: { + backend_handler: handler, + jwt_key: Hmac::new_varkey(b"jwt_secret").unwrap(), + jwt_blacklist: HashSet::new(), + }; + web::Data::>::new(app_state) + } + + fn expect_json(result: ApiResult) -> T { + if let ApiResult::Left(res) = result { + res.0 + } else { + panic!("Expected Json result, got: {:?}", result); + } + } + + #[actix_rt::test] + async fn test_user_list_ok() { + let mut backend_handler = MockTestTcpBackendHandler::new(); + backend_handler + .expect_list_users() + .times(1) + .return_once(|_| { + Ok(vec![User { + user_id: "bob".to_string(), + ..Default::default() + }]) + }); + let json = web::Json(ListUsersRequest { filters: None }); + let resp = user_list_handler(get_data(backend_handler), json).await; + assert_eq!( + expect_json(resp), + vec![User { + user_id: "bob".to_string(), + ..Default::default() + }] + ); + } +} diff --git a/src/infra/tcp_backend_handler.rs b/src/infra/tcp_backend_handler.rs new file mode 100644 index 0000000..46d9417 --- /dev/null +++ b/src/infra/tcp_backend_handler.rs @@ -0,0 +1,35 @@ +use std::collections::HashSet; +use async_trait::async_trait; + +pub type DomainError = crate::domain::error::Error; +pub type DomainResult = crate::domain::error::Result; + +#[async_trait] +pub trait TcpBackendHandler { + async fn get_jwt_blacklist(&self) -> anyhow::Result>; + async fn create_refresh_token(&self, user: &str) -> DomainResult<(String, chrono::Duration)>; + async fn check_token(&self, token: &str, user: &str) -> DomainResult; +} + +#[cfg(test)] +use crate::domain::handler::*; +#[cfg(test)] +mockall::mock! { + pub TestTcpBackendHandler{} + impl Clone for TestTcpBackendHandler { + fn clone(&self) -> Self; + } + #[async_trait] + impl BackendHandler for TestTcpBackendHandler { + async fn bind(&self, request: BindRequest) -> DomainResult<()>; + async fn list_users(&self, request: ListUsersRequest) -> DomainResult>; + async fn list_groups(&self) -> DomainResult>; + async fn get_user_groups(&self, user: String) -> DomainResult>; + } + #[async_trait] + impl TcpBackendHandler for TestTcpBackendHandler { + async fn get_jwt_blacklist(&self) -> anyhow::Result>; + async fn create_refresh_token(&self, user: &str) -> DomainResult<(String, chrono::Duration)>; + async fn check_token(&self, token: &str, user: &str) -> DomainResult; + } +} diff --git a/src/infra/tcp_server.rs b/src/infra/tcp_server.rs index 84c5ee3..5775ef6 100644 --- a/src/infra/tcp_server.rs +++ b/src/infra/tcp_server.rs @@ -1,42 +1,18 @@ -use crate::domain::handler::*; -use crate::infra::configuration::Configuration; +use crate::{ + domain::handler::*, + infra::{auth_service, tcp_api, configuration::Configuration, tcp_backend_handler::*}, +}; use actix_files::{Files, NamedFile}; use actix_http::HttpServiceBuilder; use actix_server::ServerBuilder; -use actix_service::{map_config, Service}; -use actix_web::{ - cookie::{Cookie, SameSite}, - dev::{AppConfig, ServiceRequest}, - error::{ErrorBadRequest, ErrorUnauthorized}, - web, App, HttpRequest, HttpResponse, -}; -use actix_web_httpauth::{extractors::bearer::BearerAuth, middleware::HttpAuthentication}; +use actix_service::map_config; +use actix_web::{dev::AppConfig, web, App, HttpRequest, HttpResponse}; +use actix_web_httpauth::middleware::HttpAuthentication; use anyhow::{Context, Result}; -use async_trait::async_trait; -use chrono::prelude::*; -use futures_util::FutureExt; -use futures_util::TryFutureExt; use hmac::{Hmac, NewMac}; -use jwt::{SignWithKey, VerifyWithKey}; -use log::*; use sha2::Sha512; -use std::collections::{hash_map::DefaultHasher, HashSet}; -use std::hash::{Hash, Hasher}; +use std::collections::HashSet; use std::path::PathBuf; -use time::ext::NumericalDuration; - -type Token = jwt::Token; -type SignedToken = Token; - -type DomainError = crate::domain::error::Error; -type DomainResult = crate::domain::error::Result; - -#[async_trait] -pub trait TcpBackendHandler: BackendHandler { - async fn get_jwt_blacklist(&self) -> Result>; - async fn create_refresh_token(&self, user: &str) -> DomainResult<(String, chrono::Duration)>; - async fn check_token(&self, token: &str, user: &str) -> DomainResult; -} async fn index(req: HttpRequest) -> actix_web::Result { let mut path = PathBuf::new(); @@ -46,202 +22,12 @@ async fn index(req: HttpRequest) -> actix_web::Result { Ok(NamedFile::open(path)?) } -fn error_to_http_response(error: DomainError) -> HttpResponse { - match error { - DomainError::AuthenticationError(_) => HttpResponse::Unauthorized(), - DomainError::DatabaseError(_) => HttpResponse::InternalServerError(), - } - .body(error.to_string()) -} - -fn error_to_api_response(error: DomainError) -> ApiResult { - ApiResult::Right( - error_to_http_response(error) - ) -} - -type ApiResult = actix_web::Either, HttpResponse>; - -async fn user_list_handler( - data: web::Data>, - info: web::Json, -) -> ApiResult> -where - Backend: TcpBackendHandler + 'static, -{ - let req: ListUsersRequest = info.clone(); - data.backend_handler - .list_users(req) - .await - .map(|res| ApiResult::Left(web::Json(res))) - .unwrap_or_else(error_to_api_response) -} - -fn create_jwt(key: &Hmac, user: String, groups: HashSet) -> SignedToken { - let claims = JWTClaims { - exp: Utc::now() + chrono::Duration::days(1), - iat: Utc::now(), - user, - groups, - }; - let header = jwt::Header { - algorithm: jwt::AlgorithmType::Hs512, - ..Default::default() - }; - jwt::Token::new(header, claims).sign_with_key(key).unwrap() -} - -async fn get_refresh( - data: web::Data>, - request: HttpRequest, -) -> HttpResponse -where - Backend: TcpBackendHandler + 'static, -{ - let backend_handler = &data.backend_handler; - let jwt_key = &data.jwt_key; - let (refresh_token, user) = match request.cookie("refresh_token") { - None => { - return HttpResponse::Unauthorized().body("Missing refresh token") - } - Some(t) => match t.value().split_once("+") { - None => { - return HttpResponse::Unauthorized().body("Invalid refresh token") - } - Some((t, u)) => (t.to_string(), u.to_string()), - }, - }; - let res_found = data.backend_handler.check_token(&refresh_token, &user).await; - // Async closures are not supported yet. - match res_found { - Ok(found) => { - if found { - backend_handler.get_user_groups(user.to_string()).await - } else { - Err(DomainError::AuthenticationError( - "Invalid refresh token".to_string(), - )) - } - } - Err(e) => Err(e), - } - .map(|groups| create_jwt(jwt_key, user.to_string(), groups)) - .map(|token| { - HttpResponse::Ok() - .cookie( - Cookie::build("token", token.as_str()) - .max_age(1.days()) - .path("/api") - .http_only(true) - .same_site(SameSite::Strict) - .finish(), - ) - .body(token.as_str().to_owned()) - }) - .unwrap_or_else(error_to_http_response) -} - -async fn post_authorize( - data: web::Data>, - request: web::Json, -) -> HttpResponse -where - Backend: TcpBackendHandler + 'static, -{ - let req: BindRequest = request.clone(); - data.backend_handler - .bind(req) - // If the authentication was successful, we need to fetch the groups to create the JWT - // token. - .and_then(|_| data.backend_handler.get_user_groups(request.name.clone())) - .and_then(|g| async { - Ok(( - g, - data.backend_handler - .create_refresh_token(&request.name) - .await?, - )) - }) - .await - .map(|(groups, (refresh_token, max_age))| { - let token = create_jwt(&data.jwt_key, request.name.clone(), groups); - HttpResponse::Ok() - .cookie( - Cookie::build("token", token.as_str()) - .max_age(1.days()) - .path("/api") - .http_only(true) - .same_site(SameSite::Strict) - .finish(), - ) - .cookie( - Cookie::build("refresh_token", refresh_token + "+" + &request.name) - .max_age(max_age.num_days().days()) - .path("/api/authorize/refresh") - .http_only(true) - .same_site(SameSite::Strict) - .finish(), - ) - .body(token.as_str().to_owned()) - }) - .unwrap_or_else(error_to_http_response) -} - -fn api_config(cfg: &mut web::ServiceConfig) -where - Backend: TcpBackendHandler + 'static, -{ - let json_config = web::JsonConfig::default() - .limit(4096) - .error_handler(|err, _req| { - // create custom error response - log::error!("API error: {}", err); - let msg = err.to_string(); - actix_web::error::InternalError::from_response( - err, - HttpResponse::BadRequest().body(msg).into(), - ) - .into() - }); - cfg.service( - web::resource("/users") - .app_data(json_config) - .route(web::post().to(user_list_handler::)), - ); -} - -async fn token_validator( - req: ServiceRequest, - credentials: BearerAuth, -) -> Result -where - Backend: TcpBackendHandler + 'static, -{ - let state = req - .app_data::>>() - .expect("Invalid app config"); - let token: Token<_> = VerifyWithKey::verify_with_key(credentials.token(), &state.jwt_key) - .map_err(|_| ErrorUnauthorized("Invalid JWT"))?; - if token.claims().exp.lt(&Utc::now()) { - return Err(ErrorUnauthorized("Expired JWT")); - } - let jwt_hash = { - let mut s = DefaultHasher::new(); - credentials.token().hash(&mut s); - s.finish() - }; - if state.jwt_blacklist.contains(&jwt_hash) { - return Err(ErrorUnauthorized("JWT was logged out")); - } - let groups = &token.claims().groups; - if groups.contains("lldap_admin") { - debug!("Got authorized token for user {}", &token.claims().user); - Ok(req) - } else { - Err(ErrorUnauthorized( - "JWT error: User is not in group lldap_admin", - )) +pub(crate) fn error_to_http_response(error: DomainError) -> HttpResponse { + match error { + DomainError::AuthenticationError(_) => HttpResponse::Unauthorized(), + DomainError::DatabaseError(_) => HttpResponse::InternalServerError(), } + .body(error.to_string()) } fn http_config( @@ -250,7 +36,7 @@ fn http_config( jwt_secret: String, jwt_blacklist: HashSet, ) where - Backend: TcpBackendHandler + 'static, + Backend: TcpBackendHandler + BackendHandler + 'static, { cfg.data(AppState:: { backend_handler, @@ -262,30 +48,15 @@ fn http_config( "/{filename:(index\\.html|main\\.js)?}", web::get().to(index), ) - .service(web::resource("/auth").route(web::post().to(post_authorize::))) - .service(web::resource("/auth/refresh").route(web::get().to(get_refresh::))) + .service(web::scope("/auth").configure(auth_service::configure_server::)) // API endpoint. .service( web::scope("/api") - .wrap(HttpAuthentication::bearer(token_validator::)) - .wrap_fn(|mut req, srv| { - if let Some(token_cookie) = req.cookie("token") { - if let Ok(header_value) = actix_http::header::HeaderValue::from_str(&format!( - "Bearer {}", - token_cookie.value() - )) { - req.headers_mut() - .insert(actix_http::header::AUTHORIZATION, header_value); - } else { - return async move { - Ok(req.error_response(ErrorBadRequest("Invalid token cookie"))) - } - .boxed_local(); - } - }; - Box::pin(srv.call(req)) - }) - .configure(api_config::), + .wrap(HttpAuthentication::bearer( + auth_service::token_validator::, + )) + .wrap(auth_service::CookieToHeaderTranslatorFactory) + .configure(tcp_api::api_config::), ) // Serve the /pkg path with the compiled WASM app. .service(Files::new("/pkg", "./app/pkg")) @@ -293,9 +64,9 @@ fn http_config( .service(web::scope("/").route("/.*", web::get().to(index))); } -struct AppState +pub(crate) struct AppState where - Backend: TcpBackendHandler + 'static, + Backend: TcpBackendHandler + BackendHandler + 'static, { pub backend_handler: Backend, pub jwt_key: Hmac, @@ -308,7 +79,7 @@ pub async fn build_tcp_server( server_builder: ServerBuilder, ) -> Result where - Backend: TcpBackendHandler + 'static, + Backend: TcpBackendHandler + BackendHandler + 'static, { let jwt_secret = config.jwt_secret.clone(); let jwt_blacklist = backend_handler.get_jwt_blacklist().await?; @@ -340,22 +111,6 @@ mod tests { use actix_web::test::TestRequest; use std::path::Path; - fn get_data(handler: MockTestBackendHandler) -> web::Data> { - let app_state = AppState:: { - backend_handler: handler, - jwt_key: Hmac::new_varkey(b"jwt_secret").unwrap(), - }; - web::Data::>::new(app_state) - } - - fn expect_json(result: ApiResult) -> T { - if let ApiResult::Left(res) = result { - res.0 - } else { - panic!("Expected Json result, got: {:?}", result); - } - } - #[actix_rt::test] async fn test_index_ok() { let req = TestRequest::default().to_http_request(); @@ -371,27 +126,4 @@ mod tests { let resp = index(req).await.unwrap(); assert_eq!(resp.path(), Path::new("app/main.js")); } - - #[actix_rt::test] - async fn test_user_list_ok() { - let mut backend_handler = MockTestBackendHandler::new(); - backend_handler - .expect_list_users() - .times(1) - .return_once(|_| { - Ok(vec![User { - user_id: "bob".to_string(), - ..Default::default() - }]) - }); - let json = web::Json(ListUsersRequest { filters: None }); - let resp = user_list_handler(get_data(backend_handler), json).await; - assert_eq!( - expect_json(resp), - vec![User { - user_id: "bob".to_string(), - ..Default::default() - }] - ); - } } diff --git a/src/main.rs b/src/main.rs index 6112f9e..be09a70 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,8 @@ #![forbid(unsafe_code)] -use crate::domain::sql_tables::PoolOptions; -use crate::infra::configuration::Configuration; +use crate::{ + domain::{sql_backend_handler::SqlBackendHandler, sql_tables::PoolOptions}, + infra::configuration::Configuration, +}; use anyhow::Result; use futures_util::TryFutureExt; use log::*; @@ -14,7 +16,7 @@ async fn run_server(config: Configuration) -> Result<()> { .connect(&config.database_url) .await?; domain::sql_tables::init_table(&sql_pool).await?; - let backend_handler = domain::handler::SqlBackendHandler::new(config.clone(), sql_pool.clone()); + let backend_handler = SqlBackendHandler::new(config.clone(), sql_pool.clone()); let server_builder = infra::ldap_server::build_ldap_server( &config, backend_handler.clone(),