Implement SQL connection

This commit is contained in:
Valentin Tolmer 2021-03-07 16:13:50 +01:00
parent c63c7105aa
commit dc6e8c8808
4 changed files with 118 additions and 41 deletions

View File

@ -17,6 +17,7 @@ http = "*"
ldap3_server = "*" ldap3_server = "*"
log = "*" log = "*"
serde = "*" serde = "*"
sqlx = { version = "0.5", features = [ "runtime-actix-native-tls", "sqlite", "macros" ] }
thiserror = "*" thiserror = "*"
tokio = { version = "1.2.0", features = ["full"] } tokio = { version = "1.2.0", features = ["full"] }
tokio-util = "0.6.3" tokio-util = "0.6.3"

View File

@ -2,18 +2,29 @@ use crate::infra::configuration::Configuration;
use actix_rt::net::TcpStream; use actix_rt::net::TcpStream;
use actix_server::ServerBuilder; use actix_server::ServerBuilder;
use actix_service::{fn_service, pipeline_factory}; use actix_service::{fn_service, pipeline_factory};
use anyhow::bail;
use anyhow::Result; use anyhow::Result;
use futures_util::future::ok; use futures_util::future::ok;
use log::*; use log::*;
use std::rc::Rc;
use tokio::net::tcp::WriteHalf;
use tokio_util::codec::{FramedRead, FramedWrite};
use ldap3_server::simple::*; use ldap3_server::simple::*;
use ldap3_server::LdapCodec; use ldap3_server::LdapCodec;
pub struct LdapSession { pub struct LdapSession<DB>
where
DB: sqlx::Database,
{
dn: String, dn: String,
sql_pool: Rc<sqlx::Pool<DB>>,
} }
impl LdapSession { impl<DB> LdapSession<DB>
where
DB: sqlx::Database,
{
pub fn do_bind(&mut self, sbr: &SimpleBindRequest) -> LdapMsg { pub fn do_bind(&mut self, sbr: &SimpleBindRequest) -> LdapMsg {
if sbr.dn == "cn=Directory Manager" && sbr.pw == "password" { if sbr.dn == "cn=Directory Manager" && sbr.pw == "password" {
self.dn = sbr.dn.to_string(); self.dn = sbr.dn.to_string();
@ -61,51 +72,103 @@ impl LdapSession {
pub fn do_whoami(&mut self, wr: &WhoamiRequest) -> LdapMsg { pub fn do_whoami(&mut self, wr: &WhoamiRequest) -> LdapMsg {
wr.gen_success(format!("dn: {}", self.dn).as_str()) wr.gen_success(format!("dn: {}", self.dn).as_str())
} }
pub fn handle_ldap_message(&mut self, server_op: ServerOps) -> Option<Vec<LdapMsg>>
where
DB: sqlx::Database,
{
let result = match server_op {
ServerOps::SimpleBind(sbr) => vec![self.do_bind(&sbr)],
ServerOps::Search(sr) => self.do_search(&sr),
ServerOps::Unbind(_) => {
// No need to notify on unbind (per rfc4511)
return None;
}
ServerOps::Whoami(wr) => vec![self.do_whoami(&wr)],
};
Some(result)
}
} }
pub fn build_ldap_server( async fn handle_incoming_message<DB>(
config: &Configuration, msg: Result<LdapMsg, std::io::Error>,
server_builder: ServerBuilder, resp: &mut FramedWrite<WriteHalf<'_>, LdapCodec>,
) -> Result<ServerBuilder> { session: &mut LdapSession<DB>,
) -> Result<bool>
where
DB: sqlx::Database,
{
use futures_util::SinkExt; use futures_util::SinkExt;
use futures_util::StreamExt;
use std::convert::TryFrom; use std::convert::TryFrom;
use tokio_util::codec::{FramedRead, FramedWrite}; let server_op = match msg
.map_err(|_e| ())
.and_then(|msg| ServerOps::try_from(msg))
{
Ok(a_value) => a_value,
Err(an_error) => {
let _err = resp
.send(DisconnectionNotice::gen(
LdapResultCode::Other,
"Internal Server Error",
))
.await;
let _err = resp.flush().await;
bail!("Internal server error: {:?}", an_error);
}
};
match session.handle_ldap_message(server_op) {
None => return Ok(false),
Some(result) => {
for rmsg in result.into_iter() {
if let Err(e) = resp.send(rmsg).await {
bail!("Error while sending a response: {:?}", e);
}
}
if let Err(e) = resp.flush().await {
bail!("Error while flushing responses: {:?}", e);
}
}
}
Ok(true)
}
pub fn build_ldap_server<DB>(
config: &Configuration,
sql_pool: sqlx::Pool<DB>,
server_builder: ServerBuilder,
) -> Result<ServerBuilder>
where
DB: sqlx::Database,
{
use futures_util::StreamExt;
Ok( Ok(
server_builder.bind("ldap", ("0.0.0.0", config.ldap_port), move || { server_builder.bind("ldap", ("0.0.0.0", config.ldap_port), move || {
pipeline_factory(fn_service(move |mut stream: TcpStream| async { let sql_pool = std::rc::Rc::new(sql_pool.clone());
// Configure the codec etc. pipeline_factory(fn_service(move |mut stream: TcpStream| {
let (r, w) = stream.split(); let sql_pool = sql_pool.clone();
let mut reqs = FramedRead::new(r, LdapCodec); async move {
let mut resp = FramedWrite::new(w, LdapCodec); // Configure the codec etc.
let (r, w) = stream.split();
let mut requests = FramedRead::new(r, LdapCodec);
let mut resp = FramedWrite::new(w, LdapCodec);
let mut session = LdapSession { let mut session = LdapSession {
dn: "Anonymous".to_string(), dn: "Anonymous".to_string(),
}; sql_pool,
while let Some(msg) = reqs.next().await {
let server_op = match msg
.map_err(|_e| ())
.and_then(|msg| ServerOps::try_from(msg))
{
Ok(a_value) => a_value,
Err(an_error) => {
let _err = resp
.send(DisconnectionNotice::gen(
LdapResultCode::Other,
"Internal Server Error",
))
.await;
let _err = resp.flush().await;
return Err(format!("Internal server error: {:?}", an_error));
}
}; };
}
Ok(stream) while let Some(msg) = requests.next().await {
if !handle_incoming_message(msg, &mut resp, &mut session).await? {
break;
}
}
Ok(stream)
}
})) }))
.map_err(|err| error!("Service Error: {:?}", err)) .map_err(|err: anyhow::Error| error!("Service Error: {:?}", err))
// catch // catch
.and_then(move |_| { .and_then(move |_| {
// finally // finally

View File

@ -7,10 +7,14 @@ use futures_util::future::ok;
use log::*; use log::*;
use std::sync::Arc; use std::sync::Arc;
pub fn build_tcp_server( pub fn build_tcp_server<DB>(
config: &Configuration, config: &Configuration,
sql_pool: sqlx::Pool<DB>,
server_builder: ServerBuilder, server_builder: ServerBuilder,
) -> Result<ServerBuilder> { ) -> Result<ServerBuilder>
where
DB: sqlx::Database,
{
use std::sync::atomic::AtomicUsize; use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering; use std::sync::atomic::Ordering;
use tokio::io::AsyncReadExt; use tokio::io::AsyncReadExt;

View File

@ -2,13 +2,22 @@ use crate::infra::configuration::Configuration;
use anyhow::Result; use anyhow::Result;
use futures_util::TryFutureExt; use futures_util::TryFutureExt;
use log::*; use log::*;
use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
use std::str::FromStr;
mod infra; mod infra;
async fn run_server(config: Configuration) -> Result<()> { async fn run_server(config: Configuration) -> Result<()> {
let server_builder = let sql_pool = SqlitePoolOptions::new()
infra::ldap_server::build_ldap_server(&config, actix_server::Server::build())?; .max_connections(5)
let server_builder = infra::tcp_server::build_tcp_server(&config, server_builder)?; .connect_with(SqliteConnectOptions::from_str("sqlite://users.db")?.create_if_missing(true))
.await?;
let server_builder = infra::ldap_server::build_ldap_server(
&config,
sql_pool.clone(),
actix_server::Server::build(),
)?;
let server_builder = infra::tcp_server::build_tcp_server(&config, sql_pool, server_builder)?;
server_builder.workers(1).run().await?; server_builder.workers(1).run().await?;
Ok(()) Ok(())
} }