Add a method to create a user

This commit is contained in:
Valentin Tolmer 2021-05-26 08:43:31 +02:00
parent d1a42b178a
commit 5a70f2ebc2
6 changed files with 127 additions and 44 deletions

View File

@ -20,6 +20,7 @@ actix-service = "2.0.0"
actix-web = "4.0.0-beta.6" actix-web = "4.0.0-beta.6"
actix-web-httpauth = "0.6.0-beta.1" actix-web-httpauth = "0.6.0-beta.1"
anyhow = "*" anyhow = "*"
rust-argon2 = "0.8"
async-trait = "0.1" async-trait = "0.1"
chrono = { version = "*", features = [ "serde" ]} chrono = { version = "*", features = [ "serde" ]}
clap = "3.0.0-beta.2" clap = "3.0.0-beta.2"

View File

@ -46,6 +46,17 @@ impl Default for User {
} }
} }
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Clone, Default)]
pub struct CreateUserRequest {
// Same fields as User, but no creation_date, and with password.
pub user_id: String,
pub email: String,
pub display_name: Option<String>,
pub first_name: Option<String>,
pub last_name: Option<String>,
pub password: String,
}
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize)] #[derive(PartialEq, Eq, Debug, Serialize, Deserialize)]
pub struct Group { pub struct Group {
pub display_name: String, pub display_name: String,

View File

@ -9,6 +9,7 @@ pub trait BackendHandler: Clone + Send {
async fn bind(&self, request: BindRequest) -> Result<()>; async fn bind(&self, request: BindRequest) -> Result<()>;
async fn list_users(&self, request: ListUsersRequest) -> Result<Vec<User>>; async fn list_users(&self, request: ListUsersRequest) -> Result<Vec<User>>;
async fn list_groups(&self) -> Result<Vec<Group>>; async fn list_groups(&self) -> Result<Vec<Group>>;
async fn create_user(&self, request: CreateUserRequest) -> Result<()>;
async fn get_user_groups(&self, user: String) -> Result<HashSet<String>>; async fn get_user_groups(&self, user: String) -> Result<HashSet<String>>;
} }
@ -23,6 +24,7 @@ mockall::mock! {
async fn bind(&self, request: BindRequest) -> Result<()>; async fn bind(&self, request: BindRequest) -> Result<()>;
async fn list_users(&self, request: ListUsersRequest) -> Result<Vec<User>>; async fn list_users(&self, request: ListUsersRequest) -> Result<Vec<User>>;
async fn list_groups(&self) -> Result<Vec<Group>>; async fn list_groups(&self) -> Result<Vec<Group>>;
async fn create_user(&self, request: CreateUserRequest) -> Result<()>;
async fn get_user_groups(&self, user: String) -> Result<HashSet<String>>; async fn get_user_groups(&self, user: String) -> Result<HashSet<String>>;
} }
} }

View File

