server: Add methods to get/set a password reset token

This commit is contained in:
Valentin Tolmer 2021-11-21 11:33:11 +01:00 committed by nitnelave
parent 88732556c1
commit e1503743b5
2 changed files with 92 additions and 34 deletions

View File

@ -6,10 +6,19 @@ use sea_query::{Expr, Iden, Query, SimpleExpr};
use sqlx::Row; use sqlx::Row;
use std::collections::HashSet; use std::collections::HashSet;
fn gen_random_string(len: usize) -> String {
use rand::{distributions::Alphanumeric, rngs::SmallRng, Rng, SeedableRng};
let mut rng = SmallRng::from_entropy();
std::iter::repeat(())
.map(|()| rng.sample(Alphanumeric))
.map(char::from)
.take(len)
.collect()
}
#[async_trait] #[async_trait]
impl TcpBackendHandler for SqlBackendHandler { impl TcpBackendHandler for SqlBackendHandler {
async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>> { async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>> {
use sqlx::Result;
let query = Query::select() let query = Query::select()
.column(JwtStorage::JwtHash) .column(JwtStorage::JwtHash)
.from(JwtStorage::Table) .from(JwtStorage::Table)
@ -21,21 +30,15 @@ impl TcpBackendHandler for SqlBackendHandler {
.collect::<Vec<sqlx::Result<u64>>>() .collect::<Vec<sqlx::Result<u64>>>()
.await .await
.into_iter() .into_iter()
.collect::<Result<HashSet<u64>>>() .collect::<sqlx::Result<HashSet<u64>>>()
.map_err(|e| anyhow::anyhow!(e)) .map_err(|e| anyhow::anyhow!(e))
} }
async fn create_refresh_token(&self, user: &str) -> Result<(String, chrono::Duration)> { async fn create_refresh_token(&self, user: &str) -> Result<(String, chrono::Duration)> {
use rand::{distributions::Alphanumeric, rngs::SmallRng, Rng, SeedableRng};
use std::collections::hash_map::DefaultHasher; use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher}; use std::hash::{Hash, Hasher};
// TODO: Initialize the rng only once. Maybe Arc<Cell>? // TODO: Initialize the rng only once. Maybe Arc<Cell>?
let mut rng = SmallRng::from_entropy(); let refresh_token = gen_random_string(100);
let refresh_token: String = std::iter::repeat(())
.map(|()| rng.sample(Alphanumeric))
.map(char::from)
.take(100)
.collect();
let refresh_token_hash = { let refresh_token_hash = {
let mut s = DefaultHasher::new(); let mut s = DefaultHasher::new();
refresh_token.hash(&mut s); refresh_token.hash(&mut s);
@ -71,7 +74,7 @@ impl TcpBackendHandler for SqlBackendHandler {
.await? .await?
.is_some()) .is_some())
} }
async fn blacklist_jwts(&self, user: &str) -> DomainResult<HashSet<u64>> { async fn blacklist_jwts(&self, user: &str) -> Result<HashSet<u64>> {
use sqlx::Result; use sqlx::Result;
let query = Query::select() let query = Query::select()
.column(JwtStorage::JwtHash) .column(JwtStorage::JwtHash)
@ -94,7 +97,7 @@ impl TcpBackendHandler for SqlBackendHandler {
sqlx::query(&query).execute(&self.sql_pool).await?; sqlx::query(&query).execute(&self.sql_pool).await?;
Ok(result?) Ok(result?)
} }
async fn delete_refresh_token(&self, refresh_token_hash: u64) -> DomainResult<()> { async fn delete_refresh_token(&self, refresh_token_hash: u64) -> Result<()> {
let query = Query::delete() let query = Query::delete()
.from_table(JwtRefreshStorage::Table) .from_table(JwtRefreshStorage::Table)
.and_where(Expr::col(JwtRefreshStorage::RefreshTokenHash).eq(refresh_token_hash)) .and_where(Expr::col(JwtRefreshStorage::RefreshTokenHash).eq(refresh_token_hash))
@ -102,4 +105,50 @@ impl TcpBackendHandler for SqlBackendHandler {
sqlx::query(&query).execute(&self.sql_pool).await?; sqlx::query(&query).execute(&self.sql_pool).await?;
Ok(()) Ok(())
} }
async fn start_password_reset(&self, user: &str) -> Result<Option<String>> {
let query = Query::select()
.column(Users::UserId)
.from(Users::Table)
.and_where(Expr::col(Users::UserId).eq(user))
.to_string(DbQueryBuilder {});
// Check that the user exists.
if sqlx::query(&query).fetch_one(&self.sql_pool).await.is_err() {
return Ok(None);
}
let token = gen_random_string(100);
let duration = chrono::Duration::minutes(10);
let query = Query::insert()
.into_table(PasswordResetTokens::Table)
.columns(vec![
PasswordResetTokens::Token,
PasswordResetTokens::UserId,
PasswordResetTokens::ExpiryDate,
])
.values_panic(vec![
token.clone().into(),
user.into(),
(chrono::Utc::now() + duration).naive_utc().into(),
])
.to_string(DbQueryBuilder {});
sqlx::query(&query).execute(&self.sql_pool).await?;
Ok(Some(token))
}
async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result<String> {
let query = Query::select()
.column(PasswordResetTokens::UserId)
.from(PasswordResetTokens::Table)
.and_where(Expr::col(PasswordResetTokens::Token).eq(token))
.and_where(
Expr::col(PasswordResetTokens::ExpiryDate).gt(chrono::Utc::now().naive_utc()),
)
.to_string(DbQueryBuilder {});
let (user_id,) = sqlx::query_as(&query).fetch_one(&self.sql_pool).await?;
Ok(user_id)
}
} }

View File

@ -1,15 +1,22 @@
use async_trait::async_trait; use async_trait::async_trait;
use std::collections::HashSet; use std::collections::HashSet;
pub type DomainResult<T> = crate::domain::error::Result<T>; use crate::domain::error::Result;
#[async_trait] #[async_trait]
pub trait TcpBackendHandler { pub trait TcpBackendHandler {
async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>>; async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>>;
async fn create_refresh_token(&self, user: &str) -> DomainResult<(String, chrono::Duration)>; async fn create_refresh_token(&self, user: &str) -> Result<(String, chrono::Duration)>;
async fn check_token(&self, refresh_token_hash: u64, user: &str) -> DomainResult<bool>; async fn check_token(&self, refresh_token_hash: u64, user: &str) -> Result<bool>;
async fn blacklist_jwts(&self, user: &str) -> DomainResult<HashSet<u64>>; async fn blacklist_jwts(&self, user: &str) -> Result<HashSet<u64>>;
async fn delete_refresh_token(&self, refresh_token_hash: u64) -> DomainResult<()>; async fn delete_refresh_token(&self, refresh_token_hash: u64) -> Result<()>;
/// Request a token to reset a user's password.
/// If the user doesn't exist, returns `Ok(None)`, otherwise `Ok(Some(token))`.
async fn start_password_reset(&self, user: &str) -> Result<Option<String>>;
/// Get the user ID associated with a password reset token.
async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result<String>;
} }
#[cfg(test)] #[cfg(test)]
@ -22,30 +29,32 @@ mockall::mock! {
} }
#[async_trait] #[async_trait]
impl LoginHandler for TestTcpBackendHandler { impl LoginHandler for TestTcpBackendHandler {
async fn bind(&self, request: BindRequest) -> DomainResult<()>; async fn bind(&self, request: BindRequest) -> Result<()>;
} }
#[async_trait] #[async_trait]
impl BackendHandler for TestTcpBackendHandler { impl BackendHandler for TestTcpBackendHandler {
async fn list_users(&self, filters: Option<RequestFilter>) -> DomainResult<Vec<User>>; async fn list_users(&self, filters: Option<RequestFilter>) -> Result<Vec<User>>;
async fn list_groups(&self) -> DomainResult<Vec<Group>>; async fn list_groups(&self) -> Result<Vec<Group>>;
async fn get_user_details(&self, user_id: &str) -> DomainResult<User>; async fn get_user_details(&self, user_id: &str) -> Result<User>;
async fn get_group_details(&self, group_id: GroupId) -> DomainResult<GroupIdAndName>; async fn get_group_details(&self, group_id: GroupId) -> Result<GroupIdAndName>;
async fn get_user_groups(&self, user: &str) -> DomainResult<HashSet<GroupIdAndName>>; async fn get_user_groups(&self, user: &str) -> Result<HashSet<GroupIdAndName>>;
async fn create_user(&self, request: CreateUserRequest) -> DomainResult<()>; async fn create_user(&self, request: CreateUserRequest) -> Result<()>;
async fn update_user(&self, request: UpdateUserRequest) -> DomainResult<()>; async fn update_user(&self, request: UpdateUserRequest) -> Result<()>;
async fn update_group(&self, request: UpdateGroupRequest) -> DomainResult<()>; async fn update_group(&self, request: UpdateGroupRequest) -> Result<()>;
async fn delete_user(&self, user_id: &str) -> DomainResult<()>; async fn delete_user(&self, user_id: &str) -> Result<()>;
async fn create_group(&self, group_name: &str) -> DomainResult<GroupId>; async fn create_group(&self, group_name: &str) -> Result<GroupId>;
async fn delete_group(&self, group_id: GroupId) -> DomainResult<()>; async fn delete_group(&self, group_id: GroupId) -> Result<()>;
async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> DomainResult<()>; async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> Result<()>;
async fn remove_user_from_group(&self, user_id: &str, group_id: GroupId) -> DomainResult<()>; async fn remove_user_from_group(&self, user_id: &str, group_id: GroupId) -> Result<()>;
} }
#[async_trait] #[async_trait]
impl TcpBackendHandler for TestTcpBackendHandler { impl TcpBackendHandler for TestTcpBackendHandler {
async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>>; async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>>;
async fn create_refresh_token(&self, user: &str) -> DomainResult<(String, chrono::Duration)>; async fn create_refresh_token(&self, user: &str) -> Result<(String, chrono::Duration)>;
async fn check_token(&self, refresh_token_hash: u64, user: &str) -> DomainResult<bool>; async fn check_token(&self, refresh_token_hash: u64, user: &str) -> Result<bool>;
async fn blacklist_jwts(&self, user: &str) -> DomainResult<HashSet<u64>>; async fn blacklist_jwts(&self, user: &str) -> Result<HashSet<u64>>;
async fn delete_refresh_token(&self, refresh_token_hash: u64) -> DomainResult<()>; async fn delete_refresh_token(&self, refresh_token_hash: u64) -> Result<()>;
async fn start_password_reset(&self, user: &str) -> Result<Option<String>>;
async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result<String>;
} }
} }