mirror of
https://github.com/nitnelave/lldap.git
synced 2023-04-12 14:25:13 +00:00
server: Add methods to get/set a password reset token
This commit is contained in:
parent
88732556c1
commit
e1503743b5
@ -6,10 +6,19 @@ use sea_query::{Expr, Iden, Query, SimpleExpr};
|
|||||||
use sqlx::Row;
|
use sqlx::Row;
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
|
|
||||||
|
fn gen_random_string(len: usize) -> String {
|
||||||
|
use rand::{distributions::Alphanumeric, rngs::SmallRng, Rng, SeedableRng};
|
||||||
|
let mut rng = SmallRng::from_entropy();
|
||||||
|
std::iter::repeat(())
|
||||||
|
.map(|()| rng.sample(Alphanumeric))
|
||||||
|
.map(char::from)
|
||||||
|
.take(len)
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl TcpBackendHandler for SqlBackendHandler {
|
impl TcpBackendHandler for SqlBackendHandler {
|
||||||
async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>> {
|
async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>> {
|
||||||
use sqlx::Result;
|
|
||||||
let query = Query::select()
|
let query = Query::select()
|
||||||
.column(JwtStorage::JwtHash)
|
.column(JwtStorage::JwtHash)
|
||||||
.from(JwtStorage::Table)
|
.from(JwtStorage::Table)
|
||||||
@ -21,21 +30,15 @@ impl TcpBackendHandler for SqlBackendHandler {
|
|||||||
.collect::<Vec<sqlx::Result<u64>>>()
|
.collect::<Vec<sqlx::Result<u64>>>()
|
||||||
.await
|
.await
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.collect::<Result<HashSet<u64>>>()
|
.collect::<sqlx::Result<HashSet<u64>>>()
|
||||||
.map_err(|e| anyhow::anyhow!(e))
|
.map_err(|e| anyhow::anyhow!(e))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn create_refresh_token(&self, user: &str) -> Result<(String, chrono::Duration)> {
|
async fn create_refresh_token(&self, user: &str) -> Result<(String, chrono::Duration)> {
|
||||||
use rand::{distributions::Alphanumeric, rngs::SmallRng, Rng, SeedableRng};
|
|
||||||
use std::collections::hash_map::DefaultHasher;
|
use std::collections::hash_map::DefaultHasher;
|
||||||
use std::hash::{Hash, Hasher};
|
use std::hash::{Hash, Hasher};
|
||||||
// TODO: Initialize the rng only once. Maybe Arc<Cell>?
|
// TODO: Initialize the rng only once. Maybe Arc<Cell>?
|
||||||
let mut rng = SmallRng::from_entropy();
|
let refresh_token = gen_random_string(100);
|
||||||
let refresh_token: String = std::iter::repeat(())
|
|
||||||
.map(|()| rng.sample(Alphanumeric))
|
|
||||||
.map(char::from)
|
|
||||||
.take(100)
|
|
||||||
.collect();
|
|
||||||
let refresh_token_hash = {
|
let refresh_token_hash = {
|
||||||
let mut s = DefaultHasher::new();
|
let mut s = DefaultHasher::new();
|
||||||
refresh_token.hash(&mut s);
|
refresh_token.hash(&mut s);
|
||||||
@ -71,7 +74,7 @@ impl TcpBackendHandler for SqlBackendHandler {
|
|||||||
.await?
|
.await?
|
||||||
.is_some())
|
.is_some())
|
||||||
}
|
}
|
||||||
async fn blacklist_jwts(&self, user: &str) -> DomainResult<HashSet<u64>> {
|
async fn blacklist_jwts(&self, user: &str) -> Result<HashSet<u64>> {
|
||||||
use sqlx::Result;
|
use sqlx::Result;
|
||||||
let query = Query::select()
|
let query = Query::select()
|
||||||
.column(JwtStorage::JwtHash)
|
.column(JwtStorage::JwtHash)
|
||||||
@ -94,7 +97,7 @@ impl TcpBackendHandler for SqlBackendHandler {
|
|||||||
sqlx::query(&query).execute(&self.sql_pool).await?;
|
sqlx::query(&query).execute(&self.sql_pool).await?;
|
||||||
Ok(result?)
|
Ok(result?)
|
||||||
}
|
}
|
||||||
async fn delete_refresh_token(&self, refresh_token_hash: u64) -> DomainResult<()> {
|
async fn delete_refresh_token(&self, refresh_token_hash: u64) -> Result<()> {
|
||||||
let query = Query::delete()
|
let query = Query::delete()
|
||||||
.from_table(JwtRefreshStorage::Table)
|
.from_table(JwtRefreshStorage::Table)
|
||||||
.and_where(Expr::col(JwtRefreshStorage::RefreshTokenHash).eq(refresh_token_hash))
|
.and_where(Expr::col(JwtRefreshStorage::RefreshTokenHash).eq(refresh_token_hash))
|
||||||
@ -102,4 +105,50 @@ impl TcpBackendHandler for SqlBackendHandler {
|
|||||||
sqlx::query(&query).execute(&self.sql_pool).await?;
|
sqlx::query(&query).execute(&self.sql_pool).await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn start_password_reset(&self, user: &str) -> Result<Option<String>> {
|
||||||
|
let query = Query::select()
|
||||||
|
.column(Users::UserId)
|
||||||
|
.from(Users::Table)
|
||||||
|
.and_where(Expr::col(Users::UserId).eq(user))
|
||||||
|
.to_string(DbQueryBuilder {});
|
||||||
|
|
||||||
|
// Check that the user exists.
|
||||||
|
if sqlx::query(&query).fetch_one(&self.sql_pool).await.is_err() {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
let token = gen_random_string(100);
|
||||||
|
let duration = chrono::Duration::minutes(10);
|
||||||
|
|
||||||
|
let query = Query::insert()
|
||||||
|
.into_table(PasswordResetTokens::Table)
|
||||||
|
.columns(vec![
|
||||||
|
PasswordResetTokens::Token,
|
||||||
|
PasswordResetTokens::UserId,
|
||||||
|
PasswordResetTokens::ExpiryDate,
|
||||||
|
])
|
||||||
|
.values_panic(vec![
|
||||||
|
token.clone().into(),
|
||||||
|
user.into(),
|
||||||
|
(chrono::Utc::now() + duration).naive_utc().into(),
|
||||||
|
])
|
||||||
|
.to_string(DbQueryBuilder {});
|
||||||
|
sqlx::query(&query).execute(&self.sql_pool).await?;
|
||||||
|
Ok(Some(token))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result<String> {
|
||||||
|
let query = Query::select()
|
||||||
|
.column(PasswordResetTokens::UserId)
|
||||||
|
.from(PasswordResetTokens::Table)
|
||||||
|
.and_where(Expr::col(PasswordResetTokens::Token).eq(token))
|
||||||
|
.and_where(
|
||||||
|
Expr::col(PasswordResetTokens::ExpiryDate).gt(chrono::Utc::now().naive_utc()),
|
||||||
|
)
|
||||||
|
.to_string(DbQueryBuilder {});
|
||||||
|
|
||||||
|
let (user_id,) = sqlx::query_as(&query).fetch_one(&self.sql_pool).await?;
|
||||||
|
Ok(user_id)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,15 +1,22 @@
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
|
|
||||||
pub type DomainResult<T> = crate::domain::error::Result<T>;
|
use crate::domain::error::Result;
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait TcpBackendHandler {
|
pub trait TcpBackendHandler {
|
||||||
async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>>;
|
async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>>;
|
||||||
async fn create_refresh_token(&self, user: &str) -> DomainResult<(String, chrono::Duration)>;
|
async fn create_refresh_token(&self, user: &str) -> Result<(String, chrono::Duration)>;
|
||||||
async fn check_token(&self, refresh_token_hash: u64, user: &str) -> DomainResult<bool>;
|
async fn check_token(&self, refresh_token_hash: u64, user: &str) -> Result<bool>;
|
||||||
async fn blacklist_jwts(&self, user: &str) -> DomainResult<HashSet<u64>>;
|
async fn blacklist_jwts(&self, user: &str) -> Result<HashSet<u64>>;
|
||||||
async fn delete_refresh_token(&self, refresh_token_hash: u64) -> DomainResult<()>;
|
async fn delete_refresh_token(&self, refresh_token_hash: u64) -> Result<()>;
|
||||||
|
|
||||||
|
/// Request a token to reset a user's password.
|
||||||
|
/// If the user doesn't exist, returns `Ok(None)`, otherwise `Ok(Some(token))`.
|
||||||
|
async fn start_password_reset(&self, user: &str) -> Result<Option<String>>;
|
||||||
|
|
||||||
|
/// Get the user ID associated with a password reset token.
|
||||||
|
async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result<String>;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@ -22,30 +29,32 @@ mockall::mock! {
|
|||||||
}
|
}
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl LoginHandler for TestTcpBackendHandler {
|
impl LoginHandler for TestTcpBackendHandler {
|
||||||
async fn bind(&self, request: BindRequest) -> DomainResult<()>;
|
async fn bind(&self, request: BindRequest) -> Result<()>;
|
||||||
}
|
}
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl BackendHandler for TestTcpBackendHandler {
|
impl BackendHandler for TestTcpBackendHandler {
|
||||||
async fn list_users(&self, filters: Option<RequestFilter>) -> DomainResult<Vec<User>>;
|
async fn list_users(&self, filters: Option<RequestFilter>) -> Result<Vec<User>>;
|
||||||
async fn list_groups(&self) -> DomainResult<Vec<Group>>;
|
async fn list_groups(&self) -> Result<Vec<Group>>;
|
||||||
async fn get_user_details(&self, user_id: &str) -> DomainResult<User>;
|
async fn get_user_details(&self, user_id: &str) -> Result<User>;
|
||||||
async fn get_group_details(&self, group_id: GroupId) -> DomainResult<GroupIdAndName>;
|
async fn get_group_details(&self, group_id: GroupId) -> Result<GroupIdAndName>;
|
||||||
async fn get_user_groups(&self, user: &str) -> DomainResult<HashSet<GroupIdAndName>>;
|
async fn get_user_groups(&self, user: &str) -> Result<HashSet<GroupIdAndName>>;
|
||||||
async fn create_user(&self, request: CreateUserRequest) -> DomainResult<()>;
|
async fn create_user(&self, request: CreateUserRequest) -> Result<()>;
|
||||||
async fn update_user(&self, request: UpdateUserRequest) -> DomainResult<()>;
|
async fn update_user(&self, request: UpdateUserRequest) -> Result<()>;
|
||||||
async fn update_group(&self, request: UpdateGroupRequest) -> DomainResult<()>;
|
async fn update_group(&self, request: UpdateGroupRequest) -> Result<()>;
|
||||||
async fn delete_user(&self, user_id: &str) -> DomainResult<()>;
|
async fn delete_user(&self, user_id: &str) -> Result<()>;
|
||||||
async fn create_group(&self, group_name: &str) -> DomainResult<GroupId>;
|
async fn create_group(&self, group_name: &str) -> Result<GroupId>;
|
||||||
async fn delete_group(&self, group_id: GroupId) -> DomainResult<()>;
|
async fn delete_group(&self, group_id: GroupId) -> Result<()>;
|
||||||
async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> DomainResult<()>;
|
async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> Result<()>;
|
||||||
async fn remove_user_from_group(&self, user_id: &str, group_id: GroupId) -> DomainResult<()>;
|
async fn remove_user_from_group(&self, user_id: &str, group_id: GroupId) -> Result<()>;
|
||||||
}
|
}
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl TcpBackendHandler for TestTcpBackendHandler {
|
impl TcpBackendHandler for TestTcpBackendHandler {
|
||||||
async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>>;
|
async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>>;
|
||||||
async fn create_refresh_token(&self, user: &str) -> DomainResult<(String, chrono::Duration)>;
|
async fn create_refresh_token(&self, user: &str) -> Result<(String, chrono::Duration)>;
|
||||||
async fn check_token(&self, refresh_token_hash: u64, user: &str) -> DomainResult<bool>;
|
async fn check_token(&self, refresh_token_hash: u64, user: &str) -> Result<bool>;
|
||||||
async fn blacklist_jwts(&self, user: &str) -> DomainResult<HashSet<u64>>;
|
async fn blacklist_jwts(&self, user: &str) -> Result<HashSet<u64>>;
|
||||||
async fn delete_refresh_token(&self, refresh_token_hash: u64) -> DomainResult<()>;
|
async fn delete_refresh_token(&self, refresh_token_hash: u64) -> Result<()>;
|
||||||
|
async fn start_password_reset(&self, user: &str) -> Result<Option<String>>;
|
||||||
|
async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result<String>;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user