Add a handler for OPAQUE messages

This commit is contained in:
Valentin Tolmer 2021-06-16 22:04:11 +02:00 committed by nitnelave
parent f6372c7e02
commit 7be0e420d4
11 changed files with 476 additions and 79 deletions

View File

@ -8,6 +8,8 @@ pub enum Error {
DatabaseError(#[from] sqlx::Error), DatabaseError(#[from] sqlx::Error),
#[error("Authentication protocol error for `{0}`")] #[error("Authentication protocol error for `{0}`")]
AuthenticationProtocolError(#[from] lldap_model::opaque::AuthenticationError), AuthenticationProtocolError(#[from] lldap_model::opaque::AuthenticationError),
#[error("Internal error: `{0}`")]
InternalError(String),
} }
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T> = std::result::Result<T, Error>;

View File

@ -5,8 +5,12 @@ use std::collections::HashSet;
pub use lldap_model::*; pub use lldap_model::*;
#[async_trait] #[async_trait]
pub trait BackendHandler: Clone + Send { pub trait LoginHandler: Clone + Send {
async fn bind(&self, request: BindRequest) -> Result<()>; async fn bind(&self, request: BindRequest) -> Result<()>;
}
#[async_trait]
pub trait BackendHandler: Clone + Send {
async fn list_users(&self, request: ListUsersRequest) -> Result<Vec<User>>; async fn list_users(&self, request: ListUsersRequest) -> Result<Vec<User>>;
async fn list_groups(&self) -> Result<Vec<Group>>; async fn list_groups(&self) -> Result<Vec<Group>>;
async fn create_user(&self, request: CreateUserRequest) -> Result<()>; async fn create_user(&self, request: CreateUserRequest) -> Result<()>;
@ -24,7 +28,6 @@ mockall::mock! {
} }
#[async_trait] #[async_trait]
impl BackendHandler for TestBackendHandler { impl BackendHandler for TestBackendHandler {
async fn bind(&self, request: BindRequest) -> Result<()>;
async fn list_users(&self, request: ListUsersRequest) -> Result<Vec<User>>; async fn list_users(&self, request: ListUsersRequest) -> Result<Vec<User>>;
async fn list_groups(&self) -> Result<Vec<Group>>; async fn list_groups(&self) -> Result<Vec<Group>>;
async fn create_user(&self, request: CreateUserRequest) -> Result<()>; async fn create_user(&self, request: CreateUserRequest) -> Result<()>;
@ -33,4 +36,8 @@ mockall::mock! {
async fn get_user_groups(&self, user: String) -> Result<HashSet<String>>; async fn get_user_groups(&self, user: String) -> Result<HashSet<String>>;
async fn add_user_to_group(&self, request: AddUserToGroupRequest) -> Result<()>; async fn add_user_to_group(&self, request: AddUserToGroupRequest) -> Result<()>;
} }
#[async_trait]
impl LoginHandler for TestBackendHandler {
async fn bind(&self, request: BindRequest) -> Result<()>;
}
} }

View File

@ -1,4 +1,6 @@
pub mod error; pub mod error;
pub mod handler; pub mod handler;
pub mod opaque_handler;
pub mod sql_backend_handler; pub mod sql_backend_handler;
pub mod sql_opaque_handler;
pub mod sql_tables; pub mod sql_tables;

View File

@ -0,0 +1,36 @@
use super::error::*;
use async_trait::async_trait;
pub use lldap_model::{login, registration};
#[async_trait]
pub trait OpaqueHandler: Clone + Send {
async fn login_start(
&self,
request: login::ClientLoginStartRequest,
) -> Result<login::ServerLoginStartResponse>;
async fn login_finish(&self, request: login::ClientLoginFinishRequest) -> Result<String>;
async fn registration_start(
&self,
request: registration::ClientRegistrationStartRequest,
) -> Result<registration::ServerRegistrationStartResponse>;
async fn registration_finish(
&self,
request: registration::ClientRegistrationFinishRequest,
) -> Result<()>;
}
#[cfg(test)]
mockall::mock! {
pub TestOpaqueHandler{}
impl Clone for TestOpaqueHandler {
fn clone(&self) -> Self;
}
#[async_trait]
impl OpaqueHandler for TestOpaqueHandler {
async fn login_start(&self, request: login::ClientLoginStartRequest) -> Result<login::ServerLoginStartResponse>;
async fn login_finish(&self, request: login::ClientLoginFinishRequest ) -> Result<String>;
async fn registration_start(&self, request: registration::ClientRegistrationStartRequest) -> Result<registration::ServerRegistrationStartResponse>;
async fn registration_finish(&self, request: registration::ClientRegistrationFinishRequest ) -> Result<()>;
}
}

View File

@ -4,7 +4,6 @@ use async_trait::async_trait;
use futures_util::StreamExt; use futures_util::StreamExt;
use futures_util::TryStreamExt; use futures_util::TryStreamExt;
use lldap_model::opaque; use lldap_model::opaque;
use log::*;
use sea_query::{Expr, Iden, Order, Query, SimpleExpr, Value}; use sea_query::{Expr, Iden, Order, Query, SimpleExpr, Value};
use sqlx::Row; use sqlx::Row;
use std::collections::HashSet; use std::collections::HashSet;
@ -21,7 +20,7 @@ impl SqlBackendHandler {
} }
} }
fn get_password_file( pub fn get_password_file(
clear_password: &str, clear_password: &str,
server_public_key: &opaque::PublicKey, server_public_key: &opaque::PublicKey,
) -> Result<opaque::server::ServerRegistration> { ) -> Result<opaque::server::ServerRegistration> {
@ -48,30 +47,6 @@ fn get_password_file(
)?) )?)
} }
fn passwords_match(
password_file_bytes: &[u8],
clear_password: &str,
server_private_key: &opaque::PrivateKey,
) -> Result<()> {
use opaque::{client, server};
let mut rng = rand::rngs::OsRng;
let client_login_start_result = client::login::start_login(clear_password, &mut rng)?;
let password_file = server::ServerRegistration::deserialize(password_file_bytes)
.map_err(opaque::AuthenticationError::ProtocolError)?;
let server_login_start_result = server::login::start_login(
&mut rng,
password_file,
server_private_key,
client_login_start_result.message,
)?;
client::login::finish_login(
client_login_start_result.state,
server_login_start_result.message,
)?;
Ok(())
}
fn get_filter_expr(filter: RequestFilter) -> SimpleExpr { fn get_filter_expr(filter: RequestFilter) -> SimpleExpr {
use RequestFilter::*; use RequestFilter::*;
fn get_repeated_filter( fn get_repeated_filter(
@ -95,42 +70,6 @@ fn get_filter_expr(filter: RequestFilter) -> SimpleExpr {
#[async_trait] #[async_trait]
impl BackendHandler for SqlBackendHandler { impl BackendHandler for SqlBackendHandler {
async fn bind(&self, request: BindRequest) -> Result<()> {
if request.name == self.config.ldap_user_dn {
if request.password == self.config.ldap_user_pass {
return Ok(());
} else {
debug!(r#"Invalid password for LDAP bind user"#);
return Err(Error::AuthenticationError(request.name));
}
}
let query = Query::select()
.column(Users::PasswordHash)
.from(Users::Table)
.and_where(Expr::col(Users::UserId).eq(request.name.as_str()))
.to_string(DbQueryBuilder {});
if let Ok(row) = sqlx::query(&query).fetch_one(&self.sql_pool).await {
if let Some(password_hash) =
row.get::<Option<Vec<u8>>, _>(&*Users::PasswordHash.to_string())
{
if let Err(e) = passwords_match(
&&password_hash,
&request.password,
self.config.get_server_keys().private(),
) {
debug!(r#"Invalid password for "{}": {}"#, request.name, e);
} else {
return Ok(());
}
} else {
debug!(r#"User "{}" has no password"#, request.name);
}
} else {
debug!(r#"No user found for "{}""#, request.name);
}
Err(Error::AuthenticationError(request.name))
}
async fn list_users(&self, request: ListUsersRequest) -> Result<Vec<User>> { async fn list_users(&self, request: ListUsersRequest) -> Result<Vec<User>> {
let query = { let query = {
let mut query_builder = Query::select() let mut query_builder = Query::select()

View File

@ -0,0 +1,396 @@
use super::{
error::*, handler::LoginHandler, opaque_handler::*, sql_backend_handler::SqlBackendHandler,
sql_tables::*,
};
use async_trait::async_trait;
use lldap_model::{opaque, BindRequest};
use log::*;
use rand::{CryptoRng, RngCore};
use sea_query::{Expr, Iden, Query};
use sqlx::Row;
type SqlOpaqueHandler = SqlBackendHandler;
fn generate_random_id<R: RngCore + CryptoRng>(rng: &mut R) -> String {
use rand::{distributions::Alphanumeric, Rng};
std::iter::repeat(())
.map(|()| rng.sample(Alphanumeric))
.map(char::from)
.take(32)
.collect()
}
fn passwords_match(
password_file_bytes: &[u8],
clear_password: &str,
server_private_key: &opaque::PrivateKey,
) -> Result<()> {
use opaque::{client, server};
let mut rng = rand::rngs::OsRng;
let client_login_start_result = client::login::start_login(clear_password, &mut rng)?;
let password_file = server::ServerRegistration::deserialize(password_file_bytes)
.map_err(opaque::AuthenticationError::ProtocolError)?;
let server_login_start_result = server::login::start_login(
&mut rng,
password_file,
server_private_key,
client_login_start_result.message,
)?;
client::login::finish_login(
client_login_start_result.state,
server_login_start_result.message,
)?;
Ok(())
}
#[async_trait]
impl LoginHandler for SqlBackendHandler {
async fn bind(&self, request: BindRequest) -> Result<()> {
if request.name == self.config.ldap_user_dn {
if request.password == self.config.ldap_user_pass {
return Ok(());
} else {
debug!(r#"Invalid password for LDAP bind user"#);
return Err(Error::AuthenticationError(request.name));
}
}
let query = Query::select()
.column(Users::PasswordHash)
.from(Users::Table)
.and_where(Expr::col(Users::UserId).eq(request.name.as_str()))
.to_string(DbQueryBuilder {});
if let Ok(row) = sqlx::query(&query).fetch_one(&self.sql_pool).await {
if let Some(password_hash) =
row.get::<Option<Vec<u8>>, _>(&*Users::PasswordHash.to_string())
{
if let Err(e) = passwords_match(
&&password_hash,
&request.password,
self.config.get_server_keys().private(),
) {
debug!(r#"Invalid password for "{}": {}"#, request.name, e);
} else {
return Ok(());
}
} else {
debug!(r#"User "{}" has no password"#, request.name);
}
} else {
debug!(r#"No user found for "{}""#, request.name);
}
Err(Error::AuthenticationError(request.name))
}
}
#[async_trait]
impl OpaqueHandler for SqlOpaqueHandler {
async fn login_start(
&self,
request: login::ClientLoginStartRequest,
) -> Result<login::ServerLoginStartResponse> {
// Fetch the previously registered password file from the DB.
let password_file_bytes = {
let query = Query::select()
.column(Users::PasswordHash)
.from(Users::Table)
.and_where(Expr::col(Users::UserId).eq(request.username.as_str()))
.to_string(DbQueryBuilder {});
sqlx::query(&query)
.fetch_one(&self.sql_pool)
.await?
.get::<Option<Vec<u8>>, _>(&*Users::PasswordHash.to_string())
// If no password, always fail.
.ok_or_else(|| Error::AuthenticationError(request.username.clone()))?
};
let password_file = opaque::server::ServerRegistration::deserialize(&password_file_bytes)
.map_err(|_| {
Error::InternalError(format!("Corrupted password file for {}", request.username))
})?;
let mut rng = rand::rngs::OsRng;
let start_response = opaque::server::login::start_login(
&mut rng,
password_file,
self.config.get_server_keys().private(),
request.login_start_request,
)?;
let login_attempt_id = generate_random_id(&mut rng);
{
// Insert the current login attempt in the DB.
let query = Query::insert()
.into_table(LoginAttempts::Table)
.columns(vec![
LoginAttempts::RandomId,
LoginAttempts::UserId,
LoginAttempts::ServerLoginData,
LoginAttempts::Timestamp,
])
.values_panic(vec![
login_attempt_id.as_str().into(),
request.username.as_str().into(),
start_response.state.serialize().into(),
chrono::Utc::now().naive_utc().into(),
])
.to_string(DbQueryBuilder {});
sqlx::query(&query).execute(&self.sql_pool).await?;
}
Ok(login::ServerLoginStartResponse {
login_key: login_attempt_id,
credential_response: start_response.message,
})
}
async fn login_finish(&self, request: login::ClientLoginFinishRequest) -> Result<String> {
// Fetch the previous data from this login attempt.
let row = {
let query = Query::select()
.column(LoginAttempts::UserId)
.column(LoginAttempts::ServerLoginData)
.from(LoginAttempts::Table)
.and_where(Expr::col(LoginAttempts::RandomId).eq(request.login_key.as_str()))
.and_where(
Expr::col(LoginAttempts::Timestamp)
.gt(chrono::Utc::now().naive_utc() - chrono::Duration::minutes(5)),
)
.to_string(DbQueryBuilder {});
sqlx::query(&query).fetch_one(&self.sql_pool).await?
};
let username = row.get::<String, _>(&*LoginAttempts::UserId.to_string());
let login_data = opaque::server::login::ServerLogin::deserialize(
&row.get::<Vec<u8>, _>(&*LoginAttempts::ServerLoginData.to_string()),
)
.map_err(|_| {
Error::InternalError(format!(
"Corrupted login data for user `{}` [id `{}`]",
username, request.login_key
))
})?;
// Finish the login: this makes sure the client data is correct, and gives a session key we
// don't need.
let _session_key =
opaque::server::login::finish_login(login_data, request.credential_finalization)?
.session_key;
{
// Login was successful, we can delete the login attempt from the table.
let delete_query = Query::delete()
.from_table(LoginAttempts::Table)
.and_where(Expr::col(LoginAttempts::RandomId).eq(request.login_key))
.to_string(DbQueryBuilder {});
sqlx::query(&delete_query).execute(&self.sql_pool).await?;
}
Ok(username)
}
async fn registration_start(
&self,
request: registration::ClientRegistrationStartRequest,
) -> Result<registration::ServerRegistrationStartResponse> {
let mut rng = rand::rngs::OsRng;
// Generate the server-side key and derive the data to send back.
let start_response = opaque::server::registration::start_registration(
&mut rng,
request.registration_start_request,
self.config.get_server_keys().public(),
)?;
// Unique ID to identify the registration attempt.
let registration_attempt_id = generate_random_id(&mut rng);
{
// Write the registration attempt to the DB for the later turn.
let query = Query::insert()
.into_table(RegistrationAttempts::Table)
.columns(vec![
RegistrationAttempts::RandomId,
RegistrationAttempts::UserId,
RegistrationAttempts::ServerRegistrationData,
RegistrationAttempts::Timestamp,
])
.values_panic(vec![
registration_attempt_id.as_str().into(),
request.username.as_str().into(),
start_response.state.serialize().into(),
chrono::Utc::now().naive_utc().into(),
])
.to_string(DbQueryBuilder {});
sqlx::query(&query).execute(&self.sql_pool).await?;
}
Ok(registration::ServerRegistrationStartResponse {
registration_key: registration_attempt_id,
registration_response: start_response.message,
})
}
async fn registration_finish(
&self,
request: registration::ClientRegistrationFinishRequest,
) -> Result<()> {
// Fetch the previous state.
let row = {
let query = Query::select()
.column(RegistrationAttempts::UserId)
.column(RegistrationAttempts::ServerRegistrationData)
.from(RegistrationAttempts::Table)
.and_where(
Expr::col(RegistrationAttempts::RandomId).eq(request.registration_key.as_str()),
)
.and_where(
Expr::col(RegistrationAttempts::Timestamp)
.gt(chrono::Utc::now().naive_utc() - chrono::Duration::minutes(5)),
)
.to_string(DbQueryBuilder {});
sqlx::query(&query).fetch_one(&self.sql_pool).await?
};
let username = row.get::<String, _>(&*RegistrationAttempts::UserId.to_string());
let registration_data = opaque::server::registration::ServerRegistration::deserialize(
&row.get::<Vec<u8>, _>(&*RegistrationAttempts::ServerRegistrationData.to_string()),
)
.map_err(|_| {
Error::InternalError(format!(
"Corrupted registration data for user `{}` [id `{}`]",
username, request.registration_key
))
})?;
let password_file = opaque::server::registration::get_password_file(
registration_data,
request.registration_upload,
)?;
{
// Set the user password to the new password.
let update_query = Query::update()
.table(Users::Table)
.values(vec![(
Users::PasswordHash,
password_file.serialize().into(),
)])
.and_where(Expr::col(Users::UserId).eq(username))
.to_string(DbQueryBuilder {});
sqlx::query(&update_query).execute(&self.sql_pool).await?;
}
{
// Delete the registration attempt.
let delete_query = Query::delete()
.from_table(RegistrationAttempts::Table)
.and_where(Expr::col(RegistrationAttempts::RandomId).eq(request.registration_key))
.to_string(DbQueryBuilder {});
sqlx::query(&delete_query).execute(&self.sql_pool).await?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
domain::{
handler::BackendHandler, sql_backend_handler::SqlBackendHandler, sql_tables::init_table,
},
infra::configuration::{Configuration, ConfigurationBuilder},
};
use lldap_model::*;
fn get_default_config() -> Configuration {
ConfigurationBuilder::default()
.verbose(true)
.build()
.unwrap()
}
async fn get_in_memory_db() -> Pool {
PoolOptions::new().connect("sqlite::memory:").await.unwrap()
}
async fn get_initialized_db() -> Pool {
let sql_pool = get_in_memory_db().await;
init_table(&sql_pool).await.unwrap();
sql_pool
}
async fn insert_user_no_password(handler: &SqlBackendHandler, name: &str) {
handler
.create_user(CreateUserRequest {
user_id: name.to_string(),
email: "bob@bob.bob".to_string(),
..Default::default()
})
.await
.unwrap();
}
async fn attempt_login(
opaque_handler: &SqlOpaqueHandler,
username: &str,
password: &str,
) -> Result<()> {
let mut rng = rand::rngs::OsRng;
use login::*;
let login_start = opaque::client::login::start_login(password, &mut rng)?;
let start_response = opaque_handler
.login_start(ClientLoginStartRequest {
username: username.to_string(),
login_start_request: login_start.message,
})
.await?;
let login_finish = opaque::client::login::finish_login(
login_start.state,
start_response.credential_response,
)?;
opaque_handler
.login_finish(ClientLoginFinishRequest {
login_key: start_response.login_key,
credential_finalization: login_finish.message,
})
.await?;
Ok(())
}
async fn attempt_registration(
opaque_handler: &SqlOpaqueHandler,
username: &str,
password: &str,
) -> Result<()> {
let mut rng = rand::rngs::OsRng;
use registration::*;
let registration_start =
opaque::client::registration::start_registration(password, &mut rng)?;
let start_response = opaque_handler
.registration_start(ClientRegistrationStartRequest {
username: username.to_string(),
registration_start_request: registration_start.message,
})
.await?;
let registration_finish = opaque::client::registration::finish_registration(
registration_start.state,
start_response.registration_response,
&mut rng,
)?;
opaque_handler
.registration_finish(ClientRegistrationFinishRequest {
registration_key: start_response.registration_key,
registration_upload: registration_finish.message,
})
.await
}
#[tokio::test]
async fn test_flow() -> Result<()> {
let sql_pool = get_initialized_db().await;
let config = get_default_config();
let backend_handler = SqlBackendHandler::new(config.clone(), sql_pool.clone());
let opaque_handler = SqlOpaqueHandler::new(config, sql_pool);
insert_user_no_password(&backend_handler, "bob").await;
attempt_login(&opaque_handler, "bob", "bob00")
.await
.unwrap_err();
attempt_registration(&opaque_handler, "bob", "bob00").await?;
attempt_login(&opaque_handler, "bob", "wrong_password")
.await
.unwrap_err();
attempt_login(&opaque_handler, "bob", "bob00").await?;
Ok(())
}
}

View File

@ -1,10 +1,14 @@
use crate::{ use crate::{
domain::handler::*, domain::{
handler::{BackendHandler, LoginHandler},
opaque_handler::OpaqueHandler,
},
infra::{ infra::{
tcp_backend_handler::*, tcp_backend_handler::*,
tcp_server::{error_to_http_response, AppState}, tcp_server::{error_to_http_response, AppState},
}, },
}; };
use lldap_model::{JWTClaims, BindRequest};
use actix_web::{ use actix_web::{
cookie::{Cookie, SameSite}, cookie::{Cookie, SameSite},
dev::{Service, ServiceRequest, ServiceResponse, Transform}, dev::{Service, ServiceRequest, ServiceResponse, Transform},
@ -166,7 +170,7 @@ async fn post_authorize<Backend>(
request: web::Json<BindRequest>, request: web::Json<BindRequest>,
) -> HttpResponse ) -> HttpResponse
where where
Backend: TcpBackendHandler + BackendHandler + 'static, Backend: TcpBackendHandler + BackendHandler + LoginHandler + 'static,
{ {
let req: BindRequest = request.clone(); let req: BindRequest = request.clone();
data.backend_handler data.backend_handler
@ -299,7 +303,7 @@ where
pub fn configure_server<Backend>(cfg: &mut web::ServiceConfig) pub fn configure_server<Backend>(cfg: &mut web::ServiceConfig)
where where
Backend: TcpBackendHandler + BackendHandler + 'static, Backend: TcpBackendHandler + LoginHandler + OpaqueHandler + BackendHandler + 'static,
{ {
cfg.service(web::resource("").route(web::post().to(post_authorize::<Backend>))) cfg.service(web::resource("").route(web::post().to(post_authorize::<Backend>)))
.service(web::resource("/refresh").route(web::get().to(get_refresh::<Backend>))) .service(web::resource("/refresh").route(web::get().to(get_refresh::<Backend>)))

View File

@ -1,4 +1,4 @@
use crate::domain::handler::{BackendHandler, ListUsersRequest, RequestFilter, User}; use crate::domain::handler::{BackendHandler, ListUsersRequest, LoginHandler, RequestFilter, User};
use anyhow::{bail, Result}; use anyhow::{bail, Result};
use ldap3_server::simple::*; use ldap3_server::simple::*;
@ -147,7 +147,7 @@ fn convert_filter(filter: &LdapFilter) -> Result<RequestFilter> {
} }
} }
pub struct LdapHandler<Backend: BackendHandler> { pub struct LdapHandler<Backend: BackendHandler + LoginHandler> {
dn: String, dn: String,
backend_handler: Backend, backend_handler: Backend,
pub base_dn: Vec<(String, String)>, pub base_dn: Vec<(String, String)>,
@ -155,7 +155,7 @@ pub struct LdapHandler<Backend: BackendHandler> {
ldap_user_dn: String, ldap_user_dn: String,
} }
impl<Backend: BackendHandler> LdapHandler<Backend> { impl<Backend: BackendHandler + LoginHandler> LdapHandler<Backend> {
pub fn new(backend_handler: Backend, ldap_base_dn: String, ldap_user_dn: String) -> Self { pub fn new(backend_handler: Backend, ldap_base_dn: String, ldap_user_dn: String) -> Self {
Self { Self {
dn: "Unauthenticated".to_string(), dn: "Unauthenticated".to_string(),

View File

@ -1,4 +1,4 @@
use crate::domain::handler::BackendHandler; use crate::domain::handler::{BackendHandler, LoginHandler};
use crate::infra::configuration::Configuration; use crate::infra::configuration::Configuration;
use crate::infra::ldap_handler::LdapHandler; use crate::infra::ldap_handler::LdapHandler;
use actix_rt::net::TcpStream; use actix_rt::net::TcpStream;
@ -12,11 +12,14 @@ use log::*;
use tokio::net::tcp::WriteHalf; use tokio::net::tcp::WriteHalf;
use tokio_util::codec::{FramedRead, FramedWrite}; use tokio_util::codec::{FramedRead, FramedWrite};
async fn handle_incoming_message<Backend: BackendHandler>( async fn handle_incoming_message<Backend>(
msg: Result<LdapMsg, std::io::Error>, msg: Result<LdapMsg, std::io::Error>,
resp: &mut FramedWrite<WriteHalf<'_>, LdapCodec>, resp: &mut FramedWrite<WriteHalf<'_>, LdapCodec>,
session: &mut LdapHandler<Backend>, session: &mut LdapHandler<Backend>,
) -> Result<bool> { ) -> Result<bool>
where
Backend: BackendHandler + LoginHandler,
{
use futures_util::SinkExt; use futures_util::SinkExt;
use std::convert::TryFrom; use std::convert::TryFrom;
let server_op = match msg.map_err(|_e| ()).and_then(ServerOps::try_from) { let server_op = match msg.map_err(|_e| ()).and_then(ServerOps::try_from) {
@ -56,7 +59,7 @@ pub fn build_ldap_server<Backend>(
server_builder: ServerBuilder, server_builder: ServerBuilder,
) -> Result<ServerBuilder> ) -> Result<ServerBuilder>
where where
Backend: BackendHandler + 'static, Backend: BackendHandler + LoginHandler + 'static,
{ {
use futures_util::StreamExt; use futures_util::StreamExt;

View File

@ -22,8 +22,11 @@ mockall::mock! {
fn clone(&self) -> Self; fn clone(&self) -> Self;
} }
#[async_trait] #[async_trait]
impl BackendHandler for TestTcpBackendHandler { impl LoginHandler for TestTcpBackendHandler {
async fn bind(&self, request: BindRequest) -> DomainResult<()>; async fn bind(&self, request: BindRequest) -> DomainResult<()>;
}
#[async_trait]
impl BackendHandler for TestTcpBackendHandler {
async fn list_users(&self, request: ListUsersRequest) -> DomainResult<Vec<User>>; async fn list_users(&self, request: ListUsersRequest) -> DomainResult<Vec<User>>;
async fn list_groups(&self) -> DomainResult<Vec<Group>>; async fn list_groups(&self) -> DomainResult<Vec<Group>>;
async fn get_user_groups(&self, user: String) -> DomainResult<HashSet<String>>; async fn get_user_groups(&self, user: String) -> DomainResult<HashSet<String>>;

View File

@ -1,5 +1,8 @@
use crate::{ use crate::{
domain::handler::*, domain::{
handler::{BackendHandler, LoginHandler},
opaque_handler::OpaqueHandler,
},
infra::{auth_service, configuration::Configuration, tcp_api, tcp_backend_handler::*}, infra::{auth_service, configuration::Configuration, tcp_api, tcp_backend_handler::*},
}; };
use actix_files::{Files, NamedFile}; use actix_files::{Files, NamedFile};
@ -28,7 +31,9 @@ pub(crate) fn error_to_http_response(error: DomainError) -> HttpResponse {
DomainError::AuthenticationError(_) | DomainError::AuthenticationProtocolError(_) => { DomainError::AuthenticationError(_) | DomainError::AuthenticationProtocolError(_) => {
HttpResponse::Unauthorized() HttpResponse::Unauthorized()
} }
DomainError::DatabaseError(_) => HttpResponse::InternalServerError(), DomainError::DatabaseError(_) | DomainError::InternalError(_) => {
HttpResponse::InternalServerError()
}
} }
.body(error.to_string()) .body(error.to_string())
} }
@ -39,7 +44,7 @@ fn http_config<Backend>(
jwt_secret: String, jwt_secret: String,
jwt_blacklist: HashSet<u64>, jwt_blacklist: HashSet<u64>,
) where ) where
Backend: TcpBackendHandler + BackendHandler + 'static, Backend: TcpBackendHandler + BackendHandler + LoginHandler + OpaqueHandler + 'static,
{ {
cfg.data(AppState::<Backend> { cfg.data(AppState::<Backend> {
backend_handler, backend_handler,
@ -83,7 +88,7 @@ pub async fn build_tcp_server<Backend>(
server_builder: ServerBuilder, server_builder: ServerBuilder,
) -> Result<ServerBuilder> ) -> Result<ServerBuilder>
where where
Backend: TcpBackendHandler + BackendHandler + 'static, Backend: TcpBackendHandler + BackendHandler + LoginHandler + OpaqueHandler + 'static,
{ {
let jwt_secret = config.jwt_secret.clone(); let jwt_secret = config.jwt_secret.clone();
let jwt_blacklist = backend_handler.get_jwt_blacklist().await?; let jwt_blacklist = backend_handler.get_jwt_blacklist().await?;