diff --git a/Cargo.lock b/Cargo.lock index 2f3efda..173c595 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2138,6 +2138,7 @@ dependencies = [ "opaque-ke", "orion", "rand 0.8.5", + "reqwest", "rustls 0.20.6", "rustls-pemfile", "sea-query", @@ -2162,6 +2163,7 @@ dependencies = [ "tracing-log", "tracing-subscriber", "uuid", + "webpki-roots 0.21.1", ] [[package]] @@ -3018,9 +3020,9 @@ dependencies = [ [[package]] name = "reqwest" -version = "0.11.11" +version = "0.11.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b75aa69a3f06bbcc66ede33af2af253c6f7a86b1ca0033f60c580a27074fbf92" +checksum = "431949c384f4e2ae07605ccaa56d1d9d2ecdb5cadd4f9577ccfab29f2e5149fc" dependencies = [ "base64", "bytes", @@ -3034,9 +3036,9 @@ dependencies = [ "hyper-rustls", "ipnet", "js-sys", - "lazy_static", "log", "mime", + "once_cell", "percent-encoding", "pin-project-lite", "rustls 0.20.6", diff --git a/server/Cargo.toml b/server/Cargo.toml index 6e8af52..44a6a32 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -47,6 +47,7 @@ tracing-attributes = "^0.1.21" tracing-log = "*" rustls-pemfile = "1.0.0" serde_bytes = "0.11.7" +webpki-roots = "*" [dependencies.chrono] features = ["serde"] @@ -124,5 +125,10 @@ features = ["jpeg"] default-features = false version = "0.24" +[dependencies.reqwest] +version = "0.11" +default-features = false +features = ["rustls-tls-webpki-roots"] + [dev-dependencies] mockall = "0.9.1" diff --git a/server/src/infra/cli.rs b/server/src/infra/cli.rs index 202c3b8..031d67e 100644 --- a/server/src/infra/cli.rs +++ b/server/src/infra/cli.rs @@ -20,6 +20,9 @@ pub enum Command { /// Run the LDAP and GraphQL server. #[clap(name = "run")] Run(RunOpts), + /// Test whether the LDAP and GraphQL server are responsive. + #[clap(name = "healthcheck")] + HealthCheck(RunOpts), /// Send a test email. #[clap(name = "send_test_email")] SendTestEmail(TestEmailOpts), diff --git a/server/src/infra/healthcheck.rs b/server/src/infra/healthcheck.rs new file mode 100644 index 0000000..0fdd997 --- /dev/null +++ b/server/src/infra/healthcheck.rs @@ -0,0 +1,124 @@ +use crate::infra::configuration::LdapsOptions; +use anyhow::{anyhow, bail, ensure, Context, Result}; +use futures_util::SinkExt; +use ldap3_proto::{ + proto::{ + LdapDerefAliases, LdapFilter, LdapMsg, LdapOp, LdapSearchRequest, LdapSearchResultEntry, + LdapSearchScope, + }, + LdapCodec, +}; +use tokio::net::TcpStream; +use tokio_rustls::TlsConnector as RustlsTlsConnector; +use tokio_util::codec::{FramedRead, FramedWrite}; +use tracing::{debug, info, instrument}; + +async fn check_ldap_endpoint(stream: Stream) -> Result<()> +where + Stream: tokio::io::AsyncRead + tokio::io::AsyncWrite, +{ + use tokio_stream::StreamExt; + let (r, w) = tokio::io::split(stream); + let mut requests = FramedRead::new(r, LdapCodec); + let mut resp = FramedWrite::new(w, LdapCodec); + + resp.send(LdapMsg { + msgid: 0, + op: LdapOp::SearchRequest(LdapSearchRequest { + base: "".to_string(), + scope: LdapSearchScope::Base, + aliases: LdapDerefAliases::Never, + sizelimit: 0, + timelimit: 0, + typesonly: false, + filter: LdapFilter::Present("objectClass".to_string()), + attrs: vec!["supportedExtension".to_string()], + }), + ctrl: vec![], + }) + .await?; + resp.flush().await?; + + let no_answer = || anyhow!("No answer from LDAP server"); + let invalid_answer = "Invalid answer from LDAP server"; + + let msg = requests + .next() + .await + .ok_or_else(no_answer)? + .context(invalid_answer)?; + debug!("Received message: {:?}", &msg); + match msg.op { + LdapOp::SearchResultEntry(LdapSearchResultEntry { dn, attributes }) => ensure!( + dn.is_empty() + && attributes + .into_iter() + .any(|a| a.atype == "objectClass" && a.vals == vec![b"top".to_vec()]), + invalid_answer + ), + _ => bail!(invalid_answer), + } + let msg = requests.next().await.ok_or_else(no_answer)??; + debug!("Received message: {:?}", &msg); + ensure!( + matches!(msg.op, LdapOp::SearchResultDone(_)), + invalid_answer + ); + info!("Success"); + Ok(()) +} + +#[instrument(skip_all, level = "info", err)] +pub async fn check_ldap(port: u16) -> Result<()> { + check_ldap_endpoint(TcpStream::connect(format!("localhost:{}", port)).await?).await +} + +fn get_root_certificates() -> rustls::RootCertStore { + let mut root_store = rustls::RootCertStore::empty(); + root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| { + rustls::OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + ) + })); + root_store +} + +fn get_tls_connector() -> Result { + use rustls::ClientConfig; + let client_config = std::sync::Arc::new( + ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(get_root_certificates()) + .with_no_client_auth(), + ); + Ok(client_config.into()) +} + +#[instrument(skip_all, level = "info", err)] +pub async fn check_ldaps(ldaps_options: &LdapsOptions) -> Result<()> { + if !ldaps_options.enabled { + return Ok(()); + }; + let tls_connector = get_tls_connector()?; + let url = format!("localhost:{}", ldaps_options.port); + check_ldap_endpoint( + tls_connector + .connect( + rustls::ServerName::try_from(url.as_str())?, + TcpStream::connect(&url).await?, + ) + .await?, + ) + .await +} + +#[instrument(skip_all, level = "info", err)] +pub async fn check_api(port: u16) -> Result<()> { + reqwest::get(format!("http://localhost:{}/health", port)) + .await? + .error_for_status()?; + info!("Success"); + Ok(()) +} diff --git a/server/src/infra/ldap_handler.rs b/server/src/infra/ldap_handler.rs index 34ea45b..9f788a5 100644 --- a/server/src/infra/ldap_handler.rs +++ b/server/src/infra/ldap_handler.rs @@ -633,6 +633,13 @@ impl LdapHandler Vec { + if request.base.is_empty() + && request.scope == LdapSearchScope::Base + && request.filter == LdapFilter::Present("objectClass".to_string()) + { + debug!("rootDSE request"); + return vec![root_dse_response(&self.base_dn_str), make_search_success()]; + } let user_info = match &self.user_info { None => { return vec![make_search_error( @@ -642,13 +649,6 @@ impl LdapHandler u, }; - if request.base.is_empty() - && request.scope == LdapSearchScope::Base - && request.filter == LdapFilter::Present("objectClass".to_string()) - { - debug!("rootDSE request"); - return vec![root_dse_response(&self.base_dn_str), make_search_success()]; - } let user_filter = if user_info.is_admin_or_readonly() { None } else { diff --git a/server/src/infra/logging.rs b/server/src/infra/logging.rs index 31c1056..f4a8dc2 100644 --- a/server/src/infra/logging.rs +++ b/server/src/infra/logging.rs @@ -37,9 +37,9 @@ impl RootSpanBuilder for CustomRootSpanBuilder { pub fn init(config: &Configuration) -> anyhow::Result<()> { let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| { EnvFilter::new(if config.verbose { - "sqlx=warn,debug" + "sqlx=warn,reqwest=warn,debug" } else { - "sqlx=warn,info" + "sqlx=warn,reqwest=warn,info" }) }); tracing_subscriber::registry() diff --git a/server/src/infra/mod.rs b/server/src/infra/mod.rs index 50cb9e0..f0b85f9 100644 --- a/server/src/infra/mod.rs +++ b/server/src/infra/mod.rs @@ -3,6 +3,7 @@ pub mod cli; pub mod configuration; pub mod db_cleaner; pub mod graphql; +pub mod healthcheck; pub mod jwt_sql_tables; pub mod ldap_handler; pub mod ldap_server; diff --git a/server/src/infra/tcp_server.rs b/server/src/infra/tcp_server.rs index 31d213e..76772ea 100644 --- a/server/src/infra/tcp_server.rs +++ b/server/src/infra/tcp_server.rs @@ -82,6 +82,7 @@ fn http_config( server_url, mail_options, })) + .route("/health", web::get().to(|| HttpResponse::Ok().finish())) .service(web::scope("/auth").configure(auth_service::configure_server::)) // API endpoint. .service( diff --git a/server/src/main.rs b/server/src/main.rs index 27eafae..d315799 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -2,6 +2,8 @@ #![forbid(non_ascii_idents)] #![allow(clippy::nonstandard_macro_braces)] +use std::time::Duration; + use crate::{ domain::{ handler::{BackendHandler, CreateUserRequest, GroupRequestFilter}, @@ -9,7 +11,7 @@ use crate::{ sql_opaque_handler::register_password, sql_tables::PoolOptions, }, - infra::{cli::*, configuration::Configuration, db_cleaner::Scheduler, mail}, + infra::{cli::*, configuration::Configuration, db_cleaner::Scheduler, healthcheck, mail}, }; use actix::Actor; use actix_server::ServerBuilder; @@ -132,11 +134,42 @@ fn send_test_email_command(opts: TestEmailOpts) -> Result<()> { mail::send_test_email(to, &config.smtp_options) } +fn run_healthcheck(opts: RunOpts) -> Result<()> { + debug!("CLI: {:#?}", &opts); + let config = infra::configuration::init(opts)?; + infra::logging::init(&config)?; + + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build()?; + + use tokio::time::timeout; + let delay = Duration::from_millis(3000); + let (ldap, ldaps, api) = runtime.block_on(async { + tokio::join!( + timeout(delay, healthcheck::check_ldap(config.ldap_port)), + timeout(delay, healthcheck::check_ldaps(&config.ldaps_options)), + timeout(delay, healthcheck::check_api(config.http_port)), + ) + }); + + let mut failure = false; + [ldap, ldaps, api] + .into_iter() + .filter_map(Result::err) + .for_each(|e| { + failure = true; + error!("{:#}", e) + }); + std::process::exit(if failure { 1 } else { 0 }) +} + fn main() -> Result<()> { let cli_opts = infra::cli::init(); match cli_opts.command { Command::ExportGraphQLSchema(opts) => infra::graphql::api::export_schema(opts), Command::Run(opts) => run_server_command(opts), + Command::HealthCheck(opts) => run_healthcheck(opts), Command::SendTestEmail(opts) => send_test_email_command(opts), } }