Add tests for bind in the handler

This commit is contained in:
Valentin Tolmer 2021-04-11 22:01:24 +02:00
parent 49404b24d7
commit 71045b08fe
3 changed files with 40 additions and 31 deletions

View File

@ -27,6 +27,7 @@ pub struct ListUsersRequest {
pub filters: Option<RequestFilter>, pub filters: Option<RequestFilter>,
} }
#[derive(sqlx::FromRow)]
#[cfg_attr(test, derive(PartialEq, Eq, Debug))] #[cfg_attr(test, derive(PartialEq, Eq, Debug))]
pub struct User { pub struct User {
pub user_id: String, pub user_id: String,
@ -35,29 +36,25 @@ pub struct User {
pub first_name: String, pub first_name: String,
pub last_name: String, pub last_name: String,
// pub avatar: ?, // pub avatar: ?,
pub creation_date: chrono::NaiveDateTime, // TODO: wait until supported for Any
// pub creation_date: chrono::NaiveDateTime,
} }
#[async_trait] #[async_trait]
pub trait BackendHandler: Clone + Send { pub trait BackendHandler: Clone + Send {
async fn bind(&mut self, request: BindRequest) -> Result<()>; async fn bind(&self, request: BindRequest) -> Result<()>;
async fn list_users(&mut self, request: ListUsersRequest) -> Result<Vec<User>>; async fn list_users(&self, request: ListUsersRequest) -> Result<Vec<User>>;
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct SqlBackendHandler { pub struct SqlBackendHandler {
config: Configuration, config: Configuration,
sql_pool: AnyPool, sql_pool: AnyPool,
authenticated: bool,
} }
impl SqlBackendHandler { impl SqlBackendHandler {
pub fn new(config: Configuration, sql_pool: AnyPool) -> Self { pub fn new(config: Configuration, sql_pool: AnyPool) -> Self {
SqlBackendHandler { SqlBackendHandler { config, sql_pool }
config,
sql_pool,
authenticated: false,
}
} }
} }
@ -88,10 +85,9 @@ fn get_filter_expr(filter: RequestFilter) -> SimpleExpr {
#[async_trait] #[async_trait]
impl BackendHandler for SqlBackendHandler { impl BackendHandler for SqlBackendHandler {
async fn bind(&mut self, request: BindRequest) -> Result<()> { async fn bind(&self, request: BindRequest) -> Result<()> {
if request.name == self.config.ldap_user_dn { if request.name == self.config.ldap_user_dn {
if request.password == self.config.ldap_user_pass { if request.password == self.config.ldap_user_pass {
self.authenticated = true;
return Ok(()); return Ok(());
} else { } else {
bail!(r#"Authentication error for "{}""#, request.name) bail!(r#"Authentication error for "{}""#, request.name)
@ -110,7 +106,7 @@ impl BackendHandler for SqlBackendHandler {
bail!(r#"Authentication error for "{}""#, request.name) bail!(r#"Authentication error for "{}""#, request.name)
} }
async fn list_users(&mut self, request: ListUsersRequest) -> Result<Vec<User>> { async fn list_users(&self, request: ListUsersRequest) -> Result<Vec<User>> {
let query = { let query = {
let mut query_builder = Query::select() let mut query_builder = Query::select()
.column(Users::UserId) .column(Users::UserId)
@ -133,15 +129,7 @@ impl BackendHandler for SqlBackendHandler {
query_builder.to_string(MysqlQueryBuilder) query_builder.to_string(MysqlQueryBuilder)
}; };
let results = sqlx::query(&query) let results = sqlx::query_as::<_, User>(&query)
.map(|row: sqlx::any::AnyRow| User {
user_id: row.get::<String, _>("user_id"),
email: row.get::<String, _>("email"),
display_name: row.get::<String, _>("display_name"),
first_name: row.get::<String, _>("first_name"),
last_name: row.get::<String, _>("last_name"),
creation_date: chrono::NaiveDateTime::from_timestamp(0, 0), // TODO: wait until datetime is supported for Any.
})
.fetch(&self.sql_pool) .fetch(&self.sql_pool)
.collect::<Vec<sqlx::Result<User>>>() .collect::<Vec<sqlx::Result<User>>>()
.await; .await;
@ -158,7 +146,32 @@ mockall::mock! {
} }
#[async_trait] #[async_trait]
impl BackendHandler for TestBackendHandler { impl BackendHandler for TestBackendHandler {
async fn bind(&mut self, request: BindRequest) -> Result<()>; async fn bind(&self, request: BindRequest) -> Result<()>;
async fn list_users(&mut self, request: ListUsersRequest) -> Result<Vec<User>>; async fn list_users(&self, request: ListUsersRequest) -> Result<Vec<User>>;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_bind_admin() {
let sql_pool = sqlx::any::AnyPoolOptions::new()
.connect("sqlite::memory:")
.await
.unwrap();
let mut config = Configuration::default();
config.ldap_user_dn = "admin".to_string();
config.ldap_user_pass = "test".to_string();
let handler = SqlBackendHandler::new(config, sql_pool);
assert!(true);
assert!(handler
.bind(BindRequest {
name: "admin".to_string(),
password: "test".to_string()
})
.await
.is_ok());
} }
} }

View File

@ -57,7 +57,7 @@ pub async fn init_table(pool: &AnyPool) -> sqlx::Result<()> {
.col(ColumnDef::new(Users::Password).string_len(255).not_null()) .col(ColumnDef::new(Users::Password).string_len(255).not_null())
.col(ColumnDef::new(Users::TotpSecret).string_len(64)) .col(ColumnDef::new(Users::TotpSecret).string_len(64))
.col(ColumnDef::new(Users::MfaType).string_len(64)) .col(ColumnDef::new(Users::MfaType).string_len(64))
.to_string(MysqlQueryBuilder), .to_string(SqliteQueryBuilder),
) )
.execute(pool) .execute(pool)
.await?; .await?;
@ -69,7 +69,6 @@ pub async fn init_table(pool: &AnyPool) -> sqlx::Result<()> {
ColumnDef::new(Groups::GroupId) ColumnDef::new(Groups::GroupId)
.integer() .integer()
.not_null() .not_null()
.auto_increment()
.primary_key(), .primary_key(),
) )
.col( .col(
@ -77,7 +76,7 @@ pub async fn init_table(pool: &AnyPool) -> sqlx::Result<()> {
.string_len(255) .string_len(255)
.not_null(), .not_null(),
) )
.to_string(MysqlQueryBuilder), .to_string(SqliteQueryBuilder),
) )
.execute(pool) .execute(pool)
.await?; .await?;
@ -94,8 +93,7 @@ pub async fn init_table(pool: &AnyPool) -> sqlx::Result<()> {
.col( .col(
ColumnDef::new(Memberships::GroupId) ColumnDef::new(Memberships::GroupId)
.integer() .integer()
.not_null() .not_null(),
.auto_increment(),
) )
.foreign_key( .foreign_key(
ForeignKey::create() ForeignKey::create()
@ -113,7 +111,7 @@ pub async fn init_table(pool: &AnyPool) -> sqlx::Result<()> {
.on_delete(ForeignKeyAction::Cascade) .on_delete(ForeignKeyAction::Cascade)
.on_update(ForeignKeyAction::Cascade), .on_update(ForeignKeyAction::Cascade),
) )
.to_string(MysqlQueryBuilder), .to_string(SqliteQueryBuilder),
) )
.execute(pool) .execute(pool)
.await?; .await?;

View File

@ -358,7 +358,6 @@ mod tests {
display_name: "Bôb Böbberson".to_string(), display_name: "Bôb Böbberson".to_string(),
first_name: "Bôb".to_string(), first_name: "Bôb".to_string(),
last_name: "Böbberson".to_string(), last_name: "Böbberson".to_string(),
creation_date: NaiveDateTime::from_timestamp(1_000_000_000, 0),
}, },
User { User {
user_id: "jim".to_string(), user_id: "jim".to_string(),
@ -366,7 +365,6 @@ mod tests {
display_name: "Jimminy Cricket".to_string(), display_name: "Jimminy Cricket".to_string(),
first_name: "Jim".to_string(), first_name: "Jim".to_string(),
last_name: "Cricket".to_string(), last_name: "Cricket".to_string(),
creation_date: NaiveDateTime::from_timestamp(1_003_000_000, 0),
}, },
]) ])
}); });