From bf616493c5685776e4b68759153fee627c8a7456 Mon Sep 17 00:00:00 2001
From: Valentin Tolmer <valentin@tolmer.fr>
Date: Sun, 21 Nov 2021 11:33:11 +0100
Subject: [PATCH] server: Add methods to get/set a password reset token

---
 server/src/infra/sql_backend_handler.rs | 71 +++++++++++++++++++++----
 server/src/infra/tcp_backend_handler.rs | 55 +++++++++++--------
 2 files changed, 92 insertions(+), 34 deletions(-)

diff --git a/server/src/infra/sql_backend_handler.rs b/server/src/infra/sql_backend_handler.rs
index 0f1ab58..353924b 100644
--- a/server/src/infra/sql_backend_handler.rs
+++ b/server/src/infra/sql_backend_handler.rs
@@ -6,10 +6,19 @@ use sea_query::{Expr, Iden, Query, SimpleExpr};
 use sqlx::Row;
 use std::collections::HashSet;
 
+fn gen_random_string(len: usize) -> String {
+    use rand::{distributions::Alphanumeric, rngs::SmallRng, Rng, SeedableRng};
+    let mut rng = SmallRng::from_entropy();
+    std::iter::repeat(())
+        .map(|()| rng.sample(Alphanumeric))
+        .map(char::from)
+        .take(len)
+        .collect()
+}
+
 #[async_trait]
 impl TcpBackendHandler for SqlBackendHandler {
     async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>> {
-        use sqlx::Result;
         let query = Query::select()
             .column(JwtStorage::JwtHash)
             .from(JwtStorage::Table)
@@ -21,21 +30,15 @@ impl TcpBackendHandler for SqlBackendHandler {
             .collect::<Vec<sqlx::Result<u64>>>()
             .await
             .into_iter()
-            .collect::<Result<HashSet<u64>>>()
+            .collect::<sqlx::Result<HashSet<u64>>>()
             .map_err(|e| anyhow::anyhow!(e))
     }
 
     async fn create_refresh_token(&self, user: &str) -> Result<(String, chrono::Duration)> {
-        use rand::{distributions::Alphanumeric, rngs::SmallRng, Rng, SeedableRng};
         use std::collections::hash_map::DefaultHasher;
         use std::hash::{Hash, Hasher};
         // TODO: Initialize the rng only once. Maybe Arc<Cell>?
-        let mut rng = SmallRng::from_entropy();
-        let refresh_token: String = std::iter::repeat(())
-            .map(|()| rng.sample(Alphanumeric))
-            .map(char::from)
-            .take(100)
-            .collect();
+        let refresh_token = gen_random_string(100);
         let refresh_token_hash = {
             let mut s = DefaultHasher::new();
             refresh_token.hash(&mut s);
@@ -71,7 +74,7 @@ impl TcpBackendHandler for SqlBackendHandler {
             .await?
             .is_some())
     }
-    async fn blacklist_jwts(&self, user: &str) -> DomainResult<HashSet<u64>> {
+    async fn blacklist_jwts(&self, user: &str) -> Result<HashSet<u64>> {
         use sqlx::Result;
         let query = Query::select()
             .column(JwtStorage::JwtHash)
@@ -94,7 +97,7 @@ impl TcpBackendHandler for SqlBackendHandler {
         sqlx::query(&query).execute(&self.sql_pool).await?;
         Ok(result?)
     }
-    async fn delete_refresh_token(&self, refresh_token_hash: u64) -> DomainResult<()> {
+    async fn delete_refresh_token(&self, refresh_token_hash: u64) -> Result<()> {
         let query = Query::delete()
             .from_table(JwtRefreshStorage::Table)
             .and_where(Expr::col(JwtRefreshStorage::RefreshTokenHash).eq(refresh_token_hash))
@@ -102,4 +105,50 @@ impl TcpBackendHandler for SqlBackendHandler {
         sqlx::query(&query).execute(&self.sql_pool).await?;
         Ok(())
     }
+
+    async fn start_password_reset(&self, user: &str) -> Result<Option<String>> {
+        let query = Query::select()
+            .column(Users::UserId)
+            .from(Users::Table)
+            .and_where(Expr::col(Users::UserId).eq(user))
+            .to_string(DbQueryBuilder {});
+
+        // Check that the user exists.
+        if sqlx::query(&query).fetch_one(&self.sql_pool).await.is_err() {
+            return Ok(None);
+        }
+
+        let token = gen_random_string(100);
+        let duration = chrono::Duration::minutes(10);
+
+        let query = Query::insert()
+            .into_table(PasswordResetTokens::Table)
+            .columns(vec![
+                PasswordResetTokens::Token,
+                PasswordResetTokens::UserId,
+                PasswordResetTokens::ExpiryDate,
+            ])
+            .values_panic(vec![
+                token.clone().into(),
+                user.into(),
+                (chrono::Utc::now() + duration).naive_utc().into(),
+            ])
+            .to_string(DbQueryBuilder {});
+        sqlx::query(&query).execute(&self.sql_pool).await?;
+        Ok(Some(token))
+    }
+
+    async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result<String> {
+        let query = Query::select()
+            .column(PasswordResetTokens::UserId)
+            .from(PasswordResetTokens::Table)
+            .and_where(Expr::col(PasswordResetTokens::Token).eq(token))
+            .and_where(
+                Expr::col(PasswordResetTokens::ExpiryDate).gt(chrono::Utc::now().naive_utc()),
+            )
+            .to_string(DbQueryBuilder {});
+
+        let (user_id,) = sqlx::query_as(&query).fetch_one(&self.sql_pool).await?;
+        Ok(user_id)
+    }
 }
diff --git a/server/src/infra/tcp_backend_handler.rs b/server/src/infra/tcp_backend_handler.rs
index 194001e..c5d3e02 100644
--- a/server/src/infra/tcp_backend_handler.rs
+++ b/server/src/infra/tcp_backend_handler.rs
@@ -1,15 +1,22 @@
 use async_trait::async_trait;
 use std::collections::HashSet;
 
-pub type DomainResult<T> = crate::domain::error::Result<T>;
+use crate::domain::error::Result;
 
 #[async_trait]
 pub trait TcpBackendHandler {
     async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>>;
-    async fn create_refresh_token(&self, user: &str) -> DomainResult<(String, chrono::Duration)>;
-    async fn check_token(&self, refresh_token_hash: u64, user: &str) -> DomainResult<bool>;
-    async fn blacklist_jwts(&self, user: &str) -> DomainResult<HashSet<u64>>;
-    async fn delete_refresh_token(&self, refresh_token_hash: u64) -> DomainResult<()>;
+    async fn create_refresh_token(&self, user: &str) -> Result<(String, chrono::Duration)>;
+    async fn check_token(&self, refresh_token_hash: u64, user: &str) -> Result<bool>;
+    async fn blacklist_jwts(&self, user: &str) -> Result<HashSet<u64>>;
+    async fn delete_refresh_token(&self, refresh_token_hash: u64) -> Result<()>;
+
+    /// Request a token to reset a user's password.
+    /// If the user doesn't exist, returns `Ok(None)`, otherwise `Ok(Some(token))`.
+    async fn start_password_reset(&self, user: &str) -> Result<Option<String>>;
+
+    /// Get the user ID associated with a password reset token.
+    async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result<String>;
 }
 
 #[cfg(test)]
@@ -22,30 +29,32 @@ mockall::mock! {
     }
     #[async_trait]
     impl LoginHandler for TestTcpBackendHandler {
-        async fn bind(&self, request: BindRequest) -> DomainResult<()>;
+        async fn bind(&self, request: BindRequest) -> Result<()>;
     }
     #[async_trait]
     impl BackendHandler for TestTcpBackendHandler {
-        async fn list_users(&self, filters: Option<RequestFilter>) -> DomainResult<Vec<User>>;
-        async fn list_groups(&self) -> DomainResult<Vec<Group>>;
-        async fn get_user_details(&self, user_id: &str) -> DomainResult<User>;
-        async fn get_group_details(&self, group_id: GroupId) -> DomainResult<GroupIdAndName>;
-        async fn get_user_groups(&self, user: &str) -> DomainResult<HashSet<GroupIdAndName>>;
-        async fn create_user(&self, request: CreateUserRequest) -> DomainResult<()>;
-        async fn update_user(&self, request: UpdateUserRequest) -> DomainResult<()>;
-        async fn update_group(&self, request: UpdateGroupRequest) -> DomainResult<()>;
-        async fn delete_user(&self, user_id: &str) -> DomainResult<()>;
-        async fn create_group(&self, group_name: &str) -> DomainResult<GroupId>;
-        async fn delete_group(&self, group_id: GroupId) -> DomainResult<()>;
-        async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> DomainResult<()>;
-        async fn remove_user_from_group(&self, user_id: &str, group_id: GroupId) -> DomainResult<()>;
+        async fn list_users(&self, filters: Option<RequestFilter>) -> Result<Vec<User>>;
+        async fn list_groups(&self) -> Result<Vec<Group>>;
+        async fn get_user_details(&self, user_id: &str) -> Result<User>;
+        async fn get_group_details(&self, group_id: GroupId) -> Result<GroupIdAndName>;
+        async fn get_user_groups(&self, user: &str) -> Result<HashSet<GroupIdAndName>>;
+        async fn create_user(&self, request: CreateUserRequest) -> Result<()>;
+        async fn update_user(&self, request: UpdateUserRequest) -> Result<()>;
+        async fn update_group(&self, request: UpdateGroupRequest) -> Result<()>;
+        async fn delete_user(&self, user_id: &str) -> Result<()>;
+        async fn create_group(&self, group_name: &str) -> Result<GroupId>;
+        async fn delete_group(&self, group_id: GroupId) -> Result<()>;
+        async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> Result<()>;
+        async fn remove_user_from_group(&self, user_id: &str, group_id: GroupId) -> Result<()>;
     }
     #[async_trait]
     impl TcpBackendHandler for TestTcpBackendHandler {
         async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>>;
-        async fn create_refresh_token(&self, user: &str) -> DomainResult<(String, chrono::Duration)>;
-        async fn check_token(&self, refresh_token_hash: u64, user: &str) -> DomainResult<bool>;
-        async fn blacklist_jwts(&self, user: &str) -> DomainResult<HashSet<u64>>;
-        async fn delete_refresh_token(&self, refresh_token_hash: u64) -> DomainResult<()>;
+        async fn create_refresh_token(&self, user: &str) -> Result<(String, chrono::Duration)>;
+        async fn check_token(&self, refresh_token_hash: u64, user: &str) -> Result<bool>;
+        async fn blacklist_jwts(&self, user: &str) -> Result<HashSet<u64>>;
+        async fn delete_refresh_token(&self, refresh_token_hash: u64) -> Result<()>;
+        async fn start_password_reset(&self, user: &str) -> Result<Option<String>>;
+        async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result<String>;
     }
 }