From 5a70f2ebc2a4264477001cdd2ec8242b703d589d Mon Sep 17 00:00:00 2001 From: Valentin Tolmer Date: Wed, 26 May 2021 08:43:31 +0200 Subject: [PATCH] Add a method to create a user --- Cargo.toml | 1 + model/src/lib.rs | 11 +++ src/domain/handler.rs | 2 + src/domain/sql_backend_handler.rs | 140 ++++++++++++++++++++++-------- src/domain/sql_tables.rs | 16 ++-- src/infra/tcp_backend_handler.rs | 1 + 6 files changed, 127 insertions(+), 44 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 7fa009b..c8f5bb2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ actix-service = "2.0.0" actix-web = "4.0.0-beta.6" actix-web-httpauth = "0.6.0-beta.1" anyhow = "*" +rust-argon2 = "0.8" async-trait = "0.1" chrono = { version = "*", features = [ "serde" ]} clap = "3.0.0-beta.2" diff --git a/model/src/lib.rs b/model/src/lib.rs index cafcb4e..18175a2 100644 --- a/model/src/lib.rs +++ b/model/src/lib.rs @@ -46,6 +46,17 @@ impl Default for User { } } +#[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Clone, Default)] +pub struct CreateUserRequest { + // Same fields as User, but no creation_date, and with password. + pub user_id: String, + pub email: String, + pub display_name: Option, + pub first_name: Option, + pub last_name: Option, + pub password: String, +} + #[derive(PartialEq, Eq, Debug, Serialize, Deserialize)] pub struct Group { pub display_name: String, diff --git a/src/domain/handler.rs b/src/domain/handler.rs index 1f03611..f0e65be 100644 --- a/src/domain/handler.rs +++ b/src/domain/handler.rs @@ -9,6 +9,7 @@ pub trait BackendHandler: Clone + Send { async fn bind(&self, request: BindRequest) -> Result<()>; async fn list_users(&self, request: ListUsersRequest) -> Result>; async fn list_groups(&self) -> Result>; + async fn create_user(&self, request: CreateUserRequest) -> Result<()>; async fn get_user_groups(&self, user: String) -> Result>; } @@ -23,6 +24,7 @@ mockall::mock! { async fn bind(&self, request: BindRequest) -> Result<()>; async fn list_users(&self, request: ListUsersRequest) -> Result>; async fn list_groups(&self) -> Result>; + async fn create_user(&self, request: CreateUserRequest) -> Result<()>; async fn get_user_groups(&self, user: String) -> Result>; } } diff --git a/src/domain/sql_backend_handler.rs b/src/domain/sql_backend_handler.rs index dc1d16a..232f436 100644 --- a/src/domain/sql_backend_handler.rs +++ b/src/domain/sql_backend_handler.rs @@ -20,8 +20,31 @@ impl SqlBackendHandler { } } -fn passwords_match(encrypted_password: &str, clear_password: &str) -> bool { - encrypted_password == clear_password +fn get_password_config(pepper: &str) -> argon2::Config { + argon2::Config { + secret: pepper.as_bytes(), + ..Default::default() + } +} + +fn hash_password(clear_password: &str, salt: &str, pepper: &str) -> String { + let config = get_password_config(pepper); + argon2::hash_encoded(clear_password.as_bytes(), salt.as_bytes(), &config) + .map_err(|e| anyhow::anyhow!("Error encoding password: {}", e)) + .unwrap() +} + +fn passwords_match(encrypted_password: &str, clear_password: &str, pepper: &str) -> bool { + argon2::verify_encoded_ext( + encrypted_password, + clear_password.as_bytes(), + pepper.as_bytes(), + /*additional_data=*/b"", + ) + .unwrap_or_else(|e| { + log::error!("Error checking password: {}", e); + false + }) } fn get_filter_expr(filter: RequestFilter) -> SimpleExpr { @@ -57,14 +80,15 @@ impl BackendHandler for SqlBackendHandler { } } let query = Query::select() - .column(Users::Password) + .column(Users::PasswordHash) .from(Users::Table) .and_where(Expr::col(Users::UserId).eq(request.name.as_str())) .to_string(DbQueryBuilder {}); if let Ok(row) = sqlx::query(&query).fetch_one(&self.sql_pool).await { if passwords_match( + &row.get::(&*Users::PasswordHash.to_string()), &request.password, - &row.get::(&*Users::Password.to_string()), + &self.config.secret_pepper, ) { return Ok(()); } else { @@ -182,6 +206,42 @@ impl BackendHandler for SqlBackendHandler { // Map the sqlx::Error into a domain::Error. .map_err(Error::DatabaseError) } + + async fn create_user(&self, request: CreateUserRequest) -> Result<()> { + use rand::{distributions::Alphanumeric, rngs::SmallRng, Rng, SeedableRng}; + // TODO: Initialize the rng only once. Maybe Arc? + let mut rng = SmallRng::from_entropy(); + let salt: String = std::iter::repeat(()) + .map(|()| rng.sample(Alphanumeric)) + .map(char::from) + .take(32) + .collect(); + // The salt is included in the password hash. + let password_hash = hash_password(&request.password, &salt, &self.config.secret_pepper); + let query = Query::insert() + .into_table(Users::Table) + .columns(vec![ + Users::UserId, + Users::Email, + Users::DisplayName, + Users::FirstName, + Users::LastName, + Users::CreationDate, + Users::PasswordHash, + ]) + .values_panic(vec![ + request.user_id.into(), + request.email.into(), + request.display_name.map(Into::into).unwrap_or(Value::Null), + request.first_name.map(Into::into).unwrap_or(Value::Null), + request.last_name.map(Into::into).unwrap_or(Value::Null), + chrono::Utc::now().naive_utc().into(), + password_hash.into(), + ]) + .to_string(DbQueryBuilder {}); + sqlx::query(&query).execute(&self.sql_pool).await?; + Ok(()) + } } #[cfg(test)] @@ -199,23 +259,16 @@ mod tests { sql_pool } - async fn insert_user(sql_pool: &Pool, name: &str, pass: &str) { - let query = Query::insert() - .into_table(Users::Table) - .columns(vec![ - Users::UserId, - Users::Email, - Users::CreationDate, - Users::Password, - ]) - .values_panic(vec![ - name.into(), - "bob@bob".into(), - chrono::NaiveDateTime::from_timestamp(0, 0).into(), - pass.into(), - ]) - .to_string(DbQueryBuilder {}); - sqlx::query(&query).execute(sql_pool).await.unwrap(); + async fn insert_user(handler: &SqlBackendHandler, name: &str, pass: &str) { + handler + .create_user(CreateUserRequest { + user_id: name.to_string(), + email: "bob@bob.bob".to_string(), + password: pass.to_string(), + ..Default::default() + }) + .await + .unwrap(); } async fn insert_group(sql_pool: &Pool, id: u32, name: &str) { @@ -254,12 +307,27 @@ mod tests { .unwrap(); } + #[test] + fn test_argon() { + let password = b"password"; + let salt = b"randomsalt"; + let pepper = b"pepper"; + let config = argon2::Config { + secret: pepper, + ..Default::default() + }; + let hash = argon2::hash_encoded(password, salt, &config).unwrap(); + let matches = argon2::verify_encoded_ext(&hash, password, pepper, b"").unwrap(); + assert!(matches); + } + #[tokio::test] async fn test_bind_user() { let sql_pool = get_initialized_db().await; - insert_user(&sql_pool, "bob", "bob00").await; let config = Configuration::default(); - let handler = SqlBackendHandler::new(config, sql_pool); + let handler = SqlBackendHandler::new(config, sql_pool.clone()); + insert_user(&handler, "bob", "bob00").await; + handler .bind(BindRequest { name: "bob".to_string(), @@ -286,11 +354,11 @@ mod tests { #[tokio::test] async fn test_list_users() { let sql_pool = get_initialized_db().await; - insert_user(&sql_pool, "bob", "bob00").await; - insert_user(&sql_pool, "patrick", "pass").await; - insert_user(&sql_pool, "John", "Pa33w0rd!").await; let config = Configuration::default(); let handler = SqlBackendHandler::new(config, sql_pool); + insert_user(&handler, "bob", "bob00").await; + insert_user(&handler, "patrick", "pass").await; + insert_user(&handler, "John", "Pa33w0rd!").await; { let users = handler .list_users(ListUsersRequest { filters: None }) @@ -351,17 +419,17 @@ mod tests { #[tokio::test] async fn test_list_groups() { let sql_pool = get_initialized_db().await; - insert_user(&sql_pool, "bob", "bob00").await; - insert_user(&sql_pool, "patrick", "pass").await; - insert_user(&sql_pool, "John", "Pa33w0rd!").await; + let config = Configuration::default(); + let handler = SqlBackendHandler::new(config, sql_pool.clone()); + insert_user(&handler, "bob", "bob00").await; + insert_user(&handler, "patrick", "pass").await; + insert_user(&handler, "John", "Pa33w0rd!").await; insert_group(&sql_pool, 1, "Best Group").await; insert_group(&sql_pool, 2, "Worst Group").await; insert_membership(&sql_pool, 1, "bob").await; insert_membership(&sql_pool, 1, "patrick").await; insert_membership(&sql_pool, 2, "patrick").await; insert_membership(&sql_pool, 2, "John").await; - let config = Configuration::default(); - let handler = SqlBackendHandler::new(config, sql_pool); assert_eq!( handler.list_groups().await.unwrap(), vec![ @@ -380,16 +448,16 @@ mod tests { #[tokio::test] async fn test_get_user_groups() { let sql_pool = get_initialized_db().await; - insert_user(&sql_pool, "bob", "bob00").await; - insert_user(&sql_pool, "patrick", "pass").await; - insert_user(&sql_pool, "John", "Pa33w0rd!").await; + let config = Configuration::default(); + let handler = SqlBackendHandler::new(config, sql_pool.clone()); + insert_user(&handler, "bob", "bob00").await; + insert_user(&handler, "patrick", "pass").await; + insert_user(&handler, "John", "Pa33w0rd!").await; insert_group(&sql_pool, 1, "Group1").await; insert_group(&sql_pool, 2, "Group2").await; insert_membership(&sql_pool, 1, "bob").await; insert_membership(&sql_pool, 1, "patrick").await; insert_membership(&sql_pool, 2, "patrick").await; - let config = Configuration::default(); - let handler = SqlBackendHandler::new(config, sql_pool); let mut bob_groups = HashSet::new(); bob_groups.insert("Group1".to_string()); let mut patrick_groups = HashSet::new(); diff --git a/src/domain/sql_tables.rs b/src/domain/sql_tables.rs index d008fc7..9acd5e3 100644 --- a/src/domain/sql_tables.rs +++ b/src/domain/sql_tables.rs @@ -15,7 +15,7 @@ pub enum Users { LastName, Avatar, CreationDate, - Password, + PasswordHash, TotpSecret, MfaType, } @@ -49,16 +49,16 @@ pub async fn init_table(pool: &Pool) -> sqlx::Result<()> { .primary_key(), ) .col(ColumnDef::new(Users::Email).string_len(255).not_null()) + .col(ColumnDef::new(Users::DisplayName).string_len(255)) + .col(ColumnDef::new(Users::FirstName).string_len(255)) + .col(ColumnDef::new(Users::LastName).string_len(255)) + .col(ColumnDef::new(Users::Avatar).binary()) + .col(ColumnDef::new(Users::CreationDate).date_time().not_null()) .col( - ColumnDef::new(Users::DisplayName) + ColumnDef::new(Users::PasswordHash) .string_len(255) .not_null(), ) - .col(ColumnDef::new(Users::FirstName).string_len(255).not_null()) - .col(ColumnDef::new(Users::LastName).string_len(255).not_null()) - .col(ColumnDef::new(Users::Avatar).binary()) - .col(ColumnDef::new(Users::CreationDate).date_time().not_null()) - .col(ColumnDef::new(Users::Password).string_len(255).not_null()) .col(ColumnDef::new(Users::TotpSecret).string_len(64)) .col(ColumnDef::new(Users::MfaType).string_len(64)) .to_string(DbQueryBuilder {}), @@ -129,7 +129,7 @@ mod tests { let sql_pool = PoolOptions::new().connect("sqlite::memory:").await.unwrap(); init_table(&sql_pool).await.unwrap(); sqlx::query(r#"INSERT INTO users - (user_id, email, display_name, first_name, last_name, creation_date, password) + (user_id, email, display_name, first_name, last_name, creation_date, password_hash) VALUES ("bôb", "böb@bob.bob", "Bob Bobbersön", "Bob", "Bobberson", "1970-01-01 00:00:00", "bob00")"#).execute(&sql_pool).await.unwrap(); let row = sqlx::query(r#"SELECT display_name, creation_date FROM users WHERE user_id = "bôb""#) diff --git a/src/infra/tcp_backend_handler.rs b/src/infra/tcp_backend_handler.rs index 5dba894..8520776 100644 --- a/src/infra/tcp_backend_handler.rs +++ b/src/infra/tcp_backend_handler.rs @@ -27,6 +27,7 @@ mockall::mock! { async fn list_users(&self, request: ListUsersRequest) -> DomainResult>; async fn list_groups(&self) -> DomainResult>; async fn get_user_groups(&self, user: String) -> DomainResult>; + async fn create_user(&self, request: CreateUserRequest) -> DomainResult<()>; } #[async_trait] impl TcpBackendHandler for TestTcpBackendHandler {