From d5cb53ae8a8646182c9383bac2de2038aaeaf34b Mon Sep 17 00:00:00 2001 From: Valentin Tolmer Date: Thu, 20 May 2021 17:40:30 +0200 Subject: [PATCH] Implement refresh tokens --- Cargo.toml | 1 + model/src/lib.rs | 1 + src/domain/handler.rs | 91 +++++++++++++++++++++--- src/domain/sql_tables.rs | 7 +- src/infra/jwt_sql_tables.rs | 92 +++++++++++++++++++++++++ src/infra/mod.rs | 1 + src/infra/tcp_server.rs | 134 +++++++++++++++++++++++++++++++----- src/main.rs | 3 +- 8 files changed, 301 insertions(+), 29 deletions(-) create mode 100644 src/infra/jwt_sql_tables.rs diff --git a/Cargo.toml b/Cargo.toml index 0619f50..275a679 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,6 +43,7 @@ tracing = "*" tracing-actix-web = "0.3.0-beta.2" tracing-log = "*" tracing-subscriber = "*" +rand = { version = "0.8", features = ["small_rng", "getrandom"] } [dependencies.sqlx] version = "0.5" diff --git a/model/src/lib.rs b/model/src/lib.rs index bc45991..eed7bac 100644 --- a/model/src/lib.rs +++ b/model/src/lib.rs @@ -55,6 +55,7 @@ pub struct Group { #[derive(Clone, Serialize, Deserialize)] pub struct JWTClaims { pub exp: DateTime, + pub iat: DateTime, pub user: String, pub groups: HashSet, } diff --git a/src/domain/handler.rs b/src/domain/handler.rs index 0d5802b..b36d882 100644 --- a/src/domain/handler.rs +++ b/src/domain/handler.rs @@ -1,12 +1,13 @@ use super::sql_tables::*; use crate::domain::{error::*, sql_tables::Pool}; use crate::infra::configuration::Configuration; +use crate::infra::jwt_sql_tables::*; use async_trait::async_trait; use futures_util::StreamExt; use futures_util::TryStreamExt; use log::*; use sea_query::Iden; -use sea_query::{Expr, Order, Query, SimpleExpr, SqliteQueryBuilder}; +use sea_query::{Expr, Order, Query, SimpleExpr}; use sqlx::Row; use std::collections::HashSet; @@ -72,7 +73,7 @@ impl BackendHandler for SqlBackendHandler { .column(Users::Password) .from(Users::Table) .and_where(Expr::col(Users::UserId).eq(request.name.as_str())) - .to_string(SqliteQueryBuilder); + .to_string(DbQueryBuilder {}); if let Ok(row) = sqlx::query(&query).fetch_one(&self.sql_pool).await { if passwords_match( &request.password, @@ -109,7 +110,7 @@ impl BackendHandler for SqlBackendHandler { } } - query_builder.to_string(SqliteQueryBuilder) + query_builder.to_string(DbQueryBuilder {}) }; let results = sqlx::query_as::<_, User>(&query) @@ -132,7 +133,7 @@ impl BackendHandler for SqlBackendHandler { ) .order_by(Groups::DisplayName, Order::Asc) .order_by(Memberships::UserId, Order::Asc) - .to_string(SqliteQueryBuilder); + .to_string(DbQueryBuilder {}); let mut results = sqlx::query(&query).fetch(&self.sql_pool); let mut groups = Vec::new(); @@ -178,7 +179,7 @@ impl BackendHandler for SqlBackendHandler { .equals(Memberships::Table, Memberships::GroupId), ) .and_where(Expr::col(Memberships::UserId).eq(user)) - .to_string(SqliteQueryBuilder); + .to_string(DbQueryBuilder {}); sqlx::query(&query) // Extract the group id from the row. @@ -196,6 +197,80 @@ impl BackendHandler for SqlBackendHandler { } } +#[async_trait] +impl crate::infra::tcp_server::TcpBackendHandler for SqlBackendHandler { + async fn get_jwt_blacklist(&self) -> anyhow::Result> { + use sqlx::Result; + let query = Query::select() + .column(JwtBlacklist::JwtHash) + .from(JwtBlacklist::Table) + .to_string(DbQueryBuilder {}); + + sqlx::query(&query) + .map(|row: DbRow| row.get::(&*JwtBlacklist::JwtHash.to_string()) as u64) + .fetch(&self.sql_pool) + .collect::>>() + .await + .into_iter() + .collect::>>() + .map_err(|e| anyhow::anyhow!(e)) + } + + async fn create_refresh_token(&self, user: &str) -> Result<(String, chrono::Duration)> { + use rand::{distributions::Alphanumeric, rngs::SmallRng, Rng, SeedableRng}; + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + // TODO: Initialize the rng only once. Maybe Arc? + let mut rng = SmallRng::from_entropy(); + let refresh_token: String = std::iter::repeat(()) + .map(|()| rng.sample(Alphanumeric)) + .map(char::from) + .take(100) + .collect(); + let refresh_token_hash = { + let mut s = DefaultHasher::new(); + refresh_token.hash(&mut s); + s.finish() + }; + let duration = chrono::Duration::days(30); + let query = Query::insert() + .into_table(JwtRefreshStorage::Table) + .columns(vec![ + JwtRefreshStorage::RefreshTokenHash, + JwtRefreshStorage::UserId, + JwtRefreshStorage::ExpiryDate, + ]) + .values_panic(vec![ + (refresh_token_hash as i64).into(), + user.into(), + (chrono::Utc::now() + duration).naive_utc().into(), + ]) + .to_string(DbQueryBuilder {}); + sqlx::query(&query).execute(&self.sql_pool).await?; + Ok((refresh_token, duration)) + } + + async fn check_token(&self, token: &str, user: &str) -> Result { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + let refresh_token_hash = { + let mut s = DefaultHasher::new(); + token.hash(&mut s); + s.finish() + }; + let query = Query::select() + .expr(SimpleExpr::Value(1.into())) + .from(JwtRefreshStorage::Table) + .and_where(Expr::col(JwtRefreshStorage::RefreshTokenHash).eq(refresh_token_hash as i64)) + .and_where(Expr::col(JwtRefreshStorage::UserId).eq(user)) + .to_string(DbQueryBuilder {}); + Ok(sqlx::query(&query) + .fetch_optional(&self.sql_pool) + .await? + .is_some()) + } +} + #[cfg(test)] mockall::mock! { pub TestBackendHandler{} @@ -247,7 +322,7 @@ mod tests { chrono::NaiveDateTime::from_timestamp(0, 0).into(), pass.into(), ]) - .to_string(SqliteQueryBuilder); + .to_string(DbQueryBuilder {}); sqlx::query(&query).execute(sql_pool).await.unwrap(); } @@ -256,7 +331,7 @@ mod tests { .into_table(Groups::Table) .columns(vec![Groups::GroupId, Groups::DisplayName]) .values_panic(vec![id.into(), name.into()]) - .to_string(SqliteQueryBuilder); + .to_string(DbQueryBuilder {}); sqlx::query(&query).execute(sql_pool).await.unwrap(); } @@ -265,7 +340,7 @@ mod tests { .into_table(Memberships::Table) .columns(vec![Memberships::UserId, Memberships::GroupId]) .values_panic(vec![user_id.into(), group_id.into()]) - .to_string(SqliteQueryBuilder); + .to_string(DbQueryBuilder {}); sqlx::query(&query).execute(sql_pool).await.unwrap(); } diff --git a/src/domain/sql_tables.rs b/src/domain/sql_tables.rs index cc74834..d008fc7 100644 --- a/src/domain/sql_tables.rs +++ b/src/domain/sql_tables.rs @@ -3,6 +3,7 @@ use sea_query::*; pub type Pool = sqlx::sqlite::SqlitePool; pub type PoolOptions = sqlx::sqlite::SqlitePoolOptions; pub type DbRow = sqlx::sqlite::SqliteRow; +pub type DbQueryBuilder = SqliteQueryBuilder; #[derive(Iden)] pub enum Users { @@ -60,7 +61,7 @@ pub async fn init_table(pool: &Pool) -> sqlx::Result<()> { .col(ColumnDef::new(Users::Password).string_len(255).not_null()) .col(ColumnDef::new(Users::TotpSecret).string_len(64)) .col(ColumnDef::new(Users::MfaType).string_len(64)) - .to_string(SqliteQueryBuilder), + .to_string(DbQueryBuilder {}), ) .execute(pool) .await?; @@ -79,7 +80,7 @@ pub async fn init_table(pool: &Pool) -> sqlx::Result<()> { .string_len(255) .not_null(), ) - .to_string(SqliteQueryBuilder), + .to_string(DbQueryBuilder {}), ) .execute(pool) .await?; @@ -109,7 +110,7 @@ pub async fn init_table(pool: &Pool) -> sqlx::Result<()> { .on_delete(ForeignKeyAction::Cascade) .on_update(ForeignKeyAction::Cascade), ) - .to_string(SqliteQueryBuilder), + .to_string(DbQueryBuilder {}), ) .execute(pool) .await?; diff --git a/src/infra/jwt_sql_tables.rs b/src/infra/jwt_sql_tables.rs new file mode 100644 index 0000000..34137b6 --- /dev/null +++ b/src/infra/jwt_sql_tables.rs @@ -0,0 +1,92 @@ +use sea_query::*; + +pub use crate::domain::sql_tables::*; + +/// Contains the refresh tokens for a given user. +#[derive(Iden)] +pub enum JwtRefreshStorage { + Table, + RefreshTokenHash, + UserId, + ExpiryDate, +} + +/// Contains the blacklisted JWT that haven't expired yet. +#[derive(Iden)] +pub enum JwtBlacklist { + Table, + JwtHash, + UserId, + ExpiryDate, +} + +/// This needs to be initialized after the domain tables are. +pub async fn init_table(pool: &Pool) -> sqlx::Result<()> { + sqlx::query( + &Table::create() + .table(JwtRefreshStorage::Table) + .if_not_exists() + .col( + ColumnDef::new(JwtRefreshStorage::RefreshTokenHash) + .big_integer() + .not_null() + .primary_key(), + ) + .col( + ColumnDef::new(JwtRefreshStorage::UserId) + .string_len(255) + .not_null(), + ) + .col( + ColumnDef::new(JwtRefreshStorage::ExpiryDate) + .date_time() + .not_null(), + ) + .foreign_key( + ForeignKey::create() + .name("JwtRefreshStorageUserForeignKey") + .table(JwtRefreshStorage::Table, Users::Table) + .col(JwtRefreshStorage::UserId, Users::UserId) + .on_delete(ForeignKeyAction::Cascade) + .on_update(ForeignKeyAction::Cascade), + ) + .to_string(DbQueryBuilder {}), + ) + .execute(pool) + .await?; + + sqlx::query( + &Table::create() + .table(JwtBlacklist::Table) + .if_not_exists() + .col( + ColumnDef::new(JwtBlacklist::JwtHash) + .big_integer() + .not_null() + .primary_key(), + ) + .col( + ColumnDef::new(JwtBlacklist::UserId) + .string_len(255) + .not_null(), + ) + .col( + ColumnDef::new(JwtBlacklist::ExpiryDate) + .date_time() + .not_null(), + ) + .foreign_key( + ForeignKey::create() + .name("JwtBlacklistUserForeignKey") + .table(JwtBlacklist::Table, Users::Table) + .col(JwtBlacklist::UserId, Users::UserId) + .on_delete(ForeignKeyAction::Cascade) + .on_update(ForeignKeyAction::Cascade), + ) + .to_string(DbQueryBuilder {}), + ) + .execute(pool) + .await?; + + Ok(()) +} diff --git a/src/infra/mod.rs b/src/infra/mod.rs index 177b2e5..2b65f96 100644 --- a/src/infra/mod.rs +++ b/src/infra/mod.rs @@ -1,5 +1,6 @@ pub mod cli; pub mod configuration; +pub mod jwt_sql_tables; pub mod ldap_handler; pub mod ldap_server; pub mod logging; diff --git a/src/infra/tcp_server.rs b/src/infra/tcp_server.rs index fed1598..209c33f 100644 --- a/src/infra/tcp_server.rs +++ b/src/infra/tcp_server.rs @@ -1,4 +1,4 @@ -use crate::domain::{error::Error, handler::*}; +use crate::domain::handler::*; use crate::infra::configuration::Configuration; use actix_files::{Files, NamedFile}; use actix_http::HttpServiceBuilder; @@ -12,6 +12,7 @@ use actix_web::{ }; use actix_web_httpauth::{extractors::bearer::BearerAuth, middleware::HttpAuthentication}; use anyhow::{Context, Result}; +use async_trait::async_trait; use chrono::prelude::*; use futures_util::FutureExt; use futures_util::TryFutureExt; @@ -19,13 +20,24 @@ use hmac::{Hmac, NewMac}; use jwt::{SignWithKey, VerifyWithKey}; use log::*; use sha2::Sha512; -use std::collections::HashSet; +use std::collections::{hash_map::DefaultHasher, HashSet}; +use std::hash::{Hash, Hasher}; use std::path::PathBuf; use time::ext::NumericalDuration; type Token = jwt::Token; type SignedToken = Token; +type DomainError = crate::domain::error::Error; +type DomainResult = crate::domain::error::Result; + +#[async_trait] +pub trait TcpBackendHandler: BackendHandler { + async fn get_jwt_blacklist(&self) -> Result>; + async fn create_refresh_token(&self, user: &str) -> DomainResult<(String, chrono::Duration)>; + async fn check_token(&self, token: &str, user: &str) -> DomainResult; +} + async fn index(req: HttpRequest) -> actix_web::Result { let mut path = PathBuf::new(); path.push("app"); @@ -34,11 +46,11 @@ async fn index(req: HttpRequest) -> actix_web::Result { Ok(NamedFile::open(path)?) } -fn error_to_http_response(error: Error) -> ApiResult { +fn error_to_http_response(error: DomainError) -> ApiResult { ApiResult::Right( match error { - Error::AuthenticationError(_) => HttpResponse::Unauthorized(), - Error::DatabaseError(_) => HttpResponse::InternalServerError(), + DomainError::AuthenticationError(_) => HttpResponse::Unauthorized(), + DomainError::DatabaseError(_) => HttpResponse::InternalServerError(), } .body(error.to_string()), ) @@ -51,7 +63,7 @@ async fn user_list_handler( info: web::Json, ) -> ApiResult> where - Backend: BackendHandler + 'static, + Backend: TcpBackendHandler + 'static, { let req: ListUsersRequest = info.clone(); data.backend_handler @@ -64,6 +76,7 @@ where fn create_jwt(key: &Hmac, user: String, groups: HashSet) -> SignedToken { let claims = JWTClaims { exp: Utc::now() + chrono::Duration::days(1), + iat: Utc::now(), user, groups, }; @@ -74,12 +87,64 @@ fn create_jwt(key: &Hmac, user: String, groups: HashSet) -> Sign jwt::Token::new(header, claims).sign_with_key(key).unwrap() } +async fn get_refresh( + data: web::Data>, + request: HttpRequest, +) -> ApiResult +where + Backend: TcpBackendHandler + 'static, +{ + let backend_handler = &data.backend_handler; + let jwt_key = &data.jwt_key; + let (refresh_token, user) = match request.cookie("refresh_token") { + None => { + return ApiResult::Right(HttpResponse::Unauthorized().body("Missing refresh token")) + } + Some(t) => match t.value().split_once("+") { + None => { + return ApiResult::Right(HttpResponse::Unauthorized().body("Invalid refresh token")) + } + Some((t, u)) => (t.to_string(), u.to_string()), + }, + }; + let res_found = data.backend_handler.check_token(&refresh_token, &user).await; + // Async closures are not supported yet. + match res_found { + Ok(found) => { + if found { + backend_handler.get_user_groups(user.to_string()).await + } else { + Err(DomainError::AuthenticationError( + "Invalid refresh token".to_string(), + )) + } + } + Err(e) => Err(e), + } + .map(|groups| create_jwt(jwt_key, user.to_string(), groups)) + .map(|token| { + ApiResult::Right( + HttpResponse::Ok() + .cookie( + Cookie::build("token", token.as_str()) + .max_age(1.days()) + .path("/api") + .http_only(true) + .same_site(SameSite::Strict) + .finish(), + ) + .body(token.as_str().to_owned()), + ) + }) + .unwrap_or_else(error_to_http_response) +} + async fn post_authorize( data: web::Data>, request: web::Json, ) -> ApiResult where - Backend: BackendHandler + 'static, + Backend: TcpBackendHandler + 'static, { let req: BindRequest = request.clone(); data.backend_handler @@ -87,8 +152,16 @@ where // If the authentication was successful, we need to fetch the groups to create the JWT // token. .and_then(|_| data.backend_handler.get_user_groups(request.name.clone())) + .and_then(|g| async { + Ok(( + g, + data.backend_handler + .create_refresh_token(&request.name) + .await?, + )) + }) .await - .map(|groups| { + .map(|(groups, (refresh_token, max_age))| { let token = create_jwt(&data.jwt_key, request.name.clone(), groups); ApiResult::Right( HttpResponse::Ok() @@ -100,6 +173,14 @@ where .same_site(SameSite::Strict) .finish(), ) + .cookie( + Cookie::build("refresh_token", refresh_token + "+" + &request.name) + .max_age(max_age.num_days().days()) + .path("/api/authorize/refresh") + .http_only(true) + .same_site(SameSite::Strict) + .finish(), + ) .body(token.as_str().to_owned()), ) }) @@ -108,7 +189,7 @@ where fn api_config(cfg: &mut web::ServiceConfig) where - Backend: BackendHandler + 'static, + Backend: TcpBackendHandler + 'static, { let json_config = web::JsonConfig::default() .limit(4096) @@ -134,7 +215,7 @@ async fn token_validator( credentials: BearerAuth, ) -> Result where - Backend: BackendHandler + 'static, + Backend: TcpBackendHandler + 'static, { let state = req .app_data::>>() @@ -144,6 +225,14 @@ where if token.claims().exp.lt(&Utc::now()) { return Err(ErrorUnauthorized("Expired JWT")); } + let jwt_hash = { + let mut s = DefaultHasher::new(); + credentials.token().hash(&mut s); + s.finish() + }; + if state.jwt_blacklist.contains(&jwt_hash) { + return Err(ErrorUnauthorized("JWT was logged out")); + } let groups = &token.claims().groups; if groups.contains("lldap_admin") { debug!("Got authorized token for user {}", &token.claims().user); @@ -155,13 +244,18 @@ where } } -fn http_config(cfg: &mut web::ServiceConfig, backend_handler: Backend, jwt_secret: String) -where - Backend: BackendHandler + 'static, +fn http_config( + cfg: &mut web::ServiceConfig, + backend_handler: Backend, + jwt_secret: String, + jwt_blacklist: HashSet, +) where + Backend: TcpBackendHandler + 'static, { cfg.data(AppState:: { backend_handler, jwt_key: Hmac::new_varkey(&jwt_secret.as_bytes()).unwrap(), + jwt_blacklist, }) // Serve index.html and main.js, and default to index.html. .route( @@ -169,6 +263,7 @@ where web::get().to(index), ) .service(web::resource("/api/authorize").route(web::post().to(post_authorize::))) + .service(web::resource("/api/authorize/refresh").route(web::get().to(get_refresh::))) // API endpoint. .service( web::scope("/api") @@ -200,28 +295,33 @@ where struct AppState where - Backend: BackendHandler + 'static, + Backend: TcpBackendHandler + 'static, { pub backend_handler: Backend, pub jwt_key: Hmac, + pub jwt_blacklist: HashSet, } -pub fn build_tcp_server( +pub async fn build_tcp_server( config: &Configuration, backend_handler: Backend, server_builder: ServerBuilder, ) -> Result where - Backend: BackendHandler + 'static, + Backend: TcpBackendHandler + 'static, { let jwt_secret = config.jwt_secret.clone(); + let jwt_blacklist = backend_handler.get_jwt_blacklist().await?; server_builder .bind("http", ("0.0.0.0", config.http_port), move || { let backend_handler = backend_handler.clone(); let jwt_secret = jwt_secret.clone(); + let jwt_blacklist = jwt_blacklist.clone(); HttpServiceBuilder::new() .finish(map_config( - App::new().configure(move |cfg| http_config(cfg, backend_handler, jwt_secret)), + App::new().configure(move |cfg| { + http_config(cfg, backend_handler, jwt_secret, jwt_blacklist) + }), |_| AppConfig::default(), )) .tcp() diff --git a/src/main.rs b/src/main.rs index d1aa757..6112f9e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -20,8 +20,9 @@ async fn run_server(config: Configuration) -> Result<()> { backend_handler.clone(), actix_server::Server::build(), )?; + infra::jwt_sql_tables::init_table(&sql_pool).await?; let server_builder = - infra::tcp_server::build_tcp_server(&config, backend_handler, server_builder)?; + infra::tcp_server::build_tcp_server(&config, backend_handler, server_builder).await?; server_builder.workers(1).run().await?; Ok(()) }