server: implement haveibeenpwned endpoint

See #39.
This commit is contained in:
Valentin Tolmer 2022-05-19 15:34:01 +02:00
parent 86b2b5148d
commit 278fb1630d
15 changed files with 389 additions and 45 deletions

14
Cargo.lock generated
View File

@ -2404,7 +2404,7 @@ dependencies = [
"tracing-forest", "tracing-forest",
"tracing-log", "tracing-log",
"tracing-subscriber", "tracing-subscriber",
"uuid 0.8.2", "uuid 1.3.0",
"webpki-roots", "webpki-roots",
] ]
@ -2418,6 +2418,7 @@ dependencies = [
"gloo-console", "gloo-console",
"gloo-file", "gloo-file",
"gloo-net", "gloo-net",
"gloo-timers",
"graphql_client 0.10.0", "graphql_client 0.10.0",
"http", "http",
"image", "image",
@ -2427,6 +2428,7 @@ dependencies = [
"rand 0.8.5", "rand 0.8.5",
"serde", "serde",
"serde_json", "serde_json",
"sha1",
"url-escape", "url-escape",
"validator", "validator",
"validator_derive", "validator_derive",
@ -2530,12 +2532,6 @@ dependencies = [
"digest 0.10.6", "digest 0.10.6",
] ]
[[package]]
name = "md5"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771"
[[package]] [[package]]
name = "memchr" name = "memchr"
version = "2.5.0" version = "2.5.0"
@ -4404,9 +4400,6 @@ name = "uuid"
version = "0.8.2" version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc5cf98d8186244414c848017f0e2676b3fcb46807f6668a97dfe67359a3c4b7" checksum = "bc5cf98d8186244414c848017f0e2676b3fcb46807f6668a97dfe67359a3c4b7"
dependencies = [
"md5",
]
[[package]] [[package]]
name = "uuid" name = "uuid"
@ -4415,6 +4408,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1674845326ee10d37ca60470760d4288a6f80f304007d92e5c53bab78c9cfd79" checksum = "1674845326ee10d37ca60470760d4288a6f80f304007d92e5c53bab78c9cfd79"
dependencies = [ dependencies = [
"getrandom 0.2.8", "getrandom 0.2.8",
"md-5",
] ]
[[package]] [[package]]

View File

@ -19,6 +19,7 @@ serde = "1"
serde_json = "1" serde_json = "1"
url-escape = "0.1.1" url-escape = "0.1.1"
validator = "=0.14" validator = "=0.14"
sha1 = "*"
validator_derive = "*" validator_derive = "*"
wasm-bindgen = "0.2" wasm-bindgen = "0.2"
wasm-bindgen-futures = "*" wasm-bindgen-futures = "*"
@ -27,6 +28,7 @@ yew-router = "0.16"
# Needed because of https://github.com/tkaitchuck/aHash/issues/95 # Needed because of https://github.com/tkaitchuck/aHash/issues/95
indexmap = "=1.6.2" indexmap = "=1.6.2"
gloo-timers = "0.2.6"
[dependencies.web-sys] [dependencies.web-sys]
version = "0.3" version = "0.3"

View File

@ -1,4 +1,5 @@
use crate::{ use crate::{
components::password_field::PasswordField,
components::router::{AppRoute, Link}, components::router::{AppRoute, Link},
infra::{ infra::{
api::HostService, api::HostService,
@ -254,14 +255,12 @@ impl Component for ChangePasswordForm {
{":"} {":"}
</label> </label>
<div class="col-sm-10"> <div class="col-sm-10">
<Field <PasswordField<FormModel>
form={&self.form} form={&self.form}
field_name="password" field_name="password"
input_type="password"
class="form-control" class="form-control"
class_invalid="is-invalid has-error" class_invalid="is-invalid has-error"
class_valid="has-success" class_valid="has-success"
autocomplete="new-password"
oninput={link.callback(|_| Msg::FormUpdate)} /> oninput={link.callback(|_| Msg::FormUpdate)} />
<div class="invalid-feedback"> <div class="invalid-feedback">
{&self.form.field_message("password")} {&self.form.field_message("password")}

View File

@ -149,9 +149,9 @@ impl Component for LoginForm {
let link = &ctx.link(); let link = &ctx.link();
if self.refreshing { if self.refreshing {
html! { html! {
<div> <div class="spinner-border" role="status">
<img src={"spinner.gif"} alt={"Loading"} /> <span class="sr-only">{"Loading..."}</span>
</div> </div>
} }
} else { } else {
html! { html! {

View File

@ -10,6 +10,7 @@ pub mod group_details;
pub mod group_table; pub mod group_table;
pub mod login; pub mod login;
pub mod logout; pub mod logout;
pub mod password_field;
pub mod remove_user_from_group; pub mod remove_user_from_group;
pub mod reset_password_step1; pub mod reset_password_step1;
pub mod reset_password_step2; pub mod reset_password_step2;

View File

@ -0,0 +1,152 @@
use crate::infra::{
api::{hash_password, HostService, PasswordHash, PasswordWasLeaked},
common_component::{CommonComponent, CommonComponentParts},
};
use anyhow::Result;
use gloo_timers::callback::Timeout;
use web_sys::{HtmlInputElement, InputEvent};
use yew::{html, Callback, Classes, Component, Context, Properties};
use yew_form::{Field, Form, Model};
pub enum PasswordFieldMsg {
OnInput(String),
OnInputIdle,
PasswordCheckResult(Result<(Option<PasswordWasLeaked>, PasswordHash)>),
}
#[derive(PartialEq)]
pub enum PasswordState {
// Whether the password was found in a leak.
Checked(PasswordWasLeaked),
// Server doesn't support checking passwords (TODO: move to config).
NotSupported,
// Requested a check, no response yet from the server.
Loading,
// User is still actively typing.
Typing,
}
pub struct PasswordField<FormModel: Model> {
common: CommonComponentParts<Self>,
timeout_task: Option<Timeout>,
password: String,
password_check_state: PasswordState,
_marker: std::marker::PhantomData<FormModel>,
}
impl<FormModel: Model> CommonComponent<PasswordField<FormModel>> for PasswordField<FormModel> {
fn handle_msg(
&mut self,
ctx: &Context<Self>,
msg: <Self as Component>::Message,
) -> anyhow::Result<bool> {
match msg {
PasswordFieldMsg::OnInput(password) => {
self.password = password;
if self.password_check_state != PasswordState::NotSupported {
self.password_check_state = PasswordState::Typing;
if self.password.len() >= 8 {
let link = ctx.link().clone();
self.timeout_task = Some(Timeout::new(500, move || {
link.send_message(PasswordFieldMsg::OnInputIdle)
}));
}
}
}
PasswordFieldMsg::PasswordCheckResult(result) => {
self.timeout_task = None;
// If there's an error from the backend, don't retry.
self.password_check_state = PasswordState::NotSupported;
if let (Some(check), hash) = result? {
if hash == hash_password(&self.password) {
self.password_check_state = PasswordState::Checked(check)
}
}
}
PasswordFieldMsg::OnInputIdle => {
self.timeout_task = None;
if self.password_check_state != PasswordState::NotSupported {
self.password_check_state = PasswordState::Loading;
self.common.call_backend(
ctx,
HostService::check_password_haveibeenpwned(hash_password(&self.password)),
PasswordFieldMsg::PasswordCheckResult,
);
}
}
}
Ok(true)
}
fn mut_common(&mut self) -> &mut CommonComponentParts<PasswordField<FormModel>> {
&mut self.common
}
}
#[derive(Properties, PartialEq, Clone)]
pub struct PasswordFieldProperties<FormModel: Model> {
pub field_name: String,
pub form: Form<FormModel>,
#[prop_or_else(|| { "form-control".into() })]
pub class: Classes,
#[prop_or_else(|| { "is-invalid".into() })]
pub class_invalid: Classes,
#[prop_or_else(|| { "is-valid".into() })]
pub class_valid: Classes,
#[prop_or_else(Callback::noop)]
pub oninput: Callback<String>,
}
impl<FormModel: Model> Component for PasswordField<FormModel> {
type Message = PasswordFieldMsg;
type Properties = PasswordFieldProperties<FormModel>;
fn create(_: &Context<Self>) -> Self {
Self {
common: CommonComponentParts::<Self>::create(),
timeout_task: None,
password: String::new(),
password_check_state: PasswordState::Typing,
_marker: std::marker::PhantomData,
}
}
fn update(&mut self, ctx: &Context<Self>, msg: Self::Message) -> bool {
CommonComponentParts::<Self>::update(self, ctx, msg)
}
fn view(&self, ctx: &Context<Self>) -> yew::Html {
let link = &ctx.link();
html! {
<div>
<Field<FormModel>
autocomplete={"new-password"}
input_type={"password"}
field_name={ctx.props().field_name.clone()}
form={ctx.props().form.clone()}
class={ctx.props().class.clone()}
class_invalid={ctx.props().class_invalid.clone()}
class_valid={ctx.props().class_valid.clone()}
oninput={link.callback(|e: InputEvent| {
use wasm_bindgen::JsCast;
let target = e.target().unwrap();
let input = target.dyn_into::<HtmlInputElement>().unwrap();
PasswordFieldMsg::OnInput(input.value())
})} />
{
match self.password_check_state {
PasswordState::Checked(PasswordWasLeaked(true)) => html! { <i class="bi bi-x"></i> },
PasswordState::Checked(PasswordWasLeaked(false)) => html! { <i class="bi bi-check"></i> },
PasswordState::NotSupported | PasswordState::Typing => html!{},
PasswordState::Loading =>
html! {
<div class="spinner-border spinner-border-sm" role="status">
<span class="sr-only">{"Loading..."}</span>
</div>
},
}
}
</div>
}
}
}

View File

@ -1,5 +1,8 @@
use crate::{ use crate::{
components::router::{AppRoute, Link}, components::{
password_field::PasswordField,
router::{AppRoute, Link},
},
infra::{ infra::{
api::HostService, api::HostService,
common_component::{CommonComponent, CommonComponentParts}, common_component::{CommonComponent, CommonComponentParts},
@ -176,14 +179,12 @@ impl Component for ResetPasswordStep2Form {
{"New password*:"} {"New password*:"}
</label> </label>
<div class="col-sm-10"> <div class="col-sm-10">
<Field <PasswordField<FormModel>
form={&self.form} form={&self.form}
field_name="password" field_name="password"
class="form-control" class="form-control"
class_invalid="is-invalid has-error" class_invalid="is-invalid has-error"
class_valid="has-success" class_valid="has-success"
autocomplete="new-password"
input_type="password"
oninput={link.callback(|_| Msg::FormUpdate)} /> oninput={link.callback(|_| Msg::FormUpdate)} />
<div class="invalid-feedback"> <div class="invalid-feedback">
{&self.form.field_message("password")} {&self.form.field_message("password")}

View File

@ -1,4 +1,4 @@
use super::cookies::set_cookie; use crate::infra::cookies::set_cookie;
use anyhow::{anyhow, Context, Result}; use anyhow::{anyhow, Context, Result};
use gloo_net::http::{Method, Request}; use gloo_net::http::{Method, Request};
use graphql_client::GraphQLQuery; use graphql_client::GraphQLQuery;
@ -74,6 +74,19 @@ fn set_cookies_from_jwt(response: login::ServerLoginResponse) -> Result<(String,
.context("Error setting cookie") .context("Error setting cookie")
} }
#[derive(PartialEq)]
pub struct PasswordHash(String);
#[derive(PartialEq)]
pub struct PasswordWasLeaked(pub bool);
pub fn hash_password(password: &str) -> PasswordHash {
use sha1::{Digest, Sha1};
let mut hasher = Sha1::new();
hasher.update(password);
PasswordHash(format!("{:X}", hasher.finalize()))
}
impl HostService { impl HostService {
pub async fn graphql_query<QueryType>( pub async fn graphql_query<QueryType>(
variables: QueryType::Variables, variables: QueryType::Variables,
@ -194,4 +207,35 @@ impl HostService {
!= http::StatusCode::NOT_FOUND, != http::StatusCode::NOT_FOUND,
) )
} }
pub async fn check_password_haveibeenpwned(
password_hash: PasswordHash,
) -> Result<(Option<PasswordWasLeaked>, PasswordHash)> {
use lldap_auth::password_reset::*;
let hash_prefix = &password_hash.0[0..5];
match call_server_json_with_error_message::<PasswordHashList, _>(
&format!("/auth/password/check/{}", hash_prefix),
NO_BODY,
"Could not validate token",
)
.await
{
Ok(r) => {
for PasswordHashCount { hash, count } in r.hashes {
if password_hash.0[5..] == hash && count != 0 {
return Ok((Some(PasswordWasLeaked(true)), password_hash));
}
}
Ok((Some(PasswordWasLeaked(false)), password_hash))
}
Err(e) => {
if e.to_string().contains("[501]:") {
// Unimplemented, no API key.
Ok((None, password_hash))
} else {
Err(e)
}
}
}
}
} }

Binary file not shown.

Before

Width:  |  Height:  |  Size: 44 KiB

View File

@ -102,6 +102,17 @@ pub mod password_reset {
pub user_id: String, pub user_id: String,
pub token: String, pub token: String,
} }
#[derive(Serialize, Deserialize, Clone)]
pub struct PasswordHashCount {
pub hash: String,
pub count: u64,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct PasswordHashList {
pub hashes: Vec<PasswordHashCount>,
}
} }
#[derive(Clone, Serialize, Deserialize)] #[derive(Clone, Serialize, Deserialize)]

View File

@ -59,7 +59,6 @@ version = "4"
[dependencies.figment] [dependencies.figment]
features = ["env", "toml"] features = ["env", "toml"]
version = "*" version = "*"
[dependencies.tracing-subscriber] [dependencies.tracing-subscriber]
version = "0.3" version = "0.3"
features = ["env-filter", "tracing-log"] features = ["env-filter", "tracing-log"]

View File

@ -1,21 +1,22 @@
use std::collections::{hash_map::DefaultHasher, HashSet}; use std::collections::{hash_map::DefaultHasher, HashSet};
use std::hash::{Hash, Hasher}; use std::hash::{Hash, Hasher};
use std::pin::Pin; use std::pin::Pin;
use std::task::{Context, Poll}; use std::task::Poll;
use actix_web::{ use actix_web::{
cookie::{Cookie, SameSite}, cookie::{Cookie, SameSite},
dev::{Service, ServiceRequest, ServiceResponse, Transform}, dev::{Service, ServiceRequest, ServiceResponse, Transform},
error::{ErrorBadRequest, ErrorUnauthorized}, error::{ErrorBadRequest, ErrorUnauthorized},
web, HttpRequest, HttpResponse, web, FromRequest, HttpRequest, HttpResponse,
}; };
use actix_web_httpauth::extractors::bearer::BearerAuth; use actix_web_httpauth::extractors::bearer::BearerAuth;
use anyhow::Result; use anyhow::{bail, Context, Result};
use chrono::prelude::*; use chrono::prelude::*;
use futures::future::{ok, Ready}; use futures::future::{ok, Ready};
use futures_util::FutureExt; use futures_util::FutureExt;
use hmac::Hmac; use hmac::Hmac;
use jwt::{SignWithKey, VerifyWithKey}; use jwt::{SignWithKey, VerifyWithKey};
use secstr::SecUtf8;
use sha2::Sha512; use sha2::Sha512;
use time::ext::NumericalDuration; use time::ext::NumericalDuration;
use tracing::{debug, info, instrument, warn}; use tracing::{debug, info, instrument, warn};
@ -205,6 +206,24 @@ where
.unwrap_or_else(error_to_http_response) .unwrap_or_else(error_to_http_response)
} }
async fn check_password_reset_token<'a, Backend>(
backend_handler: &Backend,
token: &Option<&'a str>,
) -> TcpResult<Option<(&'a str, UserId)>>
where
Backend: TcpBackendHandler + 'static,
{
let token = match token {
None => return Ok(None),
Some(token) => token,
};
let user_id = backend_handler
.get_user_id_for_password_reset_token(token)
.await
.map_err(|_| TcpError::UnauthorizedError("Invalid or expired token".to_string()))?;
Ok(Some((token, user_id)))
}
#[instrument(skip_all, level = "debug")] #[instrument(skip_all, level = "debug")]
async fn get_password_reset_step2<Backend>( async fn get_password_reset_step2<Backend>(
data: web::Data<AppState<Backend>>, data: web::Data<AppState<Backend>>,
@ -213,22 +232,12 @@ async fn get_password_reset_step2<Backend>(
where where
Backend: TcpBackendHandler + BackendHandler + 'static, Backend: TcpBackendHandler + BackendHandler + 'static,
{ {
let token = request let tcp_handler = data.get_tcp_handler();
.match_info() let (token, user_id) =
.get("token") check_password_reset_token(tcp_handler, &request.match_info().get("token"))
.ok_or_else(|| TcpError::BadRequest("Missing reset token".to_owned()))?; .await?
let user_id = data .ok_or_else(|| TcpError::BadRequest("Missing token".to_string()))?;
.get_tcp_handler() let _ = tcp_handler.delete_password_reset_token(token).await;
.get_user_id_for_password_reset_token(token)
.await
.map_err(|e| {
debug!("Reset token error: {e:#}");
TcpError::NotFoundError("Wrong or expired reset token".to_owned())
})?;
let _ = data
.get_tcp_handler()
.delete_password_reset_token(token)
.await;
let groups = HashSet::new(); let groups = HashSet::new();
let token = create_jwt(&data.jwt_key, user_id.to_string(), groups); let token = create_jwt(&data.jwt_key, user_id.to_string(), groups);
Ok(HttpResponse::Ok() Ok(HttpResponse::Ok()
@ -403,6 +412,7 @@ where
Backend: TcpBackendHandler + BackendHandler + OpaqueHandler + LoginHandler + 'static, Backend: TcpBackendHandler + BackendHandler + OpaqueHandler + LoginHandler + 'static,
{ {
let user_id = UserId::new(&request.username); let user_id = UserId::new(&request.username);
debug!(?user_id);
let bind_request = BindRequest { let bind_request = BindRequest {
name: user_id.clone(), name: user_id.clone(),
password: request.password.clone(), password: request.password.clone(),
@ -449,6 +459,115 @@ where
.unwrap_or_else(error_to_http_response) .unwrap_or_else(error_to_http_response)
} }
// Parse the response from the HaveIBeenPwned API. Sample response:
//
// 0018A45C4D1DEF81644B54AB7F969B88D65:1
// 00D4F6E8FA6EECAD2A3AA415EEC418D38EC:2
// 011053FD0102E94D6AE2F8B83D76FAF94F6:13
fn parse_hash_list(response: &str) -> Result<password_reset::PasswordHashList> {
use password_reset::*;
let parse_line = |line: &str| -> Result<PasswordHashCount> {
let split = line.trim().split(':').collect::<Vec<_>>();
if let [hash, count] = &split[..] {
if hash.len() == 35 {
if let Ok(count) = str::parse::<u64>(count) {
return Ok(PasswordHashCount {
hash: hash.to_string(),
count,
});
}
}
}
bail!("Invalid password hash from API: {}", line)
};
Ok(PasswordHashList {
hashes: response
.split('\n')
.map(parse_line)
.collect::<Result<Vec<_>>>()?,
})
}
// TODO: Refactor that for testing.
async fn get_password_hash_list(
hash: &str,
api_key: &SecUtf8,
) -> Result<password_reset::PasswordHashList> {
use reqwest::*;
let client = Client::new();
let resp = client
.get(format!("https://api.pwnedpasswords.com/range/{}", hash))
.header(header::USER_AGENT, "LLDAP")
.header("hibp-api-key", api_key.unsecure())
.send()
.await
.context("Could not get response from HIPB")?
.text()
.await?;
parse_hash_list(&resp).context("Invalid HIPB response")
}
async fn check_password_pwned<Backend>(
data: web::Data<AppState<Backend>>,
request: HttpRequest,
payload: web::Payload,
) -> TcpResult<HttpResponse>
where
Backend: TcpBackendHandler + BackendHandler + OpaqueHandler + 'static,
{
let has_reset_token = check_password_reset_token(
data.get_tcp_handler(),
&request
.headers()
.get("reset-token")
.map(|v| v.to_str().unwrap()),
)
.await?
.is_some();
let inner_payload = &mut payload.into_inner();
if !has_reset_token
&& BearerAuth::from_request(&request, inner_payload)
.await
.ok()
.and_then(|bearer| check_if_token_is_valid(&data, bearer.token()).ok())
.is_none()
{
return Err(TcpError::UnauthorizedError(
"No token or invalid token".to_string(),
));
}
if data.hipb_api_key.unsecure().is_empty() {
return Err(TcpError::NotImplemented("No HIPB API key".to_string()));
}
let hash = request
.match_info()
.get("hash")
.ok_or_else(|| TcpError::BadRequest("Missing hash".to_string()))?;
if hash.len() != 5 || !hash.chars().all(|c| c.is_ascii_hexdigit()) {
return Err(TcpError::BadRequest(format!(
"Bad request: invalid hash format \"{}\"",
hash
)));
}
get_password_hash_list(hash, &data.hipb_api_key)
.await
.map(|hashes| HttpResponse::Ok().json(hashes))
.map_err(|e| TcpError::InternalServerError(e.to_string()))
}
async fn check_password_pwned_handler<Backend>(
data: web::Data<AppState<Backend>>,
request: HttpRequest,
payload: web::Payload,
) -> HttpResponse
where
Backend: TcpBackendHandler + BackendHandler + OpaqueHandler + 'static,
{
check_password_pwned(data, request, payload)
.await
.unwrap_or_else(error_to_http_response)
}
#[instrument(skip_all, level = "debug")] #[instrument(skip_all, level = "debug")]
async fn opaque_register_start<Backend>( async fn opaque_register_start<Backend>(
request: actix_web::HttpRequest, request: actix_web::HttpRequest,
@ -565,7 +684,7 @@ where
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
type Future = Pin<Box<dyn core::future::Future<Output = Result<Self::Response, Self::Error>>>>; type Future = Pin<Box<dyn core::future::Future<Output = Result<Self::Response, Self::Error>>>>;
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx) self.service.poll_ready(cx)
} }
@ -636,6 +755,11 @@ where
web::resource("/simple/login").route(web::post().to(simple_login_handler::<Backend>)), web::resource("/simple/login").route(web::post().to(simple_login_handler::<Backend>)),
) )
.service(web::resource("/refresh").route(web::get().to(get_refresh_handler::<Backend>))) .service(web::resource("/refresh").route(web::get().to(get_refresh_handler::<Backend>)))
.service(
web::resource("/password/check/{hash}")
.wrap(CookieToHeaderTranslatorFactory)
.route(web::get().to(check_password_pwned_handler::<Backend>)),
)
.service(web::resource("/logout").route(web::get().to(get_logout_handler::<Backend>))) .service(web::resource("/logout").route(web::get().to(get_logout_handler::<Backend>)))
.service( .service(
web::scope("/opaque/register") web::scope("/opaque/register")

View File

@ -81,6 +81,10 @@ pub struct RunOpts {
#[clap(short, long, env = "LLDAP_DATABASE_URL")] #[clap(short, long, env = "LLDAP_DATABASE_URL")]
pub database_url: Option<String>, pub database_url: Option<String>,
/// HaveIBeenPwned API key, to check passwords against leaks.
#[clap(long, env = "LLDAP_HIPB_API_KEY")]
pub hipb_api_key: Option<String>,
#[clap(flatten)] #[clap(flatten)]
pub smtp_opts: SmtpOpts, pub smtp_opts: SmtpOpts,

View File

@ -98,6 +98,8 @@ pub struct Configuration {
pub ldaps_options: LdapsOptions, pub ldaps_options: LdapsOptions,
#[builder(default = r#"String::from("http://localhost")"#)] #[builder(default = r#"String::from("http://localhost")"#)]
pub http_url: String, pub http_url: String,
#[builder(default = r#"SecUtf8::from("")"#)]
pub hipb_api_key: SecUtf8,
#[serde(skip)] #[serde(skip)]
#[builder(field(private), default = "None")] #[builder(field(private), default = "None")]
server_setup: Option<ServerSetup>, server_setup: Option<ServerSetup>,
@ -213,6 +215,10 @@ impl ConfigOverrider for RunOpts {
if let Some(database_url) = self.database_url.as_ref() { if let Some(database_url) = self.database_url.as_ref() {
config.database_url = database_url.to_string(); config.database_url = database_url.to_string();
} }
if let Some(api_key) = self.hipb_api_key.as_ref() {
config.hipb_api_key = SecUtf8::from(api_key.clone());
}
self.smtp_opts.override_config(config); self.smtp_opts.override_config(config);
self.ldaps_opts.override_config(config); self.ldaps_opts.override_config(config);
} }

View File

@ -19,6 +19,7 @@ use actix_service::map_config;
use actix_web::{dev::AppConfig, guard, web, App, HttpResponse, Responder}; use actix_web::{dev::AppConfig, guard, web, App, HttpResponse, Responder};
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use hmac::Hmac; use hmac::Hmac;
use secstr::SecUtf8;
use sha2::Sha512; use sha2::Sha512;
use std::collections::HashSet; use std::collections::HashSet;
use std::path::PathBuf; use std::path::PathBuf;
@ -38,10 +39,10 @@ pub enum TcpError {
BadRequest(String), BadRequest(String),
#[error("Internal server error: `{0}`")] #[error("Internal server error: `{0}`")]
InternalServerError(String), InternalServerError(String),
#[error("Not found: `{0}`")]
NotFoundError(String),
#[error("Unauthorized: `{0}`")] #[error("Unauthorized: `{0}`")]
UnauthorizedError(String), UnauthorizedError(String),
#[error("Not implemented: `{0}`")]
NotImplemented(String),
} }
pub type TcpResult<T> = std::result::Result<T, TcpError>; pub type TcpResult<T> = std::result::Result<T, TcpError>;
@ -60,9 +61,9 @@ pub(crate) fn error_to_http_response(error: TcpError) -> HttpResponse {
| DomainError::EntityNotFound(_) => HttpResponse::BadRequest(), | DomainError::EntityNotFound(_) => HttpResponse::BadRequest(),
}, },
TcpError::BadRequest(_) => HttpResponse::BadRequest(), TcpError::BadRequest(_) => HttpResponse::BadRequest(),
TcpError::NotFoundError(_) => HttpResponse::NotFound(),
TcpError::InternalServerError(_) => HttpResponse::InternalServerError(), TcpError::InternalServerError(_) => HttpResponse::InternalServerError(),
TcpError::UnauthorizedError(_) => HttpResponse::Unauthorized(), TcpError::UnauthorizedError(_) => HttpResponse::Unauthorized(),
TcpError::NotImplemented(_) => HttpResponse::NotImplemented(),
} }
.body(error.to_string()) .body(error.to_string())
} }
@ -88,6 +89,7 @@ fn http_config<Backend>(
jwt_blacklist: HashSet<u64>, jwt_blacklist: HashSet<u64>,
server_url: String, server_url: String,
mail_options: MailOptions, mail_options: MailOptions,
hipb_api_key: SecUtf8,
) where ) where
Backend: TcpBackendHandler + BackendHandler + LoginHandler + OpaqueHandler + Clone + 'static, Backend: TcpBackendHandler + BackendHandler + LoginHandler + OpaqueHandler + Clone + 'static,
{ {
@ -98,6 +100,7 @@ fn http_config<Backend>(
jwt_blacklist: RwLock::new(jwt_blacklist), jwt_blacklist: RwLock::new(jwt_blacklist),
server_url, server_url,
mail_options, mail_options,
hipb_api_key,
})) }))
.route( .route(
"/health", "/health",
@ -133,6 +136,7 @@ pub(crate) struct AppState<Backend> {
pub jwt_blacklist: RwLock<HashSet<u64>>, pub jwt_blacklist: RwLock<HashSet<u64>>,
pub server_url: String, pub server_url: String,
pub mail_options: MailOptions, pub mail_options: MailOptions,
pub hipb_api_key: SecUtf8,
} }
impl<Backend: BackendHandler> AppState<Backend> { impl<Backend: BackendHandler> AppState<Backend> {
@ -173,6 +177,7 @@ where
let mail_options = config.smtp_options.clone(); let mail_options = config.smtp_options.clone();
let verbose = config.verbose; let verbose = config.verbose;
info!("Starting the API/web server on port {}", config.http_port); info!("Starting the API/web server on port {}", config.http_port);
let hipb_api_key = config.hipb_api_key.clone();
server_builder server_builder
.bind( .bind(
"http", "http",
@ -183,6 +188,7 @@ where
let jwt_blacklist = jwt_blacklist.clone(); let jwt_blacklist = jwt_blacklist.clone();
let server_url = server_url.clone(); let server_url = server_url.clone();
let mail_options = mail_options.clone(); let mail_options = mail_options.clone();
let hipb_api_key = hipb_api_key.clone();
HttpServiceBuilder::default() HttpServiceBuilder::default()
.finish(map_config( .finish(map_config(
App::new() App::new()
@ -198,6 +204,7 @@ where
jwt_blacklist, jwt_blacklist,
server_url, server_url,
mail_options, mail_options,
hipb_api_key,
) )
}), }),
|_| AppConfig::default(), |_| AppConfig::default(),