Merge branch 'main' into patch-2

This commit is contained in:
nitnelave 2022-06-30 17:43:54 +02:00 committed by GitHub
commit db1a4bf429
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 774 additions and 445 deletions

50
Cargo.lock generated
View File

@ -1998,6 +1998,8 @@ dependencies = [
"tokio-util",
"tracing",
"tracing-actix-web",
"tracing-attributes",
"tracing-forest",
"tracing-log",
"tracing-subscriber",
]
@ -2092,6 +2094,15 @@ version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ffbee8634e0d45d258acb448e7eaab3fce7a0a467395d4d9f228e3c1f01fb2e4"
[[package]]
name = "matchers"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558"
dependencies = [
"regex-automata",
]
[[package]]
name = "matches"
version = "0.1.9"
@ -2838,6 +2849,15 @@ dependencies = [
"regex-syntax",
]
[[package]]
name = "regex-automata"
version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132"
dependencies = [
"regex-syntax",
]
[[package]]
name = "regex-syntax"
version = "0.6.25"
@ -3546,18 +3566,18 @@ checksum = "b1141d4d61095b28419e22cb0bbf02755f5e54e0526f97f1e3d1d160e60885fb"
[[package]]
name = "thiserror"
version = "1.0.30"
version = "1.0.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "854babe52e4df1653706b98fcfc05843010039b406875930a70e4d9644e5c417"
checksum = "bd829fe32373d27f76265620b5309d0340cb8550f523c1dda251d6298069069a"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.30"
version = "1.0.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aa32fd3f627f367fe16f893e2597ae3c05020f8bba2666a4e6ea73d377e5714b"
checksum = "0396bc89e626244658bef819e22d0cc459e795a5ebe878e6ec336d1674a8d79a"
dependencies = [
"proc-macro2",
"quote",
@ -3744,9 +3764,9 @@ dependencies = [
[[package]]
name = "tracing-attributes"
version = "0.1.15"
version = "0.1.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c42e6fa53307c8a17e4ccd4dc81cf5ec38db9209f59b222210375b54ee40d1e2"
checksum = "cc6b8ad3567499f98a1db7a752b07a7c8c7c7c34c332ec00effb2b0027974b7c"
dependencies = [
"proc-macro2",
"quote",
@ -3762,6 +3782,20 @@ dependencies = [
"lazy_static",
]
[[package]]
name = "tracing-forest"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5db74d83f3fcda3ca1355dd91294098df02cc03d54e6cce81e40a18671c3fd7a"
dependencies = [
"chrono",
"smallvec",
"thiserror",
"tokio",
"tracing",
"tracing-subscriber",
]
[[package]]
name = "tracing-futures"
version = "0.2.5"
@ -3790,9 +3824,13 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "80a4ddde70311d8da398062ecf6fc2c309337de6b0f77d6c27aff8d53f6fca52"
dependencies = [
"ansi_term",
"lazy_static",
"matchers",
"regex",
"sharded-slab",
"smallvec",
"thread_local",
"tracing",
"tracing-core",
"tracing-log",
]

View File

@ -211,8 +211,8 @@ impl Component for ChangePasswordForm {
CommonComponentParts::<Self>::update(self, msg)
}
fn change(&mut self, _: Self::Properties) -> ShouldRender {
false
fn change(&mut self, props: Self::Properties) -> ShouldRender {
self.common.change(props)
}
fn view(&self) -> Html {

View File

@ -92,8 +92,8 @@ impl Component for CreateGroupForm {
CommonComponentParts::<Self>::update(self, msg)
}
fn change(&mut self, _: Self::Properties) -> ShouldRender {
false
fn change(&mut self, props: Self::Properties) -> ShouldRender {
self.common.change(props)
}
fn view(&self) -> Html {

View File

@ -185,8 +185,8 @@ impl Component for CreateUserForm {
CommonComponentParts::<Self>::update(self, msg)
}
fn change(&mut self, _: Self::Properties) -> ShouldRender {
false
fn change(&mut self, props: Self::Properties) -> ShouldRender {
self.common.change(props)
}
fn view(&self) -> Html {

View File

@ -190,8 +190,8 @@ impl Component for GroupDetails {
CommonComponentParts::<Self>::update(self, msg)
}
fn change(&mut self, _: Self::Properties) -> ShouldRender {
false
fn change(&mut self, props: Self::Properties) -> ShouldRender {
self.common.change(props)
}
fn view(&self) -> Html {

View File

@ -75,8 +75,8 @@ impl Component for GroupTable {
CommonComponentParts::<Self>::update(self, msg)
}
fn change(&mut self, _: Self::Properties) -> ShouldRender {
false
fn change(&mut self, props: Self::Properties) -> ShouldRender {
self.common.change(props)
}
fn view(&self) -> Html {

View File

@ -141,8 +141,8 @@ impl Component for LoginForm {
CommonComponentParts::<Self>::update(self, msg)
}
fn change(&mut self, _: Self::Properties) -> ShouldRender {
false
fn change(&mut self, props: Self::Properties) -> ShouldRender {
self.common.change(props)
}
fn view(&self) -> Html {

View File

@ -55,8 +55,8 @@ impl Component for LogoutButton {
CommonComponentParts::<Self>::update(self, msg)
}
fn change(&mut self, _: Self::Properties) -> ShouldRender {
false
fn change(&mut self, props: Self::Properties) -> ShouldRender {
self.common.change(props)
}
fn view(&self) -> Html {

View File

@ -81,8 +81,8 @@ impl Component for RemoveUserFromGroupComponent {
)
}
fn change(&mut self, _: Self::Properties) -> ShouldRender {
false
fn change(&mut self, props: Self::Properties) -> ShouldRender {
self.common.change(props)
}
fn view(&self) -> Html {

View File

@ -76,8 +76,8 @@ impl Component for ResetPasswordStep1Form {
CommonComponentParts::<Self>::update(self, msg)
}
fn change(&mut self, _: Self::Properties) -> ShouldRender {
false
fn change(&mut self, props: Self::Properties) -> ShouldRender {
self.common.change(props)
}
fn view(&self) -> Html {

View File

@ -145,8 +145,8 @@ impl Component for ResetPasswordStep2Form {
CommonComponentParts::<Self>::update(self, msg)
}
fn change(&mut self, _: Self::Properties) -> ShouldRender {
false
fn change(&mut self, props: Self::Properties) -> ShouldRender {
self.common.change(props)
}
fn view(&self) -> Html {

View File

@ -185,8 +185,8 @@ impl Component for UserDetails {
CommonComponentParts::<Self>::update(self, msg)
}
fn change(&mut self, _: Self::Properties) -> ShouldRender {
false
fn change(&mut self, props: Self::Properties) -> ShouldRender {
self.common.change(props)
}
fn view(&self) -> Html {

View File

@ -96,8 +96,8 @@ impl Component for UserDetailsForm {
)
}
fn change(&mut self, _: Self::Properties) -> ShouldRender {
false
fn change(&mut self, props: Self::Properties) -> ShouldRender {
self.common.change(props)
}
fn view(&self) -> Html {

View File

@ -81,8 +81,8 @@ impl Component for UserTable {
CommonComponentParts::<Self>::update(self, msg)
}
fn change(&mut self, _: Self::Properties) -> ShouldRender {
false
fn change(&mut self, props: Self::Properties) -> ShouldRender {
self.common.change(props)
}
fn view(&self) -> Html {

View File

@ -4,6 +4,7 @@ This was achieved by using the docker [jasonbean/guacamole](https://registry.hub
## To setup LDAP
### Using `guacamole.properties`
Open and edit your Apache Guacamole properties files
Located at `guacamole/guacamole.properties`
@ -22,9 +23,26 @@ ldap-search-bind-password: replacewithyoursecret
ldap-user-search-filter: (memberof=cn=lldap_apacheguac,ou=groups,dc=example,dc=com)
```
* Exclude `ldap-user-search-filter` if you do not want to limit users based on a group(s)
### Using docker variables
```
LDAP_HOSTNAME: localhost
LDAP_PORT: 3890
LDAP_ENCRYPTION_METHOD: none
LDAP_USER_BASE_DN: ou=people,dc=example,dc=com
LDAP_USERNAME_ATTRIBUTE: uid
LDAP_SEARCH_BIND_DN: uid=admin,ou=people,dc=example,dc=com
LDAP_SEARCH_BIND_PASSWORD: replacewithyoursecret
LDAP_USER_SEARCH_FILTER: (memberof=cn=lldap_guacamole,ou=groups,dc=example,dc=com)
```
### Notes
* You set it either through `guacamole.properties` or docker variables, not both.
* Exclude `ldap-user-search-filter/LDAP_USER_SEARCH_FILTER` if you do not want to limit users based on a group(s)
* it is a filter that permits users with `lldap_guacamole` sample group.
* Replace `dc=example,dc=com` with your LLDAP configured domain for all occurances
* Apache Guacamole does not lock you out when enabling LDAP. Your `static` IDs still are able to log in.
* setting `LDAP_ENCRYPTION_METHOD` is disabling SSL
## To enable LDAP
Restart your Apache Guacamole app for changes to take effect

View File

@ -41,10 +41,9 @@ tokio = { version = "1.13.1", features = ["full"] }
tokio-native-tls = "0.3"
tokio-util = "0.6.3"
tokio-stream = "*"
tracing = "*"
tracing-actix-web = "0.4.0-beta.7"
tracing-attributes = "^0.1.21"
tracing-log = "*"
tracing-subscriber = "0.3"
rand = { version = "0.8", features = ["small_rng", "getrandom"] }
juniper_actix = "0.4.0"
juniper = "0.15.6"
@ -53,6 +52,10 @@ itertools = "0.10.1"
[dependencies.opaque-ke]
version = "0.6"
[dependencies.tracing-subscriber]
version = "0.3"
features = ["env-filter", "tracing-log"]
[dependencies.lettre]
version = "0.10.0-rc.3"
features = [
@ -95,5 +98,12 @@ version = "*"
features = ["vendored"]
version = "*"
[dependencies.tracing-forest]
features = ["smallvec", "chrono", "tokio"]
version = "^0.1.4"
[dependencies.tracing]
version = "*"
[dev-dependencies]
mockall = "0.9.1"

View File

@ -6,6 +6,7 @@ use sea_query::{Alias, Cond, Expr, Iden, Order, Query, SimpleExpr};
use sea_query_binder::SqlxBinder;
use sqlx::{query_as_with, query_with, FromRow, Row};
use std::collections::HashSet;
use tracing::{debug, instrument};
#[derive(Debug, Clone)]
pub struct SqlBackendHandler {
@ -110,11 +111,13 @@ fn get_group_filter_expr(filter: GroupRequestFilter) -> SimpleExpr {
#[async_trait]
impl BackendHandler for SqlBackendHandler {
#[instrument(skip_all, level = "debug", ret, err)]
async fn list_users(
&self,
filters: Option<UserRequestFilter>,
get_groups: bool,
) -> Result<Vec<UserAndGroups>> {
debug!(?filters, get_groups);
let (query, values) = {
let mut query_builder = Query::select()
.column((Users::Table, Users::UserId))
@ -167,7 +170,8 @@ impl BackendHandler for SqlBackendHandler {
query_builder.build_sqlx(DbQueryBuilder {})
};
log::error!("query: {}", &query);
debug!(%query);
// For group_by.
use itertools::Itertools;
@ -199,11 +203,12 @@ impl BackendHandler for SqlBackendHandler {
},
});
}
Ok(users)
}
#[instrument(skip_all, level = "debug", ret, err)]
async fn list_groups(&self, filters: Option<GroupRequestFilter>) -> Result<Vec<Group>> {
debug!(?filters);
let (query, values) = {
let mut query_builder = Query::select()
.column((Groups::Table, Groups::GroupId))
@ -233,6 +238,7 @@ impl BackendHandler for SqlBackendHandler {
query_builder.build_sqlx(DbQueryBuilder {})
};
debug!(%query);
// For group_by.
use itertools::Itertools;
@ -264,7 +270,9 @@ impl BackendHandler for SqlBackendHandler {
Ok(groups)
}
#[instrument(skip_all, level = "debug", ret, err)]
async fn get_user_details(&self, user_id: &UserId) -> Result<User> {
debug!(?user_id);
let (query, values) = Query::select()
.column(Users::UserId)
.column(Users::Email)
@ -276,19 +284,23 @@ impl BackendHandler for SqlBackendHandler {
.from(Users::Table)
.cond_where(Expr::col(Users::UserId).eq(user_id))
.build_sqlx(DbQueryBuilder {});
debug!(%query);
Ok(query_as_with::<_, User, _>(query.as_str(), values)
.fetch_one(&self.sql_pool)
.await?)
}
#[instrument(skip_all, level = "debug", ret, err)]
async fn get_group_details(&self, group_id: GroupId) -> Result<GroupIdAndName> {
debug!(?group_id);
let (query, values) = Query::select()
.column(Groups::GroupId)
.column(Groups::DisplayName)
.from(Groups::Table)
.cond_where(Expr::col(Groups::GroupId).eq(group_id))
.build_sqlx(DbQueryBuilder {});
debug!(%query);
Ok(
query_as_with::<_, GroupIdAndName, _>(query.as_str(), values)
@ -297,12 +309,9 @@ impl BackendHandler for SqlBackendHandler {
)
}
#[instrument(skip_all, level = "debug", ret, err)]
async fn get_user_groups(&self, user_id: &UserId) -> Result<HashSet<GroupIdAndName>> {
if *user_id == self.config.ldap_user_dn {
let mut groups = HashSet::new();
groups.insert(GroupIdAndName(GroupId(1), "lldap_admin".to_string()));
return Ok(groups);
}
debug!(?user_id);
let (query, values) = Query::select()
.column((Groups::Table, Groups::GroupId))
.column(Groups::DisplayName)
@ -314,6 +323,7 @@ impl BackendHandler for SqlBackendHandler {
)
.cond_where(Expr::col(Memberships::UserId).eq(user_id))
.build_sqlx(DbQueryBuilder {});
debug!(%query);
query_with(query.as_str(), values)
// Extract the group id from the row.
@ -335,7 +345,9 @@ impl BackendHandler for SqlBackendHandler {
.map_err(DomainError::DatabaseError)
}
#[instrument(skip_all, level = "debug", err)]
async fn create_user(&self, request: CreateUserRequest) -> Result<()> {
debug!(user_id = ?request.user_id);
let columns = vec![
Users::UserId,
Users::Email,
@ -356,13 +368,16 @@ impl BackendHandler for SqlBackendHandler {
chrono::Utc::now().naive_utc().into(),
])
.build_sqlx(DbQueryBuilder {});
debug!(%query);
query_with(query.as_str(), values)
.execute(&self.sql_pool)
.await?;
Ok(())
}
#[instrument(skip_all, level = "debug", err)]
async fn update_user(&self, request: UpdateUserRequest) -> Result<()> {
debug!(user_id = ?request.user_id);
let mut values = Vec::new();
if let Some(email) = request.email {
values.push((Users::Email, email.into()));
@ -384,13 +399,16 @@ impl BackendHandler for SqlBackendHandler {
.values(values)
.cond_where(Expr::col(Users::UserId).eq(request.user_id))
.build_sqlx(DbQueryBuilder {});
debug!(%query);
query_with(query.as_str(), values)
.execute(&self.sql_pool)
.await?;
Ok(())
}
#[instrument(skip_all, level = "debug", err)]
async fn update_group(&self, request: UpdateGroupRequest) -> Result<()> {
debug!(?request.group_id);
let mut values = Vec::new();
if let Some(display_name) = request.display_name {
values.push((Groups::DisplayName, display_name.into()));
@ -403,29 +421,36 @@ impl BackendHandler for SqlBackendHandler {
.values(values)
.cond_where(Expr::col(Groups::GroupId).eq(request.group_id))
.build_sqlx(DbQueryBuilder {});
debug!(%query);
query_with(query.as_str(), values)
.execute(&self.sql_pool)
.await?;
Ok(())
}
#[instrument(skip_all, level = "debug", err)]
async fn delete_user(&self, user_id: &UserId) -> Result<()> {
let (delete_query, values) = Query::delete()
debug!(?user_id);
let (query, values) = Query::delete()
.from_table(Users::Table)
.cond_where(Expr::col(Users::UserId).eq(user_id))
.build_sqlx(DbQueryBuilder {});
query_with(delete_query.as_str(), values)
debug!(%query);
query_with(query.as_str(), values)
.execute(&self.sql_pool)
.await?;
Ok(())
}
#[instrument(skip_all, level = "debug", ret, err)]
async fn create_group(&self, group_name: &str) -> Result<GroupId> {
debug!(?group_name);
let (query, values) = Query::insert()
.into_table(Groups::Table)
.columns(vec![Groups::DisplayName])
.values_panic(vec![group_name.into()])
.build_sqlx(DbQueryBuilder {});
debug!(%query);
query_with(query.as_str(), values)
.execute(&self.sql_pool)
.await?;
@ -434,36 +459,45 @@ impl BackendHandler for SqlBackendHandler {
.from(Groups::Table)
.cond_where(Expr::col(Groups::DisplayName).eq(group_name))
.build_sqlx(DbQueryBuilder {});
debug!(%query);
let row = query_with(query.as_str(), values)
.fetch_one(&self.sql_pool)
.await?;
Ok(GroupId(row.get::<i32, _>(&*Groups::GroupId.to_string())))
}
#[instrument(skip_all, level = "debug", err)]
async fn delete_group(&self, group_id: GroupId) -> Result<()> {
let (delete_query, values) = Query::delete()
debug!(?group_id);
let (query, values) = Query::delete()
.from_table(Groups::Table)
.cond_where(Expr::col(Groups::GroupId).eq(group_id))
.build_sqlx(DbQueryBuilder {});
query_with(delete_query.as_str(), values)
.execute(&self.sql_pool)
.await?;
Ok(())
}
async fn add_user_to_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()> {
let (query, values) = Query::insert()
.into_table(Memberships::Table)
.columns(vec![Memberships::UserId, Memberships::GroupId])
.values_panic(vec![user_id.into(), group_id.into()])
.build_sqlx(DbQueryBuilder {});
debug!(%query);
query_with(query.as_str(), values)
.execute(&self.sql_pool)
.await?;
Ok(())
}
#[instrument(skip_all, level = "debug", err)]
async fn add_user_to_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()> {
debug!(?user_id, ?group_id);
let (query, values) = Query::insert()
.into_table(Memberships::Table)
.columns(vec![Memberships::UserId, Memberships::GroupId])
.values_panic(vec![user_id.into(), group_id.into()])
.build_sqlx(DbQueryBuilder {});
debug!(%query);
query_with(query.as_str(), values)
.execute(&self.sql_pool)
.await?;
Ok(())
}
#[instrument(skip_all, level = "debug", err)]
async fn remove_user_from_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()> {
debug!(?user_id, ?group_id);
let (query, values) = Query::delete()
.from_table(Memberships::Table)
.cond_where(
@ -472,6 +506,7 @@ impl BackendHandler for SqlBackendHandler {
.add(Expr::col(Memberships::UserId).eq(user_id)),
)
.build_sqlx(DbQueryBuilder {});
debug!(%query);
query_with(query.as_str(), values)
.execute(&self.sql_pool)
.await?;
@ -566,24 +601,6 @@ mod tests {
.collect::<Vec<_>>()
}
#[tokio::test]
async fn test_bind_admin() {
let sql_pool = get_in_memory_db().await;
let config = ConfigurationBuilder::default()
.ldap_user_dn(UserId::new("admin"))
.ldap_user_pass(secstr::SecUtf8::from("test"))
.build()
.unwrap();
let handler = SqlBackendHandler::new(config, sql_pool);
handler
.bind(BindRequest {
name: UserId::new("admin"),
password: "test".to_string(),
})
.await
.unwrap();
}
#[tokio::test]
async fn test_bind_user() {
let sql_pool = get_initialized_db().await;

View File

@ -7,14 +7,15 @@ use super::{
};
use async_trait::async_trait;
use lldap_auth::opaque;
use log::*;
use sea_query::{Expr, Iden, Query};
use sea_query_binder::SqlxBinder;
use secstr::SecUtf8;
use sqlx::Row;
use tracing::{debug, instrument};
type SqlOpaqueHandler = SqlBackendHandler;
#[instrument(skip_all, level = "debug", err)]
fn passwords_match(
password_file_bytes: &[u8],
clear_password: &str,
@ -48,6 +49,7 @@ impl SqlBackendHandler {
)?)
}
#[instrument(skip_all, level = "debug", err)]
async fn get_password_file_for_user(
&self,
username: &str,
@ -86,18 +88,8 @@ impl SqlBackendHandler {
#[async_trait]
impl LoginHandler for SqlBackendHandler {
#[instrument(skip_all, level = "debug", err)]
async fn bind(&self, request: BindRequest) -> Result<()> {
if request.name == self.config.ldap_user_dn {
if SecUtf8::from(request.password) == self.config.ldap_user_pass {
return Ok(());
} else {
debug!(r#"Invalid password for LDAP bind user"#);
return Err(DomainError::AuthenticationError(format!(
" for user '{}'",
request.name
)));
}
}
let (query, values) = Query::select()
.column(Users::PasswordHash)
.from(Users::Table)
@ -135,6 +127,7 @@ impl LoginHandler for SqlBackendHandler {
#[async_trait]
impl OpaqueHandler for SqlOpaqueHandler {
#[instrument(skip_all, level = "debug", err)]
async fn login_start(
&self,
request: login::ClientLoginStartRequest,
@ -163,6 +156,7 @@ impl OpaqueHandler for SqlOpaqueHandler {
})
}
#[instrument(skip_all, level = "debug", err)]
async fn login_finish(&self, request: login::ClientLoginFinishRequest) -> Result<UserId> {
let secret_key = self.get_orion_secret_key()?;
let login::ServerData {
@ -181,6 +175,7 @@ impl OpaqueHandler for SqlOpaqueHandler {
Ok(UserId::new(&username))
}
#[instrument(skip_all, level = "debug", err)]
async fn registration_start(
&self,
request: registration::ClientRegistrationStartRequest,
@ -202,6 +197,7 @@ impl OpaqueHandler for SqlOpaqueHandler {
})
}
#[instrument(skip_all, level = "debug", err)]
async fn registration_finish(
&self,
request: registration::ClientRegistrationFinishRequest,
@ -230,6 +226,7 @@ impl OpaqueHandler for SqlOpaqueHandler {
}
/// Convenience function to set a user's password.
#[instrument(skip_all, level = "debug", err)]
pub(crate) async fn register_password(
opaque_handler: &SqlOpaqueHandler,
username: &UserId,

View File

@ -13,14 +13,14 @@ use actix_web_httpauth::extractors::bearer::BearerAuth;
use anyhow::Result;
use chrono::prelude::*;
use futures::future::{ok, Ready};
use futures_util::{FutureExt, TryFutureExt};
use futures_util::FutureExt;
use hmac::Hmac;
use jwt::{SignWithKey, VerifyWithKey};
use log::*;
use sha2::Sha512;
use time::ext::NumericalDuration;
use tracing::{debug, instrument, warn};
use lldap_auth::{login, opaque, password_reset, registration, JWTClaims};
use lldap_auth::{login, password_reset, registration, JWTClaims};
use crate::{
domain::{
@ -30,7 +30,7 @@ use crate::{
},
infra::{
tcp_backend_handler::*,
tcp_server::{error_to_http_response, AppState},
tcp_server::{error_to_http_response, AppState, TcpError, TcpResult},
},
};
@ -51,9 +51,9 @@ fn create_jwt(key: &Hmac<Sha512>, user: String, groups: HashSet<GroupIdAndName>)
jwt::Token::new(header, claims).sign_with_key(key).unwrap()
}
fn parse_refresh_token(token: &str) -> std::result::Result<(u64, UserId), HttpResponse> {
fn parse_refresh_token(token: &str) -> TcpResult<(u64, UserId)> {
match token.split_once('+') {
None => Err(HttpResponse::Unauthorized().body("Invalid refresh token")),
None => Err(DomainError::AuthenticationError("Invalid refresh token".to_string()).into()),
Some((token, u)) => {
let refresh_token_hash = {
let mut s = DefaultHasher::new();
@ -65,86 +65,92 @@ fn parse_refresh_token(token: &str) -> std::result::Result<(u64, UserId), HttpRe
}
}
fn get_refresh_token(request: HttpRequest) -> std::result::Result<(u64, UserId), HttpResponse> {
fn get_refresh_token(request: HttpRequest) -> TcpResult<(u64, UserId)> {
match (
request.cookie("refresh_token"),
request.headers().get("refresh-token"),
) {
(Some(c), _) => parse_refresh_token(c.value()),
(_, Some(t)) => parse_refresh_token(t.to_str().unwrap()),
(None, None) => Err(HttpResponse::Unauthorized().body("Missing refresh token")),
(None, None) => {
Err(DomainError::AuthenticationError("Missing refresh token".to_string()).into())
}
}
}
#[instrument(skip_all, level = "debug")]
async fn get_refresh<Backend>(
data: web::Data<AppState<Backend>>,
request: HttpRequest,
) -> HttpResponse
) -> TcpResult<HttpResponse>
where
Backend: TcpBackendHandler + BackendHandler + 'static,
{
let backend_handler = &data.backend_handler;
let jwt_key = &data.jwt_key;
let (refresh_token_hash, user) = match get_refresh_token(request) {
Ok(t) => t,
Err(http_response) => return http_response,
};
let res_found = data
let (refresh_token_hash, user) = get_refresh_token(request)?;
let found = data
.backend_handler
.check_token(refresh_token_hash, &user)
.await;
// Async closures are not supported yet.
match res_found {
Ok(found) => {
if found {
backend_handler.get_user_groups(&user).await
} else {
Err(DomainError::AuthenticationError(
"Invalid refresh token".to_string(),
))
}
}
Err(e) => Err(e),
.await?;
if !found {
return Err(TcpError::DomainError(DomainError::AuthenticationError(
"Invalid refresh token".to_string(),
)));
}
.map(|groups| create_jwt(jwt_key, user.to_string(), groups))
.map(|token| {
HttpResponse::Ok()
.cookie(
Cookie::build("token", token.as_str())
.max_age(1.days())
.path("/")
.http_only(true)
.same_site(SameSite::Strict)
.finish(),
)
.json(&login::ServerLoginResponse {
token: token.as_str().to_owned(),
refresh_token: None,
})
})
.unwrap_or_else(error_to_http_response)
Ok(backend_handler
.get_user_groups(&user)
.await
.map(|groups| create_jwt(jwt_key, user.to_string(), groups))
.map(|token| {
HttpResponse::Ok()
.cookie(
Cookie::build("token", token.as_str())
.max_age(1.days())
.path("/")
.http_only(true)
.same_site(SameSite::Strict)
.finish(),
)
.json(&login::ServerLoginResponse {
token: token.as_str().to_owned(),
refresh_token: None,
})
})?)
}
async fn get_password_reset_step1<Backend>(
async fn get_refresh_handler<Backend>(
data: web::Data<AppState<Backend>>,
request: HttpRequest,
) -> HttpResponse
where
Backend: TcpBackendHandler + BackendHandler + 'static,
{
get_refresh(data, request)
.await
.unwrap_or_else(error_to_http_response)
}
#[instrument(skip_all, level = "debug")]
async fn get_password_reset_step1<Backend>(
data: web::Data<AppState<Backend>>,
request: HttpRequest,
) -> TcpResult<()>
where
Backend: TcpBackendHandler + BackendHandler + 'static,
{
let user_id = match request.match_info().get("user_id") {
None => return HttpResponse::BadRequest().body("Missing user ID"),
None => return Err(TcpError::BadRequest("Missing user ID".to_string())),
Some(id) => UserId::new(id),
};
let token = match data.backend_handler.start_password_reset(&user_id).await {
Err(e) => return HttpResponse::InternalServerError().body(e.to_string()),
Ok(None) => return HttpResponse::Ok().finish(),
Ok(Some(token)) => token,
let token = match data.backend_handler.start_password_reset(&user_id).await? {
None => return Ok(()),
Some(token) => token,
};
let user = match data.backend_handler.get_user_details(&user_id).await {
Err(e) => {
warn!("Error getting used details: {:#?}", e);
return HttpResponse::Ok().finish();
return Ok(());
}
Ok(u) => u,
};
@ -156,37 +162,50 @@ where
&data.mail_options,
) {
warn!("Error sending email: {:#?}", e);
return HttpResponse::InternalServerError().body(format!("Could not send email: {}", e));
return Err(TcpError::InternalServerError(format!(
"Could not send email: {}",
e
)));
}
HttpResponse::Ok().finish()
Ok(())
}
async fn get_password_reset_step2<Backend>(
async fn get_password_reset_step1_handler<Backend>(
data: web::Data<AppState<Backend>>,
request: HttpRequest,
) -> HttpResponse
where
Backend: TcpBackendHandler + BackendHandler + 'static,
{
let token = match request.match_info().get("token") {
None => return HttpResponse::BadRequest().body("Missing token"),
Some(token) => token,
};
let user_id = match data
get_password_reset_step1(data, request)
.await
.map(|()| HttpResponse::Ok().finish())
.unwrap_or_else(error_to_http_response)
}
#[instrument(skip_all, level = "debug")]
async fn get_password_reset_step2<Backend>(
data: web::Data<AppState<Backend>>,
request: HttpRequest,
) -> TcpResult<HttpResponse>
where
Backend: TcpBackendHandler + BackendHandler + 'static,
{
let token = request
.match_info()
.get("token")
.ok_or_else(|| TcpError::BadRequest("Missing reset token".to_string()))?;
let user_id = data
.backend_handler
.get_user_id_for_password_reset_token(token)
.await
{
Err(_) => return HttpResponse::Unauthorized().body("Invalid or expired token"),
Ok(user_id) => user_id,
};
.await?;
let _ = data
.backend_handler
.delete_password_reset_token(token)
.await;
let groups = HashSet::new();
let token = create_jwt(&data.jwt_key, user_id.to_string(), groups);
HttpResponse::Ok()
Ok(HttpResponse::Ok()
.cookie(
Cookie::build("token", token.as_str())
.max_age(5.minutes())
@ -199,43 +218,39 @@ where
.json(&password_reset::ServerPasswordResetResponse {
user_id: user_id.to_string(),
token: token.as_str().to_owned(),
})
}))
}
async fn get_logout<Backend>(
async fn get_password_reset_step2_handler<Backend>(
data: web::Data<AppState<Backend>>,
request: HttpRequest,
) -> HttpResponse
where
Backend: TcpBackendHandler + BackendHandler + 'static,
{
let (refresh_token_hash, user) = match get_refresh_token(request) {
Ok(t) => t,
Err(http_response) => return http_response,
};
if let Err(response) = data
.backend_handler
get_password_reset_step2(data, request)
.await
.unwrap_or_else(error_to_http_response)
}
#[instrument(skip_all, level = "debug")]
async fn get_logout<Backend>(
data: web::Data<AppState<Backend>>,
request: HttpRequest,
) -> TcpResult<HttpResponse>
where
Backend: TcpBackendHandler + BackendHandler + 'static,
{
let (refresh_token_hash, user) = get_refresh_token(request)?;
data.backend_handler
.delete_refresh_token(refresh_token_hash)
.map_err(error_to_http_response)
.await
{
return response;
};
match data
.backend_handler
.blacklist_jwts(&user)
.map_err(error_to_http_response)
.await
{
Ok(new_blacklisted_jwts) => {
let mut jwt_blacklist = data.jwt_blacklist.write().unwrap();
for jwt in new_blacklisted_jwts {
jwt_blacklist.insert(jwt);
}
}
Err(response) => return response,
};
HttpResponse::Ok()
.await?;
let new_blacklisted_jwts = data.backend_handler.blacklist_jwts(&user).await?;
let mut jwt_blacklist = data.jwt_blacklist.write().unwrap();
for jwt in new_blacklisted_jwts {
jwt_blacklist.insert(jwt);
}
Ok(HttpResponse::Ok()
.cookie(
Cookie::build("token", "")
.max_age(0.days())
@ -252,15 +267,28 @@ where
.same_site(SameSite::Strict)
.finish(),
)
.finish()
.finish())
}
pub(crate) fn error_to_api_response<T>(error: DomainError) -> ApiResult<T> {
ApiResult::Right(error_to_http_response(error))
async fn get_logout_handler<Backend>(
data: web::Data<AppState<Backend>>,
request: HttpRequest,
) -> HttpResponse
where
Backend: TcpBackendHandler + BackendHandler + 'static,
{
get_logout(data, request)
.await
.unwrap_or_else(error_to_http_response)
}
pub(crate) fn error_to_api_response<T, E: Into<TcpError>>(error: E) -> ApiResult<T> {
ApiResult::Right(error_to_http_response(error.into()))
}
pub type ApiResult<M> = actix_web::Either<web::Json<M>, HttpResponse>;
#[instrument(skip_all, level = "debug")]
async fn opaque_login_start<Backend>(
data: web::Data<AppState<Backend>>,
request: web::Json<login::ClientLoginStartRequest>,
@ -275,196 +303,201 @@ where
.unwrap_or_else(error_to_api_response)
}
#[instrument(skip_all, level = "debug")]
async fn get_login_successful_response<Backend>(
data: &web::Data<AppState<Backend>>,
name: &UserId,
) -> HttpResponse
) -> TcpResult<HttpResponse>
where
Backend: TcpBackendHandler + BackendHandler,
{
// The authentication was successful, we need to fetch the groups to create the JWT
// token.
data.backend_handler
.get_user_groups(name)
.and_then(|g| async { Ok((g, data.backend_handler.create_refresh_token(name).await?)) })
.await
.map(|(groups, (refresh_token, max_age))| {
let token = create_jwt(&data.jwt_key, name.to_string(), groups);
let refresh_token_plus_name = refresh_token + "+" + name.as_str();
let groups = data.backend_handler.get_user_groups(name).await?;
let (refresh_token, max_age) = data.backend_handler.create_refresh_token(name).await?;
let token = create_jwt(&data.jwt_key, name.to_string(), groups);
let refresh_token_plus_name = refresh_token + "+" + name.as_str();
HttpResponse::Ok()
.cookie(
Cookie::build("token", token.as_str())
.max_age(1.days())
.path("/")
.http_only(true)
.same_site(SameSite::Strict)
.finish(),
)
.cookie(
Cookie::build("refresh_token", refresh_token_plus_name.clone())
.max_age(max_age.num_days().days())
.path("/auth")
.http_only(true)
.same_site(SameSite::Strict)
.finish(),
)
.json(&login::ServerLoginResponse {
token: token.as_str().to_owned(),
refresh_token: Some(refresh_token_plus_name),
})
})
.unwrap_or_else(error_to_http_response)
Ok(HttpResponse::Ok()
.cookie(
Cookie::build("token", token.as_str())
.max_age(1.days())
.path("/")
.http_only(true)
.same_site(SameSite::Strict)
.finish(),
)
.cookie(
Cookie::build("refresh_token", refresh_token_plus_name.clone())
.max_age(max_age.num_days().days())
.path("/auth")
.http_only(true)
.same_site(SameSite::Strict)
.finish(),
)
.json(&login::ServerLoginResponse {
token: token.as_str().to_owned(),
refresh_token: Some(refresh_token_plus_name),
}))
}
#[instrument(skip_all, level = "debug")]
async fn opaque_login_finish<Backend>(
data: web::Data<AppState<Backend>>,
request: web::Json<login::ClientLoginFinishRequest>,
) -> TcpResult<HttpResponse>
where
Backend: TcpBackendHandler + BackendHandler + OpaqueHandler + 'static,
{
let name = data
.backend_handler
.login_finish(request.into_inner())
.await?;
get_login_successful_response(&data, &name).await
}
async fn opaque_login_finish_handler<Backend>(
data: web::Data<AppState<Backend>>,
request: web::Json<login::ClientLoginFinishRequest>,
) -> HttpResponse
where
Backend: TcpBackendHandler + BackendHandler + OpaqueHandler + 'static,
{
let name = match data
.backend_handler
.login_finish(request.into_inner())
opaque_login_finish(data, request)
.await
{
Ok(n) => n,
Err(e) => return error_to_http_response(e),
};
get_login_successful_response(&data, &name).await
.unwrap_or_else(error_to_http_response)
}
#[instrument(skip_all, level = "debug")]
async fn simple_login<Backend>(
data: web::Data<AppState<Backend>>,
request: web::Json<login::ClientSimpleLoginRequest>,
) -> TcpResult<HttpResponse>
where
Backend: TcpBackendHandler + BackendHandler + OpaqueHandler + LoginHandler + 'static,
{
let user_id = UserId::new(&request.username);
let bind_request = BindRequest {
name: user_id.clone(),
password: request.password.clone(),
};
data.backend_handler.bind(bind_request).await?;
get_login_successful_response(&data, &user_id).await
}
async fn simple_login_handler<Backend>(
data: web::Data<AppState<Backend>>,
request: web::Json<login::ClientSimpleLoginRequest>,
) -> HttpResponse
where
Backend: TcpBackendHandler + BackendHandler + OpaqueHandler + 'static,
Backend: TcpBackendHandler + BackendHandler + OpaqueHandler + LoginHandler + 'static,
{
let password = &request.password;
let mut rng = rand::rngs::OsRng;
let opaque::client::login::ClientLoginStartResult { state, message } =
match opaque::client::login::start_login(password, &mut rng) {
Ok(n) => n,
Err(e) => {
return HttpResponse::InternalServerError()
.body(format!("Internal Server Error: {:#?}", e))
}
};
let username = request.username.clone();
let start_request = login::ClientLoginStartRequest {
username: username.clone(),
login_start_request: message,
};
let start_response = match data.backend_handler.login_start(start_request).await {
Ok(n) => n,
Err(e) => return error_to_http_response(e),
};
let login_finish =
match opaque::client::login::finish_login(state, start_response.credential_response) {
Err(_) => {
return error_to_http_response(DomainError::AuthenticationError(String::from(
"Invalid username or password",
)))
}
Ok(l) => l,
};
let finish_request = login::ClientLoginFinishRequest {
server_data: start_response.server_data,
credential_finalization: login_finish.message,
};
let name = match data.backend_handler.login_finish(finish_request).await {
Ok(n) => n,
Err(e) => return error_to_http_response(e),
};
simple_login(data, request)
.await
.unwrap_or_else(error_to_http_response)
}
#[instrument(skip_all, level = "debug")]
async fn post_authorize<Backend>(
data: web::Data<AppState<Backend>>,
request: web::Json<BindRequest>,
) -> TcpResult<HttpResponse>
where
Backend: TcpBackendHandler + BackendHandler + LoginHandler + 'static,
{
let name = request.name.clone();
debug!(%name);
data.backend_handler.bind(request.into_inner()).await?;
get_login_successful_response(&data, &name).await
}
async fn post_authorize<Backend>(
async fn post_authorize_handler<Backend>(
data: web::Data<AppState<Backend>>,
request: web::Json<BindRequest>,
) -> HttpResponse
where
Backend: TcpBackendHandler + BackendHandler + LoginHandler + 'static,
{
let name = request.name.clone();
if let Err(e) = data.backend_handler.bind(request.into_inner()).await {
return error_to_http_response(e);
}
get_login_successful_response(&data, &name).await
post_authorize(data, request)
.await
.unwrap_or_else(error_to_http_response)
}
#[instrument(skip_all, level = "debug")]
async fn opaque_register_start<Backend>(
request: actix_web::HttpRequest,
mut payload: actix_web::web::Payload,
data: web::Data<AppState<Backend>>,
) -> TcpResult<registration::ServerRegistrationStartResponse>
where
Backend: OpaqueHandler + 'static,
{
use actix_web::FromRequest;
let validation_result = BearerAuth::from_request(&request, &mut payload.0)
.await
.ok()
.and_then(|bearer| check_if_token_is_valid(&data, bearer.token()).ok())
.ok_or_else(|| {
TcpError::UnauthorizedError("Not authorized to change the user's password".to_string())
})?;
let registration_start_request =
web::Json::<registration::ClientRegistrationStartRequest>::from_request(
&request,
&mut payload.0,
)
.await
.map_err(|e| TcpError::BadRequest(format!("{:#?}", e)))?
.into_inner();
let user_id = &registration_start_request.username;
if !validation_result.can_write(user_id) {
return Err(TcpError::UnauthorizedError(
"Not authorized to change the user's password".to_string(),
));
}
Ok(data
.backend_handler
.registration_start(registration_start_request)
.await?)
}
async fn opaque_register_start_handler<Backend>(
request: actix_web::HttpRequest,
payload: actix_web::web::Payload,
data: web::Data<AppState<Backend>>,
) -> ApiResult<registration::ServerRegistrationStartResponse>
where
Backend: OpaqueHandler + 'static,
{
use actix_web::FromRequest;
let validation_result = match BearerAuth::from_request(&request, &mut payload.0)
.await
.ok()
.and_then(|bearer| check_if_token_is_valid(&data, bearer.token()).ok())
{
Some(t) => t,
None => {
return ApiResult::Right(
HttpResponse::Unauthorized().body("Not authorized to change the user's password"),
)
}
};
let registration_start_request =
match web::Json::<registration::ClientRegistrationStartRequest>::from_request(
&request,
&mut payload.0,
)
.await
{
Ok(r) => r,
Err(e) => {
return ApiResult::Right(
HttpResponse::BadRequest().body(format!("Bad request: {:#?}", e)),
)
}
}
.into_inner();
let user_id = &registration_start_request.username;
if !validation_result.can_write(user_id) {
return ApiResult::Right(
HttpResponse::Unauthorized().body("Not authorized to change the user's password"),
);
}
data.backend_handler
.registration_start(registration_start_request)
opaque_register_start(request, payload, data)
.await
.map(|res| ApiResult::Left(web::Json(res)))
.unwrap_or_else(error_to_api_response)
}
#[instrument(skip_all, level = "debug")]
async fn opaque_register_finish<Backend>(
data: web::Data<AppState<Backend>>,
request: web::Json<registration::ClientRegistrationFinishRequest>,
) -> TcpResult<HttpResponse>
where
Backend: TcpBackendHandler + BackendHandler + OpaqueHandler + 'static,
{
data.backend_handler
.registration_finish(request.into_inner())
.await?;
Ok(HttpResponse::Ok().finish())
}
async fn opaque_register_finish_handler<Backend>(
data: web::Data<AppState<Backend>>,
request: web::Json<registration::ClientRegistrationFinishRequest>,
) -> HttpResponse
where
Backend: TcpBackendHandler + BackendHandler + OpaqueHandler + 'static,
{
if let Err(e) = data
.backend_handler
.registration_finish(request.into_inner())
opaque_register_finish(data, request)
.await
{
return error_to_http_response(e);
}
HttpResponse::Ok().finish()
.unwrap_or_else(error_to_http_response)
}
pub struct CookieToHeaderTranslatorFactory;
@ -530,6 +563,7 @@ pub enum Permission {
Regular,
}
#[derive(Debug)]
pub struct ValidationResults {
pub user: String,
pub permission: Permission,
@ -567,6 +601,7 @@ impl ValidationResults {
}
}
#[instrument(skip_all, level = "debug", err, ret)]
pub(crate) fn check_if_token_is_valid<Backend>(
state: &AppState<Backend>,
token_str: &str,
@ -607,35 +642,38 @@ pub fn configure_server<Backend>(cfg: &mut web::ServiceConfig)
where
Backend: TcpBackendHandler + LoginHandler + OpaqueHandler + BackendHandler + 'static,
{
cfg.service(web::resource("").route(web::post().to(post_authorize::<Backend>)))
cfg.service(web::resource("").route(web::post().to(post_authorize_handler::<Backend>)))
.service(
web::resource("/opaque/login/start")
.route(web::post().to(opaque_login_start::<Backend>)),
)
.service(
web::resource("/opaque/login/finish")
.route(web::post().to(opaque_login_finish::<Backend>)),
.route(web::post().to(opaque_login_finish_handler::<Backend>)),
)
.service(web::resource("/simple/login").route(web::post().to(simple_login::<Backend>)))
.service(web::resource("/refresh").route(web::get().to(get_refresh::<Backend>)))
.service(
web::resource("/simple/login").route(web::post().to(simple_login_handler::<Backend>)),
)
.service(web::resource("/refresh").route(web::get().to(get_refresh_handler::<Backend>)))
.service(
web::resource("/reset/step1/{user_id}")
.route(web::get().to(get_password_reset_step1::<Backend>)),
.route(web::get().to(get_password_reset_step1_handler::<Backend>)),
)
.service(
web::resource("/reset/step2/{token}")
.route(web::get().to(get_password_reset_step2::<Backend>)),
.route(web::get().to(get_password_reset_step2_handler::<Backend>)),
)
.service(web::resource("/logout").route(web::get().to(get_logout::<Backend>)))
.service(web::resource("/logout").route(web::get().to(get_logout_handler::<Backend>)))
.service(
web::scope("/opaque/register")
.wrap(CookieToHeaderTranslatorFactory)
.service(
web::resource("/start").route(web::post().to(opaque_register_start::<Backend>)),
web::resource("/start")
.route(web::post().to(opaque_register_start_handler::<Backend>)),
)
.service(
web::resource("/finish")
.route(web::post().to(opaque_register_finish::<Backend>)),
.route(web::post().to(opaque_register_finish_handler::<Backend>)),
),
);
}

View File

@ -7,6 +7,7 @@ use chrono::Local;
use cron::Schedule;
use sea_query::{Expr, Query};
use std::{str::FromStr, time::Duration};
use tracing::{debug, error, info, instrument};
// Define actor
pub struct Scheduler {
@ -19,7 +20,7 @@ impl Actor for Scheduler {
type Context = Context<Self>;
fn started(&mut self, context: &mut Context<Self>) {
log::info!("DB Cleanup Cron started");
info!("DB Cleanup Cron started");
context.run_later(self.duration_until_next(), move |this, ctx| {
this.schedule_task(ctx)
@ -27,7 +28,7 @@ impl Actor for Scheduler {
}
fn stopped(&mut self, _ctx: &mut Context<Self>) {
log::info!("DB Cleanup stopped");
info!("DB Cleanup stopped");
}
}
@ -38,7 +39,6 @@ impl Scheduler {
}
fn schedule_task(&self, ctx: &mut Context<Self>) {
log::info!("Cleaning DB");
let future = actix::fut::wrap_future::<_, Self>(Self::cleanup_db(self.sql_pool.clone()));
ctx.spawn(future);
@ -47,17 +47,16 @@ impl Scheduler {
});
}
#[instrument(skip_all)]
async fn cleanup_db(sql_pool: Pool) {
if let Err(e) = sqlx::query(
&Query::delete()
.from_table(JwtRefreshStorage::Table)
.and_where(Expr::col(JwtRefreshStorage::ExpiryDate).lt(Local::now().naive_utc()))
.to_string(DbQueryBuilder {}),
)
.execute(&sql_pool)
.await
{
log::error!("DB error while cleaning up JWT refresh tokens: {}", e);
info!("Cleaning DB");
let query = Query::delete()
.from_table(JwtRefreshStorage::Table)
.and_where(Expr::col(JwtRefreshStorage::ExpiryDate).lt(Local::now().naive_utc()))
.to_string(DbQueryBuilder {});
debug!(%query);
if let Err(e) = sqlx::query(&query).execute(&sql_pool).await {
error!("DB error while cleaning up JWT refresh tokens: {}", e);
};
if let Err(e) = sqlx::query(
&Query::delete()
@ -68,9 +67,9 @@ impl Scheduler {
.execute(&sql_pool)
.await
{
log::error!("DB error while cleaning up JWT storage: {}", e);
error!("DB error while cleaning up JWT storage: {}", e);
};
log::info!("DB cleaned!");
info!("DB cleaned!");
}
fn duration_until_next(&self) -> Duration {

View File

@ -2,6 +2,7 @@ use crate::domain::handler::{
BackendHandler, CreateUserRequest, GroupId, UpdateGroupRequest, UpdateUserRequest, UserId,
};
use juniper::{graphql_object, FieldResult, GraphQLInputObject, GraphQLObject};
use tracing::{debug, debug_span, Instrument};
use super::api::Context;
@ -63,7 +64,12 @@ impl<Handler: BackendHandler + Sync> Mutation<Handler> {
context: &Context<Handler>,
user: CreateUserInput,
) -> FieldResult<super::query::User<Handler>> {
let span = debug_span!("[GraphQL mutation] create_user");
span.in_scope(|| {
debug!(?user.id);
});
if !context.validation_result.is_admin() {
span.in_scope(|| debug!("Unauthorized"));
return Err("Unauthorized user creation".into());
}
let user_id = UserId::new(&user.id);
@ -76,10 +82,12 @@ impl<Handler: BackendHandler + Sync> Mutation<Handler> {
first_name: user.first_name,
last_name: user.last_name,
})
.instrument(span.clone())
.await?;
Ok(context
.handler
.get_user_details(&user_id)
.instrument(span)
.await
.map(Into::into)?)
}
@ -88,13 +96,19 @@ impl<Handler: BackendHandler + Sync> Mutation<Handler> {
context: &Context<Handler>,
name: String,
) -> FieldResult<super::query::Group<Handler>> {
let span = debug_span!("[GraphQL mutation] create_group");
span.in_scope(|| {
debug!(?name);
});
if !context.validation_result.is_admin() {
span.in_scope(|| debug!("Unauthorized"));
return Err("Unauthorized group creation".into());
}
let group_id = context.handler.create_group(&name).await?;
Ok(context
.handler
.get_group_details(group_id)
.instrument(span)
.await
.map(Into::into)?)
}
@ -103,7 +117,12 @@ impl<Handler: BackendHandler + Sync> Mutation<Handler> {
context: &Context<Handler>,
user: UpdateUserInput,
) -> FieldResult<Success> {
let span = debug_span!("[GraphQL mutation] update_user");
span.in_scope(|| {
debug!(?user.id);
});
if !context.validation_result.can_write(&user.id) {
span.in_scope(|| debug!("Unauthorized"));
return Err("Unauthorized user update".into());
}
context
@ -115,6 +134,7 @@ impl<Handler: BackendHandler + Sync> Mutation<Handler> {
first_name: user.first_name,
last_name: user.last_name,
})
.instrument(span)
.await?;
Ok(Success::new())
}
@ -123,10 +143,16 @@ impl<Handler: BackendHandler + Sync> Mutation<Handler> {
context: &Context<Handler>,
group: UpdateGroupInput,
) -> FieldResult<Success> {
let span = debug_span!("[GraphQL mutation] update_group");
span.in_scope(|| {
debug!(?group.id);
});
if !context.validation_result.is_admin() {
span.in_scope(|| debug!("Unauthorized"));
return Err("Unauthorized group update".into());
}
if group.id == 1 {
span.in_scope(|| debug!("Cannot change admin group details"));
return Err("Cannot change admin group details".into());
}
context
@ -135,6 +161,7 @@ impl<Handler: BackendHandler + Sync> Mutation<Handler> {
group_id: GroupId(group.id),
display_name: group.display_name,
})
.instrument(span)
.await?;
Ok(Success::new())
}
@ -144,12 +171,18 @@ impl<Handler: BackendHandler + Sync> Mutation<Handler> {
user_id: String,
group_id: i32,
) -> FieldResult<Success> {
let span = debug_span!("[GraphQL mutation] add_user_to_group");
span.in_scope(|| {
debug!(?user_id, ?group_id);
});
if !context.validation_result.is_admin() {
span.in_scope(|| debug!("Unauthorized"));
return Err("Unauthorized group membership modification".into());
}
context
.handler
.add_user_to_group(&UserId::new(&user_id), GroupId(group_id))
.instrument(span)
.await?;
Ok(Success::new())
}
@ -159,38 +192,65 @@ impl<Handler: BackendHandler + Sync> Mutation<Handler> {
user_id: String,
group_id: i32,
) -> FieldResult<Success> {
let span = debug_span!("[GraphQL mutation] remove_user_from_group");
span.in_scope(|| {
debug!(?user_id, ?group_id);
});
if !context.validation_result.is_admin() {
span.in_scope(|| debug!("Unauthorized"));
return Err("Unauthorized group membership modification".into());
}
if context.validation_result.user == user_id && group_id == 1 {
span.in_scope(|| debug!("Cannot remove admin rights for current user"));
return Err("Cannot remove admin rights for current user".into());
}
context
.handler
.remove_user_from_group(&UserId::new(&user_id), GroupId(group_id))
.instrument(span)
.await?;
Ok(Success::new())
}
async fn delete_user(context: &Context<Handler>, user_id: String) -> FieldResult<Success> {
let span = debug_span!("[GraphQL mutation] delete_user");
span.in_scope(|| {
debug!(?user_id);
});
if !context.validation_result.is_admin() {
span.in_scope(|| debug!("Unauthorized"));
return Err("Unauthorized user deletion".into());
}
if context.validation_result.user == user_id {
span.in_scope(|| debug!("Cannot delete current user"));
return Err("Cannot delete current user".into());
}
context.handler.delete_user(&UserId::new(&user_id)).await?;
context
.handler
.delete_user(&UserId::new(&user_id))
.instrument(span)
.await?;
Ok(Success::new())
}
async fn delete_group(context: &Context<Handler>, group_id: i32) -> FieldResult<Success> {
let span = debug_span!("[GraphQL mutation] delete_group");
span.in_scope(|| {
debug!(?group_id);
});
if !context.validation_result.is_admin() {
span.in_scope(|| debug!("Unauthorized"));
return Err("Unauthorized group deletion".into());
}
if group_id == 1 {
span.in_scope(|| debug!("Cannot delete admin group"));
return Err("Cannot delete admin group".into());
}
context.handler.delete_group(GroupId(group_id)).await?;
context
.handler
.delete_group(GroupId(group_id))
.instrument(span)
.await?;
Ok(Success::new())
}
}

View File

@ -1,6 +1,7 @@
use crate::domain::handler::{BackendHandler, GroupId, GroupIdAndName, UserId};
use juniper::{graphql_object, FieldResult, GraphQLInputObject};
use serde::{Deserialize, Serialize};
use tracing::{debug, debug_span, Instrument};
type DomainRequestFilter = crate::domain::handler::UserRequestFilter;
type DomainUser = crate::domain::handler::User;
@ -108,12 +109,18 @@ impl<Handler: BackendHandler + Sync> Query<Handler> {
}
pub async fn user(context: &Context<Handler>, user_id: String) -> FieldResult<User<Handler>> {
let span = debug_span!("[GraphQL query] user");
span.in_scope(|| {
debug!(?user_id);
});
if !context.validation_result.can_read(&user_id) {
span.in_scope(|| debug!("Unauthorized"));
return Err("Unauthorized access to user data".into());
}
Ok(context
.handler
.get_user_details(&UserId::new(&user_id))
.instrument(span)
.await
.map(Into::into)?)
}
@ -122,34 +129,49 @@ impl<Handler: BackendHandler + Sync> Query<Handler> {
context: &Context<Handler>,
#[graphql(name = "where")] filters: Option<RequestFilter>,
) -> FieldResult<Vec<User<Handler>>> {
let span = debug_span!("[GraphQL query] users");
span.in_scope(|| {
debug!(?filters);
});
if !context.validation_result.is_admin_or_readonly() {
span.in_scope(|| debug!("Unauthorized"));
return Err("Unauthorized access to user list".into());
}
Ok(context
.handler
.list_users(filters.map(TryInto::try_into).transpose()?, false)
.instrument(span)
.await
.map(|v| v.into_iter().map(Into::into).collect())?)
}
async fn groups(context: &Context<Handler>) -> FieldResult<Vec<Group<Handler>>> {
let span = debug_span!("[GraphQL query] groups");
if !context.validation_result.is_admin_or_readonly() {
span.in_scope(|| debug!("Unauthorized"));
return Err("Unauthorized access to group list".into());
}
Ok(context
.handler
.list_groups(None)
.instrument(span)
.await
.map(|v| v.into_iter().map(Into::into).collect())?)
}
async fn group(context: &Context<Handler>, group_id: i32) -> FieldResult<Group<Handler>> {
let span = debug_span!("[GraphQL query] group");
span.in_scope(|| {
debug!(?group_id);
});
if !context.validation_result.is_admin_or_readonly() {
span.in_scope(|| debug!("Unauthorized"));
return Err("Unauthorized access to group data".into());
}
Ok(context
.handler
.get_group_details(GroupId(group_id))
.instrument(span)
.await
.map(Into::into)?)
}
@ -199,9 +221,14 @@ impl<Handler: BackendHandler + Sync> User<Handler> {
/// The groups to which this user belongs.
async fn groups(&self, context: &Context<Handler>) -> FieldResult<Vec<Group<Handler>>> {
let span = debug_span!("[GraphQL query] user::groups");
span.in_scope(|| {
debug!(user_id = ?self.user.user_id);
});
Ok(context
.handler
.get_user_groups(&self.user.user_id)
.instrument(span)
.await
.map(|set| set.into_iter().map(Into::into).collect())?)
}
@ -244,7 +271,12 @@ impl<Handler: BackendHandler + Sync> Group<Handler> {
}
/// The groups to which this user belongs.
async fn users(&self, context: &Context<Handler>) -> FieldResult<Vec<User<Handler>>> {
let span = debug_span!("[GraphQL query] group::users");
span.in_scope(|| {
debug!(name = %self.display_name);
});
if !context.validation_result.is_admin_or_readonly() {
span.in_scope(|| debug!("Unauthorized"));
return Err("Unauthorized access to group data".into());
}
Ok(context
@ -253,6 +285,7 @@ impl<Handler: BackendHandler + Sync> Group<Handler> {
Some(DomainRequestFilter::MemberOfId(GroupId(self.group_id))),
false,
)
.instrument(span)
.await
.map(|v| v.into_iter().map(Into::into).collect())?)
}

View File

@ -15,7 +15,7 @@ use ldap3_server::proto::{
LdapFilter, LdapOp, LdapPartialAttribute, LdapPasswordModifyRequest, LdapResult,
LdapResultCode, LdapSearchRequest, LdapSearchResultEntry, LdapSearchScope,
};
use log::{debug, warn};
use tracing::{debug, instrument, warn};
#[derive(Debug, PartialEq, Eq, Clone)]
struct LdapDn(String);
@ -198,22 +198,26 @@ fn get_user_attribute(
}))
}
fn expand_attribute_wildcards(attributes: &[String], all_attribute_keys: &[&str]) -> Vec<String> {
let mut attributes_out = attributes.to_owned();
#[instrument(skip_all, level = "debug")]
fn expand_attribute_wildcards(
ldap_attributes: &[String],
all_attribute_keys: &[&str],
) -> Vec<String> {
let mut attributes_out = ldap_attributes.to_owned();
if attributes_out.iter().any(|x| x == "*") || attributes_out.is_empty() {
debug!(r#"Expanding * / empty attrs:"#);
// Remove occurrences of '*'
attributes_out.retain(|x| x != "*");
// Splice in all non-operational attributes
attributes_out.extend(all_attribute_keys.iter().map(|s| s.to_string()));
}
debug!(r#"Expanded: "{:?}""#, &attributes_out);
// Deduplicate, preserving order
attributes_out.into_iter().unique().collect_vec()
let resolved_attributes = attributes_out.into_iter().unique().collect_vec();
debug!(?ldap_attributes, ?resolved_attributes);
resolved_attributes
}
const ALL_USER_ATTRIBUTE_KEYS: &[&str] = &[
"objectclass",
"dn",
@ -470,8 +474,9 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
}
}
#[instrument(skip_all, level = "debug")]
pub async fn do_bind(&mut self, request: &LdapBindRequest) -> (LdapResultCode, String) {
debug!(r#"Received bind request for "{}""#, &request.dn);
debug!("DN: {}", &request.dn);
let user_id = match get_user_id_from_distinguished_name(
&request.dn.to_ascii_lowercase(),
&self.base_dn,
@ -507,6 +512,7 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
Permission::Regular
},
));
debug!("Success!");
(LdapResultCode::Success, "".to_string())
}
Err(_) => (LdapResultCode::InvalidCredentials, "".to_string()),
@ -597,8 +603,8 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
}
}
pub async fn do_search(&mut self, request: &LdapSearchRequest) -> Vec<LdapOp> {
let user_filter = match &self.user_info {
pub async fn do_search_or_dse(&mut self, request: &LdapSearchRequest) -> Vec<LdapOp> {
let user_filter = match self.user_info.clone() {
Some((_, Permission::Admin)) | Some((_, Permission::Readonly)) => None,
Some((user_id, Permission::Regular)) => Some(user_id),
None => {
@ -612,10 +618,19 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
&& request.scope == LdapSearchScope::Base
&& request.filter == LdapFilter::Present("objectClass".to_string())
{
debug!("Received rootDSE request");
debug!("rootDSE request");
return vec![root_dse_response(&self.base_dn_str), make_search_success()];
}
debug!("Received search request: {:?}", &request);
self.do_search(request, user_filter).await
}
#[instrument(skip_all, level = "debug")]
pub async fn do_search(
&mut self,
request: &LdapSearchRequest,
user_filter: Option<UserId>,
) -> Vec<LdapOp> {
let user_filter = user_filter.as_ref();
let dn_parts = match parse_distinguished_name(&request.base.to_ascii_lowercase()) {
Ok(dn) => dn,
Err(_) => {
@ -626,6 +641,7 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
}
};
let scope = get_search_scope(&self.base_dn, &dn_parts);
debug!(?request.base, ?scope);
let get_user_list = || async {
self.get_user_list(&request.filter, &request.attrs, &request.base, &user_filter)
.await
@ -676,14 +692,16 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
results
}
#[instrument(skip_all, level = "debug")]
async fn get_user_list(
&self,
filter: &LdapFilter,
ldap_filter: &LdapFilter,
attributes: &[String],
base: &str,
user_filter: &Option<&UserId>,
) -> Vec<LdapOp> {
let filters = match self.convert_user_filter(filter) {
debug!(?ldap_filter);
let filters = match self.convert_user_filter(ldap_filter) {
Ok(f) => f,
Err(e) => {
return vec![make_search_error(
@ -692,19 +710,20 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
)]
}
};
let filters = match user_filter {
let parsed_filters = match user_filter {
None => filters,
Some(u) => {
UserRequestFilter::And(vec![filters, UserRequestFilter::UserId((*u).clone())])
}
};
debug!(?parsed_filters);
let expanded_attributes = expand_attribute_wildcards(attributes, ALL_USER_ATTRIBUTE_KEYS);
let need_groups = expanded_attributes
.iter()
.any(|s| s.to_ascii_lowercase() == "memberof");
let users = match self
.backend_handler
.list_users(Some(filters), need_groups)
.list_users(Some(parsed_filters), need_groups)
.await
{
Ok(users) => users,
@ -737,14 +756,16 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
})
}
#[instrument(skip_all, level = "debug")]
async fn get_groups_list(
&self,
filter: &LdapFilter,
ldap_filter: &LdapFilter,
attributes: &[String],
base: &str,
user_filter: &Option<&UserId>,
) -> Vec<LdapOp> {
let filter = match self.convert_group_filter(filter) {
debug!(?ldap_filter);
let filter = match self.convert_group_filter(ldap_filter) {
Ok(f) => f,
Err(e) => {
return vec![make_search_error(
@ -753,14 +774,14 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
)]
}
};
let filter = match user_filter {
let parsed_filters = match user_filter {
None => filter,
Some(u) => {
GroupRequestFilter::And(vec![filter, GroupRequestFilter::Member((*u).clone())])
}
};
let groups = match self.backend_handler.list_groups(Some(filter)).await {
debug!(?parsed_filters);
let groups = match self.backend_handler.list_groups(Some(parsed_filters)).await {
Ok(groups) => groups,
Err(e) => {
return vec![make_search_error(
@ -805,7 +826,7 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
saslcreds: None,
})]
}
LdapOp::SearchRequest(request) => self.do_search(&request).await,
LdapOp::SearchRequest(request) => self.do_search_or_dse(&request).await,
LdapOp::UnbindRequest => {
self.user_info = None;
// No need to notify on unbind (per rfc4511)
@ -873,7 +894,7 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
.collect::<Result<_>>()?,
)),
LdapFilter::Not(filter) => Ok(GroupRequestFilter::Not(Box::new(
self.convert_group_filter(&*filter)?,
self.convert_group_filter(filter)?,
))),
LdapFilter::Present(field) => {
if ALL_GROUP_ATTRIBUTE_KEYS.contains(&field.to_ascii_lowercase().as_str()) {
@ -903,7 +924,7 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
.collect::<Result<_>>()?,
)),
LdapFilter::Not(filter) => Ok(UserRequestFilter::Not(Box::new(
self.convert_user_filter(&*filter)?,
self.convert_user_filter(filter)?,
))),
LdapFilter::Equality(field, value) => {
let field = &field.to_ascii_lowercase();
@ -1176,7 +1197,7 @@ mod tests {
let request =
make_user_search_request::<String>(LdapFilter::And(vec![]), vec!["1.1".to_string()]);
assert_eq!(
ldap_handler.do_search(&request).await,
ldap_handler.do_search_or_dse(&request).await,
vec![
LdapOp::SearchResultEntry(LdapSearchResultEntry {
dn: "uid=test,ou=people,dc=example,dc=com".to_string(),
@ -1199,7 +1220,7 @@ mod tests {
let request =
make_user_search_request::<String>(LdapFilter::And(vec![]), vec!["1.1".to_string()]);
assert_eq!(
ldap_handler.do_search(&request).await,
ldap_handler.do_search_or_dse(&request).await,
vec![make_search_success()],
);
}
@ -1226,7 +1247,7 @@ mod tests {
vec!["memberOf".to_string()],
);
assert_eq!(
ldap_handler.do_search(&request).await,
ldap_handler.do_search_or_dse(&request).await,
vec![
LdapOp::SearchResultEntry(LdapSearchResultEntry {
dn: "uid=bob,ou=people,dc=example,dc=com".to_string(),
@ -1266,7 +1287,7 @@ mod tests {
attrs: vec!["1.1".to_string()],
};
assert_eq!(
ldap_handler.do_search(&request).await,
ldap_handler.do_search_or_dse(&request).await,
vec![make_search_success()],
);
}
@ -1397,7 +1418,7 @@ mod tests {
],
);
assert_eq!(
ldap_handler.do_search(&request).await,
ldap_handler.do_search_or_dse(&request).await,
vec![
LdapOp::SearchResultEntry(LdapSearchResultEntry {
dn: "uid=bob_1,ou=people,dc=example,dc=com".to_string(),
@ -1515,7 +1536,7 @@ mod tests {
vec!["objectClass", "dn", "cn", "uniqueMember"],
);
assert_eq!(
ldap_handler.do_search(&request).await,
ldap_handler.do_search_or_dse(&request).await,
vec![
LdapOp::SearchResultEntry(LdapSearchResultEntry {
dn: "cn=group_1,ou=groups,dc=example,dc=com".to_string(),
@ -1612,7 +1633,7 @@ mod tests {
vec!["1.1"],
);
assert_eq!(
ldap_handler.do_search(&request).await,
ldap_handler.do_search_or_dse(&request).await,
vec![
LdapOp::SearchResultEntry(LdapSearchResultEntry {
dn: "cn=group_1,ou=groups,dc=example,dc=com".to_string(),
@ -1650,7 +1671,7 @@ mod tests {
vec!["cn"],
);
assert_eq!(
ldap_handler.do_search(&request).await,
ldap_handler.do_search_or_dse(&request).await,
vec![
LdapOp::SearchResultEntry(LdapSearchResultEntry {
dn: "cn=group_1,ou=groups,dc=example,dc=com".to_string(),
@ -1687,7 +1708,7 @@ mod tests {
attrs: vec!["1.1".to_string()],
};
assert_eq!(
ldap_handler.do_search(&request).await,
ldap_handler.do_search_or_dse(&request).await,
vec![make_search_success()],
);
}
@ -1717,7 +1738,7 @@ mod tests {
vec!["cn"],
);
assert_eq!(
ldap_handler.do_search(&request).await,
ldap_handler.do_search_or_dse(&request).await,
vec![make_search_error(
LdapResultCode::Other,
r#"Error while listing groups "ou=groups,dc=example,dc=com": Internal error: `Error getting groups`"#.to_string()
@ -1737,7 +1758,7 @@ mod tests {
vec!["cn"],
);
assert_eq!(
ldap_handler.do_search(&request).await,
ldap_handler.do_search_or_dse(&request).await,
vec![make_search_error(
LdapResultCode::UnwillingToPerform,
r#"Unsupported group filter: Unsupported group filter: Substring("whatever", LdapSubstringFilter { initial: None, any: [], final_: None })"#
@ -1785,7 +1806,7 @@ mod tests {
vec!["objectClass"],
);
assert_eq!(
ldap_handler.do_search(&request).await,
ldap_handler.do_search_or_dse(&request).await,
vec![make_search_success()]
);
}
@ -1809,7 +1830,7 @@ mod tests {
vec!["objectClass"],
);
assert_eq!(
ldap_handler.do_search(&request).await,
ldap_handler.do_search_or_dse(&request).await,
vec![make_search_success()]
);
let request = make_user_search_request(
@ -1817,7 +1838,7 @@ mod tests {
vec!["objectClass"],
);
assert_eq!(
ldap_handler.do_search(&request).await,
ldap_handler.do_search_or_dse(&request).await,
vec![make_search_error(
LdapResultCode::UnwillingToPerform,
"Unsupported user filter: while parsing a group ID: Missing DN value".to_string()
@ -1831,7 +1852,7 @@ mod tests {
vec!["objectClass"],
);
assert_eq!(
ldap_handler.do_search(&request).await,
ldap_handler.do_search_or_dse(&request).await,
vec![make_search_error(
LdapResultCode::UnwillingToPerform,
"Unsupported user filter: Unexpected group DN format. Got \"cn=mygroup,dc=example,dc=com\", expected: \"cn=groupname,ou=groups,dc=example,dc=com\"".to_string()
@ -1869,7 +1890,7 @@ mod tests {
vec!["objectclass"],
);
assert_eq!(
ldap_handler.do_search(&request).await,
ldap_handler.do_search_or_dse(&request).await,
vec![
LdapOp::SearchResultEntry(LdapSearchResultEntry {
dn: "uid=bob_1,ou=people,dc=example,dc=com".to_string(),
@ -1921,7 +1942,7 @@ mod tests {
vec!["objectClass", "dn", "cn"],
);
assert_eq!(
ldap_handler.do_search(&request).await,
ldap_handler.do_search_or_dse(&request).await,
vec![
LdapOp::SearchResultEntry(LdapSearchResultEntry {
dn: "uid=bob_1,ou=people,dc=example,dc=com".to_string(),
@ -2086,12 +2107,18 @@ mod tests {
make_search_success(),
];
assert_eq!(ldap_handler.do_search(&request).await, expected_result);
assert_eq!(
ldap_handler.do_search_or_dse(&request).await,
expected_result
);
let request2 =
make_search_request("dc=example,dc=com", LdapFilter::And(vec![]), vec!["*", "*"]);
assert_eq!(ldap_handler.do_search(&request2).await, expected_result);
assert_eq!(
ldap_handler.do_search_or_dse(&request2).await,
expected_result
);
let request3 = make_search_request(
"dc=example,dc=com",
@ -2099,12 +2126,18 @@ mod tests {
vec!["*", "+", "+"],
);
assert_eq!(ldap_handler.do_search(&request3).await, expected_result);
assert_eq!(
ldap_handler.do_search_or_dse(&request3).await,
expected_result
);
let request4 =
make_search_request("dc=example,dc=com", LdapFilter::And(vec![]), vec![""; 0]);
assert_eq!(ldap_handler.do_search(&request4).await, expected_result);
assert_eq!(
ldap_handler.do_search_or_dse(&request4).await,
expected_result
);
let request5 = make_search_request(
"dc=example,dc=com",
@ -2112,7 +2145,10 @@ mod tests {
vec!["objectclass", "dn", "uid", "*"],
);
assert_eq!(ldap_handler.do_search(&request5).await, expected_result);
assert_eq!(
ldap_handler.do_search_or_dse(&request5).await,
expected_result
);
}
#[tokio::test]
@ -2124,7 +2160,7 @@ mod tests {
vec!["objectClass"],
);
assert_eq!(
ldap_handler.do_search(&request).await,
ldap_handler.do_search_or_dse(&request).await,
vec![make_search_success()]
);
}
@ -2140,7 +2176,7 @@ mod tests {
vec!["objectClass"],
);
assert_eq!(
ldap_handler.do_search(&request).await,
ldap_handler.do_search_or_dse(&request).await,
vec![make_search_error(
LdapResultCode::UnwillingToPerform,
"Unsupported user filter: Unsupported user filter: Substring(\"uid\", LdapSubstringFilter { initial: None, any: [], final_: None })".to_string()
@ -2272,7 +2308,7 @@ mod tests {
attrs: vec!["supportedExtension".to_string()],
};
assert_eq!(
ldap_handler.do_search(&request).await,
ldap_handler.do_search_or_dse(&request).await,
vec![
root_dse_response("dc=example,dc=com"),
make_search_success()

View File

@ -10,12 +10,13 @@ use actix_server::ServerBuilder;
use actix_service::{fn_service, ServiceFactoryExt};
use anyhow::{Context, Result};
use ldap3_server::{proto::LdapMsg, LdapCodec};
use log::*;
use native_tls::{Identity, TlsAcceptor};
use tokio_native_tls::TlsAcceptor as NativeTlsAcceptor;
use tokio_util::codec::{FramedRead, FramedWrite};
use tracing::{debug, error, info, instrument};
async fn handle_incoming_message<Backend, Writer>(
#[instrument(skip_all, level = "info", name = "LDAP request")]
async fn handle_ldap_message<Backend, Writer>(
msg: Result<LdapMsg, std::io::Error>,
resp: &mut Writer,
session: &mut LdapHandler<Backend>,
@ -27,18 +28,18 @@ where
{
use futures_util::SinkExt;
let msg = msg.context("while receiving LDAP op")?;
debug!("Received LDAP message: {:?}", &msg);
debug!(?msg);
match session.handle_ldap_message(msg.op).await {
None => return Ok(false),
Some(result) => {
if result.is_empty() {
debug!("No response");
}
for result_op in result.into_iter() {
debug!("Replying with LDAP op: {:?}", &result_op);
for response in result.into_iter() {
debug!(?response);
resp.send(LdapMsg {
msgid: msg.msgid,
op: result_op,
op: response,
ctrl: vec![],
})
.await
@ -66,6 +67,7 @@ fn get_file_as_byte_vec(filename: &str) -> Result<Vec<u8>> {
.context(format!("while reading file {}", filename))
}
#[instrument(skip_all, level = "info", name = "LDAP session")]
async fn handle_ldap_stream<Stream, Backend>(
stream: Stream,
backend_handler: Backend,
@ -91,7 +93,7 @@ where
);
while let Some(msg) = requests.next().await {
if !handle_incoming_message(msg, &mut resp, &mut session)
if !handle_ldap_message(msg, &mut resp, &mut session)
.await
.context("while handling incoming messages")?
{
@ -145,6 +147,7 @@ where
.map_err(|err: anyhow::Error| error!("[LDAP] Service Error: {:#}", err))
};
info!("Starting the LDAP server on port {}", config.ldap_port);
let server_builder = server_builder
.bind("ldap", ("0.0.0.0", config.ldap_port), binder)
.with_context(|| format!("while binding to the port {}", config.ldap_port));
@ -176,6 +179,10 @@ where
.map_err(|err: anyhow::Error| error!("[LDAPS] Service Error: {:#}", err))
};
info!(
"Starting the LDAPS server on port {}",
config.ldaps_options.port
);
server_builder.and_then(|s| {
s.bind("ldaps", ("0.0.0.0", config.ldaps_options.port), tls_binder)
.with_context(|| format!("while binding to the port {}", config.ldaps_options.port))

View File

@ -1,30 +1,50 @@
use crate::infra::configuration::Configuration;
use tracing_subscriber::prelude::*;
use actix_web::{
dev::{ServiceRequest, ServiceResponse},
Error,
};
use tracing::{error, info, Span};
use tracing_actix_web::{root_span, RootSpanBuilder};
use tracing_subscriber::{filter::EnvFilter, layer::SubscriberExt, util::SubscriberInitExt};
/// We will define a custom root span builder to capture additional fields, specific
/// to our application, on top of the ones provided by `DefaultRootSpanBuilder` out of the box.
pub struct CustomRootSpanBuilder;
impl RootSpanBuilder for CustomRootSpanBuilder {
fn on_request_start(request: &ServiceRequest) -> Span {
let span = root_span!(request);
span.in_scope(|| {
info!(uri = %request.uri());
});
span
}
fn on_request_end<B>(_: Span, outcome: &Result<ServiceResponse<B>, Error>) {
match &outcome {
Ok(response) => {
if let Some(error) = response.response().error() {
error!(?error);
} else {
info!(status_code = &response.response().status().as_u16());
}
}
Err(error) => error!(?error),
};
}
}
pub fn init(config: &Configuration) -> anyhow::Result<()> {
let max_log_level = log_level_from_config(config);
let sqlx_max_log_level = sqlx_log_level_from_config(config);
let filter = tracing_subscriber::filter::Targets::new()
.with_target("lldap", max_log_level)
.with_target("sqlx", sqlx_max_log_level);
let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| {
EnvFilter::new(if config.verbose {
"sqlx=warn,debug"
} else {
"sqlx=warn,info"
})
});
tracing_subscriber::registry()
.with(tracing_subscriber::fmt::layer().with_filter(filter))
.with(env_filter)
.with(tracing_forest::ForestLayer::default())
.init();
Ok(())
}
fn log_level_from_config(config: &Configuration) -> tracing::Level {
if config.verbose {
tracing::Level::DEBUG
} else {
tracing::Level::INFO
}
}
fn sqlx_log_level_from_config(config: &Configuration) -> tracing::Level {
if config.verbose {
tracing::Level::INFO
} else {
tracing::Level::WARN
}
}

View File

@ -4,7 +4,7 @@ use lettre::{
message::Mailbox, transport::smtp::authentication::Credentials, Message, SmtpTransport,
Transport,
};
use log::debug;
use tracing::debug;
fn send_email(to: Mailbox, subject: &str, body: String, options: &MailOptions) -> Result<()> {
let from = options

View File

@ -6,6 +6,7 @@ use sea_query::{Expr, Iden, Query, SimpleExpr};
use sea_query_binder::SqlxBinder;
use sqlx::{query_as_with, query_with, Row};
use std::collections::HashSet;
use tracing::{debug, instrument};
fn gen_random_string(len: usize) -> String {
use rand::{distributions::Alphanumeric, rngs::SmallRng, Rng, SeedableRng};
@ -19,12 +20,14 @@ fn gen_random_string(len: usize) -> String {
#[async_trait]
impl TcpBackendHandler for SqlBackendHandler {
#[instrument(skip_all, level = "debug")]
async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>> {
let (query, values) = Query::select()
.column(JwtStorage::JwtHash)
.from(JwtStorage::Table)
.build_sqlx(DbQueryBuilder {});
debug!(%query);
query_with(&query, values)
.map(|row: DbRow| row.get::<i64, _>(&*JwtStorage::JwtHash.to_string()) as u64)
.fetch(&self.sql_pool)
@ -35,7 +38,9 @@ impl TcpBackendHandler for SqlBackendHandler {
.map_err(|e| anyhow::anyhow!(e))
}
#[instrument(skip_all, level = "debug")]
async fn create_refresh_token(&self, user: &UserId) -> Result<(String, chrono::Duration)> {
debug!(?user);
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
// TODO: Initialize the rng only once. Maybe Arc<Cell>?
@ -59,23 +64,30 @@ impl TcpBackendHandler for SqlBackendHandler {
(chrono::Utc::now() + duration).naive_utc().into(),
])
.build_sqlx(DbQueryBuilder {});
debug!(%query);
query_with(&query, values).execute(&self.sql_pool).await?;
Ok((refresh_token, duration))
}
#[instrument(skip_all, level = "debug")]
async fn check_token(&self, refresh_token_hash: u64, user: &UserId) -> Result<bool> {
debug!(?user);
let (query, values) = Query::select()
.expr(SimpleExpr::Value(1.into()))
.from(JwtRefreshStorage::Table)
.and_where(Expr::col(JwtRefreshStorage::RefreshTokenHash).eq(refresh_token_hash as i64))
.and_where(Expr::col(JwtRefreshStorage::UserId).eq(user))
.build_sqlx(DbQueryBuilder {});
debug!(%query);
Ok(query_with(&query, values)
.fetch_optional(&self.sql_pool)
.await?
.is_some())
}
#[instrument(skip_all, level = "debug")]
async fn blacklist_jwts(&self, user: &UserId) -> Result<HashSet<u64>> {
debug!(?user);
use sqlx::Result;
let (query, values) = Query::select()
.column(JwtStorage::JwtHash)
@ -95,31 +107,39 @@ impl TcpBackendHandler for SqlBackendHandler {
.values(vec![(JwtStorage::Blacklisted, true.into())])
.and_where(Expr::col(JwtStorage::UserId).eq(user))
.build_sqlx(DbQueryBuilder {});
debug!(%query);
query_with(&query, values).execute(&self.sql_pool).await?;
Ok(result?)
}
#[instrument(skip_all, level = "debug")]
async fn delete_refresh_token(&self, refresh_token_hash: u64) -> Result<()> {
let (query, values) = Query::delete()
.from_table(JwtRefreshStorage::Table)
.and_where(Expr::col(JwtRefreshStorage::RefreshTokenHash).eq(refresh_token_hash))
.and_where(Expr::col(JwtRefreshStorage::RefreshTokenHash).eq(refresh_token_hash as i64))
.build_sqlx(DbQueryBuilder {});
debug!(%query);
query_with(&query, values).execute(&self.sql_pool).await?;
Ok(())
}
#[instrument(skip_all, level = "debug")]
async fn start_password_reset(&self, user: &UserId) -> Result<Option<String>> {
debug!(?user);
let (query, values) = Query::select()
.column(Users::UserId)
.from(Users::Table)
.and_where(Expr::col(Users::UserId).eq(user))
.build_sqlx(DbQueryBuilder {});
debug!(%query);
// Check that the user exists.
if query_with(&query, values)
.fetch_one(&self.sql_pool)
.await
.is_err()
{
debug!("User not found");
return Ok(None);
}
@ -139,10 +159,12 @@ impl TcpBackendHandler for SqlBackendHandler {
(chrono::Utc::now() + duration).naive_utc().into(),
])
.build_sqlx(DbQueryBuilder {});
debug!(%query);
query_with(&query, values).execute(&self.sql_pool).await?;
Ok(Some(token))
}
#[instrument(skip_all, level = "debug", ret)]
async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result<UserId> {
let (query, values) = Query::select()
.column(PasswordResetTokens::UserId)
@ -152,6 +174,7 @@ impl TcpBackendHandler for SqlBackendHandler {
Expr::col(PasswordResetTokens::ExpiryDate).gt(chrono::Utc::now().naive_utc()),
)
.build_sqlx(DbQueryBuilder {});
debug!(%query);
let (user_id,) = query_as_with(&query, values)
.fetch_one(&self.sql_pool)
@ -159,11 +182,13 @@ impl TcpBackendHandler for SqlBackendHandler {
Ok(user_id)
}
#[instrument(skip_all, level = "debug")]
async fn delete_password_reset_token(&self, token: &str) -> Result<()> {
let (query, values) = Query::delete()
.from_table(PasswordResetTokens::Table)
.and_where(Expr::col(PasswordResetTokens::Token).eq(token))
.build_sqlx(DbQueryBuilder {});
debug!(%query);
query_with(&query, values).execute(&self.sql_pool).await?;
Ok(())
}

View File

@ -7,6 +7,7 @@ use crate::{
infra::{
auth_service,
configuration::{Configuration, MailOptions},
logging::CustomRootSpanBuilder,
tcp_backend_handler::*,
},
};
@ -21,6 +22,7 @@ use sha2::Sha512;
use std::collections::HashSet;
use std::path::PathBuf;
use std::sync::RwLock;
use tracing::info;
async fn index() -> actix_web::Result<NamedFile> {
let mut path = PathBuf::new();
@ -29,17 +31,36 @@ async fn index() -> actix_web::Result<NamedFile> {
Ok(NamedFile::open(path)?)
}
pub(crate) fn error_to_http_response(error: DomainError) -> HttpResponse {
#[derive(thiserror::Error, Debug)]
pub enum TcpError {
#[error("`{0}`")]
DomainError(#[from] DomainError),
#[error("Bad request: `{0}`")]
BadRequest(String),
#[error("Internal server error: `{0}`")]
InternalServerError(String),
#[error("Unauthorized: `{0}`")]
UnauthorizedError(String),
}
pub type TcpResult<T> = std::result::Result<T, TcpError>;
pub(crate) fn error_to_http_response(error: TcpError) -> HttpResponse {
match error {
DomainError::AuthenticationError(_) | DomainError::AuthenticationProtocolError(_) => {
HttpResponse::Unauthorized()
}
DomainError::DatabaseError(_)
| DomainError::InternalError(_)
| DomainError::UnknownCryptoError(_) => HttpResponse::InternalServerError(),
DomainError::Base64DecodeError(_) | DomainError::BinarySerializationError(_) => {
HttpResponse::BadRequest()
}
TcpError::DomainError(ref de) => match de {
DomainError::AuthenticationError(_) | DomainError::AuthenticationProtocolError(_) => {
HttpResponse::Unauthorized()
}
DomainError::DatabaseError(_)
| DomainError::InternalError(_)
| DomainError::UnknownCryptoError(_) => HttpResponse::InternalServerError(),
DomainError::Base64DecodeError(_) | DomainError::BinarySerializationError(_) => {
HttpResponse::BadRequest()
}
},
TcpError::BadRequest(_) => HttpResponse::BadRequest(),
TcpError::InternalServerError(_) => HttpResponse::InternalServerError(),
TcpError::UnauthorizedError(_) => HttpResponse::Unauthorized(),
}
.body(error.to_string())
}
@ -105,6 +126,7 @@ where
.context("while getting the jwt blacklist")?;
let server_url = config.http_url.clone();
let mail_options = config.smtp_options.clone();
info!("Starting the API/web server on port {}", config.http_port);
server_builder
.bind("http", ("0.0.0.0", config.http_port), move || {
let backend_handler = backend_handler.clone();
@ -114,16 +136,18 @@ where
let mail_options = mail_options.clone();
HttpServiceBuilder::new()
.finish(map_config(
App::new().configure(move |cfg| {
http_config(
cfg,
backend_handler,
jwt_secret,
jwt_blacklist,
server_url,
mail_options,
)
}),
App::new()
.wrap(tracing_actix_web::TracingLogger::<CustomRootSpanBuilder>::new())
.configure(move |cfg| {
http_config(
cfg,
backend_handler,
jwt_secret,
jwt_blacklist,
server_url,
mail_options,
)
}),
|_| AppConfig::default(),
))
.tcp()

View File

@ -12,9 +12,10 @@ use crate::{
infra::{cli::*, configuration::Configuration, db_cleaner::Scheduler, mail},
};
use actix::Actor;
use actix_server::ServerBuilder;
use anyhow::{anyhow, Context, Result};
use futures_util::TryFutureExt;
use log::*;
use tracing::*;
mod domain;
mod infra;
@ -45,7 +46,10 @@ async fn create_admin_user(handler: &SqlBackendHandler, config: &Configuration)
.context("Error adding admin user to group")
}
async fn run_server(config: Configuration) -> Result<()> {
#[instrument(skip_all)]
async fn set_up_server(config: Configuration) -> Result<ServerBuilder> {
info!("Starting LLDAP....");
let sql_pool = PoolOptions::new()
.max_connections(5)
.connect(&config.database_url)
@ -89,7 +93,12 @@ async fn run_server(config: Configuration) -> Result<()> {
// Run every hour.
let scheduler = Scheduler::new("0 0 * * * * *", sql_pool);
scheduler.start();
server_builder
Ok(server_builder)
}
async fn run_server(config: Configuration) -> Result<()> {
set_up_server(config)
.await?
.workers(1)
.run()
.await
@ -103,8 +112,6 @@ fn run_server_command(opts: RunOpts) -> Result<()> {
let config = infra::configuration::init(opts)?;
infra::logging::init(&config)?;
info!("Starting LLDAP....");
actix::run(
run_server(config).unwrap_or_else(|e| error!("Could not bring up the servers: {:#}", e)),
)?;