Hash refesh tokens earlier

This commit is contained in:
Valentin Tolmer 2021-05-23 16:26:24 +02:00
parent 10404abbb0
commit 28b7be0500
3 changed files with 20 additions and 25 deletions

View File

@ -45,12 +45,19 @@ fn create_jwt(key: &Hmac<Sha512>, user: String, groups: HashSet<String>) -> Sign
fn get_refresh_token_from_cookie( fn get_refresh_token_from_cookie(
request: HttpRequest, request: HttpRequest,
) -> std::result::Result<(String, String), HttpResponse> { ) -> std::result::Result<(u64, String), HttpResponse> {
match request.cookie("refresh_token") { match request.cookie("refresh_token") {
None => Err(HttpResponse::Unauthorized().body("Missing refresh token")), None => Err(HttpResponse::Unauthorized().body("Missing refresh token")),
Some(t) => match t.value().split_once("+") { Some(t) => match t.value().split_once("+") {
None => Err(HttpResponse::Unauthorized().body("Invalid refresh token")), None => Err(HttpResponse::Unauthorized().body("Invalid refresh token")),
Some((t, u)) => Ok((t.to_string(), u.to_string())), Some((token, u)) => {
let refresh_token_hash = {
let mut s = DefaultHasher::new();
token.hash(&mut s);
s.finish()
};
Ok((refresh_token_hash, u.to_string()))
}
}, },
} }
} }
@ -64,13 +71,13 @@ where
{ {
let backend_handler = &data.backend_handler; let backend_handler = &data.backend_handler;
let jwt_key = &data.jwt_key; let jwt_key = &data.jwt_key;
let (refresh_token, user) = match get_refresh_token_from_cookie(request) { let (refresh_token_hash, user) = match get_refresh_token_from_cookie(request) {
Ok(t) => t, Ok(t) => t,
Err(http_response) => return http_response, Err(http_response) => return http_response,
}; };
let res_found = data let res_found = data
.backend_handler .backend_handler
.check_token(&refresh_token, &user) .check_token(refresh_token_hash, &user)
.await; .await;
// Async closures are not supported yet. // Async closures are not supported yet.
match res_found { match res_found {
@ -108,13 +115,13 @@ async fn get_logout<Backend>(
where where
Backend: TcpBackendHandler + BackendHandler + 'static, Backend: TcpBackendHandler + BackendHandler + 'static,
{ {
let (refresh_token, user) = match get_refresh_token_from_cookie(request) { let (refresh_token_hash, user) = match get_refresh_token_from_cookie(request) {
Ok(t) => t, Ok(t) => t,
Err(http_response) => return http_response, Err(http_response) => return http_response,
}; };
if let Err(response) = data if let Err(response) = data
.backend_handler .backend_handler
.delete_refresh_token(&refresh_token) .delete_refresh_token(refresh_token_hash)
.map_err(error_to_http_response) .map_err(error_to_http_response)
.await .await
{ {
@ -131,7 +138,7 @@ where
for jwt in new_blacklisted_jwts { for jwt in new_blacklisted_jwts {
jwt_blacklist.insert(jwt); jwt_blacklist.insert(jwt);
} }
}, }
Err(response) => return response, Err(response) => return response,
}; };
HttpResponse::Ok() HttpResponse::Ok()

View File

@ -4,9 +4,7 @@ use async_trait::async_trait;
use futures_util::StreamExt; use futures_util::StreamExt;
use sea_query::{Expr, Iden, Query, SimpleExpr}; use sea_query::{Expr, Iden, Query, SimpleExpr};
use sqlx::Row; use sqlx::Row;
use std::collections::hash_map::DefaultHasher;
use std::collections::HashSet; use std::collections::HashSet;
use std::hash::{Hash, Hasher};
#[async_trait] #[async_trait]
impl TcpBackendHandler for SqlBackendHandler { impl TcpBackendHandler for SqlBackendHandler {
@ -61,12 +59,7 @@ impl TcpBackendHandler for SqlBackendHandler {
Ok((refresh_token, duration)) Ok((refresh_token, duration))
} }
async fn check_token(&self, token: &str, user: &str) -> Result<bool> { async fn check_token(&self, refresh_token_hash: u64, user: &str) -> Result<bool> {
let refresh_token_hash = {
let mut s = DefaultHasher::new();
token.hash(&mut s);
s.finish()
};
let query = Query::select() let query = Query::select()
.expr(SimpleExpr::Value(1.into())) .expr(SimpleExpr::Value(1.into()))
.from(JwtRefreshStorage::Table) .from(JwtRefreshStorage::Table)
@ -101,12 +94,7 @@ impl TcpBackendHandler for SqlBackendHandler {
sqlx::query(&query).execute(&self.sql_pool).await?; sqlx::query(&query).execute(&self.sql_pool).await?;
Ok(result?) Ok(result?)
} }
async fn delete_refresh_token(&self, token: &str) -> DomainResult<()> { async fn delete_refresh_token(&self, refresh_token_hash: u64) -> DomainResult<()> {
let refresh_token_hash = {
let mut s = DefaultHasher::new();
token.hash(&mut s);
s.finish()
};
let query = Query::delete() let query = Query::delete()
.from_table(JwtRefreshStorage::Table) .from_table(JwtRefreshStorage::Table)
.and_where(Expr::col(JwtRefreshStorage::RefreshTokenHash).eq(refresh_token_hash)) .and_where(Expr::col(JwtRefreshStorage::RefreshTokenHash).eq(refresh_token_hash))

View File

@ -8,9 +8,9 @@ pub type DomainResult<T> = crate::domain::error::Result<T>;
pub trait TcpBackendHandler { pub trait TcpBackendHandler {
async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>>; async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>>;
async fn create_refresh_token(&self, user: &str) -> DomainResult<(String, chrono::Duration)>; async fn create_refresh_token(&self, user: &str) -> DomainResult<(String, chrono::Duration)>;
async fn check_token(&self, token: &str, user: &str) -> DomainResult<bool>; async fn check_token(&self, refresh_token_hash: u64, user: &str) -> DomainResult<bool>;
async fn blacklist_jwts(&self, user: &str) -> DomainResult<HashSet<u64>>; async fn blacklist_jwts(&self, user: &str) -> DomainResult<HashSet<u64>>;
async fn delete_refresh_token(&self, token: &str) -> DomainResult<()>; async fn delete_refresh_token(&self, refresh_token_hash: u64) -> DomainResult<()>;
} }
#[cfg(test)] #[cfg(test)]
@ -32,8 +32,8 @@ mockall::mock! {
impl TcpBackendHandler for TestTcpBackendHandler { impl TcpBackendHandler for TestTcpBackendHandler {
async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>>; async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>>;
async fn create_refresh_token(&self, user: &str) -> DomainResult<(String, chrono::Duration)>; async fn create_refresh_token(&self, user: &str) -> DomainResult<(String, chrono::Duration)>;
async fn check_token(&self, token: &str, user: &str) -> DomainResult<bool>; async fn check_token(&self, refresh_token_hash: u64, user: &str) -> DomainResult<bool>;
async fn blacklist_jwts(&self, user: &str) -> DomainResult<HashSet<u64>>; async fn blacklist_jwts(&self, user: &str) -> DomainResult<HashSet<u64>>;
async fn delete_refresh_token(&self, token: &str) -> DomainResult<()>; async fn delete_refresh_token(&self, refresh_token_hash: u64) -> DomainResult<()>;
} }
} }