server: implement healthcheck

This commit is contained in:
Valentin Tolmer 2022-10-12 14:52:19 +02:00 committed by nitnelave
parent 01d4b6e1fc
commit 3aaf53442b
9 changed files with 183 additions and 13 deletions

8
Cargo.lock generated
View File

@ -2138,6 +2138,7 @@ dependencies = [
"opaque-ke", "opaque-ke",
"orion", "orion",
"rand 0.8.5", "rand 0.8.5",
"reqwest",
"rustls 0.20.6", "rustls 0.20.6",
"rustls-pemfile", "rustls-pemfile",
"sea-query", "sea-query",
@ -2162,6 +2163,7 @@ dependencies = [
"tracing-log", "tracing-log",
"tracing-subscriber", "tracing-subscriber",
"uuid", "uuid",
"webpki-roots 0.21.1",
] ]
[[package]] [[package]]
@ -3018,9 +3020,9 @@ dependencies = [
[[package]] [[package]]
name = "reqwest" name = "reqwest"
version = "0.11.11" version = "0.11.12"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b75aa69a3f06bbcc66ede33af2af253c6f7a86b1ca0033f60c580a27074fbf92" checksum = "431949c384f4e2ae07605ccaa56d1d9d2ecdb5cadd4f9577ccfab29f2e5149fc"
dependencies = [ dependencies = [
"base64", "base64",
"bytes", "bytes",
@ -3034,9 +3036,9 @@ dependencies = [
"hyper-rustls", "hyper-rustls",
"ipnet", "ipnet",
"js-sys", "js-sys",
"lazy_static",
"log", "log",
"mime", "mime",
"once_cell",
"percent-encoding", "percent-encoding",
"pin-project-lite", "pin-project-lite",
"rustls 0.20.6", "rustls 0.20.6",

View File

@ -47,6 +47,7 @@ tracing-attributes = "^0.1.21"
tracing-log = "*" tracing-log = "*"
rustls-pemfile = "1.0.0" rustls-pemfile = "1.0.0"
serde_bytes = "0.11.7" serde_bytes = "0.11.7"
webpki-roots = "*"
[dependencies.chrono] [dependencies.chrono]
features = ["serde"] features = ["serde"]
@ -124,5 +125,10 @@ features = ["jpeg"]
default-features = false default-features = false
version = "0.24" version = "0.24"
[dependencies.reqwest]
version = "0.11"
default-features = false
features = ["rustls-tls-webpki-roots"]
[dev-dependencies] [dev-dependencies]
mockall = "0.9.1" mockall = "0.9.1"

View File

@ -20,6 +20,9 @@ pub enum Command {
/// Run the LDAP and GraphQL server. /// Run the LDAP and GraphQL server.
#[clap(name = "run")] #[clap(name = "run")]
Run(RunOpts), Run(RunOpts),
/// Test whether the LDAP and GraphQL server are responsive.
#[clap(name = "healthcheck")]
HealthCheck(RunOpts),
/// Send a test email. /// Send a test email.
#[clap(name = "send_test_email")] #[clap(name = "send_test_email")]
SendTestEmail(TestEmailOpts), SendTestEmail(TestEmailOpts),

View File

@ -0,0 +1,124 @@
use crate::infra::configuration::LdapsOptions;
use anyhow::{anyhow, bail, ensure, Context, Result};
use futures_util::SinkExt;
use ldap3_proto::{
proto::{
LdapDerefAliases, LdapFilter, LdapMsg, LdapOp, LdapSearchRequest, LdapSearchResultEntry,
LdapSearchScope,
},
LdapCodec,
};
use tokio::net::TcpStream;
use tokio_rustls::TlsConnector as RustlsTlsConnector;
use tokio_util::codec::{FramedRead, FramedWrite};
use tracing::{debug, info, instrument};
async fn check_ldap_endpoint<Stream>(stream: Stream) -> Result<()>
where
Stream: tokio::io::AsyncRead + tokio::io::AsyncWrite,
{
use tokio_stream::StreamExt;
let (r, w) = tokio::io::split(stream);
let mut requests = FramedRead::new(r, LdapCodec);
let mut resp = FramedWrite::new(w, LdapCodec);
resp.send(LdapMsg {
msgid: 0,
op: LdapOp::SearchRequest(LdapSearchRequest {
base: "".to_string(),
scope: LdapSearchScope::Base,
aliases: LdapDerefAliases::Never,
sizelimit: 0,
timelimit: 0,
typesonly: false,
filter: LdapFilter::Present("objectClass".to_string()),
attrs: vec!["supportedExtension".to_string()],
}),
ctrl: vec![],
})
.await?;
resp.flush().await?;
let no_answer = || anyhow!("No answer from LDAP server");
let invalid_answer = "Invalid answer from LDAP server";
let msg = requests
.next()
.await
.ok_or_else(no_answer)?
.context(invalid_answer)?;
debug!("Received message: {:?}", &msg);
match msg.op {
LdapOp::SearchResultEntry(LdapSearchResultEntry { dn, attributes }) => ensure!(
dn.is_empty()
&& attributes
.into_iter()
.any(|a| a.atype == "objectClass" && a.vals == vec![b"top".to_vec()]),
invalid_answer
),
_ => bail!(invalid_answer),
}
let msg = requests.next().await.ok_or_else(no_answer)??;
debug!("Received message: {:?}", &msg);
ensure!(
matches!(msg.op, LdapOp::SearchResultDone(_)),
invalid_answer
);
info!("Success");
Ok(())
}
#[instrument(skip_all, level = "info", err)]
pub async fn check_ldap(port: u16) -> Result<()> {
check_ldap_endpoint(TcpStream::connect(format!("localhost:{}", port)).await?).await
}
fn get_root_certificates() -> rustls::RootCertStore {
let mut root_store = rustls::RootCertStore::empty();
root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
}));
root_store
}
fn get_tls_connector() -> Result<RustlsTlsConnector> {
use rustls::ClientConfig;
let client_config = std::sync::Arc::new(
ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(get_root_certificates())
.with_no_client_auth(),
);
Ok(client_config.into())
}
#[instrument(skip_all, level = "info", err)]
pub async fn check_ldaps(ldaps_options: &LdapsOptions) -> Result<()> {
if !ldaps_options.enabled {
return Ok(());
};
let tls_connector = get_tls_connector()?;
let url = format!("localhost:{}", ldaps_options.port);
check_ldap_endpoint(
tls_connector
.connect(
rustls::ServerName::try_from(url.as_str())?,
TcpStream::connect(&url).await?,
)
.await?,
)
.await
}
#[instrument(skip_all, level = "info", err)]
pub async fn check_api(port: u16) -> Result<()> {
reqwest::get(format!("http://localhost:{}/health", port))
.await?
.error_for_status()?;
info!("Success");
Ok(())
}

View File

@ -633,6 +633,13 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
} }
pub async fn do_search_or_dse(&mut self, request: &LdapSearchRequest) -> Vec<LdapOp> { pub async fn do_search_or_dse(&mut self, request: &LdapSearchRequest) -> Vec<LdapOp> {
if request.base.is_empty()
&& request.scope == LdapSearchScope::Base
&& request.filter == LdapFilter::Present("objectClass".to_string())
{
debug!("rootDSE request");
return vec![root_dse_response(&self.base_dn_str), make_search_success()];
}
let user_info = match &self.user_info { let user_info = match &self.user_info {
None => { None => {
return vec![make_search_error( return vec![make_search_error(
@ -642,13 +649,6 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
} }
Some(u) => u, Some(u) => u,
}; };
if request.base.is_empty()
&& request.scope == LdapSearchScope::Base
&& request.filter == LdapFilter::Present("objectClass".to_string())
{
debug!("rootDSE request");
return vec![root_dse_response(&self.base_dn_str), make_search_success()];
}
let user_filter = if user_info.is_admin_or_readonly() { let user_filter = if user_info.is_admin_or_readonly() {
None None
} else { } else {

View File

@ -37,9 +37,9 @@ impl RootSpanBuilder for CustomRootSpanBuilder {
pub fn init(config: &Configuration) -> anyhow::Result<()> { pub fn init(config: &Configuration) -> anyhow::Result<()> {
let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| { let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| {
EnvFilter::new(if config.verbose { EnvFilter::new(if config.verbose {
"sqlx=warn,debug" "sqlx=warn,reqwest=warn,debug"
} else { } else {
"sqlx=warn,info" "sqlx=warn,reqwest=warn,info"
}) })
}); });
tracing_subscriber::registry() tracing_subscriber::registry()

View File

@ -3,6 +3,7 @@ pub mod cli;
pub mod configuration; pub mod configuration;
pub mod db_cleaner; pub mod db_cleaner;
pub mod graphql; pub mod graphql;
pub mod healthcheck;
pub mod jwt_sql_tables; pub mod jwt_sql_tables;
pub mod ldap_handler; pub mod ldap_handler;
pub mod ldap_server; pub mod ldap_server;

View File

@ -82,6 +82,7 @@ fn http_config<Backend>(
server_url, server_url,
mail_options, mail_options,
})) }))
.route("/health", web::get().to(|| HttpResponse::Ok().finish()))
.service(web::scope("/auth").configure(auth_service::configure_server::<Backend>)) .service(web::scope("/auth").configure(auth_service::configure_server::<Backend>))
// API endpoint. // API endpoint.
.service( .service(

View File

@ -2,6 +2,8 @@
#![forbid(non_ascii_idents)] #![forbid(non_ascii_idents)]
#![allow(clippy::nonstandard_macro_braces)] #![allow(clippy::nonstandard_macro_braces)]
use std::time::Duration;
use crate::{ use crate::{
domain::{ domain::{
handler::{BackendHandler, CreateUserRequest, GroupRequestFilter}, handler::{BackendHandler, CreateUserRequest, GroupRequestFilter},
@ -9,7 +11,7 @@ use crate::{
sql_opaque_handler::register_password, sql_opaque_handler::register_password,
sql_tables::PoolOptions, sql_tables::PoolOptions,
}, },
infra::{cli::*, configuration::Configuration, db_cleaner::Scheduler, mail}, infra::{cli::*, configuration::Configuration, db_cleaner::Scheduler, healthcheck, mail},
}; };
use actix::Actor; use actix::Actor;
use actix_server::ServerBuilder; use actix_server::ServerBuilder;
@ -132,11 +134,42 @@ fn send_test_email_command(opts: TestEmailOpts) -> Result<()> {
mail::send_test_email(to, &config.smtp_options) mail::send_test_email(to, &config.smtp_options)
} }
fn run_healthcheck(opts: RunOpts) -> Result<()> {
debug!("CLI: {:#?}", &opts);
let config = infra::configuration::init(opts)?;
infra::logging::init(&config)?;
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
use tokio::time::timeout;
let delay = Duration::from_millis(3000);
let (ldap, ldaps, api) = runtime.block_on(async {
tokio::join!(
timeout(delay, healthcheck::check_ldap(config.ldap_port)),
timeout(delay, healthcheck::check_ldaps(&config.ldaps_options)),
timeout(delay, healthcheck::check_api(config.http_port)),
)
});
let mut failure = false;
[ldap, ldaps, api]
.into_iter()
.filter_map(Result::err)
.for_each(|e| {
failure = true;
error!("{:#}", e)
});
std::process::exit(if failure { 1 } else { 0 })
}
fn main() -> Result<()> { fn main() -> Result<()> {
let cli_opts = infra::cli::init(); let cli_opts = infra::cli::init();
match cli_opts.command { match cli_opts.command {
Command::ExportGraphQLSchema(opts) => infra::graphql::api::export_schema(opts), Command::ExportGraphQLSchema(opts) => infra::graphql::api::export_schema(opts),
Command::Run(opts) => run_server_command(opts), Command::Run(opts) => run_server_command(opts),
Command::HealthCheck(opts) => run_healthcheck(opts),
Command::SendTestEmail(opts) => send_test_email_command(opts), Command::SendTestEmail(opts) => send_test_email_command(opts),
} }
} }