diff --git a/Cargo.toml b/Cargo.toml index db3c9c6..3495f5f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ http = "*" ldap3_server = "*" log = "*" serde = "*" +sqlx = { version = "0.5", features = [ "runtime-actix-native-tls", "sqlite", "macros" ] } thiserror = "*" tokio = { version = "1.2.0", features = ["full"] } tokio-util = "0.6.3" diff --git a/src/infra/ldap_server.rs b/src/infra/ldap_server.rs index b45f44e..a51eedb 100644 --- a/src/infra/ldap_server.rs +++ b/src/infra/ldap_server.rs @@ -2,18 +2,29 @@ use crate::infra::configuration::Configuration; use actix_rt::net::TcpStream; use actix_server::ServerBuilder; use actix_service::{fn_service, pipeline_factory}; +use anyhow::bail; use anyhow::Result; use futures_util::future::ok; use log::*; +use std::rc::Rc; +use tokio::net::tcp::WriteHalf; +use tokio_util::codec::{FramedRead, FramedWrite}; use ldap3_server::simple::*; use ldap3_server::LdapCodec; -pub struct LdapSession { +pub struct LdapSession +where + DB: sqlx::Database, +{ dn: String, + sql_pool: Rc>, } -impl LdapSession { +impl LdapSession +where + DB: sqlx::Database, +{ pub fn do_bind(&mut self, sbr: &SimpleBindRequest) -> LdapMsg { if sbr.dn == "cn=Directory Manager" && sbr.pw == "password" { self.dn = sbr.dn.to_string(); @@ -61,51 +72,103 @@ impl LdapSession { pub fn do_whoami(&mut self, wr: &WhoamiRequest) -> LdapMsg { wr.gen_success(format!("dn: {}", self.dn).as_str()) } + pub fn handle_ldap_message(&mut self, server_op: ServerOps) -> Option> + where + DB: sqlx::Database, + { + let result = match server_op { + ServerOps::SimpleBind(sbr) => vec![self.do_bind(&sbr)], + ServerOps::Search(sr) => self.do_search(&sr), + ServerOps::Unbind(_) => { + // No need to notify on unbind (per rfc4511) + return None; + } + ServerOps::Whoami(wr) => vec![self.do_whoami(&wr)], + }; + Some(result) + } } -pub fn build_ldap_server( - config: &Configuration, - server_builder: ServerBuilder, -) -> Result { +async fn handle_incoming_message( + msg: Result, + resp: &mut FramedWrite, LdapCodec>, + session: &mut LdapSession, +) -> Result +where + DB: sqlx::Database, +{ use futures_util::SinkExt; - use futures_util::StreamExt; use std::convert::TryFrom; - use tokio_util::codec::{FramedRead, FramedWrite}; + let server_op = match msg + .map_err(|_e| ()) + .and_then(|msg| ServerOps::try_from(msg)) + { + Ok(a_value) => a_value, + Err(an_error) => { + let _err = resp + .send(DisconnectionNotice::gen( + LdapResultCode::Other, + "Internal Server Error", + )) + .await; + let _err = resp.flush().await; + bail!("Internal server error: {:?}", an_error); + } + }; + + match session.handle_ldap_message(server_op) { + None => return Ok(false), + Some(result) => { + for rmsg in result.into_iter() { + if let Err(e) = resp.send(rmsg).await { + bail!("Error while sending a response: {:?}", e); + } + } + + if let Err(e) = resp.flush().await { + bail!("Error while flushing responses: {:?}", e); + } + } + } + Ok(true) +} + +pub fn build_ldap_server( + config: &Configuration, + sql_pool: sqlx::Pool, + server_builder: ServerBuilder, +) -> Result +where + DB: sqlx::Database, +{ + use futures_util::StreamExt; Ok( server_builder.bind("ldap", ("0.0.0.0", config.ldap_port), move || { - pipeline_factory(fn_service(move |mut stream: TcpStream| async { - // Configure the codec etc. - let (r, w) = stream.split(); - let mut reqs = FramedRead::new(r, LdapCodec); - let mut resp = FramedWrite::new(w, LdapCodec); + let sql_pool = std::rc::Rc::new(sql_pool.clone()); + pipeline_factory(fn_service(move |mut stream: TcpStream| { + let sql_pool = sql_pool.clone(); + async move { + // Configure the codec etc. + let (r, w) = stream.split(); + let mut requests = FramedRead::new(r, LdapCodec); + let mut resp = FramedWrite::new(w, LdapCodec); - let mut session = LdapSession { - dn: "Anonymous".to_string(), - }; - - while let Some(msg) = reqs.next().await { - let server_op = match msg - .map_err(|_e| ()) - .and_then(|msg| ServerOps::try_from(msg)) - { - Ok(a_value) => a_value, - Err(an_error) => { - let _err = resp - .send(DisconnectionNotice::gen( - LdapResultCode::Other, - "Internal Server Error", - )) - .await; - let _err = resp.flush().await; - return Err(format!("Internal server error: {:?}", an_error)); - } + let mut session = LdapSession { + dn: "Anonymous".to_string(), + sql_pool, }; - } - Ok(stream) + while let Some(msg) = requests.next().await { + if !handle_incoming_message(msg, &mut resp, &mut session).await? { + break; + } + } + + Ok(stream) + } })) - .map_err(|err| error!("Service Error: {:?}", err)) + .map_err(|err: anyhow::Error| error!("Service Error: {:?}", err)) // catch .and_then(move |_| { // finally diff --git a/src/infra/tcp_server.rs b/src/infra/tcp_server.rs index 8459d85..d63c865 100644 --- a/src/infra/tcp_server.rs +++ b/src/infra/tcp_server.rs @@ -7,10 +7,14 @@ use futures_util::future::ok; use log::*; use std::sync::Arc; -pub fn build_tcp_server( +pub fn build_tcp_server( config: &Configuration, + sql_pool: sqlx::Pool, server_builder: ServerBuilder, -) -> Result { +) -> Result +where + DB: sqlx::Database, +{ use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering; use tokio::io::AsyncReadExt; diff --git a/src/main.rs b/src/main.rs index e0de0cb..5408c99 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,13 +2,22 @@ use crate::infra::configuration::Configuration; use anyhow::Result; use futures_util::TryFutureExt; use log::*; +use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions}; +use std::str::FromStr; mod infra; async fn run_server(config: Configuration) -> Result<()> { - let server_builder = - infra::ldap_server::build_ldap_server(&config, actix_server::Server::build())?; - let server_builder = infra::tcp_server::build_tcp_server(&config, server_builder)?; + let sql_pool = SqlitePoolOptions::new() + .max_connections(5) + .connect_with(SqliteConnectOptions::from_str("sqlite://users.db")?.create_if_missing(true)) + .await?; + let server_builder = infra::ldap_server::build_ldap_server( + &config, + sql_pool.clone(), + actix_server::Server::build(), + )?; + let server_builder = infra::tcp_server::build_tcp_server(&config, sql_pool, server_builder)?; server_builder.workers(1).run().await?; Ok(()) }