mirror of
				https://github.com/nitnelave/lldap.git
				synced 2023-04-12 14:25:13 +00:00 
			
		
		
		
	Implement server-side logout
This commit is contained in:
		
							parent
							
								
									28a941924e
								
							
						
					
					
						commit
						10404abbb0
					
				@ -1,24 +1,30 @@
 | 
				
			|||||||
use crate::{domain::handler::*, infra::{tcp_backend_handler::*, tcp_server::{AppState, error_to_http_response}}};
 | 
					use crate::{
 | 
				
			||||||
 | 
					    domain::handler::*,
 | 
				
			||||||
 | 
					    infra::{
 | 
				
			||||||
 | 
					        tcp_backend_handler::*,
 | 
				
			||||||
 | 
					        tcp_server::{error_to_http_response, AppState},
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					use actix_web::{
 | 
				
			||||||
 | 
					    cookie::{Cookie, SameSite},
 | 
				
			||||||
 | 
					    dev::{Service, ServiceRequest, ServiceResponse, Transform},
 | 
				
			||||||
 | 
					    error::{ErrorBadRequest, ErrorUnauthorized},
 | 
				
			||||||
 | 
					    web, HttpRequest, HttpResponse,
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					use actix_web_httpauth::extractors::bearer::BearerAuth;
 | 
				
			||||||
 | 
					use anyhow::Result;
 | 
				
			||||||
 | 
					use chrono::prelude::*;
 | 
				
			||||||
 | 
					use futures::future::{ok, Ready};
 | 
				
			||||||
 | 
					use futures_util::{FutureExt, TryFutureExt};
 | 
				
			||||||
use hmac::Hmac;
 | 
					use hmac::Hmac;
 | 
				
			||||||
use jwt::{SignWithKey, VerifyWithKey};
 | 
					use jwt::{SignWithKey, VerifyWithKey};
 | 
				
			||||||
use log::*;
 | 
					use log::*;
 | 
				
			||||||
 | 
					use sha2::Sha512;
 | 
				
			||||||
use std::collections::{hash_map::DefaultHasher, HashSet};
 | 
					use std::collections::{hash_map::DefaultHasher, HashSet};
 | 
				
			||||||
use std::hash::{Hash, Hasher};
 | 
					use std::hash::{Hash, Hasher};
 | 
				
			||||||
use time::ext::NumericalDuration;
 | 
					 | 
				
			||||||
use actix_web_httpauth::extractors::bearer::BearerAuth;
 | 
					 | 
				
			||||||
use anyhow::Result;
 | 
					 | 
				
			||||||
use std::task::{Context, Poll};
 | 
					 | 
				
			||||||
use std::pin::Pin;
 | 
					use std::pin::Pin;
 | 
				
			||||||
use actix_web::{
 | 
					use std::task::{Context, Poll};
 | 
				
			||||||
    cookie::{Cookie, SameSite},
 | 
					use time::ext::NumericalDuration;
 | 
				
			||||||
    dev::{ServiceRequest, Service, Transform, ServiceResponse},
 | 
					 | 
				
			||||||
    error::{ErrorBadRequest, ErrorUnauthorized},
 | 
					 | 
				
			||||||
    web, HttpRequest, HttpResponse
 | 
					 | 
				
			||||||
};
 | 
					 | 
				
			||||||
use futures_util::{FutureExt, TryFutureExt};
 | 
					 | 
				
			||||||
use futures::future::{ok, Ready};
 | 
					 | 
				
			||||||
use chrono::prelude::*;
 | 
					 | 
				
			||||||
use sha2::Sha512;
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
type Token<S> = jwt::Token<jwt::Header, JWTClaims, S>;
 | 
					type Token<S> = jwt::Token<jwt::Header, JWTClaims, S>;
 | 
				
			||||||
type SignedToken = Token<jwt::token::Signed>;
 | 
					type SignedToken = Token<jwt::token::Signed>;
 | 
				
			||||||
@ -37,6 +43,18 @@ fn create_jwt(key: &Hmac<Sha512>, user: String, groups: HashSet<String>) -> Sign
 | 
				
			|||||||
    jwt::Token::new(header, claims).sign_with_key(key).unwrap()
 | 
					    jwt::Token::new(header, claims).sign_with_key(key).unwrap()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					fn get_refresh_token_from_cookie(
 | 
				
			||||||
 | 
					    request: HttpRequest,
 | 
				
			||||||
 | 
					) -> std::result::Result<(String, String), HttpResponse> {
 | 
				
			||||||
 | 
					    match request.cookie("refresh_token") {
 | 
				
			||||||
 | 
					        None => Err(HttpResponse::Unauthorized().body("Missing refresh token")),
 | 
				
			||||||
 | 
					        Some(t) => match t.value().split_once("+") {
 | 
				
			||||||
 | 
					            None => Err(HttpResponse::Unauthorized().body("Invalid refresh token")),
 | 
				
			||||||
 | 
					            Some((t, u)) => Ok((t.to_string(), u.to_string())),
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async fn get_refresh<Backend>(
 | 
					async fn get_refresh<Backend>(
 | 
				
			||||||
    data: web::Data<AppState<Backend>>,
 | 
					    data: web::Data<AppState<Backend>>,
 | 
				
			||||||
    request: HttpRequest,
 | 
					    request: HttpRequest,
 | 
				
			||||||
@ -46,18 +64,14 @@ where
 | 
				
			|||||||
{
 | 
					{
 | 
				
			||||||
    let backend_handler = &data.backend_handler;
 | 
					    let backend_handler = &data.backend_handler;
 | 
				
			||||||
    let jwt_key = &data.jwt_key;
 | 
					    let jwt_key = &data.jwt_key;
 | 
				
			||||||
    let (refresh_token, user) = match request.cookie("refresh_token") {
 | 
					    let (refresh_token, user) = match get_refresh_token_from_cookie(request) {
 | 
				
			||||||
        None => {
 | 
					        Ok(t) => t,
 | 
				
			||||||
            return HttpResponse::Unauthorized().body("Missing refresh token")
 | 
					        Err(http_response) => return http_response,
 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        Some(t) => match t.value().split_once("+") {
 | 
					 | 
				
			||||||
            None => {
 | 
					 | 
				
			||||||
                return HttpResponse::Unauthorized().body("Invalid refresh token")
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
            Some((t, u)) => (t.to_string(), u.to_string()),
 | 
					 | 
				
			||||||
        },
 | 
					 | 
				
			||||||
    };
 | 
					    };
 | 
				
			||||||
    let res_found = data.backend_handler.check_token(&refresh_token, &user).await;
 | 
					    let res_found = data
 | 
				
			||||||
 | 
					        .backend_handler
 | 
				
			||||||
 | 
					        .check_token(&refresh_token, &user)
 | 
				
			||||||
 | 
					        .await;
 | 
				
			||||||
    // Async closures are not supported yet.
 | 
					    // Async closures are not supported yet.
 | 
				
			||||||
    match res_found {
 | 
					    match res_found {
 | 
				
			||||||
        Ok(found) => {
 | 
					        Ok(found) => {
 | 
				
			||||||
@ -73,20 +87,73 @@ where
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
    .map(|groups| create_jwt(jwt_key, user.to_string(), groups))
 | 
					    .map(|groups| create_jwt(jwt_key, user.to_string(), groups))
 | 
				
			||||||
    .map(|token| {
 | 
					    .map(|token| {
 | 
				
			||||||
            HttpResponse::Ok()
 | 
					        HttpResponse::Ok()
 | 
				
			||||||
                .cookie(
 | 
					            .cookie(
 | 
				
			||||||
                    Cookie::build("token", token.as_str())
 | 
					                Cookie::build("token", token.as_str())
 | 
				
			||||||
                        .max_age(1.days())
 | 
					                    .max_age(1.days())
 | 
				
			||||||
                        .path("/api")
 | 
					                    .path("/api")
 | 
				
			||||||
                        .http_only(true)
 | 
					                    .http_only(true)
 | 
				
			||||||
                        .same_site(SameSite::Strict)
 | 
					                    .same_site(SameSite::Strict)
 | 
				
			||||||
                        .finish(),
 | 
					                    .finish(),
 | 
				
			||||||
                )
 | 
					            )
 | 
				
			||||||
                .body(token.as_str().to_owned())
 | 
					            .body(token.as_str().to_owned())
 | 
				
			||||||
    })
 | 
					    })
 | 
				
			||||||
    .unwrap_or_else(error_to_http_response)
 | 
					    .unwrap_or_else(error_to_http_response)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async fn get_logout<Backend>(
 | 
				
			||||||
 | 
					    data: web::Data<AppState<Backend>>,
 | 
				
			||||||
 | 
					    request: HttpRequest,
 | 
				
			||||||
 | 
					) -> HttpResponse
 | 
				
			||||||
 | 
					where
 | 
				
			||||||
 | 
					    Backend: TcpBackendHandler + BackendHandler + 'static,
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
					    let (refresh_token, user) = match get_refresh_token_from_cookie(request) {
 | 
				
			||||||
 | 
					        Ok(t) => t,
 | 
				
			||||||
 | 
					        Err(http_response) => return http_response,
 | 
				
			||||||
 | 
					    };
 | 
				
			||||||
 | 
					    if let Err(response) = data
 | 
				
			||||||
 | 
					        .backend_handler
 | 
				
			||||||
 | 
					        .delete_refresh_token(&refresh_token)
 | 
				
			||||||
 | 
					        .map_err(error_to_http_response)
 | 
				
			||||||
 | 
					        .await
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					        return response;
 | 
				
			||||||
 | 
					    };
 | 
				
			||||||
 | 
					    match data
 | 
				
			||||||
 | 
					        .backend_handler
 | 
				
			||||||
 | 
					        .blacklist_jwts(&user)
 | 
				
			||||||
 | 
					        .map_err(error_to_http_response)
 | 
				
			||||||
 | 
					        .await
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					        Ok(new_blacklisted_jwts) => {
 | 
				
			||||||
 | 
					            let mut jwt_blacklist = data.jwt_blacklist.write().unwrap();
 | 
				
			||||||
 | 
					            for jwt in new_blacklisted_jwts {
 | 
				
			||||||
 | 
					                jwt_blacklist.insert(jwt);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        Err(response) => return response,
 | 
				
			||||||
 | 
					    };
 | 
				
			||||||
 | 
					    HttpResponse::Ok()
 | 
				
			||||||
 | 
					        .cookie(
 | 
				
			||||||
 | 
					            Cookie::build("token", "")
 | 
				
			||||||
 | 
					                .max_age(0.days())
 | 
				
			||||||
 | 
					                .path("/api")
 | 
				
			||||||
 | 
					                .http_only(true)
 | 
				
			||||||
 | 
					                .same_site(SameSite::Strict)
 | 
				
			||||||
 | 
					                .finish(),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        .cookie(
 | 
				
			||||||
 | 
					            Cookie::build("refresh_token", "")
 | 
				
			||||||
 | 
					                .max_age(0.days())
 | 
				
			||||||
 | 
					                .path("/api/authorize/refresh")
 | 
				
			||||||
 | 
					                .http_only(true)
 | 
				
			||||||
 | 
					                .same_site(SameSite::Strict)
 | 
				
			||||||
 | 
					                .finish(),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        .finish()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async fn post_authorize<Backend>(
 | 
					async fn post_authorize<Backend>(
 | 
				
			||||||
    data: web::Data<AppState<Backend>>,
 | 
					    data: web::Data<AppState<Backend>>,
 | 
				
			||||||
    request: web::Json<BindRequest>,
 | 
					    request: web::Json<BindRequest>,
 | 
				
			||||||
@ -111,24 +178,24 @@ where
 | 
				
			|||||||
        .await
 | 
					        .await
 | 
				
			||||||
        .map(|(groups, (refresh_token, max_age))| {
 | 
					        .map(|(groups, (refresh_token, max_age))| {
 | 
				
			||||||
            let token = create_jwt(&data.jwt_key, request.name.clone(), groups);
 | 
					            let token = create_jwt(&data.jwt_key, request.name.clone(), groups);
 | 
				
			||||||
                HttpResponse::Ok()
 | 
					            HttpResponse::Ok()
 | 
				
			||||||
                    .cookie(
 | 
					                .cookie(
 | 
				
			||||||
                        Cookie::build("token", token.as_str())
 | 
					                    Cookie::build("token", token.as_str())
 | 
				
			||||||
                            .max_age(1.days())
 | 
					                        .max_age(1.days())
 | 
				
			||||||
                            .path("/api")
 | 
					                        .path("/api")
 | 
				
			||||||
                            .http_only(true)
 | 
					                        .http_only(true)
 | 
				
			||||||
                            .same_site(SameSite::Strict)
 | 
					                        .same_site(SameSite::Strict)
 | 
				
			||||||
                            .finish(),
 | 
					                        .finish(),
 | 
				
			||||||
                    )
 | 
					                )
 | 
				
			||||||
                    .cookie(
 | 
					                .cookie(
 | 
				
			||||||
                        Cookie::build("refresh_token", refresh_token + "+" + &request.name)
 | 
					                    Cookie::build("refresh_token", refresh_token + "+" + &request.name)
 | 
				
			||||||
                            .max_age(max_age.num_days().days())
 | 
					                        .max_age(max_age.num_days().days())
 | 
				
			||||||
                            .path("/api/authorize/refresh")
 | 
					                        .path("/api/authorize/refresh")
 | 
				
			||||||
                            .http_only(true)
 | 
					                        .http_only(true)
 | 
				
			||||||
                            .same_site(SameSite::Strict)
 | 
					                        .same_site(SameSite::Strict)
 | 
				
			||||||
                            .finish(),
 | 
					                        .finish(),
 | 
				
			||||||
                    )
 | 
					                )
 | 
				
			||||||
                    .body(token.as_str().to_owned())
 | 
					                .body(token.as_str().to_owned())
 | 
				
			||||||
        })
 | 
					        })
 | 
				
			||||||
        .unwrap_or_else(error_to_http_response)
 | 
					        .unwrap_or_else(error_to_http_response)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -137,9 +204,9 @@ pub struct CookieToHeaderTranslatorFactory;
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
impl<S, B> Transform<S, ServiceRequest> for CookieToHeaderTranslatorFactory
 | 
					impl<S, B> Transform<S, ServiceRequest> for CookieToHeaderTranslatorFactory
 | 
				
			||||||
where
 | 
					where
 | 
				
			||||||
  S: Service<ServiceRequest, Response = ServiceResponse<B>, Error=actix_web::Error>,
 | 
					    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error>,
 | 
				
			||||||
  S::Future: 'static,
 | 
					    S::Future: 'static,
 | 
				
			||||||
  B: 'static,
 | 
					    B: 'static,
 | 
				
			||||||
{
 | 
					{
 | 
				
			||||||
    type Response = ServiceResponse<B>;
 | 
					    type Response = ServiceResponse<B>;
 | 
				
			||||||
    type Error = actix_web::Error;
 | 
					    type Error = actix_web::Error;
 | 
				
			||||||
@ -211,7 +278,7 @@ where
 | 
				
			|||||||
        credentials.token().hash(&mut s);
 | 
					        credentials.token().hash(&mut s);
 | 
				
			||||||
        s.finish()
 | 
					        s.finish()
 | 
				
			||||||
    };
 | 
					    };
 | 
				
			||||||
    if state.jwt_blacklist.contains(&jwt_hash) {
 | 
					    if state.jwt_blacklist.read().unwrap().contains(&jwt_hash) {
 | 
				
			||||||
        return Err(ErrorUnauthorized("JWT was logged out"));
 | 
					        return Err(ErrorUnauthorized("JWT was logged out"));
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    let groups = &token.claims().groups;
 | 
					    let groups = &token.claims().groups;
 | 
				
			||||||
@ -225,13 +292,11 @@ where
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					pub fn configure_server<Backend>(cfg: &mut web::ServiceConfig)
 | 
				
			||||||
pub fn configure_server<Backend>(
 | 
					where
 | 
				
			||||||
    cfg: &mut web::ServiceConfig,
 | 
					 | 
				
			||||||
) where
 | 
					 | 
				
			||||||
    Backend: TcpBackendHandler + BackendHandler + 'static,
 | 
					    Backend: TcpBackendHandler + BackendHandler + 'static,
 | 
				
			||||||
{
 | 
					{
 | 
				
			||||||
    cfg
 | 
					    cfg.service(web::resource("").route(web::post().to(post_authorize::<Backend>)))
 | 
				
			||||||
    .service(web::resource("").route(web::post().to(post_authorize::<Backend>)))
 | 
					        .service(web::resource("/refresh").route(web::get().to(get_refresh::<Backend>)))
 | 
				
			||||||
    .service(web::resource("/refresh").route(web::get().to(get_refresh::<Backend>)));
 | 
					        .service(web::resource("/logout").route(web::get().to(get_logout::<Backend>)));
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -13,11 +13,12 @@ pub enum JwtRefreshStorage {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
/// Contains the blacklisted JWT that haven't expired yet.
 | 
					/// Contains the blacklisted JWT that haven't expired yet.
 | 
				
			||||||
#[derive(Iden)]
 | 
					#[derive(Iden)]
 | 
				
			||||||
pub enum JwtBlacklist {
 | 
					pub enum JwtStorage {
 | 
				
			||||||
    Table,
 | 
					    Table,
 | 
				
			||||||
    JwtHash,
 | 
					    JwtHash,
 | 
				
			||||||
    UserId,
 | 
					    UserId,
 | 
				
			||||||
    ExpiryDate,
 | 
					    ExpiryDate,
 | 
				
			||||||
 | 
					    Blacklisted,
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/// This needs to be initialized after the domain tables are.
 | 
					/// This needs to be initialized after the domain tables are.
 | 
				
			||||||
@ -57,29 +58,35 @@ pub async fn init_table(pool: &Pool) -> sqlx::Result<()> {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    sqlx::query(
 | 
					    sqlx::query(
 | 
				
			||||||
        &Table::create()
 | 
					        &Table::create()
 | 
				
			||||||
            .table(JwtBlacklist::Table)
 | 
					            .table(JwtStorage::Table)
 | 
				
			||||||
            .if_not_exists()
 | 
					            .if_not_exists()
 | 
				
			||||||
            .col(
 | 
					            .col(
 | 
				
			||||||
                ColumnDef::new(JwtBlacklist::JwtHash)
 | 
					                ColumnDef::new(JwtStorage::JwtHash)
 | 
				
			||||||
                    .big_integer()
 | 
					                    .big_integer()
 | 
				
			||||||
                    .not_null()
 | 
					                    .not_null()
 | 
				
			||||||
                    .primary_key(),
 | 
					                    .primary_key(),
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            .col(
 | 
					            .col(
 | 
				
			||||||
                ColumnDef::new(JwtBlacklist::UserId)
 | 
					                ColumnDef::new(JwtStorage::UserId)
 | 
				
			||||||
                    .string_len(255)
 | 
					                    .string_len(255)
 | 
				
			||||||
                    .not_null(),
 | 
					                    .not_null(),
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            .col(
 | 
					            .col(
 | 
				
			||||||
                ColumnDef::new(JwtBlacklist::ExpiryDate)
 | 
					                ColumnDef::new(JwtStorage::ExpiryDate)
 | 
				
			||||||
                    .date_time()
 | 
					                    .date_time()
 | 
				
			||||||
                    .not_null(),
 | 
					                    .not_null(),
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					            .col(
 | 
				
			||||||
 | 
					                ColumnDef::new(JwtStorage::Blacklisted)
 | 
				
			||||||
 | 
					                    .boolean()
 | 
				
			||||||
 | 
					                    .default(false)
 | 
				
			||||||
 | 
					                    .not_null(),
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
            .foreign_key(
 | 
					            .foreign_key(
 | 
				
			||||||
                ForeignKey::create()
 | 
					                ForeignKey::create()
 | 
				
			||||||
                    .name("JwtBlacklistUserForeignKey")
 | 
					                    .name("JwtStorageUserForeignKey")
 | 
				
			||||||
                    .table(JwtBlacklist::Table, Users::Table)
 | 
					                    .table(JwtStorage::Table, Users::Table)
 | 
				
			||||||
                    .col(JwtBlacklist::UserId, Users::UserId)
 | 
					                    .col(JwtStorage::UserId, Users::UserId)
 | 
				
			||||||
                    .on_delete(ForeignKeyAction::Cascade)
 | 
					                    .on_delete(ForeignKeyAction::Cascade)
 | 
				
			||||||
                    .on_update(ForeignKeyAction::Cascade),
 | 
					                    .on_update(ForeignKeyAction::Cascade),
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
				
			|||||||
@ -4,19 +4,21 @@ use async_trait::async_trait;
 | 
				
			|||||||
use futures_util::StreamExt;
 | 
					use futures_util::StreamExt;
 | 
				
			||||||
use sea_query::{Expr, Iden, Query, SimpleExpr};
 | 
					use sea_query::{Expr, Iden, Query, SimpleExpr};
 | 
				
			||||||
use sqlx::Row;
 | 
					use sqlx::Row;
 | 
				
			||||||
 | 
					use std::collections::hash_map::DefaultHasher;
 | 
				
			||||||
use std::collections::HashSet;
 | 
					use std::collections::HashSet;
 | 
				
			||||||
 | 
					use std::hash::{Hash, Hasher};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#[async_trait]
 | 
					#[async_trait]
 | 
				
			||||||
impl TcpBackendHandler for SqlBackendHandler {
 | 
					impl TcpBackendHandler for SqlBackendHandler {
 | 
				
			||||||
    async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>> {
 | 
					    async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>> {
 | 
				
			||||||
        use sqlx::Result;
 | 
					        use sqlx::Result;
 | 
				
			||||||
        let query = Query::select()
 | 
					        let query = Query::select()
 | 
				
			||||||
            .column(JwtBlacklist::JwtHash)
 | 
					            .column(JwtStorage::JwtHash)
 | 
				
			||||||
            .from(JwtBlacklist::Table)
 | 
					            .from(JwtStorage::Table)
 | 
				
			||||||
            .to_string(DbQueryBuilder {});
 | 
					            .to_string(DbQueryBuilder {});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        sqlx::query(&query)
 | 
					        sqlx::query(&query)
 | 
				
			||||||
            .map(|row: DbRow| row.get::<i64, _>(&*JwtBlacklist::JwtHash.to_string()) as u64)
 | 
					            .map(|row: DbRow| row.get::<i64, _>(&*JwtStorage::JwtHash.to_string()) as u64)
 | 
				
			||||||
            .fetch(&self.sql_pool)
 | 
					            .fetch(&self.sql_pool)
 | 
				
			||||||
            .collect::<Vec<sqlx::Result<u64>>>()
 | 
					            .collect::<Vec<sqlx::Result<u64>>>()
 | 
				
			||||||
            .await
 | 
					            .await
 | 
				
			||||||
@ -60,8 +62,6 @@ impl TcpBackendHandler for SqlBackendHandler {
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async fn check_token(&self, token: &str, user: &str) -> Result<bool> {
 | 
					    async fn check_token(&self, token: &str, user: &str) -> Result<bool> {
 | 
				
			||||||
        use std::collections::hash_map::DefaultHasher;
 | 
					 | 
				
			||||||
        use std::hash::{Hash, Hasher};
 | 
					 | 
				
			||||||
        let refresh_token_hash = {
 | 
					        let refresh_token_hash = {
 | 
				
			||||||
            let mut s = DefaultHasher::new();
 | 
					            let mut s = DefaultHasher::new();
 | 
				
			||||||
            token.hash(&mut s);
 | 
					            token.hash(&mut s);
 | 
				
			||||||
@ -78,4 +78,40 @@ impl TcpBackendHandler for SqlBackendHandler {
 | 
				
			|||||||
            .await?
 | 
					            .await?
 | 
				
			||||||
            .is_some())
 | 
					            .is_some())
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					    async fn blacklist_jwts(&self, user: &str) -> DomainResult<HashSet<u64>> {
 | 
				
			||||||
 | 
					        use sqlx::Result;
 | 
				
			||||||
 | 
					        let query = Query::select()
 | 
				
			||||||
 | 
					            .column(JwtStorage::JwtHash)
 | 
				
			||||||
 | 
					            .from(JwtStorage::Table)
 | 
				
			||||||
 | 
					            .and_where(Expr::col(JwtStorage::UserId).eq(user))
 | 
				
			||||||
 | 
					            .and_where(Expr::col(JwtStorage::Blacklisted).eq(true))
 | 
				
			||||||
 | 
					            .to_string(DbQueryBuilder {});
 | 
				
			||||||
 | 
					        let result = sqlx::query(&query)
 | 
				
			||||||
 | 
					            .map(|row: DbRow| row.get::<i64, _>(&*JwtStorage::JwtHash.to_string()) as u64)
 | 
				
			||||||
 | 
					            .fetch(&self.sql_pool)
 | 
				
			||||||
 | 
					            .collect::<Vec<sqlx::Result<u64>>>()
 | 
				
			||||||
 | 
					            .await
 | 
				
			||||||
 | 
					            .into_iter()
 | 
				
			||||||
 | 
					            .collect::<Result<HashSet<u64>>>();
 | 
				
			||||||
 | 
					        let query = Query::update()
 | 
				
			||||||
 | 
					            .table(JwtStorage::Table)
 | 
				
			||||||
 | 
					            .values(vec![(JwtStorage::Blacklisted, true.into())])
 | 
				
			||||||
 | 
					            .and_where(Expr::col(JwtStorage::UserId).eq(user))
 | 
				
			||||||
 | 
					            .to_string(DbQueryBuilder {});
 | 
				
			||||||
 | 
					        sqlx::query(&query).execute(&self.sql_pool).await?;
 | 
				
			||||||
 | 
					        Ok(result?)
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    async fn delete_refresh_token(&self, token: &str) -> DomainResult<()> {
 | 
				
			||||||
 | 
					        let refresh_token_hash = {
 | 
				
			||||||
 | 
					            let mut s = DefaultHasher::new();
 | 
				
			||||||
 | 
					            token.hash(&mut s);
 | 
				
			||||||
 | 
					            s.finish()
 | 
				
			||||||
 | 
					        };
 | 
				
			||||||
 | 
					        let query = Query::delete()
 | 
				
			||||||
 | 
					            .from_table(JwtRefreshStorage::Table)
 | 
				
			||||||
 | 
					            .and_where(Expr::col(JwtRefreshStorage::RefreshTokenHash).eq(refresh_token_hash))
 | 
				
			||||||
 | 
					            .to_string(DbQueryBuilder {});
 | 
				
			||||||
 | 
					        sqlx::query(&query).execute(&self.sql_pool).await?;
 | 
				
			||||||
 | 
					        Ok(())
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -1,10 +1,12 @@
 | 
				
			|||||||
use crate::{
 | 
					use crate::{
 | 
				
			||||||
    domain::handler::*,
 | 
					    domain::handler::*,
 | 
				
			||||||
    infra::{tcp_server::{AppState, error_to_http_response}, tcp_backend_handler::*},
 | 
					    infra::{
 | 
				
			||||||
 | 
					        tcp_backend_handler::*,
 | 
				
			||||||
 | 
					        tcp_server::{error_to_http_response, AppState},
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
use actix_web::{web, HttpResponse};
 | 
					use actix_web::{web, HttpResponse};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
fn error_to_api_response<T>(error: DomainError) -> ApiResult<T> {
 | 
					fn error_to_api_response<T>(error: DomainError) -> ApiResult<T> {
 | 
				
			||||||
    ApiResult::Right(error_to_http_response(error))
 | 
					    ApiResult::Right(error_to_http_response(error))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -54,12 +56,15 @@ mod tests {
 | 
				
			|||||||
    use super::*;
 | 
					    use super::*;
 | 
				
			||||||
    use hmac::{Hmac, NewMac};
 | 
					    use hmac::{Hmac, NewMac};
 | 
				
			||||||
    use std::collections::HashSet;
 | 
					    use std::collections::HashSet;
 | 
				
			||||||
 | 
					    use std::sync::RwLock;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    fn get_data(handler: MockTestTcpBackendHandler) -> web::Data<AppState<MockTestTcpBackendHandler>> {
 | 
					    fn get_data(
 | 
				
			||||||
 | 
					        handler: MockTestTcpBackendHandler,
 | 
				
			||||||
 | 
					    ) -> web::Data<AppState<MockTestTcpBackendHandler>> {
 | 
				
			||||||
        let app_state = AppState::<MockTestTcpBackendHandler> {
 | 
					        let app_state = AppState::<MockTestTcpBackendHandler> {
 | 
				
			||||||
            backend_handler: handler,
 | 
					            backend_handler: handler,
 | 
				
			||||||
            jwt_key: Hmac::new_varkey(b"jwt_secret").unwrap(),
 | 
					            jwt_key: Hmac::new_varkey(b"jwt_secret").unwrap(),
 | 
				
			||||||
            jwt_blacklist: HashSet::new(),
 | 
					            jwt_blacklist: RwLock::new(HashSet::new()),
 | 
				
			||||||
        };
 | 
					        };
 | 
				
			||||||
        web::Data::<AppState<MockTestTcpBackendHandler>>::new(app_state)
 | 
					        web::Data::<AppState<MockTestTcpBackendHandler>>::new(app_state)
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
				
			|||||||
@ -1,5 +1,5 @@
 | 
				
			|||||||
use std::collections::HashSet;
 | 
					 | 
				
			||||||
use async_trait::async_trait;
 | 
					use async_trait::async_trait;
 | 
				
			||||||
 | 
					use std::collections::HashSet;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
pub type DomainError = crate::domain::error::Error;
 | 
					pub type DomainError = crate::domain::error::Error;
 | 
				
			||||||
pub type DomainResult<T> = crate::domain::error::Result<T>;
 | 
					pub type DomainResult<T> = crate::domain::error::Result<T>;
 | 
				
			||||||
@ -9,6 +9,8 @@ pub trait TcpBackendHandler {
 | 
				
			|||||||
    async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>>;
 | 
					    async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>>;
 | 
				
			||||||
    async fn create_refresh_token(&self, user: &str) -> DomainResult<(String, chrono::Duration)>;
 | 
					    async fn create_refresh_token(&self, user: &str) -> DomainResult<(String, chrono::Duration)>;
 | 
				
			||||||
    async fn check_token(&self, token: &str, user: &str) -> DomainResult<bool>;
 | 
					    async fn check_token(&self, token: &str, user: &str) -> DomainResult<bool>;
 | 
				
			||||||
 | 
					    async fn blacklist_jwts(&self, user: &str) -> DomainResult<HashSet<u64>>;
 | 
				
			||||||
 | 
					    async fn delete_refresh_token(&self, token: &str) -> DomainResult<()>;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#[cfg(test)]
 | 
					#[cfg(test)]
 | 
				
			||||||
@ -31,5 +33,7 @@ mockall::mock! {
 | 
				
			|||||||
        async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>>;
 | 
					        async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>>;
 | 
				
			||||||
        async fn create_refresh_token(&self, user: &str) -> DomainResult<(String, chrono::Duration)>;
 | 
					        async fn create_refresh_token(&self, user: &str) -> DomainResult<(String, chrono::Duration)>;
 | 
				
			||||||
        async fn check_token(&self, token: &str, user: &str) -> DomainResult<bool>;
 | 
					        async fn check_token(&self, token: &str, user: &str) -> DomainResult<bool>;
 | 
				
			||||||
 | 
					        async fn blacklist_jwts(&self, user: &str) -> DomainResult<HashSet<u64>>;
 | 
				
			||||||
 | 
					        async fn delete_refresh_token(&self, token: &str) -> DomainResult<()>;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -1,6 +1,6 @@
 | 
				
			|||||||
use crate::{
 | 
					use crate::{
 | 
				
			||||||
    domain::handler::*,
 | 
					    domain::handler::*,
 | 
				
			||||||
    infra::{auth_service, tcp_api, configuration::Configuration, tcp_backend_handler::*},
 | 
					    infra::{auth_service, configuration::Configuration, tcp_api, tcp_backend_handler::*},
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
use actix_files::{Files, NamedFile};
 | 
					use actix_files::{Files, NamedFile};
 | 
				
			||||||
use actix_http::HttpServiceBuilder;
 | 
					use actix_http::HttpServiceBuilder;
 | 
				
			||||||
@ -13,6 +13,7 @@ use hmac::{Hmac, NewMac};
 | 
				
			|||||||
use sha2::Sha512;
 | 
					use sha2::Sha512;
 | 
				
			||||||
use std::collections::HashSet;
 | 
					use std::collections::HashSet;
 | 
				
			||||||
use std::path::PathBuf;
 | 
					use std::path::PathBuf;
 | 
				
			||||||
 | 
					use std::sync::RwLock;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async fn index(req: HttpRequest) -> actix_web::Result<NamedFile> {
 | 
					async fn index(req: HttpRequest) -> actix_web::Result<NamedFile> {
 | 
				
			||||||
    let mut path = PathBuf::new();
 | 
					    let mut path = PathBuf::new();
 | 
				
			||||||
@ -41,7 +42,7 @@ fn http_config<Backend>(
 | 
				
			|||||||
    cfg.data(AppState::<Backend> {
 | 
					    cfg.data(AppState::<Backend> {
 | 
				
			||||||
        backend_handler,
 | 
					        backend_handler,
 | 
				
			||||||
        jwt_key: Hmac::new_varkey(&jwt_secret.as_bytes()).unwrap(),
 | 
					        jwt_key: Hmac::new_varkey(&jwt_secret.as_bytes()).unwrap(),
 | 
				
			||||||
        jwt_blacklist,
 | 
					        jwt_blacklist: RwLock::new(jwt_blacklist),
 | 
				
			||||||
    })
 | 
					    })
 | 
				
			||||||
    // Serve index.html and main.js, and default to index.html.
 | 
					    // Serve index.html and main.js, and default to index.html.
 | 
				
			||||||
    .route(
 | 
					    .route(
 | 
				
			||||||
@ -70,7 +71,7 @@ where
 | 
				
			|||||||
{
 | 
					{
 | 
				
			||||||
    pub backend_handler: Backend,
 | 
					    pub backend_handler: Backend,
 | 
				
			||||||
    pub jwt_key: Hmac<Sha512>,
 | 
					    pub jwt_key: Hmac<Sha512>,
 | 
				
			||||||
    pub jwt_blacklist: HashSet<u64>,
 | 
					    pub jwt_blacklist: RwLock<HashSet<u64>>,
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
pub async fn build_tcp_server<Backend>(
 | 
					pub async fn build_tcp_server<Backend>(
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user