From 1a947358fa9a8dcdc08b1abc297c828bcbab0495 Mon Sep 17 00:00:00 2001 From: Valentin Tolmer Date: Wed, 10 Mar 2021 12:06:32 +0100 Subject: [PATCH] Simplify DB handling with sqlx::Any --- Cargo.toml | 2 +- src/infra/ldap_server.rs | 40 +++++++++++++--------------------------- src/infra/tcp_server.rs | 10 ++++------ src/main.rs | 7 +++---- 4 files changed, 21 insertions(+), 38 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 3495f5f..fd8668f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,7 @@ http = "*" ldap3_server = "*" log = "*" serde = "*" -sqlx = { version = "0.5", features = [ "runtime-actix-native-tls", "sqlite", "macros" ] } +sqlx = { version = "0.5", features = [ "runtime-actix-native-tls", "any", "sqlite", "mysql", "postgres", "macros" ] } thiserror = "*" tokio = { version = "1.2.0", features = ["full"] } tokio-util = "0.6.3" diff --git a/src/infra/ldap_server.rs b/src/infra/ldap_server.rs index a51eedb..6160174 100644 --- a/src/infra/ldap_server.rs +++ b/src/infra/ldap_server.rs @@ -6,25 +6,19 @@ use anyhow::bail; use anyhow::Result; use futures_util::future::ok; use log::*; -use std::rc::Rc; +use sqlx::any::AnyPool; use tokio::net::tcp::WriteHalf; use tokio_util::codec::{FramedRead, FramedWrite}; use ldap3_server::simple::*; use ldap3_server::LdapCodec; -pub struct LdapSession -where - DB: sqlx::Database, -{ +pub struct LdapSession { dn: String, - sql_pool: Rc>, + sql_pool: AnyPool, } -impl LdapSession -where - DB: sqlx::Database, -{ +impl LdapSession { pub fn do_bind(&mut self, sbr: &SimpleBindRequest) -> LdapMsg { if sbr.dn == "cn=Directory Manager" && sbr.pw == "password" { self.dn = sbr.dn.to_string(); @@ -72,10 +66,8 @@ where pub fn do_whoami(&mut self, wr: &WhoamiRequest) -> LdapMsg { wr.gen_success(format!("dn: {}", self.dn).as_str()) } - pub fn handle_ldap_message(&mut self, server_op: ServerOps) -> Option> - where - DB: sqlx::Database, - { + + pub fn handle_ldap_message(&mut self, server_op: ServerOps) -> Option> { let result = match server_op { ServerOps::SimpleBind(sbr) => vec![self.do_bind(&sbr)], ServerOps::Search(sr) => self.do_search(&sr), @@ -89,14 +81,11 @@ where } } -async fn handle_incoming_message( +async fn handle_incoming_message( msg: Result, resp: &mut FramedWrite, LdapCodec>, - session: &mut LdapSession, -) -> Result -where - DB: sqlx::Database, -{ + session: &mut LdapSession, +) -> Result { use futures_util::SinkExt; use std::convert::TryFrom; let server_op = match msg @@ -133,19 +122,16 @@ where Ok(true) } -pub fn build_ldap_server( +pub fn build_ldap_server( config: &Configuration, - sql_pool: sqlx::Pool, + sql_pool: AnyPool, server_builder: ServerBuilder, -) -> Result -where - DB: sqlx::Database, -{ +) -> Result { use futures_util::StreamExt; Ok( server_builder.bind("ldap", ("0.0.0.0", config.ldap_port), move || { - let sql_pool = std::rc::Rc::new(sql_pool.clone()); + let sql_pool = sql_pool.clone(); pipeline_factory(fn_service(move |mut stream: TcpStream| { let sql_pool = sql_pool.clone(); async move { diff --git a/src/infra/tcp_server.rs b/src/infra/tcp_server.rs index d63c865..de8efd6 100644 --- a/src/infra/tcp_server.rs +++ b/src/infra/tcp_server.rs @@ -5,16 +5,14 @@ use actix_service::pipeline_factory; use anyhow::{Context, Result}; use futures_util::future::ok; use log::*; +use sqlx::any::AnyPool; use std::sync::Arc; -pub fn build_tcp_server( +pub fn build_tcp_server( config: &Configuration, - sql_pool: sqlx::Pool, + sql_pool: AnyPool, server_builder: ServerBuilder, -) -> Result -where - DB: sqlx::Database, -{ +) -> Result { use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering; use tokio::io::AsyncReadExt; diff --git a/src/main.rs b/src/main.rs index 5408c99..3eacb6c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,15 +2,14 @@ use crate::infra::configuration::Configuration; use anyhow::Result; use futures_util::TryFutureExt; use log::*; -use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions}; -use std::str::FromStr; +use sqlx::any::AnyPoolOptions; mod infra; async fn run_server(config: Configuration) -> Result<()> { - let sql_pool = SqlitePoolOptions::new() + let sql_pool = AnyPoolOptions::new() .max_connections(5) - .connect_with(SqliteConnectOptions::from_str("sqlite://users.db")?.create_if_missing(true)) + .connect("sqlite://users.db?mode=rwc") .await?; let server_builder = infra::ldap_server::build_ldap_server( &config,