mirror of
				https://github.com/nitnelave/lldap.git
				synced 2023-04-12 14:25:13 +00:00 
			
		
		
		
	Hash refesh tokens earlier
This commit is contained in:
		
							parent
							
								
									10404abbb0
								
							
						
					
					
						commit
						28b7be0500
					
				@ -45,12 +45,19 @@ fn create_jwt(key: &Hmac<Sha512>, user: String, groups: HashSet<String>) -> Sign
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
fn get_refresh_token_from_cookie(
 | 
					fn get_refresh_token_from_cookie(
 | 
				
			||||||
    request: HttpRequest,
 | 
					    request: HttpRequest,
 | 
				
			||||||
) -> std::result::Result<(String, String), HttpResponse> {
 | 
					) -> std::result::Result<(u64, String), HttpResponse> {
 | 
				
			||||||
    match request.cookie("refresh_token") {
 | 
					    match request.cookie("refresh_token") {
 | 
				
			||||||
        None => Err(HttpResponse::Unauthorized().body("Missing refresh token")),
 | 
					        None => Err(HttpResponse::Unauthorized().body("Missing refresh token")),
 | 
				
			||||||
        Some(t) => match t.value().split_once("+") {
 | 
					        Some(t) => match t.value().split_once("+") {
 | 
				
			||||||
            None => Err(HttpResponse::Unauthorized().body("Invalid refresh token")),
 | 
					            None => Err(HttpResponse::Unauthorized().body("Invalid refresh token")),
 | 
				
			||||||
            Some((t, u)) => Ok((t.to_string(), u.to_string())),
 | 
					            Some((token, u)) => {
 | 
				
			||||||
 | 
					                let refresh_token_hash = {
 | 
				
			||||||
 | 
					                    let mut s = DefaultHasher::new();
 | 
				
			||||||
 | 
					                    token.hash(&mut s);
 | 
				
			||||||
 | 
					                    s.finish()
 | 
				
			||||||
 | 
					                };
 | 
				
			||||||
 | 
					                Ok((refresh_token_hash, u.to_string()))
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
        },
 | 
					        },
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -64,13 +71,13 @@ where
 | 
				
			|||||||
{
 | 
					{
 | 
				
			||||||
    let backend_handler = &data.backend_handler;
 | 
					    let backend_handler = &data.backend_handler;
 | 
				
			||||||
    let jwt_key = &data.jwt_key;
 | 
					    let jwt_key = &data.jwt_key;
 | 
				
			||||||
    let (refresh_token, user) = match get_refresh_token_from_cookie(request) {
 | 
					    let (refresh_token_hash, user) = match get_refresh_token_from_cookie(request) {
 | 
				
			||||||
        Ok(t) => t,
 | 
					        Ok(t) => t,
 | 
				
			||||||
        Err(http_response) => return http_response,
 | 
					        Err(http_response) => return http_response,
 | 
				
			||||||
    };
 | 
					    };
 | 
				
			||||||
    let res_found = data
 | 
					    let res_found = data
 | 
				
			||||||
        .backend_handler
 | 
					        .backend_handler
 | 
				
			||||||
        .check_token(&refresh_token, &user)
 | 
					        .check_token(refresh_token_hash, &user)
 | 
				
			||||||
        .await;
 | 
					        .await;
 | 
				
			||||||
    // Async closures are not supported yet.
 | 
					    // Async closures are not supported yet.
 | 
				
			||||||
    match res_found {
 | 
					    match res_found {
 | 
				
			||||||
@ -108,13 +115,13 @@ async fn get_logout<Backend>(
 | 
				
			|||||||
where
 | 
					where
 | 
				
			||||||
    Backend: TcpBackendHandler + BackendHandler + 'static,
 | 
					    Backend: TcpBackendHandler + BackendHandler + 'static,
 | 
				
			||||||
{
 | 
					{
 | 
				
			||||||
    let (refresh_token, user) = match get_refresh_token_from_cookie(request) {
 | 
					    let (refresh_token_hash, user) = match get_refresh_token_from_cookie(request) {
 | 
				
			||||||
        Ok(t) => t,
 | 
					        Ok(t) => t,
 | 
				
			||||||
        Err(http_response) => return http_response,
 | 
					        Err(http_response) => return http_response,
 | 
				
			||||||
    };
 | 
					    };
 | 
				
			||||||
    if let Err(response) = data
 | 
					    if let Err(response) = data
 | 
				
			||||||
        .backend_handler
 | 
					        .backend_handler
 | 
				
			||||||
        .delete_refresh_token(&refresh_token)
 | 
					        .delete_refresh_token(refresh_token_hash)
 | 
				
			||||||
        .map_err(error_to_http_response)
 | 
					        .map_err(error_to_http_response)
 | 
				
			||||||
        .await
 | 
					        .await
 | 
				
			||||||
    {
 | 
					    {
 | 
				
			||||||
@ -131,7 +138,7 @@ where
 | 
				
			|||||||
            for jwt in new_blacklisted_jwts {
 | 
					            for jwt in new_blacklisted_jwts {
 | 
				
			||||||
                jwt_blacklist.insert(jwt);
 | 
					                jwt_blacklist.insert(jwt);
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        },
 | 
					        }
 | 
				
			||||||
        Err(response) => return response,
 | 
					        Err(response) => return response,
 | 
				
			||||||
    };
 | 
					    };
 | 
				
			||||||
    HttpResponse::Ok()
 | 
					    HttpResponse::Ok()
 | 
				
			||||||
 | 
				
			|||||||
@ -4,9 +4,7 @@ use async_trait::async_trait;
 | 
				
			|||||||
use futures_util::StreamExt;
 | 
					use futures_util::StreamExt;
 | 
				
			||||||
use sea_query::{Expr, Iden, Query, SimpleExpr};
 | 
					use sea_query::{Expr, Iden, Query, SimpleExpr};
 | 
				
			||||||
use sqlx::Row;
 | 
					use sqlx::Row;
 | 
				
			||||||
use std::collections::hash_map::DefaultHasher;
 | 
					 | 
				
			||||||
use std::collections::HashSet;
 | 
					use std::collections::HashSet;
 | 
				
			||||||
use std::hash::{Hash, Hasher};
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
#[async_trait]
 | 
					#[async_trait]
 | 
				
			||||||
impl TcpBackendHandler for SqlBackendHandler {
 | 
					impl TcpBackendHandler for SqlBackendHandler {
 | 
				
			||||||
@ -61,12 +59,7 @@ impl TcpBackendHandler for SqlBackendHandler {
 | 
				
			|||||||
        Ok((refresh_token, duration))
 | 
					        Ok((refresh_token, duration))
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async fn check_token(&self, token: &str, user: &str) -> Result<bool> {
 | 
					    async fn check_token(&self, refresh_token_hash: u64, user: &str) -> Result<bool> {
 | 
				
			||||||
        let refresh_token_hash = {
 | 
					 | 
				
			||||||
            let mut s = DefaultHasher::new();
 | 
					 | 
				
			||||||
            token.hash(&mut s);
 | 
					 | 
				
			||||||
            s.finish()
 | 
					 | 
				
			||||||
        };
 | 
					 | 
				
			||||||
        let query = Query::select()
 | 
					        let query = Query::select()
 | 
				
			||||||
            .expr(SimpleExpr::Value(1.into()))
 | 
					            .expr(SimpleExpr::Value(1.into()))
 | 
				
			||||||
            .from(JwtRefreshStorage::Table)
 | 
					            .from(JwtRefreshStorage::Table)
 | 
				
			||||||
@ -101,12 +94,7 @@ impl TcpBackendHandler for SqlBackendHandler {
 | 
				
			|||||||
        sqlx::query(&query).execute(&self.sql_pool).await?;
 | 
					        sqlx::query(&query).execute(&self.sql_pool).await?;
 | 
				
			||||||
        Ok(result?)
 | 
					        Ok(result?)
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    async fn delete_refresh_token(&self, token: &str) -> DomainResult<()> {
 | 
					    async fn delete_refresh_token(&self, refresh_token_hash: u64) -> DomainResult<()> {
 | 
				
			||||||
        let refresh_token_hash = {
 | 
					 | 
				
			||||||
            let mut s = DefaultHasher::new();
 | 
					 | 
				
			||||||
            token.hash(&mut s);
 | 
					 | 
				
			||||||
            s.finish()
 | 
					 | 
				
			||||||
        };
 | 
					 | 
				
			||||||
        let query = Query::delete()
 | 
					        let query = Query::delete()
 | 
				
			||||||
            .from_table(JwtRefreshStorage::Table)
 | 
					            .from_table(JwtRefreshStorage::Table)
 | 
				
			||||||
            .and_where(Expr::col(JwtRefreshStorage::RefreshTokenHash).eq(refresh_token_hash))
 | 
					            .and_where(Expr::col(JwtRefreshStorage::RefreshTokenHash).eq(refresh_token_hash))
 | 
				
			||||||
 | 
				
			|||||||
@ -8,9 +8,9 @@ pub type DomainResult<T> = crate::domain::error::Result<T>;
 | 
				
			|||||||
pub trait TcpBackendHandler {
 | 
					pub trait TcpBackendHandler {
 | 
				
			||||||
    async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>>;
 | 
					    async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>>;
 | 
				
			||||||
    async fn create_refresh_token(&self, user: &str) -> DomainResult<(String, chrono::Duration)>;
 | 
					    async fn create_refresh_token(&self, user: &str) -> DomainResult<(String, chrono::Duration)>;
 | 
				
			||||||
    async fn check_token(&self, token: &str, user: &str) -> DomainResult<bool>;
 | 
					    async fn check_token(&self, refresh_token_hash: u64, user: &str) -> DomainResult<bool>;
 | 
				
			||||||
    async fn blacklist_jwts(&self, user: &str) -> DomainResult<HashSet<u64>>;
 | 
					    async fn blacklist_jwts(&self, user: &str) -> DomainResult<HashSet<u64>>;
 | 
				
			||||||
    async fn delete_refresh_token(&self, token: &str) -> DomainResult<()>;
 | 
					    async fn delete_refresh_token(&self, refresh_token_hash: u64) -> DomainResult<()>;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#[cfg(test)]
 | 
					#[cfg(test)]
 | 
				
			||||||
@ -32,8 +32,8 @@ mockall::mock! {
 | 
				
			|||||||
    impl TcpBackendHandler for TestTcpBackendHandler {
 | 
					    impl TcpBackendHandler for TestTcpBackendHandler {
 | 
				
			||||||
        async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>>;
 | 
					        async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>>;
 | 
				
			||||||
        async fn create_refresh_token(&self, user: &str) -> DomainResult<(String, chrono::Duration)>;
 | 
					        async fn create_refresh_token(&self, user: &str) -> DomainResult<(String, chrono::Duration)>;
 | 
				
			||||||
        async fn check_token(&self, token: &str, user: &str) -> DomainResult<bool>;
 | 
					        async fn check_token(&self, refresh_token_hash: u64, user: &str) -> DomainResult<bool>;
 | 
				
			||||||
        async fn blacklist_jwts(&self, user: &str) -> DomainResult<HashSet<u64>>;
 | 
					        async fn blacklist_jwts(&self, user: &str) -> DomainResult<HashSet<u64>>;
 | 
				
			||||||
        async fn delete_refresh_token(&self, token: &str) -> DomainResult<()>;
 | 
					        async fn delete_refresh_token(&self, refresh_token_hash: u64) -> DomainResult<()>;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user