From e1503743b56d200a340e1cd4db0518c3a195db18 Mon Sep 17 00:00:00 2001 From: Valentin Tolmer Date: Sun, 21 Nov 2021 11:33:11 +0100 Subject: [PATCH] server: Add methods to get/set a password reset token --- server/src/infra/sql_backend_handler.rs | 71 +++++++++++++++++++++---- server/src/infra/tcp_backend_handler.rs | 55 +++++++++++-------- 2 files changed, 92 insertions(+), 34 deletions(-) diff --git a/server/src/infra/sql_backend_handler.rs b/server/src/infra/sql_backend_handler.rs index 0f1ab58..353924b 100644 --- a/server/src/infra/sql_backend_handler.rs +++ b/server/src/infra/sql_backend_handler.rs @@ -6,10 +6,19 @@ use sea_query::{Expr, Iden, Query, SimpleExpr}; use sqlx::Row; use std::collections::HashSet; +fn gen_random_string(len: usize) -> String { + use rand::{distributions::Alphanumeric, rngs::SmallRng, Rng, SeedableRng}; + let mut rng = SmallRng::from_entropy(); + std::iter::repeat(()) + .map(|()| rng.sample(Alphanumeric)) + .map(char::from) + .take(len) + .collect() +} + #[async_trait] impl TcpBackendHandler for SqlBackendHandler { async fn get_jwt_blacklist(&self) -> anyhow::Result> { - use sqlx::Result; let query = Query::select() .column(JwtStorage::JwtHash) .from(JwtStorage::Table) @@ -21,21 +30,15 @@ impl TcpBackendHandler for SqlBackendHandler { .collect::>>() .await .into_iter() - .collect::>>() + .collect::>>() .map_err(|e| anyhow::anyhow!(e)) } async fn create_refresh_token(&self, user: &str) -> Result<(String, chrono::Duration)> { - use rand::{distributions::Alphanumeric, rngs::SmallRng, Rng, SeedableRng}; use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher}; // TODO: Initialize the rng only once. Maybe Arc? - let mut rng = SmallRng::from_entropy(); - let refresh_token: String = std::iter::repeat(()) - .map(|()| rng.sample(Alphanumeric)) - .map(char::from) - .take(100) - .collect(); + let refresh_token = gen_random_string(100); let refresh_token_hash = { let mut s = DefaultHasher::new(); refresh_token.hash(&mut s); @@ -71,7 +74,7 @@ impl TcpBackendHandler for SqlBackendHandler { .await? .is_some()) } - async fn blacklist_jwts(&self, user: &str) -> DomainResult> { + async fn blacklist_jwts(&self, user: &str) -> Result> { use sqlx::Result; let query = Query::select() .column(JwtStorage::JwtHash) @@ -94,7 +97,7 @@ impl TcpBackendHandler for SqlBackendHandler { sqlx::query(&query).execute(&self.sql_pool).await?; Ok(result?) } - async fn delete_refresh_token(&self, refresh_token_hash: u64) -> DomainResult<()> { + async fn delete_refresh_token(&self, refresh_token_hash: u64) -> Result<()> { let query = Query::delete() .from_table(JwtRefreshStorage::Table) .and_where(Expr::col(JwtRefreshStorage::RefreshTokenHash).eq(refresh_token_hash)) @@ -102,4 +105,50 @@ impl TcpBackendHandler for SqlBackendHandler { sqlx::query(&query).execute(&self.sql_pool).await?; Ok(()) } + + async fn start_password_reset(&self, user: &str) -> Result> { + let query = Query::select() + .column(Users::UserId) + .from(Users::Table) + .and_where(Expr::col(Users::UserId).eq(user)) + .to_string(DbQueryBuilder {}); + + // Check that the user exists. + if sqlx::query(&query).fetch_one(&self.sql_pool).await.is_err() { + return Ok(None); + } + + let token = gen_random_string(100); + let duration = chrono::Duration::minutes(10); + + let query = Query::insert() + .into_table(PasswordResetTokens::Table) + .columns(vec![ + PasswordResetTokens::Token, + PasswordResetTokens::UserId, + PasswordResetTokens::ExpiryDate, + ]) + .values_panic(vec![ + token.clone().into(), + user.into(), + (chrono::Utc::now() + duration).naive_utc().into(), + ]) + .to_string(DbQueryBuilder {}); + sqlx::query(&query).execute(&self.sql_pool).await?; + Ok(Some(token)) + } + + async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result { + let query = Query::select() + .column(PasswordResetTokens::UserId) + .from(PasswordResetTokens::Table) + .and_where(Expr::col(PasswordResetTokens::Token).eq(token)) + .and_where( + Expr::col(PasswordResetTokens::ExpiryDate).gt(chrono::Utc::now().naive_utc()), + ) + .to_string(DbQueryBuilder {}); + + let (user_id,) = sqlx::query_as(&query).fetch_one(&self.sql_pool).await?; + Ok(user_id) + } } diff --git a/server/src/infra/tcp_backend_handler.rs b/server/src/infra/tcp_backend_handler.rs index 194001e..c5d3e02 100644 --- a/server/src/infra/tcp_backend_handler.rs +++ b/server/src/infra/tcp_backend_handler.rs @@ -1,15 +1,22 @@ use async_trait::async_trait; use std::collections::HashSet; -pub type DomainResult = crate::domain::error::Result; +use crate::domain::error::Result; #[async_trait] pub trait TcpBackendHandler { async fn get_jwt_blacklist(&self) -> anyhow::Result>; - async fn create_refresh_token(&self, user: &str) -> DomainResult<(String, chrono::Duration)>; - async fn check_token(&self, refresh_token_hash: u64, user: &str) -> DomainResult; - async fn blacklist_jwts(&self, user: &str) -> DomainResult>; - async fn delete_refresh_token(&self, refresh_token_hash: u64) -> DomainResult<()>; + async fn create_refresh_token(&self, user: &str) -> Result<(String, chrono::Duration)>; + async fn check_token(&self, refresh_token_hash: u64, user: &str) -> Result; + async fn blacklist_jwts(&self, user: &str) -> Result>; + async fn delete_refresh_token(&self, refresh_token_hash: u64) -> Result<()>; + + /// Request a token to reset a user's password. + /// If the user doesn't exist, returns `Ok(None)`, otherwise `Ok(Some(token))`. + async fn start_password_reset(&self, user: &str) -> Result>; + + /// Get the user ID associated with a password reset token. + async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result; } #[cfg(test)] @@ -22,30 +29,32 @@ mockall::mock! { } #[async_trait] impl LoginHandler for TestTcpBackendHandler { - async fn bind(&self, request: BindRequest) -> DomainResult<()>; + async fn bind(&self, request: BindRequest) -> Result<()>; } #[async_trait] impl BackendHandler for TestTcpBackendHandler { - async fn list_users(&self, filters: Option) -> DomainResult>; - async fn list_groups(&self) -> DomainResult>; - async fn get_user_details(&self, user_id: &str) -> DomainResult; - async fn get_group_details(&self, group_id: GroupId) -> DomainResult; - async fn get_user_groups(&self, user: &str) -> DomainResult>; - async fn create_user(&self, request: CreateUserRequest) -> DomainResult<()>; - async fn update_user(&self, request: UpdateUserRequest) -> DomainResult<()>; - async fn update_group(&self, request: UpdateGroupRequest) -> DomainResult<()>; - async fn delete_user(&self, user_id: &str) -> DomainResult<()>; - async fn create_group(&self, group_name: &str) -> DomainResult; - async fn delete_group(&self, group_id: GroupId) -> DomainResult<()>; - async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> DomainResult<()>; - async fn remove_user_from_group(&self, user_id: &str, group_id: GroupId) -> DomainResult<()>; + async fn list_users(&self, filters: Option) -> Result>; + async fn list_groups(&self) -> Result>; + async fn get_user_details(&self, user_id: &str) -> Result; + async fn get_group_details(&self, group_id: GroupId) -> Result; + async fn get_user_groups(&self, user: &str) -> Result>; + async fn create_user(&self, request: CreateUserRequest) -> Result<()>; + async fn update_user(&self, request: UpdateUserRequest) -> Result<()>; + async fn update_group(&self, request: UpdateGroupRequest) -> Result<()>; + async fn delete_user(&self, user_id: &str) -> Result<()>; + async fn create_group(&self, group_name: &str) -> Result; + async fn delete_group(&self, group_id: GroupId) -> Result<()>; + async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> Result<()>; + async fn remove_user_from_group(&self, user_id: &str, group_id: GroupId) -> Result<()>; } #[async_trait] impl TcpBackendHandler for TestTcpBackendHandler { async fn get_jwt_blacklist(&self) -> anyhow::Result>; - async fn create_refresh_token(&self, user: &str) -> DomainResult<(String, chrono::Duration)>; - async fn check_token(&self, refresh_token_hash: u64, user: &str) -> DomainResult; - async fn blacklist_jwts(&self, user: &str) -> DomainResult>; - async fn delete_refresh_token(&self, refresh_token_hash: u64) -> DomainResult<()>; + async fn create_refresh_token(&self, user: &str) -> Result<(String, chrono::Duration)>; + async fn check_token(&self, refresh_token_hash: u64, user: &str) -> Result; + async fn blacklist_jwts(&self, user: &str) -> Result>; + async fn delete_refresh_token(&self, refresh_token_hash: u64) -> Result<()>; + async fn start_password_reset(&self, user: &str) -> Result>; + async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result; } }