@ -20,8 +20,31 @@ impl SqlBackendHandler {
} }
} }
fn passwords_match(encrypted_password: &str, clear_password: &str) -> bool { fn get_password_config(pepper: &str) -> argon2::Config {
encrypted_password == clear_password argon2::Config {
secret: pepper.as_bytes(),
..Default::default()
}
}
fn hash_password(clear_password: &str, salt: &str, pepper: &str) -> String {
let config = get_password_config(pepper);
argon2::hash_encoded(clear_password.as_bytes(), salt.as_bytes(), &config)
.map_err(|e| anyhow::anyhow!("Error encoding password: {}", e))
.unwrap()
}
fn passwords_match(encrypted_password: &str, clear_password: &str, pepper: &str) -> bool {
argon2::verify_encoded_ext(
encrypted_password,
clear_password.as_bytes(),
pepper.as_bytes(),
/*additional_data=*/b"",
)
.unwrap_or_else(|e| {
log::error!("Error checking password: {}", e);
false
})
} }
fn get_filter_expr(filter: RequestFilter) -> SimpleExpr { fn get_filter_expr(filter: RequestFilter) -> SimpleExpr {
@ -57,14 +80,15 @@ impl BackendHandler for SqlBackendHandler {
} }
} }
let query = Query::select() let query = Query::select()
.column(Users::Password) .column(Users::PasswordHash)
.from(Users::Table) .from(Users::Table)
.and_where(Expr::col(Users::UserId).eq(request.name.as_str())) .and_where(Expr::col(Users::UserId).eq(request.name.as_str()))
.to_string(DbQueryBuilder {}); .to_string(DbQueryBuilder {});
if let Ok(row) = sqlx::query(&query).fetch_one(&self.sql_pool).await { if let Ok(row) = sqlx::query(&query).fetch_one(&self.sql_pool).await {
if passwords_match( if passwords_match(
&row.get::<String, _>(&*Users::PasswordHash.to_string()),
&request.password, &request.password,
&row.get::<String, _>(&*Users::Password.to_string()), &self.config.secret_pepper,
) { ) {
return Ok(()); return Ok(());
} else { } else {
@ -182,6 +206,42 @@ impl BackendHandler for SqlBackendHandler {
// Map the sqlx::Error into a domain::Error. // Map the sqlx::Error into a domain::Error.
.map_err(Error::DatabaseError) .map_err(Error::DatabaseError)
} }
async fn create_user(&self, request: CreateUserRequest) -> Result<()> {
use rand::{distributions::Alphanumeric, rngs::SmallRng, Rng, SeedableRng};
// TODO: Initialize the rng only once. Maybe Arc<Cell>?
let mut rng = SmallRng::from_entropy();
let salt: String = std::iter::repeat(())
.map(|()| rng.sample(Alphanumeric))
.map(char::from)
.take(32)
.collect();
// The salt is included in the password hash.
let password_hash = hash_password(&request.password, &salt, &self.config.secret_pepper);
let query = Query::insert()
.into_table(Users::Table)
.columns(vec![
Users::UserId,
Users::Email,
Users::DisplayName,
Users::FirstName,
Users::LastName,
Users::CreationDate,
Users::PasswordHash,
])
.values_panic(vec![
request.user_id.into(),
request.email.into(),
request.display_name.map(Into::into).unwrap_or(Value::Null),
request.first_name.map(Into::into).unwrap_or(Value::Null),
request.last_name.map(Into::into).unwrap_or(Value::Null),
chrono::Utc::now().naive_utc().into(),
password_hash.into(),
])
.to_string(DbQueryBuilder {});
sqlx::query(&query).execute(&self.sql_pool).await?;
Ok(())
}
} }
#[cfg(test)] #[cfg(test)]
@ -199,23 +259,16 @@ mod tests {
sql_pool sql_pool
} }
async fn insert_user(sql_pool: &Pool, name: &str, pass: &str) { async fn insert_user(handler: &SqlBackendHandler, name: &str, pass: &str) {
let query = Query::insert() handler
.into_table(Users::Table) .create_user(CreateUserRequest {
.columns(vec![ user_id: name.to_string(),
Users::UserId, email: "bob@bob.bob".to_string(),
Users::Email, password: pass.to_string(),
Users::CreationDate, ..Default::default()
Users::Password, })
]) .await
.values_panic(vec![ .unwrap();
name.into(),
"bob@bob".into(),
chrono::NaiveDateTime::from_timestamp(0, 0).into(),
pass.into(),
])
.to_string(DbQueryBuilder {});
sqlx::query(&query).execute(sql_pool).await.unwrap();
} }
async fn insert_group(sql_pool: &Pool, id: u32, name: &str) { async fn insert_group(sql_pool: &Pool, id: u32, name: &str) {
@ -254,12 +307,27 @@ mod tests {
.unwrap(); .unwrap();
} }
#[test]
fn test_argon() {
let password = b"password";
let salt = b"randomsalt";
let pepper = b"pepper";
let config = argon2::Config {
secret: pepper,
..Default::default()
};
let hash = argon2::hash_encoded(password, salt, &config).unwrap();
let matches = argon2::verify_encoded_ext(&hash, password, pepper, b"").unwrap();
assert!(matches);
}
#[tokio::test] #[tokio::test]
async fn test_bind_user() { async fn test_bind_user() {
let sql_pool = get_initialized_db().await; let sql_pool = get_initialized_db().await;
insert_user(&sql_pool, "bob", "bob00").await;
let config = Configuration::default(); let config = Configuration::default();
let handler = SqlBackendHandler::new(config, sql_pool); let handler = SqlBackendHandler::new(config, sql_pool.clone());
insert_user(&handler, "bob", "bob00").await;
handler handler
.bind(BindRequest { .bind(BindRequest {
name: "bob".to_string(), name: "bob".to_string(),
@ -286,11 +354,11 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_list_users() { async fn test_list_users() {
let sql_pool = get_initialized_db().await; let sql_pool = get_initialized_db().await;
insert_user(&sql_pool, "bob", "bob00").await;
insert_user(&sql_pool, "patrick", "pass").await;
insert_user(&sql_pool, "John", "Pa33w0rd!").await;
let config = Configuration::default(); let config = Configuration::default();
let handler = SqlBackendHandler::new(config, sql_pool); let handler = SqlBackendHandler::new(config, sql_pool);
insert_user(&handler, "bob", "bob00").await;
insert_user(&handler, "patrick", "pass").await;
insert_user(&handler, "John", "Pa33w0rd!").await;
{ {
let users = handler let users = handler
.list_users(ListUsersRequest { filters: None }) .list_users(ListUsersRequest { filters: None })
@ -351,17 +419,17 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_list_groups() { async fn test_list_groups() {
let sql_pool = get_initialized_db().await; let sql_pool = get_initialized_db().await;
insert_user(&sql_pool, "bob", "bob00").await; let config = Configuration::default();
insert_user(&sql_pool, "patrick", "pass").await; let handler = SqlBackendHandler::new(config, sql_pool.clone());
insert_user(&sql_pool, "John", "Pa33w0rd!").await; insert_user(&handler, "bob", "bob00").await;
insert_user(&handler, "patrick", "pass").await;
insert_user(&handler, "John", "Pa33w0rd!").await;
insert_group(&sql_pool, 1, "Best Group").await; insert_group(&sql_pool, 1, "Best Group").await;
insert_group(&sql_pool, 2, "Worst Group").await; insert_group(&sql_pool, 2, "Worst Group").await;
insert_membership(&sql_pool, 1, "bob").await; insert_membership(&sql_pool, 1, "bob").await;
insert_membership(&sql_pool, 1, "patrick").await; insert_membership(&sql_pool, 1, "patrick").await;
insert_membership(&sql_pool, 2, "patrick").await; insert_membership(&sql_pool, 2, "patrick").await;
insert_membership(&sql_pool, 2, "John").await; insert_membership(&sql_pool, 2, "John").await;
let config = Configuration::default();
let handler = SqlBackendHandler::new(config, sql_pool);
assert_eq!( assert_eq!(
handler.list_groups().await.unwrap(), handler.list_groups().await.unwrap(),
vec![ vec![
@ -380,16 +448,16 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_get_user_groups() { async fn test_get_user_groups() {
let sql_pool = get_initialized_db().await; let sql_pool = get_initialized_db().await;
insert_user(&sql_pool, "bob", "bob00").await; let config = Configuration::default();
insert_user(&sql_pool, "patrick", "pass").await; let handler = SqlBackendHandler::new(config, sql_pool.clone());
insert_user(&sql_pool, "John", "Pa33w0rd!").await; insert_user(&handler, "bob", "bob00").await;
insert_user(&handler, "patrick", "pass").await;
insert_user(&handler, "John", "Pa33w0rd!").await;
insert_group(&sql_pool, 1, "Group1").await; insert_group(&sql_pool, 1, "Group1").await;
insert_group(&sql_pool, 2, "Group2").await; insert_group(&sql_pool, 2, "Group2").await;
insert_membership(&sql_pool, 1, "bob").await; insert_membership(&sql_pool, 1, "bob").await;
insert_membership(&sql_pool, 1, "patrick").await; insert_membership(&sql_pool, 1, "patrick").await;
insert_membership(&sql_pool, 2, "patrick").await; insert_membership(&sql_pool, 2, "patrick").await;
let config = Configuration::default();
let handler = SqlBackendHandler::new(config, sql_pool);
let mut bob_groups = HashSet::new(); let mut bob_groups = HashSet::new();
bob_groups.insert("Group1".to_string()); bob_groups.insert("Group1".to_string());
let mut patrick_groups = HashSet::new(); let mut patrick_groups = HashSet::new();

View File

@ -15,7 +15,7 @@ pub enum Users {
LastName, LastName,
Avatar, Avatar,
CreationDate, CreationDate,
Password, PasswordHash,
TotpSecret, TotpSecret,
MfaType, MfaType,
} }
@ -49,16 +49,16 @@ pub async fn init_table(pool: &Pool) -> sqlx::Result<()> {
.primary_key(), .primary_key(),
) )
.col(ColumnDef::new(Users::Email).string_len(255).not_null()) .col(ColumnDef::new(Users::Email).string_len(255).not_null())
.col(ColumnDef::new(Users::DisplayName).string_len(255))
.col(ColumnDef::new(Users::FirstName).string_len(255))
.col(ColumnDef::new(Users::LastName).string_len(255))
.col(ColumnDef::new(Users::Avatar).binary())
.col(ColumnDef::new(Users::CreationDate).date_time().not_null())
.col( .col(
ColumnDef::new(Users::DisplayName) ColumnDef::new(Users::PasswordHash)
.string_len(255) .string_len(255)
.not_null(), .not_null(),
) )
.col(ColumnDef::new(Users::FirstName).string_len(255).not_null())
.col(ColumnDef::new(Users::LastName).string_len(255).not_null())
.col(ColumnDef::new(Users::Avatar).binary())
.col(ColumnDef::new(Users::CreationDate).date_time().not_null())
.col(ColumnDef::new(Users::Password).string_len(255).not_null())
.col(ColumnDef::new(Users::TotpSecret).string_len(64)) .col(ColumnDef::new(Users::TotpSecret).string_len(64))
.col(ColumnDef::new(Users::MfaType).string_len(64)) .col(ColumnDef::new(Users::MfaType).string_len(64))
.to_string(DbQueryBuilder {}), .to_string(DbQueryBuilder {}),
@ -129,7 +129,7 @@ mod tests {
let sql_pool = PoolOptions::new().connect("sqlite::memory:").await.unwrap(); let sql_pool = PoolOptions::new().connect("sqlite::memory:").await.unwrap();
init_table(&sql_pool).await.unwrap(); init_table(&sql_pool).await.unwrap();
sqlx::query(r#"INSERT INTO users sqlx::query(r#"INSERT INTO users
(user_id, email, display_name, first_name, last_name, creation_date, password) (user_id, email, display_name, first_name, last_name, creation_date, password_hash)
VALUES ("bôb", "böb@bob.bob", "Bob Bobbersön", "Bob", "Bobberson", "1970-01-01 00:00:00", "bob00")"#).execute(&sql_pool).await.unwrap(); VALUES ("bôb", "böb@bob.bob", "Bob Bobbersön", "Bob", "Bobberson", "1970-01-01 00:00:00", "bob00")"#).execute(&sql_pool).await.unwrap();
let row = let row =
sqlx::query(r#"SELECT display_name, creation_date FROM users WHERE user_id = "bôb""#) sqlx::query(r#"SELECT display_name, creation_date FROM users WHERE user_id = "bôb""#)

View File

@ -27,6 +27,7 @@ mockall::mock! {
async fn list_users(&self, request: ListUsersRequest) -> DomainResult<Vec<User>>; async fn list_users(&self, request: ListUsersRequest) -> DomainResult<Vec<User>>;
async fn list_groups(&self) -> DomainResult<Vec<Group>>; async fn list_groups(&self) -> DomainResult<Vec<Group>>;
async fn get_user_groups(&self, user: String) -> DomainResult<HashSet<String>>; async fn get_user_groups(&self, user: String) -> DomainResult<HashSet<String>>;
async fn create_user(&self, request: CreateUserRequest) -> DomainResult<()>;
} }
#[async_trait] #[async_trait]
impl TcpBackendHandler for TestTcpBackendHandler { impl TcpBackendHandler for TestTcpBackendHandler {