diff --git a/src/infra/auth_service.rs b/src/infra/auth_service.rs index b69e84e..0714c28 100644 --- a/src/infra/auth_service.rs +++ b/src/infra/auth_service.rs @@ -1,24 +1,30 @@ -use crate::{domain::handler::*, infra::{tcp_backend_handler::*, tcp_server::{AppState, error_to_http_response}}}; +use crate::{ + domain::handler::*, + infra::{ + tcp_backend_handler::*, + tcp_server::{error_to_http_response, AppState}, + }, +}; +use actix_web::{ + cookie::{Cookie, SameSite}, + dev::{Service, ServiceRequest, ServiceResponse, Transform}, + error::{ErrorBadRequest, ErrorUnauthorized}, + web, HttpRequest, HttpResponse, +}; +use actix_web_httpauth::extractors::bearer::BearerAuth; +use anyhow::Result; +use chrono::prelude::*; +use futures::future::{ok, Ready}; +use futures_util::{FutureExt, TryFutureExt}; use hmac::Hmac; use jwt::{SignWithKey, VerifyWithKey}; use log::*; +use sha2::Sha512; use std::collections::{hash_map::DefaultHasher, HashSet}; use std::hash::{Hash, Hasher}; -use time::ext::NumericalDuration; -use actix_web_httpauth::extractors::bearer::BearerAuth; -use anyhow::Result; -use std::task::{Context, Poll}; use std::pin::Pin; -use actix_web::{ - cookie::{Cookie, SameSite}, - dev::{ServiceRequest, Service, Transform, ServiceResponse}, - error::{ErrorBadRequest, ErrorUnauthorized}, - web, HttpRequest, HttpResponse -}; -use futures_util::{FutureExt, TryFutureExt}; -use futures::future::{ok, Ready}; -use chrono::prelude::*; -use sha2::Sha512; +use std::task::{Context, Poll}; +use time::ext::NumericalDuration; type Token = jwt::Token; type SignedToken = Token; @@ -37,6 +43,18 @@ fn create_jwt(key: &Hmac, user: String, groups: HashSet) -> Sign jwt::Token::new(header, claims).sign_with_key(key).unwrap() } +fn get_refresh_token_from_cookie( + request: HttpRequest, +) -> std::result::Result<(String, String), HttpResponse> { + match request.cookie("refresh_token") { + None => Err(HttpResponse::Unauthorized().body("Missing refresh token")), + Some(t) => match t.value().split_once("+") { + None => Err(HttpResponse::Unauthorized().body("Invalid refresh token")), + Some((t, u)) => Ok((t.to_string(), u.to_string())), + }, + } +} + async fn get_refresh( data: web::Data>, request: HttpRequest, @@ -46,18 +64,14 @@ where { let backend_handler = &data.backend_handler; let jwt_key = &data.jwt_key; - let (refresh_token, user) = match request.cookie("refresh_token") { - None => { - return HttpResponse::Unauthorized().body("Missing refresh token") - } - Some(t) => match t.value().split_once("+") { - None => { - return HttpResponse::Unauthorized().body("Invalid refresh token") - } - Some((t, u)) => (t.to_string(), u.to_string()), - }, + let (refresh_token, user) = match get_refresh_token_from_cookie(request) { + Ok(t) => t, + Err(http_response) => return http_response, }; - let res_found = data.backend_handler.check_token(&refresh_token, &user).await; + let res_found = data + .backend_handler + .check_token(&refresh_token, &user) + .await; // Async closures are not supported yet. match res_found { Ok(found) => { @@ -73,20 +87,73 @@ where } .map(|groups| create_jwt(jwt_key, user.to_string(), groups)) .map(|token| { - HttpResponse::Ok() - .cookie( - Cookie::build("token", token.as_str()) - .max_age(1.days()) - .path("/api") - .http_only(true) - .same_site(SameSite::Strict) - .finish(), - ) - .body(token.as_str().to_owned()) + HttpResponse::Ok() + .cookie( + Cookie::build("token", token.as_str()) + .max_age(1.days()) + .path("/api") + .http_only(true) + .same_site(SameSite::Strict) + .finish(), + ) + .body(token.as_str().to_owned()) }) .unwrap_or_else(error_to_http_response) } +async fn get_logout( + data: web::Data>, + request: HttpRequest, +) -> HttpResponse +where + Backend: TcpBackendHandler + BackendHandler + 'static, +{ + let (refresh_token, user) = match get_refresh_token_from_cookie(request) { + Ok(t) => t, + Err(http_response) => return http_response, + }; + if let Err(response) = data + .backend_handler + .delete_refresh_token(&refresh_token) + .map_err(error_to_http_response) + .await + { + return response; + }; + match data + .backend_handler + .blacklist_jwts(&user) + .map_err(error_to_http_response) + .await + { + Ok(new_blacklisted_jwts) => { + let mut jwt_blacklist = data.jwt_blacklist.write().unwrap(); + for jwt in new_blacklisted_jwts { + jwt_blacklist.insert(jwt); + } + }, + Err(response) => return response, + }; + HttpResponse::Ok() + .cookie( + Cookie::build("token", "") + .max_age(0.days()) + .path("/api") + .http_only(true) + .same_site(SameSite::Strict) + .finish(), + ) + .cookie( + Cookie::build("refresh_token", "") + .max_age(0.days()) + .path("/api/authorize/refresh") + .http_only(true) + .same_site(SameSite::Strict) + .finish(), + ) + .finish() +} + async fn post_authorize( data: web::Data>, request: web::Json, @@ -111,24 +178,24 @@ where .await .map(|(groups, (refresh_token, max_age))| { let token = create_jwt(&data.jwt_key, request.name.clone(), groups); - HttpResponse::Ok() - .cookie( - Cookie::build("token", token.as_str()) - .max_age(1.days()) - .path("/api") - .http_only(true) - .same_site(SameSite::Strict) - .finish(), - ) - .cookie( - Cookie::build("refresh_token", refresh_token + "+" + &request.name) - .max_age(max_age.num_days().days()) - .path("/api/authorize/refresh") - .http_only(true) - .same_site(SameSite::Strict) - .finish(), - ) - .body(token.as_str().to_owned()) + HttpResponse::Ok() + .cookie( + Cookie::build("token", token.as_str()) + .max_age(1.days()) + .path("/api") + .http_only(true) + .same_site(SameSite::Strict) + .finish(), + ) + .cookie( + Cookie::build("refresh_token", refresh_token + "+" + &request.name) + .max_age(max_age.num_days().days()) + .path("/api/authorize/refresh") + .http_only(true) + .same_site(SameSite::Strict) + .finish(), + ) + .body(token.as_str().to_owned()) }) .unwrap_or_else(error_to_http_response) } @@ -137,9 +204,9 @@ pub struct CookieToHeaderTranslatorFactory; impl Transform for CookieToHeaderTranslatorFactory where - S: Service, Error=actix_web::Error>, - S::Future: 'static, - B: 'static, + S: Service, Error = actix_web::Error>, + S::Future: 'static, + B: 'static, { type Response = ServiceResponse; type Error = actix_web::Error; @@ -211,7 +278,7 @@ where credentials.token().hash(&mut s); s.finish() }; - if state.jwt_blacklist.contains(&jwt_hash) { + if state.jwt_blacklist.read().unwrap().contains(&jwt_hash) { return Err(ErrorUnauthorized("JWT was logged out")); } let groups = &token.claims().groups; @@ -225,13 +292,11 @@ where } } - -pub fn configure_server( - cfg: &mut web::ServiceConfig, -) where +pub fn configure_server(cfg: &mut web::ServiceConfig) +where Backend: TcpBackendHandler + BackendHandler + 'static, { - cfg - .service(web::resource("").route(web::post().to(post_authorize::))) - .service(web::resource("/refresh").route(web::get().to(get_refresh::))); + cfg.service(web::resource("").route(web::post().to(post_authorize::))) + .service(web::resource("/refresh").route(web::get().to(get_refresh::))) + .service(web::resource("/logout").route(web::get().to(get_logout::))); } diff --git a/src/infra/jwt_sql_tables.rs b/src/infra/jwt_sql_tables.rs index 34137b6..7025cc8 100644 --- a/src/infra/jwt_sql_tables.rs +++ b/src/infra/jwt_sql_tables.rs @@ -13,11 +13,12 @@ pub enum JwtRefreshStorage { /// Contains the blacklisted JWT that haven't expired yet. #[derive(Iden)] -pub enum JwtBlacklist { +pub enum JwtStorage { Table, JwtHash, UserId, ExpiryDate, + Blacklisted, } /// This needs to be initialized after the domain tables are. @@ -57,29 +58,35 @@ pub async fn init_table(pool: &Pool) -> sqlx::Result<()> { sqlx::query( &Table::create() - .table(JwtBlacklist::Table) + .table(JwtStorage::Table) .if_not_exists() .col( - ColumnDef::new(JwtBlacklist::JwtHash) + ColumnDef::new(JwtStorage::JwtHash) .big_integer() .not_null() .primary_key(), ) .col( - ColumnDef::new(JwtBlacklist::UserId) + ColumnDef::new(JwtStorage::UserId) .string_len(255) .not_null(), ) .col( - ColumnDef::new(JwtBlacklist::ExpiryDate) + ColumnDef::new(JwtStorage::ExpiryDate) .date_time() .not_null(), ) + .col( + ColumnDef::new(JwtStorage::Blacklisted) + .boolean() + .default(false) + .not_null(), + ) .foreign_key( ForeignKey::create() - .name("JwtBlacklistUserForeignKey") - .table(JwtBlacklist::Table, Users::Table) - .col(JwtBlacklist::UserId, Users::UserId) + .name("JwtStorageUserForeignKey") + .table(JwtStorage::Table, Users::Table) + .col(JwtStorage::UserId, Users::UserId) .on_delete(ForeignKeyAction::Cascade) .on_update(ForeignKeyAction::Cascade), ) diff --git a/src/infra/sql_backend_handler.rs b/src/infra/sql_backend_handler.rs index 0f22ded..d1c1484 100644 --- a/src/infra/sql_backend_handler.rs +++ b/src/infra/sql_backend_handler.rs @@ -4,19 +4,21 @@ use async_trait::async_trait; use futures_util::StreamExt; use sea_query::{Expr, Iden, Query, SimpleExpr}; use sqlx::Row; +use std::collections::hash_map::DefaultHasher; use std::collections::HashSet; +use std::hash::{Hash, Hasher}; #[async_trait] impl TcpBackendHandler for SqlBackendHandler { async fn get_jwt_blacklist(&self) -> anyhow::Result> { use sqlx::Result; let query = Query::select() - .column(JwtBlacklist::JwtHash) - .from(JwtBlacklist::Table) + .column(JwtStorage::JwtHash) + .from(JwtStorage::Table) .to_string(DbQueryBuilder {}); sqlx::query(&query) - .map(|row: DbRow| row.get::(&*JwtBlacklist::JwtHash.to_string()) as u64) + .map(|row: DbRow| row.get::(&*JwtStorage::JwtHash.to_string()) as u64) .fetch(&self.sql_pool) .collect::>>() .await @@ -60,8 +62,6 @@ impl TcpBackendHandler for SqlBackendHandler { } async fn check_token(&self, token: &str, user: &str) -> Result { - use std::collections::hash_map::DefaultHasher; - use std::hash::{Hash, Hasher}; let refresh_token_hash = { let mut s = DefaultHasher::new(); token.hash(&mut s); @@ -78,4 +78,40 @@ impl TcpBackendHandler for SqlBackendHandler { .await? .is_some()) } + async fn blacklist_jwts(&self, user: &str) -> DomainResult> { + use sqlx::Result; + let query = Query::select() + .column(JwtStorage::JwtHash) + .from(JwtStorage::Table) + .and_where(Expr::col(JwtStorage::UserId).eq(user)) + .and_where(Expr::col(JwtStorage::Blacklisted).eq(true)) + .to_string(DbQueryBuilder {}); + let result = sqlx::query(&query) + .map(|row: DbRow| row.get::(&*JwtStorage::JwtHash.to_string()) as u64) + .fetch(&self.sql_pool) + .collect::>>() + .await + .into_iter() + .collect::>>(); + let query = Query::update() + .table(JwtStorage::Table) + .values(vec![(JwtStorage::Blacklisted, true.into())]) + .and_where(Expr::col(JwtStorage::UserId).eq(user)) + .to_string(DbQueryBuilder {}); + sqlx::query(&query).execute(&self.sql_pool).await?; + Ok(result?) + } + async fn delete_refresh_token(&self, token: &str) -> DomainResult<()> { + let refresh_token_hash = { + let mut s = DefaultHasher::new(); + token.hash(&mut s); + s.finish() + }; + let query = Query::delete() + .from_table(JwtRefreshStorage::Table) + .and_where(Expr::col(JwtRefreshStorage::RefreshTokenHash).eq(refresh_token_hash)) + .to_string(DbQueryBuilder {}); + sqlx::query(&query).execute(&self.sql_pool).await?; + Ok(()) + } } diff --git a/src/infra/tcp_api.rs b/src/infra/tcp_api.rs index 546e480..8dd1644 100644 --- a/src/infra/tcp_api.rs +++ b/src/infra/tcp_api.rs @@ -1,10 +1,12 @@ use crate::{ domain::handler::*, - infra::{tcp_server::{AppState, error_to_http_response}, tcp_backend_handler::*}, + infra::{ + tcp_backend_handler::*, + tcp_server::{error_to_http_response, AppState}, + }, }; use actix_web::{web, HttpResponse}; - fn error_to_api_response(error: DomainError) -> ApiResult { ApiResult::Right(error_to_http_response(error)) } @@ -54,12 +56,15 @@ mod tests { use super::*; use hmac::{Hmac, NewMac}; use std::collections::HashSet; + use std::sync::RwLock; - fn get_data(handler: MockTestTcpBackendHandler) -> web::Data> { + fn get_data( + handler: MockTestTcpBackendHandler, + ) -> web::Data> { let app_state = AppState:: { backend_handler: handler, jwt_key: Hmac::new_varkey(b"jwt_secret").unwrap(), - jwt_blacklist: HashSet::new(), + jwt_blacklist: RwLock::new(HashSet::new()), }; web::Data::>::new(app_state) } diff --git a/src/infra/tcp_backend_handler.rs b/src/infra/tcp_backend_handler.rs index 46d9417..7351528 100644 --- a/src/infra/tcp_backend_handler.rs +++ b/src/infra/tcp_backend_handler.rs @@ -1,5 +1,5 @@ -use std::collections::HashSet; use async_trait::async_trait; +use std::collections::HashSet; pub type DomainError = crate::domain::error::Error; pub type DomainResult = crate::domain::error::Result; @@ -9,6 +9,8 @@ pub trait TcpBackendHandler { async fn get_jwt_blacklist(&self) -> anyhow::Result>; async fn create_refresh_token(&self, user: &str) -> DomainResult<(String, chrono::Duration)>; async fn check_token(&self, token: &str, user: &str) -> DomainResult; + async fn blacklist_jwts(&self, user: &str) -> DomainResult>; + async fn delete_refresh_token(&self, token: &str) -> DomainResult<()>; } #[cfg(test)] @@ -31,5 +33,7 @@ mockall::mock! { async fn get_jwt_blacklist(&self) -> anyhow::Result>; async fn create_refresh_token(&self, user: &str) -> DomainResult<(String, chrono::Duration)>; async fn check_token(&self, token: &str, user: &str) -> DomainResult; + async fn blacklist_jwts(&self, user: &str) -> DomainResult>; + async fn delete_refresh_token(&self, token: &str) -> DomainResult<()>; } } diff --git a/src/infra/tcp_server.rs b/src/infra/tcp_server.rs index 5775ef6..8190db3 100644 --- a/src/infra/tcp_server.rs +++ b/src/infra/tcp_server.rs @@ -1,6 +1,6 @@ use crate::{ domain::handler::*, - infra::{auth_service, tcp_api, configuration::Configuration, tcp_backend_handler::*}, + infra::{auth_service, configuration::Configuration, tcp_api, tcp_backend_handler::*}, }; use actix_files::{Files, NamedFile}; use actix_http::HttpServiceBuilder; @@ -13,6 +13,7 @@ use hmac::{Hmac, NewMac}; use sha2::Sha512; use std::collections::HashSet; use std::path::PathBuf; +use std::sync::RwLock; async fn index(req: HttpRequest) -> actix_web::Result { let mut path = PathBuf::new(); @@ -41,7 +42,7 @@ fn http_config( cfg.data(AppState:: { backend_handler, jwt_key: Hmac::new_varkey(&jwt_secret.as_bytes()).unwrap(), - jwt_blacklist, + jwt_blacklist: RwLock::new(jwt_blacklist), }) // Serve index.html and main.js, and default to index.html. .route( @@ -70,7 +71,7 @@ where { pub backend_handler: Backend, pub jwt_key: Hmac, - pub jwt_blacklist: HashSet, + pub jwt_blacklist: RwLock>, } pub async fn build_tcp_server(