server,app: migrate to sea-orm

This commit is contained in:
Valentin Tolmer 2022-11-21 09:13:25 +01:00 committed by nitnelave
parent a3a27f0049
commit e89b1538af
40 changed files with 2125 additions and 1390 deletions

586
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -12,3 +12,7 @@ default-members = ["server"]
[patch.crates-io.ldap3_proto] [patch.crates-io.ldap3_proto]
git = 'https://github.com/nitnelave/ldap3_server/' git = 'https://github.com/nitnelave/ldap3_server/'
rev = '7b50b2b82c383f5f70e02e11072bb916629ed2bc' rev = '7b50b2b82c383f5f70e02e11072bb916629ed2bc'
[patch.crates-io.opaque-ke]
git = 'https://github.com/nitnelave/opaque-ke/'
branch = 'zeroize_1.5'

View File

@ -3,6 +3,7 @@ name = "lldap_app"
version = "0.4.2-alpha" version = "0.4.2-alpha"
authors = ["Valentin Tolmer <valentin@tolmer.fr>"] authors = ["Valentin Tolmer <valentin@tolmer.fr>"]
edition = "2021" edition = "2021"
include = ["src/**/*", "queries/**/*", "Cargo.toml", "../schema.graphql"]
[dependencies] [dependencies]
anyhow = "1" anyhow = "1"

View File

@ -38,7 +38,6 @@ pub struct CreateUserModel {
username: String, username: String,
#[validate(email(message = "A valid email is required"))] #[validate(email(message = "A valid email is required"))]
email: String, email: String,
#[validate(length(min = 1, message = "Display name is required"))]
display_name: String, display_name: String,
first_name: String, first_name: String,
last_name: String, last_name: String,
@ -244,9 +243,7 @@ impl Component for CreateUserForm {
<div class="form-group row mb-3"> <div class="form-group row mb-3">
<label for="display-name" <label for="display-name"
class="form-label col-4 col-form-label"> class="form-label col-4 col-form-label">
{"Display name"} {"Display name:"}
<span class="text-danger">{"*"}</span>
{":"}
</label> </label>
<div class="col-8"> <div class="col-8">
<Field <Field

View File

@ -89,7 +89,7 @@ impl GroupDetails {
{"Creation date: "} {"Creation date: "}
</label> </label>
<div class="col-8"> <div class="col-8">
<span id="creationDate" class="form-constrol-static">{g.creation_date.date().naive_local()}</span> <span id="creationDate" class="form-constrol-static">{g.creation_date.naive_local().date()}</span>
</div> </div>
</div> </div>
<div class="form-group row mb-3"> <div class="form-group row mb-3">

View File

@ -124,7 +124,7 @@ impl GroupTable {
</Link> </Link>
</td> </td>
<td> <td>
{&group.creation_date.date().naive_local()} {&group.creation_date.naive_local().date()}
</td> </td>
<td> <td>
<DeleteGroup <DeleteGroup

View File

@ -43,7 +43,6 @@ impl FromStr for JsFile {
pub struct UserModel { pub struct UserModel {
#[validate(email)] #[validate(email)]
email: String, email: String,
#[validate(length(min = 1, message = "Display name is required"))]
display_name: String, display_name: String,
first_name: String, first_name: String,
last_name: String, last_name: String,
@ -176,7 +175,10 @@ impl Component for UserDetailsForm {
type Field = yew_form::Field<UserModel>; type Field = yew_form::Field<UserModel>;
let avatar_base64 = maybe_to_base64(&self.avatar).unwrap_or_default(); let avatar_base64 = maybe_to_base64(&self.avatar).unwrap_or_default();
let avatar_string = avatar_base64.as_ref().unwrap_or(&self.common.user.avatar); let avatar_string = avatar_base64
.as_deref()
.or(self.common.user.avatar.as_deref())
.unwrap_or("");
html! { html! {
<div class="py-3"> <div class="py-3">
<form class="form"> <form class="form">
@ -195,7 +197,7 @@ impl Component for UserDetailsForm {
{"Creation date: "} {"Creation date: "}
</label> </label>
<div class="col-8"> <div class="col-8">
<span id="creationDate" class="form-control-static">{&self.common.user.creation_date.date().naive_local()}</span> <span id="creationDate" class="form-control-static">{&self.common.user.creation_date.naive_local().date()}</span>
</div> </div>
</div> </div>
<div class="form-group row mb-3"> <div class="form-group row mb-3">
@ -231,9 +233,7 @@ impl Component for UserDetailsForm {
<div class="form-group row mb-3"> <div class="form-group row mb-3">
<label for="display_name" <label for="display_name"
class="form-label col-4 col-form-label"> class="form-label col-4 col-form-label">
{"Display Name"} {"Display Name: "}
<span class="text-danger">{"*"}</span>
{":"}
</label> </label>
<div class="col-8"> <div class="col-8">
<Field <Field
@ -402,7 +402,7 @@ impl UserDetailsForm {
self.common.user.first_name = model.first_name; self.common.user.first_name = model.first_name;
self.common.user.last_name = model.last_name; self.common.user.last_name = model.last_name;
if let Some(avatar) = maybe_to_base64(&self.avatar)? { if let Some(avatar) = maybe_to_base64(&self.avatar)? {
self.common.user.avatar = avatar; self.common.user.avatar = Some(avatar);
} }
self.just_updated = true; self.just_updated = true;
} }

View File

@ -133,7 +133,7 @@ impl UserTable {
<td>{&user.display_name}</td> <td>{&user.display_name}</td>
<td>{&user.first_name}</td> <td>{&user.first_name}</td>
<td>{&user.last_name}</td> <td>{&user.last_name}</td>
<td>{&user.creation_date.date().naive_local()}</td> <td>{&user.creation_date.naive_local().date()}</td>
<td> <td>
<DeleteUser <DeleteUser
username=user.id.clone() username=user.id.clone()

View File

@ -53,7 +53,11 @@ pub fn get_cookie(cookie_name: &str) -> Result<Option<String>> {
pub fn delete_cookie(cookie_name: &str) -> Result<()> { pub fn delete_cookie(cookie_name: &str) -> Result<()> {
if get_cookie(cookie_name)?.is_some() { if get_cookie(cookie_name)?.is_some() {
set_cookie(cookie_name, "", &Utc.ymd(1970, 1, 1).and_hms(0, 0, 0)) set_cookie(
cookie_name,
"",
&Utc.with_ymd_and_hms(1970, 1, 1, 0, 0, 0).unwrap(),
)
} else { } else {
Ok(()) Ok(())
} }

View File

@ -69,7 +69,7 @@ type User {
displayName: String! displayName: String!
firstName: String! firstName: String!
lastName: String! lastName: String!
avatar: String! avatar: String
creationDate: DateTimeUtc! creationDate: DateTimeUtc!
uuid: String! uuid: String!
"The groups to which this user belongs." "The groups to which this user belongs."

View File

@ -35,7 +35,6 @@ rustls = "0.20"
serde = "*" serde = "*"
serde_json = "1" serde_json = "1"
sha2 = "0.9" sha2 = "0.9"
sqlx-core = "0.5.11"
thiserror = "*" thiserror = "*"
time = "0.2" time = "0.2"
tokio-rustls = "0.23" tokio-rustls = "0.23"
@ -70,28 +69,12 @@ features = ["builder", "serde", "smtp-transport", "tokio1-rustls-tls"]
default-features = false default-features = false
version = "0.10.0-rc.3" version = "0.10.0-rc.3"
[dependencies.sqlx]
version = "0.5.11"
features = [
"any",
"chrono",
"macros",
"mysql",
"postgres",
"runtime-actix-rustls",
"sqlite",
]
[dependencies.lldap_auth] [dependencies.lldap_auth]
path = "../auth" path = "../auth"
[dependencies.sea-query] [dependencies.sea-query]
version = "^0.25" version = "*"
features = ["with-chrono", "sqlx-sqlite"] features = ["with-chrono"]
[dependencies.sea-query-binder]
version = "0.1"
features = ["with-chrono", "sqlx-sqlite", "sqlx-any"]
[dependencies.opaque-ke] [dependencies.opaque-ke]
version = "0.6" version = "0.6"
@ -125,6 +108,11 @@ features = ["jpeg"]
default-features = false default-features = false
version = "0.24" version = "0.24"
[dependencies.sea-orm]
version= "0.10.3"
default-features = false
features = ["macros", "with-chrono", "with-uuid", "sqlx-all", "runtime-actix-rustls"]
[dependencies.reqwest] [dependencies.reqwest]
version = "0.11" version = "0.11"
default-features = false default-features = false

View File

@ -6,7 +6,7 @@ pub enum DomainError {
#[error("Authentication error: `{0}`")] #[error("Authentication error: `{0}`")]
AuthenticationError(String), AuthenticationError(String),
#[error("Database error: `{0}`")] #[error("Database error: `{0}`")]
DatabaseError(#[from] sqlx::Error), DatabaseError(#[from] sea_orm::DbErr),
#[error("Authentication protocol error for `{0}`")] #[error("Authentication protocol error for `{0}`")]
AuthenticationProtocolError(#[from] lldap_auth::opaque::AuthenticationError), AuthenticationProtocolError(#[from] lldap_auth::opaque::AuthenticationError),
#[error("Unknown crypto error: `{0}`")] #[error("Unknown crypto error: `{0}`")]
@ -15,6 +15,8 @@ pub enum DomainError {
BinarySerializationError(#[from] bincode::Error), BinarySerializationError(#[from] bincode::Error),
#[error("Invalid base64: `{0}`")] #[error("Invalid base64: `{0}`")]
Base64DecodeError(#[from] base64::DecodeError), Base64DecodeError(#[from] base64::DecodeError),
#[error("Entity not found: `{0}`")]
EntityNotFound(String),
#[error("Internal error: `{0}`")] #[error("Internal error: `{0}`")]
InternalError(String), InternalError(String),
} }

View File

@ -1,13 +1,12 @@
use super::{error::*, sql_tables::UserColumn}; use super::error::*;
use async_trait::async_trait; use async_trait::async_trait;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashSet; use std::collections::HashSet;
#[derive( pub use super::model::{GroupColumn, UserColumn};
PartialEq, Hash, Eq, Clone, Debug, Default, Serialize, Deserialize, sqlx::FromRow, sqlx::Type,
)] #[derive(PartialEq, Hash, Eq, Clone, Debug, Default, Serialize, Deserialize)]
#[serde(try_from = "&str")] #[serde(try_from = "&str")]
#[sqlx(transparent)]
pub struct Uuid(String); pub struct Uuid(String);
impl Uuid { impl Uuid {
@ -43,17 +42,26 @@ impl std::string::ToString for Uuid {
} }
} }
impl sea_orm::TryGetable for Uuid {
fn try_get(
res: &sea_orm::QueryResult,
pre: &str,
col: &str,
) -> std::result::Result<Self, sea_orm::TryGetError> {
Ok(Uuid(String::try_get(res, pre, col)?))
}
}
#[cfg(test)] #[cfg(test)]
#[macro_export] #[macro_export]
macro_rules! uuid { macro_rules! uuid {
($s:literal) => { ($s:literal) => {
$crate::domain::handler::Uuid::try_from($s).unwrap() <$crate::domain::handler::Uuid as std::convert::TryFrom<_>>::try_from($s).unwrap()
}; };
} }
#[derive(PartialEq, Eq, Clone, Debug, Default, Serialize, Deserialize, sqlx::Type)] #[derive(PartialEq, Eq, Clone, Debug, Default, Serialize, Deserialize)]
#[serde(from = "String")] #[serde(from = "String")]
#[sqlx(transparent)]
pub struct UserId(String); pub struct UserId(String);
impl UserId { impl UserId {
@ -82,17 +90,22 @@ impl From<String> for UserId {
} }
} }
#[derive(PartialEq, Eq, Clone, Debug, Default, Serialize, Deserialize, sqlx::Type)] #[derive(PartialEq, Eq, Clone, Debug, Serialize, Deserialize)]
#[sqlx(transparent)]
pub struct JpegPhoto(#[serde(with = "serde_bytes")] Vec<u8>); pub struct JpegPhoto(#[serde(with = "serde_bytes")] Vec<u8>);
impl From<JpegPhoto> for sea_query::Value { impl JpegPhoto {
pub fn null() -> Self {
Self(vec![])
}
}
impl From<JpegPhoto> for sea_orm::Value {
fn from(photo: JpegPhoto) -> Self { fn from(photo: JpegPhoto) -> Self {
photo.0.into() photo.0.into()
} }
} }
impl From<&JpegPhoto> for sea_query::Value { impl From<&JpegPhoto> for sea_orm::Value {
fn from(photo: &JpegPhoto) -> Self { fn from(photo: &JpegPhoto) -> Self {
photo.0.as_slice().into() photo.0.as_slice().into()
} }
@ -101,6 +114,9 @@ impl From<&JpegPhoto> for sea_query::Value {
impl TryFrom<&[u8]> for JpegPhoto { impl TryFrom<&[u8]> for JpegPhoto {
type Error = anyhow::Error; type Error = anyhow::Error;
fn try_from(bytes: &[u8]) -> anyhow::Result<Self> { fn try_from(bytes: &[u8]) -> anyhow::Result<Self> {
if bytes.is_empty() {
return Ok(JpegPhoto::null());
}
// Confirm that it's a valid Jpeg, then store only the bytes. // Confirm that it's a valid Jpeg, then store only the bytes.
image::io::Reader::with_format(std::io::Cursor::new(bytes), image::ImageFormat::Jpeg) image::io::Reader::with_format(std::io::Cursor::new(bytes), image::ImageFormat::Jpeg)
.decode()?; .decode()?;
@ -111,6 +127,9 @@ impl TryFrom<&[u8]> for JpegPhoto {
impl TryFrom<Vec<u8>> for JpegPhoto { impl TryFrom<Vec<u8>> for JpegPhoto {
type Error = anyhow::Error; type Error = anyhow::Error;
fn try_from(bytes: Vec<u8>) -> anyhow::Result<Self> { fn try_from(bytes: Vec<u8>) -> anyhow::Result<Self> {
if bytes.is_empty() {
return Ok(JpegPhoto::null());
}
// Confirm that it's a valid Jpeg, then store only the bytes. // Confirm that it's a valid Jpeg, then store only the bytes.
image::io::Reader::with_format( image::io::Reader::with_format(
std::io::Cursor::new(bytes.as_slice()), std::io::Cursor::new(bytes.as_slice()),
@ -160,14 +179,14 @@ impl JpegPhoto {
} }
} }
#[derive(PartialEq, Eq, Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] #[derive(PartialEq, Eq, Debug, Clone, Serialize, Deserialize, sea_orm::FromQueryResult)]
pub struct User { pub struct User {
pub user_id: UserId, pub user_id: UserId,
pub email: String, pub email: String,
pub display_name: String, pub display_name: Option<String>,
pub first_name: String, pub first_name: Option<String>,
pub last_name: String, pub last_name: Option<String>,
pub avatar: JpegPhoto, pub avatar: Option<JpegPhoto>,
pub creation_date: chrono::DateTime<chrono::Utc>, pub creation_date: chrono::DateTime<chrono::Utc>,
pub uuid: Uuid, pub uuid: Uuid,
} }
@ -176,14 +195,14 @@ pub struct User {
impl Default for User { impl Default for User {
fn default() -> Self { fn default() -> Self {
use chrono::TimeZone; use chrono::TimeZone;
let epoch = chrono::Utc.timestamp(0, 0); let epoch = chrono::Utc.timestamp_opt(0, 0).unwrap();
User { User {
user_id: UserId::default(), user_id: UserId::default(),
email: String::new(), email: String::new(),
display_name: String::new(), display_name: None,
first_name: String::new(), first_name: None,
last_name: String::new(), last_name: None,
avatar: JpegPhoto::default(), avatar: None,
creation_date: epoch, creation_date: epoch,
uuid: Uuid::from_name_and_date("", &epoch), uuid: Uuid::from_name_and_date("", &epoch),
} }
@ -263,11 +282,10 @@ pub trait LoginHandler: Clone + Send {
async fn bind(&self, request: BindRequest) -> Result<()>; async fn bind(&self, request: BindRequest) -> Result<()>;
} }
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)] #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[sqlx(transparent)]
pub struct GroupId(pub i32); pub struct GroupId(pub i32);
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::FromRow)] #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, sea_orm::FromQueryResult)]
pub struct GroupDetails { pub struct GroupDetails {
pub group_id: GroupId, pub group_id: GroupId,
pub display_name: String, pub display_name: String,
@ -349,8 +367,8 @@ mod tests {
fn test_uuid_time() { fn test_uuid_time() {
use chrono::prelude::*; use chrono::prelude::*;
let user_id = "bob"; let user_id = "bob";
let date1 = Utc.ymd(2014, 7, 8).and_hms(9, 10, 11); let date1 = Utc.with_ymd_and_hms(2014, 7, 8, 9, 10, 11).unwrap();
let date2 = Utc.ymd(2014, 7, 8).and_hms(9, 10, 12); let date2 = Utc.with_ymd_and_hms(2014, 7, 8, 9, 10, 12).unwrap();
assert_ne!( assert_ne!(
Uuid::from_name_and_date(user_id, &date1), Uuid::from_name_and_date(user_id, &date1),
Uuid::from_name_and_date(user_id, &date2) Uuid::from_name_and_date(user_id, &date2)

View File

@ -4,9 +4,8 @@ use ldap3_proto::{
use tracing::{debug, info, instrument, warn}; use tracing::{debug, info, instrument, warn};
use crate::domain::{ use crate::domain::{
handler::{BackendHandler, Group, GroupRequestFilter, UserId, Uuid}, handler::{BackendHandler, Group, GroupColumn, GroupRequestFilter, UserId, Uuid},
ldap::error::LdapError, ldap::error::LdapError,
sql_tables::GroupColumn,
}; };
use super::{ use super::{

View File

@ -4,9 +4,8 @@ use ldap3_proto::{
use tracing::{debug, info, instrument, warn}; use tracing::{debug, info, instrument, warn};
use crate::domain::{ use crate::domain::{
handler::{BackendHandler, GroupDetails, User, UserId, UserRequestFilter}, handler::{BackendHandler, GroupDetails, User, UserColumn, UserId, UserRequestFilter},
ldap::{error::LdapError, utils::expand_attribute_wildcards}, ldap::{error::LdapError, utils::expand_attribute_wildcards},
sql_tables::UserColumn,
}; };
use super::{ use super::{
@ -34,9 +33,9 @@ fn get_user_attribute(
"uid" => vec![user.user_id.to_string().into_bytes()], "uid" => vec![user.user_id.to_string().into_bytes()],
"entryuuid" => vec![user.uuid.to_string().into_bytes()], "entryuuid" => vec![user.uuid.to_string().into_bytes()],
"mail" => vec![user.email.clone().into_bytes()], "mail" => vec![user.email.clone().into_bytes()],
"givenname" => vec![user.first_name.clone().into_bytes()], "givenname" => vec![user.first_name.clone()?.into_bytes()],
"sn" => vec![user.last_name.clone().into_bytes()], "sn" => vec![user.last_name.clone()?.into_bytes()],
"jpegphoto" => vec![user.avatar.clone().into_bytes()], "jpegphoto" => vec![user.avatar.clone()?.into_bytes()],
"memberof" => groups "memberof" => groups
.into_iter() .into_iter()
.flatten() .flatten()
@ -48,7 +47,7 @@ fn get_user_attribute(
.into_bytes() .into_bytes()
}) })
.collect(), .collect(),
"cn" | "displayname" => vec![user.display_name.clone().into_bytes()], "cn" | "displayname" => vec![user.display_name.clone()?.into_bytes()],
"createtimestamp" | "modifytimestamp" => vec![user.creation_date.to_rfc3339().into_bytes()], "createtimestamp" | "modifytimestamp" => vec![user.creation_date.to_rfc3339().into_bytes()],
"1.1" => return None, "1.1" => return None,
// We ignore the operational attribute wildcard. // We ignore the operational attribute wildcard.

View File

@ -2,10 +2,7 @@ use itertools::Itertools;
use ldap3_proto::LdapResultCode; use ldap3_proto::LdapResultCode;
use tracing::{debug, instrument, warn}; use tracing::{debug, instrument, warn};
use crate::domain::{ use crate::domain::handler::{GroupColumn, UserColumn, UserId};
handler::UserId,
sql_tables::{GroupColumn, UserColumn},
};
use super::error::{LdapError, LdapResult}; use super::error::{LdapError, LdapResult};

View File

@ -1,6 +1,7 @@
pub mod error; pub mod error;
pub mod handler; pub mod handler;
pub mod ldap; pub mod ldap;
pub mod model;
pub mod opaque_handler; pub mod opaque_handler;
pub mod sql_backend_handler; pub mod sql_backend_handler;
pub mod sql_group_backend_handler; pub mod sql_group_backend_handler;

View File

@ -0,0 +1,53 @@
//! `SeaORM` Entity. Generated by sea-orm-codegen 0.10.3
use sea_orm::entity::prelude::*;
use serde::{Deserialize, Serialize};
use crate::domain::handler::{GroupId, Uuid};
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)]
#[sea_orm(table_name = "groups")]
pub struct Model {
#[sea_orm(primary_key, auto_increment = false)]
pub group_id: GroupId,
pub display_name: String,
pub creation_date: chrono::DateTime<chrono::Utc>,
pub uuid: Uuid,
}
impl From<Model> for crate::domain::handler::Group {
fn from(group: Model) -> Self {
Self {
id: group.group_id,
display_name: group.display_name,
creation_date: group.creation_date,
uuid: group.uuid,
users: vec![],
}
}
}
impl From<Model> for crate::domain::handler::GroupDetails {
fn from(group: Model) -> Self {
Self {
group_id: group.group_id,
display_name: group.display_name,
creation_date: group.creation_date,
uuid: group.uuid,
}
}
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(has_many = "super::memberships::Entity")]
Memberships,
}
impl Related<super::memberships::Entity> for Entity {
fn to() -> RelationDef {
Relation::Memberships.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@ -0,0 +1,35 @@
//! `SeaORM` Entity. Generated by sea-orm-codegen 0.10.3
use sea_orm::entity::prelude::*;
use serde::{Deserialize, Serialize};
use crate::domain::handler::UserId;
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)]
#[sea_orm(table_name = "jwt_refresh_storage")]
pub struct Model {
#[sea_orm(primary_key, auto_increment = false)]
pub refresh_token_hash: i64,
pub user_id: UserId,
pub expiry_date: chrono::DateTime<chrono::Utc>,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::users::Entity",
from = "Column::UserId",
to = "super::users::Column::UserId",
on_update = "Cascade",
on_delete = "Cascade"
)]
Users,
}
impl Related<super::users::Entity> for Entity {
fn to() -> RelationDef {
Relation::Users.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@ -0,0 +1,36 @@
//! `SeaORM` Entity. Generated by sea-orm-codegen 0.10.3
use sea_orm::entity::prelude::*;
use serde::{Deserialize, Serialize};
use crate::domain::handler::UserId;
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)]
#[sea_orm(table_name = "jwt_storage")]
pub struct Model {
#[sea_orm(primary_key, auto_increment = false)]
pub jwt_hash: i64,
pub user_id: UserId,
pub expiry_date: chrono::DateTime<chrono::Utc>,
pub blacklisted: bool,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::users::Entity",
from = "Column::UserId",
to = "super::users::Column::UserId",
on_update = "Cascade",
on_delete = "Cascade"
)]
Users,
}
impl Related<super::users::Entity> for Entity {
fn to() -> RelationDef {
Relation::Users.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@ -0,0 +1,73 @@
//! `SeaORM` Entity. Generated by sea-orm-codegen 0.10.3
use sea_orm::entity::prelude::*;
use serde::{Deserialize, Serialize};
use crate::domain::handler::{GroupId, UserId};
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)]
#[sea_orm(table_name = "memberships")]
pub struct Model {
#[sea_orm(primary_key)]
pub user_id: UserId,
#[sea_orm(primary_key)]
pub group_id: GroupId,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::groups::Entity",
from = "Column::GroupId",
to = "super::groups::Column::GroupId",
on_update = "Cascade",
on_delete = "Cascade"
)]
Groups,
#[sea_orm(
belongs_to = "super::users::Entity",
from = "Column::UserId",
to = "super::users::Column::UserId",
on_update = "Cascade",
on_delete = "Cascade"
)]
Users,
}
impl Related<super::groups::Entity> for Entity {
fn to() -> RelationDef {
Relation::Groups.def()
}
}
impl Related<super::users::Entity> for Entity {
fn to() -> RelationDef {
Relation::Users.def()
}
}
#[derive(Debug)]
pub struct UserToGroup;
impl Linked for UserToGroup {
type FromEntity = super::User;
type ToEntity = super::Group;
fn link(&self) -> Vec<RelationDef> {
vec![Relation::Users.def().rev(), Relation::Groups.def()]
}
}
#[derive(Debug)]
pub struct GroupToUser;
impl Linked for GroupToUser {
type FromEntity = super::Group;
type ToEntity = super::User;
fn link(&self) -> Vec<RelationDef> {
vec![Relation::Groups.def().rev(), Relation::Users.def()]
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@ -0,0 +1,12 @@
//! `SeaORM` Entity. Generated by sea-orm-codegen 0.10.3
pub mod prelude;
pub mod groups;
pub mod jwt_refresh_storage;
pub mod jwt_storage;
pub mod memberships;
pub mod password_reset_tokens;
pub mod users;
pub use prelude::*;

View File

@ -0,0 +1,35 @@
//! `SeaORM` Entity. Generated by sea-orm-codegen 0.10.3
use sea_orm::entity::prelude::*;
use serde::{Deserialize, Serialize};
use crate::domain::handler::UserId;
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)]
#[sea_orm(table_name = "password_reset_tokens")]
pub struct Model {
#[sea_orm(primary_key, auto_increment = false)]
pub token: String,
pub user_id: UserId,
pub expiry_date: chrono::DateTime<chrono::Utc>,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::users::Entity",
from = "Column::UserId",
to = "super::users::Column::UserId",
on_update = "Cascade",
on_delete = "Cascade"
)]
Users,
}
impl Related<super::users::Entity> for Entity {
fn to() -> RelationDef {
Relation::Users.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@ -0,0 +1,14 @@
//! `SeaORM` Entity. Generated by sea-orm-codegen 0.10.3
pub use super::groups::Column as GroupColumn;
pub use super::groups::Entity as Group;
pub use super::jwt_refresh_storage::Column as JwtRefreshStorageColumn;
pub use super::jwt_refresh_storage::Entity as JwtRefreshStorage;
pub use super::jwt_storage::Column as JwtStorageColumn;
pub use super::jwt_storage::Entity as JwtStorage;
pub use super::memberships::Column as MembershipColumn;
pub use super::memberships::Entity as Membership;
pub use super::password_reset_tokens::Column as PasswordResetTokensColumn;
pub use super::password_reset_tokens::Entity as PasswordResetTokens;
pub use super::users::Column as UserColumn;
pub use super::users::Entity as User;

View File

@ -0,0 +1,134 @@
//! `SeaORM` Entity. Generated by sea-orm-codegen 0.10.3
use sea_orm::entity::prelude::*;
use serde::{Deserialize, Serialize};
use crate::domain::handler::{JpegPhoto, UserId, Uuid};
#[derive(Copy, Clone, Default, Debug, DeriveEntity)]
pub struct Entity;
#[derive(Clone, Debug, PartialEq, DeriveModel, Eq, Serialize, Deserialize, DeriveActiveModel)]
#[sea_orm(table_name = "users")]
pub struct Model {
#[sea_orm(primary_key, auto_increment = false)]
pub user_id: UserId,
pub email: String,
pub display_name: Option<String>,
pub first_name: Option<String>,
pub last_name: Option<String>,
pub avatar: Option<JpegPhoto>,
pub creation_date: chrono::DateTime<chrono::Utc>,
pub password_hash: Option<Vec<u8>>,
pub totp_secret: Option<String>,
pub mfa_type: Option<String>,
pub uuid: Uuid,
}
impl EntityName for Entity {
fn table_name(&self) -> &str {
"users"
}
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn, PartialEq, Eq, Serialize, Deserialize)]
pub enum Column {
UserId,
Email,
DisplayName,
FirstName,
LastName,
Avatar,
CreationDate,
PasswordHash,
TotpSecret,
MfaType,
Uuid,
}
impl ColumnTrait for Column {
type EntityName = Entity;
fn def(&self) -> ColumnDef {
match self {
Column::UserId => ColumnType::String(Some(255)),
Column::Email => ColumnType::String(Some(255)),
Column::DisplayName => ColumnType::String(Some(255)),
Column::FirstName => ColumnType::String(Some(255)),
Column::LastName => ColumnType::String(Some(255)),
Column::Avatar => ColumnType::Binary,
Column::CreationDate => ColumnType::DateTime,
Column::PasswordHash => ColumnType::Binary,
Column::TotpSecret => ColumnType::String(Some(64)),
Column::MfaType => ColumnType::String(Some(64)),
Column::Uuid => ColumnType::String(Some(36)),
}
.def()
}
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(has_many = "super::memberships::Entity")]
Memberships,
#[sea_orm(has_many = "super::jwt_refresh_storage::Entity")]
JwtRefreshStorage,
#[sea_orm(has_many = "super::jwt_storage::Entity")]
JwtStorage,
#[sea_orm(has_many = "super::password_reset_tokens::Entity")]
PasswordResetTokens,
}
#[derive(Copy, Clone, Debug, EnumIter, DerivePrimaryKey)]
pub enum PrimaryKey {
UserId,
}
impl PrimaryKeyTrait for PrimaryKey {
type ValueType = UserId;
fn auto_increment() -> bool {
false
}
}
impl Related<super::memberships::Entity> for Entity {
fn to() -> RelationDef {
Relation::Memberships.def()
}
}
impl Related<super::jwt_refresh_storage::Entity> for Entity {
fn to() -> RelationDef {
Relation::JwtRefreshStorage.def()
}
}
impl Related<super::jwt_storage::Entity> for Entity {
fn to() -> RelationDef {
Relation::JwtStorage.def()
}
}
impl Related<super::password_reset_tokens::Entity> for Entity {
fn to() -> RelationDef {
Relation::PasswordResetTokens.def()
}
}
impl ActiveModelBehavior for ActiveModel {}
impl From<Model> for crate::domain::handler::User {
fn from(user: Model) -> Self {
Self {
user_id: user.user_id,
email: user.email,
display_name: user.display_name,
first_name: user.first_name,
last_name: user.last_name,
creation_date: user.creation_date,
uuid: user.uuid,
avatar: user.avatar,
}
}
}

View File

@ -5,11 +5,11 @@ use async_trait::async_trait;
#[derive(Clone)] #[derive(Clone)]
pub struct SqlBackendHandler { pub struct SqlBackendHandler {
pub(crate) config: Configuration, pub(crate) config: Configuration,
pub(crate) sql_pool: Pool, pub(crate) sql_pool: DbConnection,
} }
impl SqlBackendHandler { impl SqlBackendHandler {
pub fn new(config: Configuration, sql_pool: Pool) -> Self { pub fn new(config: Configuration, sql_pool: DbConnection) -> Self {
SqlBackendHandler { config, sql_pool } SqlBackendHandler { config, sql_pool }
} }
} }
@ -23,16 +23,23 @@ pub mod tests {
use crate::domain::sql_tables::init_table; use crate::domain::sql_tables::init_table;
use crate::infra::configuration::ConfigurationBuilder; use crate::infra::configuration::ConfigurationBuilder;
use lldap_auth::{opaque, registration}; use lldap_auth::{opaque, registration};
use sea_orm::Database;
pub fn get_default_config() -> Configuration { pub fn get_default_config() -> Configuration {
ConfigurationBuilder::for_tests() ConfigurationBuilder::for_tests()
} }
pub async fn get_in_memory_db() -> Pool { pub async fn get_in_memory_db() -> DbConnection {
PoolOptions::new().connect("sqlite::memory:").await.unwrap() crate::infra::logging::init_for_tests();
let mut sql_opt = sea_orm::ConnectOptions::new("sqlite::memory:".to_owned());
sql_opt
.max_connections(1)
.sqlx_logging(true)
.sqlx_logging_level(log::LevelFilter::Debug);
Database::connect(sql_opt).await.unwrap()
} }
pub async fn get_initialized_db() -> Pool { pub async fn get_initialized_db() -> DbConnection {
let sql_pool = get_in_memory_db().await; let sql_pool = get_in_memory_db().await;
init_table(&sql_pool).await.unwrap(); init_table(&sql_pool).await.unwrap();
sql_pool sql_pool

View File

@ -1,21 +1,22 @@
use crate::domain::handler::Uuid;
use super::{ use super::{
error::Result, error::{DomainError, Result},
handler::{ handler::{
Group, GroupBackendHandler, GroupDetails, GroupId, GroupRequestFilter, UpdateGroupRequest, Group, GroupBackendHandler, GroupDetails, GroupId, GroupRequestFilter, UpdateGroupRequest,
UserId,
}, },
model::{self, GroupColumn, MembershipColumn},
sql_backend_handler::SqlBackendHandler, sql_backend_handler::SqlBackendHandler,
sql_tables::{DbQueryBuilder, Groups, Memberships},
}; };
use async_trait::async_trait; use async_trait::async_trait;
use sea_query::{Cond, Expr, Iden, Order, Query, SimpleExpr}; use sea_orm::{
use sea_query_binder::SqlxBinder; ActiveModelTrait, ActiveValue, ColumnTrait, EntityTrait, QueryFilter, QueryOrder, QuerySelect,
use sqlx::{query_as_with, query_with, FromRow, Row}; QueryTrait,
};
use sea_query::{Cond, IntoCondition, SimpleExpr};
use tracing::{debug, instrument}; use tracing::{debug, instrument};
// Returns the condition for the SQL query, and whether it requires joining with the groups table.
fn get_group_filter_expr(filter: GroupRequestFilter) -> Cond { fn get_group_filter_expr(filter: GroupRequestFilter) -> Cond {
use sea_query::IntoCondition;
use GroupRequestFilter::*; use GroupRequestFilter::*;
match filter { match filter {
And(fs) => { And(fs) => {
@ -35,23 +36,17 @@ fn get_group_filter_expr(filter: GroupRequestFilter) -> Cond {
} }
} }
Not(f) => get_group_filter_expr(*f).not(), Not(f) => get_group_filter_expr(*f).not(),
DisplayName(name) => Expr::col((Groups::Table, Groups::DisplayName)) DisplayName(name) => GroupColumn::DisplayName.eq(name).into_condition(),
.eq(name) GroupId(id) => GroupColumn::GroupId.eq(id.0).into_condition(),
.into_condition(), Uuid(uuid) => GroupColumn::Uuid.eq(uuid.to_string()).into_condition(),
GroupId(id) => Expr::col((Groups::Table, Groups::GroupId))
.eq(id.0)
.into_condition(),
Uuid(uuid) => Expr::col((Groups::Table, Groups::Uuid))
.eq(uuid.to_string())
.into_condition(),
// WHERE (group_id in (SELECT group_id FROM memberships WHERE user_id = user)) // WHERE (group_id in (SELECT group_id FROM memberships WHERE user_id = user))
Member(user) => Expr::col((Memberships::Table, Memberships::GroupId)) Member(user) => GroupColumn::GroupId
.in_subquery( .in_subquery(
Query::select() model::Membership::find()
.column(Memberships::GroupId) .select_only()
.from(Memberships::Table) .column(MembershipColumn::GroupId)
.cond_where(Expr::col(Memberships::UserId).eq(user)) .filter(MembershipColumn::UserId.eq(user))
.take(), .into_query(),
) )
.into_condition(), .into_condition(),
} }
@ -62,94 +57,67 @@ impl GroupBackendHandler for SqlBackendHandler {
#[instrument(skip_all, level = "debug", ret, err)] #[instrument(skip_all, level = "debug", ret, err)]
async fn list_groups(&self, filters: Option<GroupRequestFilter>) -> Result<Vec<Group>> { async fn list_groups(&self, filters: Option<GroupRequestFilter>) -> Result<Vec<Group>> {
debug!(?filters); debug!(?filters);
let (query, values) = { let results = model::Group::find()
let mut query_builder = Query::select() // The order_by must be before find_with_related otherwise the primary order is by group_id.
.column((Groups::Table, Groups::GroupId)) .order_by_asc(GroupColumn::DisplayName)
.column(Groups::DisplayName) .find_with_related(model::Membership)
.column(Groups::CreationDate) .filter(
.column(Groups::Uuid) filters
.column(Memberships::UserId) .map(|f| {
.from(Groups::Table) GroupColumn::GroupId
.left_join( .in_subquery(
Memberships::Table, model::Group::find()
Expr::tbl(Groups::Table, Groups::GroupId) .find_also_linked(model::memberships::GroupToUser)
.equals(Memberships::Table, Memberships::GroupId), .select_only()
.column(GroupColumn::GroupId)
.filter(get_group_filter_expr(f))
.into_query(),
) )
.order_by(Groups::DisplayName, Order::Asc) .into_condition()
.order_by(Memberships::UserId, Order::Asc) })
.to_owned(); .unwrap_or_else(|| SimpleExpr::Value(true.into()).into_condition()),
)
if let Some(filter) = filters { .all(&self.sql_pool)
query_builder.cond_where(get_group_filter_expr(filter)); .await?;
} Ok(results
query_builder.build_sqlx(DbQueryBuilder {})
};
debug!(%query);
// For group_by.
use itertools::Itertools;
let mut groups = Vec::new();
// The rows are returned sorted by display_name, equivalent to group_id. We group them by
// this key which gives us one element (`rows`) per group.
for (group_details, rows) in &query_with(&query, values)
.fetch_all(&self.sql_pool)
.await?
.into_iter() .into_iter()
.group_by(|row| GroupDetails::from_row(row).unwrap()) .map(|(group, users)| {
{ let users: Vec<_> = users.into_iter().map(|u| u.user_id).collect();
groups.push(Group { Group {
id: group_details.group_id, users,
display_name: group_details.display_name, ..group.into()
creation_date: group_details.creation_date,
uuid: group_details.uuid,
users: rows
.map(|row| row.get::<UserId, _>(&*Memberships::UserId.to_string()))
// If a group has no users, an empty string is returned because of the left
// join.
.filter(|s| !s.as_str().is_empty())
.collect(),
});
} }
Ok(groups) })
.collect())
} }
#[instrument(skip_all, level = "debug", ret, err)] #[instrument(skip_all, level = "debug", ret, err)]
async fn get_group_details(&self, group_id: GroupId) -> Result<GroupDetails> { async fn get_group_details(&self, group_id: GroupId) -> Result<GroupDetails> {
debug!(?group_id); debug!(?group_id);
let (query, values) = Query::select() model::Group::find_by_id(group_id)
.column(Groups::GroupId) .into_model::<GroupDetails>()
.column(Groups::DisplayName) .one(&self.sql_pool)
.column(Groups::CreationDate) .await?
.column(Groups::Uuid) .ok_or_else(|| DomainError::EntityNotFound(format!("{:?}", group_id)))
.from(Groups::Table)
.cond_where(Expr::col(Groups::GroupId).eq(group_id))
.build_sqlx(DbQueryBuilder {});
debug!(%query);
Ok(query_as_with::<_, GroupDetails, _>(&query, values)
.fetch_one(&self.sql_pool)
.await?)
} }
#[instrument(skip_all, level = "debug", err)] #[instrument(skip_all, level = "debug", err)]
async fn update_group(&self, request: UpdateGroupRequest) -> Result<()> { async fn update_group(&self, request: UpdateGroupRequest) -> Result<()> {
debug!(?request.group_id); debug!(?request.group_id);
let mut values = Vec::new(); let update_group = model::groups::ActiveModel {
if let Some(display_name) = request.display_name { display_name: request
values.push((Groups::DisplayName, display_name.into())); .display_name
} .map(ActiveValue::Set)
if values.is_empty() { .unwrap_or_default(),
return Ok(()); ..Default::default()
} };
let (query, values) = Query::update() model::Group::update_many()
.table(Groups::Table) .set(update_group)
.values(values) .filter(sea_orm::ColumnTrait::eq(
.cond_where(Expr::col(Groups::GroupId).eq(request.group_id)) &GroupColumn::GroupId,
.build_sqlx(DbQueryBuilder {}); request.group_id,
debug!(%query); ))
query_with(query.as_str(), values) .exec(&self.sql_pool)
.execute(&self.sql_pool)
.await?; .await?;
Ok(()) Ok(())
} }
@ -157,30 +125,29 @@ impl GroupBackendHandler for SqlBackendHandler {
#[instrument(skip_all, level = "debug", ret, err)] #[instrument(skip_all, level = "debug", ret, err)]
async fn create_group(&self, group_name: &str) -> Result<GroupId> { async fn create_group(&self, group_name: &str) -> Result<GroupId> {
debug!(?group_name); debug!(?group_name);
crate::domain::sql_tables::create_group(group_name, &self.sql_pool).await?; let now = chrono::Utc::now();
let (query, values) = Query::select() let uuid = Uuid::from_name_and_date(group_name, &now);
.column(Groups::GroupId) let new_group = model::groups::ActiveModel {
.from(Groups::Table) display_name: ActiveValue::Set(group_name.to_owned()),
.cond_where(Expr::col(Groups::DisplayName).eq(group_name)) creation_date: ActiveValue::Set(now),
.build_sqlx(DbQueryBuilder {}); uuid: ActiveValue::Set(uuid),
debug!(%query); ..Default::default()
let row = query_with(query.as_str(), values) };
.fetch_one(&self.sql_pool) Ok(new_group.insert(&self.sql_pool).await?.group_id)
.await?;
Ok(GroupId(row.get::<i32, _>(&*Groups::GroupId.to_string())))
} }
#[instrument(skip_all, level = "debug", err)] #[instrument(skip_all, level = "debug", err)]
async fn delete_group(&self, group_id: GroupId) -> Result<()> { async fn delete_group(&self, group_id: GroupId) -> Result<()> {
debug!(?group_id); debug!(?group_id);
let (query, values) = Query::delete() let res = model::Group::delete_by_id(group_id)
.from_table(Groups::Table) .exec(&self.sql_pool)
.cond_where(Expr::col(Groups::GroupId).eq(group_id))
.build_sqlx(DbQueryBuilder {});
debug!(%query);
query_with(query.as_str(), values)
.execute(&self.sql_pool)
.await?; .await?;
if res.rows_affected == 0 {
return Err(DomainError::EntityNotFound(format!(
"No such group: '{:?}'",
group_id
)));
}
Ok(()) Ok(())
} }
} }
@ -188,7 +155,7 @@ impl GroupBackendHandler for SqlBackendHandler {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::domain::sql_backend_handler::tests::*; use crate::domain::{handler::UserId, sql_backend_handler::tests::*};
async fn get_group_ids( async fn get_group_ids(
handler: &SqlBackendHandler, handler: &SqlBackendHandler,
@ -203,12 +170,29 @@ mod tests {
.collect::<Vec<_>>() .collect::<Vec<_>>()
} }
async fn get_group_names(
handler: &SqlBackendHandler,
filters: Option<GroupRequestFilter>,
) -> Vec<String> {
handler
.list_groups(filters)
.await
.unwrap()
.into_iter()
.map(|g| g.display_name)
.collect::<Vec<_>>()
}
#[tokio::test] #[tokio::test]
async fn test_list_groups_no_filter() { async fn test_list_groups_no_filter() {
let fixture = TestFixture::new().await; let fixture = TestFixture::new().await;
assert_eq!( assert_eq!(
get_group_ids(&fixture.handler, None).await, get_group_names(&fixture.handler, None).await,
vec![fixture.groups[0], fixture.groups[2], fixture.groups[1]] vec![
"Best Group".to_owned(),
"Empty Group".to_owned(),
"Worst Group".to_owned()
]
); );
} }
@ -216,15 +200,15 @@ mod tests {
async fn test_list_groups_simple_filter() { async fn test_list_groups_simple_filter() {
let fixture = TestFixture::new().await; let fixture = TestFixture::new().await;
assert_eq!( assert_eq!(
get_group_ids( get_group_names(
&fixture.handler, &fixture.handler,
Some(GroupRequestFilter::Or(vec![ Some(GroupRequestFilter::Or(vec![
GroupRequestFilter::DisplayName("Empty Group".to_string()), GroupRequestFilter::DisplayName("Empty Group".to_owned()),
GroupRequestFilter::Member(UserId::new("bob")), GroupRequestFilter::Member(UserId::new("bob")),
])) ]))
) )
.await, .await,
vec![fixture.groups[0], fixture.groups[2]] vec!["Best Group".to_owned(), "Empty Group".to_owned()]
); );
} }
@ -236,7 +220,7 @@ mod tests {
&fixture.handler, &fixture.handler,
Some(GroupRequestFilter::And(vec![ Some(GroupRequestFilter::And(vec![
GroupRequestFilter::Not(Box::new(GroupRequestFilter::DisplayName( GroupRequestFilter::Not(Box::new(GroupRequestFilter::DisplayName(
"value".to_string() "value".to_owned()
))), ))),
GroupRequestFilter::GroupId(fixture.groups[0]), GroupRequestFilter::GroupId(fixture.groups[0]),
])) ]))
@ -273,7 +257,7 @@ mod tests {
.handler .handler
.update_group(UpdateGroupRequest { .update_group(UpdateGroupRequest {
group_id: fixture.groups[0], group_id: fixture.groups[0],
display_name: Some("Awesomest Group".to_string()), display_name: Some("Awesomest Group".to_owned()),
}) })
.await .await
.unwrap(); .unwrap();
@ -288,6 +272,10 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_delete_group() { async fn test_delete_group() {
let fixture = TestFixture::new().await; let fixture = TestFixture::new().await;
assert_eq!(
get_group_ids(&fixture.handler, None).await,
vec![fixture.groups[0], fixture.groups[2], fixture.groups[1]]
);
fixture fixture
.handler .handler
.delete_group(fixture.groups[0]) .delete_group(fixture.groups[0])

View File

@ -1,55 +1,87 @@
use super::{ use super::{
handler::{GroupId, UserId, Uuid}, handler::{GroupId, UserId, Uuid},
sql_tables::{ sql_tables::{DbConnection, SchemaVersion},
DbQueryBuilder, DbRow, Groups, Memberships, Metadata, Pool, SchemaVersion, Users,
},
}; };
use sea_query::*; use sea_orm::{ConnectionTrait, FromQueryResult, Statement};
use sea_query_binder::SqlxBinder; use sea_query::{ColumnDef, Expr, ForeignKey, ForeignKeyAction, Iden, Query, Table, Value};
use sqlx::Row; use serde::{Deserialize, Serialize};
use tracing::{debug, warn}; use tracing::{instrument, warn};
pub async fn create_group(group_name: &str, pool: &Pool) -> sqlx::Result<()> { #[derive(Iden, PartialEq, Eq, Debug, Serialize, Deserialize, Clone)]
let now = chrono::Utc::now(); pub enum Users {
let (query, values) = Query::insert() Table,
.into_table(Groups::Table) UserId,
.columns(vec![ Email,
Groups::DisplayName, DisplayName,
Groups::CreationDate, FirstName,
Groups::Uuid, LastName,
]) Avatar,
.values_panic(vec![ CreationDate,
group_name.into(), PasswordHash,
now.naive_utc().into(), TotpSecret,
Uuid::from_name_and_date(group_name, &now).into(), MfaType,
]) Uuid,
.build_sqlx(DbQueryBuilder {});
debug!(%query);
sqlx::query_with(query.as_str(), values)
.execute(pool)
.await
.map(|_| ())
} }
pub async fn get_schema_version(pool: &Pool) -> Option<SchemaVersion> { #[derive(Iden, PartialEq, Eq, Debug, Serialize, Deserialize, Clone)]
sqlx::query( pub enum Groups {
&Query::select() Table,
GroupId,
DisplayName,
CreationDate,
Uuid,
}
#[derive(Iden)]
pub enum Memberships {
Table,
UserId,
GroupId,
}
// Metadata about the SQL DB.
#[derive(Iden)]
pub enum Metadata {
Table,
// Which version of the schema we're at.
Version,
}
#[derive(FromQueryResult, PartialEq, Eq, Debug)]
pub struct JustSchemaVersion {
pub version: SchemaVersion,
}
#[instrument(skip_all, level = "debug", ret)]
pub async fn get_schema_version(pool: &DbConnection) -> Option<SchemaVersion> {
JustSchemaVersion::find_by_statement(
pool.get_database_backend().build(
Query::select()
.from(Metadata::Table) .from(Metadata::Table)
.column(Metadata::Version) .column(Metadata::Version),
.to_string(DbQueryBuilder {}), ),
) )
.map(|row: DbRow| row.get::<SchemaVersion, _>(&*Metadata::Version.to_string())) .one(pool)
.fetch_one(pool)
.await .await
.ok() .ok()
.flatten()
.map(|j| j.version)
} }
pub async fn upgrade_to_v1(pool: &Pool) -> sqlx::Result<()> { pub async fn upgrade_to_v1(pool: &DbConnection) -> std::result::Result<(), sea_orm::DbErr> {
let builder = pool.get_database_backend();
// SQLite needs this pragma to be turned on. Other DB might not understand this, so ignore the // SQLite needs this pragma to be turned on. Other DB might not understand this, so ignore the
// error. // error.
let _ = sqlx::query("PRAGMA foreign_keys = ON").execute(pool).await; let _ = pool
sqlx::query( .execute(Statement::from_string(
&Table::create() builder,
"PRAGMA foreign_keys = ON".to_owned(),
))
.await;
pool.execute(
builder.build(
Table::create()
.table(Users::Table) .table(Users::Table)
.if_not_exists() .if_not_exists()
.col( .col(
@ -64,21 +96,21 @@ pub async fn upgrade_to_v1(pool: &Pool) -> sqlx::Result<()> {
.string_len(255) .string_len(255)
.not_null(), .not_null(),
) )
.col(ColumnDef::new(Users::FirstName).string_len(255).not_null()) .col(ColumnDef::new(Users::FirstName).string_len(255))
.col(ColumnDef::new(Users::LastName).string_len(255).not_null()) .col(ColumnDef::new(Users::LastName).string_len(255))
.col(ColumnDef::new(Users::Avatar).binary()) .col(ColumnDef::new(Users::Avatar).binary())
.col(ColumnDef::new(Users::CreationDate).date_time().not_null()) .col(ColumnDef::new(Users::CreationDate).date_time().not_null())
.col(ColumnDef::new(Users::PasswordHash).binary()) .col(ColumnDef::new(Users::PasswordHash).binary())
.col(ColumnDef::new(Users::TotpSecret).string_len(64)) .col(ColumnDef::new(Users::TotpSecret).string_len(64))
.col(ColumnDef::new(Users::MfaType).string_len(64)) .col(ColumnDef::new(Users::MfaType).string_len(64))
.col(ColumnDef::new(Users::Uuid).string_len(36).not_null()) .col(ColumnDef::new(Users::Uuid).string_len(36).not_null()),
.to_string(DbQueryBuilder {}), ),
) )
.execute(pool)
.await?; .await?;
sqlx::query( pool.execute(
&Table::create() builder.build(
Table::create()
.table(Groups::Table) .table(Groups::Table)
.if_not_exists() .if_not_exists()
.col( .col(
@ -94,25 +126,23 @@ pub async fn upgrade_to_v1(pool: &Pool) -> sqlx::Result<()> {
.not_null(), .not_null(),
) )
.col(ColumnDef::new(Users::CreationDate).date_time().not_null()) .col(ColumnDef::new(Users::CreationDate).date_time().not_null())
.col(ColumnDef::new(Users::Uuid).string_len(36).not_null()) .col(ColumnDef::new(Users::Uuid).string_len(36).not_null()),
.to_string(DbQueryBuilder {}), ),
) )
.execute(pool)
.await?; .await?;
// If the creation_date column doesn't exist, add it. // If the creation_date column doesn't exist, add it.
if sqlx::query( if pool
&Table::alter() .execute(
.table(Groups::Table) builder.build(
.add_column( Table::alter().table(Groups::Table).add_column(
ColumnDef::new(Groups::CreationDate) ColumnDef::new(Groups::CreationDate)
.date_time() .date_time()
.not_null() .not_null()
.default(chrono::Utc::now().naive_utc()), .default(chrono::Utc::now().naive_utc()),
),
),
) )
.to_string(DbQueryBuilder {}),
)
.execute(pool)
.await .await
.is_ok() .is_ok()
{ {
@ -120,107 +150,109 @@ pub async fn upgrade_to_v1(pool: &Pool) -> sqlx::Result<()> {
} }
// If the uuid column doesn't exist, add it. // If the uuid column doesn't exist, add it.
if sqlx::query( if pool
&Table::alter() .execute(
.table(Groups::Table) builder.build(
.add_column( Table::alter().table(Groups::Table).add_column(
ColumnDef::new(Groups::Uuid) ColumnDef::new(Groups::Uuid)
.string_len(36) .string_len(36)
.not_null() .not_null()
.default(""), .default(""),
),
),
) )
.to_string(DbQueryBuilder {}),
)
.execute(pool)
.await .await
.is_ok() .is_ok()
{ {
warn!("`uuid` column not found in `groups`, creating it"); warn!("`uuid` column not found in `groups`, creating it");
for row in sqlx::query( #[derive(FromQueryResult)]
&Query::select() struct ShortGroupDetails {
group_id: GroupId,
display_name: String,
creation_date: chrono::DateTime<chrono::Utc>,
}
for result in ShortGroupDetails::find_by_statement(
builder.build(
Query::select()
.from(Groups::Table) .from(Groups::Table)
.column(Groups::GroupId) .column(Groups::GroupId)
.column(Groups::DisplayName) .column(Groups::DisplayName)
.column(Groups::CreationDate) .column(Groups::CreationDate),
.to_string(DbQueryBuilder {}), ),
) )
.fetch_all(pool) .all(pool)
.await? .await?
{ {
sqlx::query( pool.execute(
&Query::update() builder.build(
Query::update()
.table(Groups::Table) .table(Groups::Table)
.value( .value(
Groups::Uuid, Groups::Uuid,
Uuid::from_name_and_date( Value::from(Uuid::from_name_and_date(
&row.get::<String, _>(&*Groups::DisplayName.to_string()), &result.display_name,
&row.get::<chrono::DateTime<chrono::Utc>, _>( &result.creation_date,
&*Groups::CreationDate.to_string(), )),
)
.and_where(Expr::col(Groups::GroupId).eq(result.group_id)),
), ),
) )
.into(),
)
.and_where(
Expr::col(Groups::GroupId)
.eq(row.get::<GroupId, _>(&*Groups::GroupId.to_string())),
)
.to_string(DbQueryBuilder {}),
)
.execute(pool)
.await?; .await?;
} }
} }
if sqlx::query( if pool
&Table::alter() .execute(
.table(Users::Table) builder.build(
.add_column( Table::alter().table(Users::Table).add_column(
ColumnDef::new(Users::Uuid) ColumnDef::new(Users::Uuid)
.string_len(36) .string_len(36)
.not_null() .not_null()
.default(""), .default(""),
),
),
) )
.to_string(DbQueryBuilder {}),
)
.execute(pool)
.await .await
.is_ok() .is_ok()
{ {
warn!("`uuid` column not found in `users`, creating it"); warn!("`uuid` column not found in `users`, creating it");
for row in sqlx::query( #[derive(FromQueryResult)]
&Query::select() struct ShortUserDetails {
user_id: UserId,
creation_date: chrono::DateTime<chrono::Utc>,
}
for result in ShortUserDetails::find_by_statement(
builder.build(
Query::select()
.from(Users::Table) .from(Users::Table)
.column(Users::UserId) .column(Users::UserId)
.column(Users::CreationDate) .column(Users::CreationDate),
.to_string(DbQueryBuilder {}), ),
) )
.fetch_all(pool) .all(pool)
.await? .await?
{ {
let user_id = row.get::<UserId, _>(&*Users::UserId.to_string()); pool.execute(
sqlx::query( builder.build(
&Query::update() Query::update()
.table(Users::Table) .table(Users::Table)
.value( .value(
Users::Uuid, Users::Uuid,
Uuid::from_name_and_date( Value::from(Uuid::from_name_and_date(
user_id.as_str(), result.user_id.as_str(),
&row.get::<chrono::DateTime<chrono::Utc>, _>( &result.creation_date,
&*Users::CreationDate.to_string(), )),
)
.and_where(Expr::col(Users::UserId).eq(result.user_id)),
), ),
) )
.into(),
)
.and_where(Expr::col(Users::UserId).eq(user_id))
.to_string(DbQueryBuilder {}),
)
.execute(pool)
.await?; .await?;
} }
} }
sqlx::query( pool.execute(
&Table::create() builder.build(
Table::create()
.table(Memberships::Table) .table(Memberships::Table)
.if_not_exists() .if_not_exists()
.col( .col(
@ -244,59 +276,63 @@ pub async fn upgrade_to_v1(pool: &Pool) -> sqlx::Result<()> {
.to(Groups::Table, Groups::GroupId) .to(Groups::Table, Groups::GroupId)
.on_delete(ForeignKeyAction::Cascade) .on_delete(ForeignKeyAction::Cascade)
.on_update(ForeignKeyAction::Cascade), .on_update(ForeignKeyAction::Cascade),
),
),
) )
.to_string(DbQueryBuilder {}),
)
.execute(pool)
.await?; .await?;
if sqlx::query( if pool
&Query::select() .query_one(
builder.build(
Query::select()
.from(Groups::Table) .from(Groups::Table)
.column(Groups::DisplayName) .column(Groups::DisplayName)
.cond_where(Expr::col(Groups::DisplayName).eq("lldap_readonly")) .cond_where(Expr::col(Groups::DisplayName).eq("lldap_readonly")),
.to_string(DbQueryBuilder {}), ),
) )
.fetch_one(pool)
.await .await
.is_ok() .is_ok()
{ {
sqlx::query( pool.execute(
&Query::update() builder.build(
Query::update()
.table(Groups::Table) .table(Groups::Table)
.values(vec![(Groups::DisplayName, "lldap_password_manager".into())]) .values(vec![(Groups::DisplayName, "lldap_password_manager".into())])
.cond_where(Expr::col(Groups::DisplayName).eq("lldap_readonly")) .cond_where(Expr::col(Groups::DisplayName).eq("lldap_readonly")),
.to_string(DbQueryBuilder {}), ),
) )
.execute(pool)
.await?; .await?;
create_group("lldap_strict_readonly", pool).await?
} }
sqlx::query( pool.execute(
&Table::create() builder.build(
Table::create()
.table(Metadata::Table) .table(Metadata::Table)
.if_not_exists() .if_not_exists()
.col(ColumnDef::new(Metadata::Version).tiny_integer().not_null()) .col(ColumnDef::new(Metadata::Version).tiny_integer()),
.to_string(DbQueryBuilder {}), ),
) )
.execute(pool)
.await?; .await?;
sqlx::query( pool.execute(
&Query::insert() builder.build(
Query::insert()
.into_table(Metadata::Table) .into_table(Metadata::Table)
.columns(vec![Metadata::Version]) .columns(vec![Metadata::Version])
.values_panic(vec![SchemaVersion(1).into()]) .values_panic(vec![SchemaVersion(1).into()]),
.to_string(DbQueryBuilder {}), ),
) )
.execute(pool)
.await?; .await?;
assert_eq!(get_schema_version(pool).await.unwrap().0, 1);
Ok(()) Ok(())
} }
pub async fn migrate_from_version(_pool: &Pool, version: SchemaVersion) -> anyhow::Result<()> { pub async fn migrate_from_version(
_pool: &DbConnection,
version: SchemaVersion,
) -> anyhow::Result<()> {
if version.0 > 1 { if version.0 > 1 {
anyhow::bail!("DB version downgrading is not supported"); anyhow::bail!("DB version downgrading is not supported");
} }

View File

@ -1,16 +1,14 @@
use super::{ use super::{
error::{DomainError, Result}, error::{DomainError, Result},
handler::{BindRequest, LoginHandler, UserId}, handler::{BindRequest, LoginHandler, UserId},
model::{self, UserColumn},
opaque_handler::{login, registration, OpaqueHandler}, opaque_handler::{login, registration, OpaqueHandler},
sql_backend_handler::SqlBackendHandler, sql_backend_handler::SqlBackendHandler,
sql_tables::{DbQueryBuilder, Users},
}; };
use async_trait::async_trait; use async_trait::async_trait;
use lldap_auth::opaque; use lldap_auth::opaque;
use sea_query::{Expr, Iden, Query}; use sea_orm::{ActiveValue, EntityTrait, FromQueryResult, QuerySelect};
use sea_query_binder::SqlxBinder;
use secstr::SecUtf8; use secstr::SecUtf8;
use sqlx::Row;
use tracing::{debug, instrument}; use tracing::{debug, instrument};
type SqlOpaqueHandler = SqlBackendHandler; type SqlOpaqueHandler = SqlBackendHandler;
@ -50,39 +48,19 @@ impl SqlBackendHandler {
} }
#[instrument(skip_all, level = "debug", err)] #[instrument(skip_all, level = "debug", err)]
async fn get_password_file_for_user( async fn get_password_file_for_user(&self, user_id: UserId) -> Result<Option<Vec<u8>>> {
&self, #[derive(FromQueryResult)]
username: &str, struct OnlyPasswordHash {
) -> Result<Option<opaque::server::ServerRegistration>> { password_hash: Option<Vec<u8>>,
}
// Fetch the previously registered password file from the DB. // Fetch the previously registered password file from the DB.
let password_file_bytes = { Ok(model::User::find_by_id(user_id)
let (query, values) = Query::select() .select_only()
.column(Users::PasswordHash) .column(UserColumn::PasswordHash)
.from(Users::Table) .into_model::<OnlyPasswordHash>()
.cond_where(Expr::col(Users::UserId).eq(username)) .one(&self.sql_pool)
.build_sqlx(DbQueryBuilder {});
if let Some(row) = sqlx::query_with(query.as_str(), values)
.fetch_optional(&self.sql_pool)
.await? .await?
{ .and_then(|u| u.password_hash))
if let Some(bytes) =
row.get::<Option<Vec<u8>>, _>(&*Users::PasswordHash.to_string())
{
bytes
} else {
// No password set.
return Ok(None);
}
} else {
// No such user.
return Ok(None);
}
};
opaque::server::ServerRegistration::deserialize(&password_file_bytes)
.map(Option::Some)
.map_err(|_| {
DomainError::InternalError(format!("Corrupted password file for {}", username))
})
} }
} }
@ -90,17 +68,9 @@ impl SqlBackendHandler {
impl LoginHandler for SqlBackendHandler { impl LoginHandler for SqlBackendHandler {
#[instrument(skip_all, level = "debug", err)] #[instrument(skip_all, level = "debug", err)]
async fn bind(&self, request: BindRequest) -> Result<()> { async fn bind(&self, request: BindRequest) -> Result<()> {
let (query, values) = Query::select() if let Some(password_hash) = self
.column(Users::PasswordHash) .get_password_file_for_user(request.name.clone())
.from(Users::Table) .await?
.cond_where(Expr::col(Users::UserId).eq(&request.name))
.build_sqlx(DbQueryBuilder {});
if let Ok(row) = sqlx::query_with(&query, values)
.fetch_one(&self.sql_pool)
.await
{
if let Some(password_hash) =
row.get::<Option<Vec<u8>>, _>(&*Users::PasswordHash.to_string())
{ {
if let Err(e) = passwords_match( if let Err(e) = passwords_match(
&password_hash, &password_hash,
@ -113,10 +83,10 @@ impl LoginHandler for SqlBackendHandler {
return Ok(()); return Ok(());
} }
} else { } else {
debug!(r#"User "{}" has no password"#, &request.name); debug!(
} r#"User "{}" doesn't exist or has no password"#,
} else { &request.name
debug!(r#"No user found for "{}""#, &request.name); );
} }
Err(DomainError::AuthenticationError(format!( Err(DomainError::AuthenticationError(format!(
" for user '{}'", " for user '{}'",
@ -132,7 +102,18 @@ impl OpaqueHandler for SqlOpaqueHandler {
&self, &self,
request: login::ClientLoginStartRequest, request: login::ClientLoginStartRequest,
) -> Result<login::ServerLoginStartResponse> { ) -> Result<login::ServerLoginStartResponse> {
let maybe_password_file = self.get_password_file_for_user(&request.username).await?; let maybe_password_file = self
.get_password_file_for_user(UserId::new(&request.username))
.await?
.map(|bytes| {
opaque::server::ServerRegistration::deserialize(&bytes).map_err(|_| {
DomainError::InternalError(format!(
"Corrupted password file for {}",
&request.username
))
})
})
.transpose()?;
let mut rng = rand::rngs::OsRng; let mut rng = rand::rngs::OsRng;
// Get the CredentialResponse for the user, or a dummy one if no user/no password. // Get the CredentialResponse for the user, or a dummy one if no user/no password.
@ -210,17 +191,16 @@ impl OpaqueHandler for SqlOpaqueHandler {
let password_file = let password_file =
opaque::server::registration::get_password_file(request.registration_upload); opaque::server::registration::get_password_file(request.registration_upload);
{
// Set the user password to the new password. // Set the user password to the new password.
let (update_query, values) = Query::update() let user_update = model::users::ActiveModel {
.table(Users::Table) user_id: ActiveValue::Set(UserId::new(&username)),
.value(Users::PasswordHash, password_file.serialize().into()) password_hash: ActiveValue::Set(Some(password_file.serialize())),
.cond_where(Expr::col(Users::UserId).eq(username)) ..Default::default()
.build_sqlx(DbQueryBuilder {}); };
sqlx::query_with(update_query.as_str(), values) model::User::update_many()
.execute(&self.sql_pool) .set(user_update)
.exec(&self.sql_pool)
.await?; .await?;
}
Ok(()) Ok(())
} }
} }

View File

@ -1,39 +1,116 @@
use super::{ use super::{
handler::{GroupId, UserId, Uuid}, handler::{GroupId, JpegPhoto, UserId, Uuid},
sql_migrations::{get_schema_version, migrate_from_version, upgrade_to_v1}, sql_migrations::{get_schema_version, migrate_from_version, upgrade_to_v1},
}; };
use sea_query::*; use sea_orm::{DbErr, Value};
use serde::{Deserialize, Serialize};
pub use super::sql_migrations::create_group; pub type DbConnection = sea_orm::DatabaseConnection;
pub type Pool = sqlx::sqlite::SqlitePool; #[derive(Copy, PartialEq, Eq, Debug, Clone)]
pub type PoolOptions = sqlx::sqlite::SqlitePoolOptions;
pub type DbRow = sqlx::sqlite::SqliteRow;
pub type DbQueryBuilder = SqliteQueryBuilder;
#[derive(Copy, PartialEq, Eq, Debug, Clone, sqlx::FromRow, sqlx::Type)]
#[sqlx(transparent)]
pub struct SchemaVersion(pub u8); pub struct SchemaVersion(pub u8);
impl From<GroupId> for Value { impl sea_orm::TryGetable for SchemaVersion {
fn try_get(
res: &sea_orm::QueryResult,
pre: &str,
col: &str,
) -> Result<Self, sea_orm::TryGetError> {
Ok(SchemaVersion(u8::try_get(res, pre, col)?))
}
}
impl From<GroupId> for sea_orm::Value {
fn from(group_id: GroupId) -> Self { fn from(group_id: GroupId) -> Self {
group_id.0.into() group_id.0.into()
} }
} }
impl From<UserId> for sea_query::Value { impl sea_orm::TryGetable for GroupId {
fn try_get(
res: &sea_orm::QueryResult,
pre: &str,
col: &str,
) -> Result<Self, sea_orm::TryGetError> {
Ok(GroupId(i32::try_get(res, pre, col)?))
}
}
impl sea_orm::sea_query::value::ValueType for GroupId {
fn try_from(v: sea_orm::Value) -> Result<Self, sea_orm::sea_query::ValueTypeErr> {
Ok(GroupId(<i32 as sea_orm::sea_query::ValueType>::try_from(
v,
)?))
}
fn type_name() -> String {
"GroupId".to_owned()
}
fn array_type() -> sea_orm::sea_query::ArrayType {
sea_orm::sea_query::ArrayType::Int
}
fn column_type() -> sea_orm::sea_query::ColumnType {
sea_orm::sea_query::ColumnType::Integer(None)
}
}
impl sea_orm::TryFromU64 for GroupId {
fn try_from_u64(n: u64) -> Result<Self, sea_orm::DbErr> {
Ok(GroupId(i32::try_from_u64(n)?))
}
}
impl From<UserId> for sea_orm::Value {
fn from(user_id: UserId) -> Self { fn from(user_id: UserId) -> Self {
user_id.into_string().into() user_id.into_string().into()
} }
} }
impl From<&UserId> for sea_query::Value { impl From<&UserId> for sea_orm::Value {
fn from(user_id: &UserId) -> Self { fn from(user_id: &UserId) -> Self {
user_id.as_str().into() user_id.as_str().into()
} }
} }
impl sea_orm::TryGetable for UserId {
fn try_get(
res: &sea_orm::QueryResult,
pre: &str,
col: &str,
) -> Result<Self, sea_orm::TryGetError> {
Ok(UserId::new(&String::try_get(res, pre, col)?))
}
}
impl sea_orm::TryFromU64 for UserId {
fn try_from_u64(_n: u64) -> Result<Self, sea_orm::DbErr> {
Err(sea_orm::DbErr::ConvertFromU64(
"UserId cannot be constructed from u64",
))
}
}
impl sea_orm::sea_query::value::ValueType for UserId {
fn try_from(v: sea_orm::Value) -> Result<Self, sea_orm::sea_query::ValueTypeErr> {
Ok(UserId::new(
<String as sea_orm::sea_query::ValueType>::try_from(v)?.as_str(),
))
}
fn type_name() -> String {
"UserId".to_owned()
}
fn array_type() -> sea_orm::sea_query::ArrayType {
sea_orm::sea_query::ArrayType::String
}
fn column_type() -> sea_orm::sea_query::ColumnType {
sea_orm::sea_query::ColumnType::String(Some(255))
}
}
impl From<Uuid> for sea_query::Value { impl From<Uuid> for sea_query::Value {
fn from(uuid: Uuid) -> Self { fn from(uuid: Uuid) -> Self {
uuid.as_str().into() uuid.as_str().into()
@ -46,57 +123,84 @@ impl From<&Uuid> for sea_query::Value {
} }
} }
impl sea_orm::TryGetable for JpegPhoto {
fn try_get(
res: &sea_orm::QueryResult,
pre: &str,
col: &str,
) -> Result<Self, sea_orm::TryGetError> {
<JpegPhoto as std::convert::TryFrom<Vec<_>>>::try_from(Vec::<u8>::try_get(res, pre, col)?)
.map_err(|e| {
sea_orm::TryGetError::DbErr(DbErr::TryIntoErr {
from: "[u8]",
into: "JpegPhoto",
source: e.into(),
})
})
}
}
impl sea_orm::sea_query::value::ValueType for JpegPhoto {
fn try_from(v: sea_orm::Value) -> Result<Self, sea_orm::sea_query::ValueTypeErr> {
<JpegPhoto as std::convert::TryFrom<_>>::try_from(
<Vec<u8> as sea_orm::sea_query::ValueType>::try_from(v)?.as_slice(),
)
.map_err(|_| sea_orm::sea_query::ValueTypeErr {})
}
fn type_name() -> String {
"JpegPhoto".to_owned()
}
fn array_type() -> sea_orm::sea_query::ArrayType {
sea_orm::sea_query::ArrayType::Bytes
}
fn column_type() -> sea_orm::sea_query::ColumnType {
sea_orm::sea_query::ColumnType::Binary(sea_orm::sea_query::BlobSize::Long)
}
}
impl sea_orm::sea_query::Nullable for JpegPhoto {
fn null() -> sea_orm::Value {
JpegPhoto::null().into()
}
}
impl sea_orm::entity::IntoActiveValue<JpegPhoto> for JpegPhoto {
fn into_active_value(self) -> sea_orm::ActiveValue<JpegPhoto> {
sea_orm::ActiveValue::Set(self)
}
}
impl sea_orm::sea_query::value::ValueType for Uuid {
fn try_from(v: sea_orm::Value) -> Result<Self, sea_orm::sea_query::ValueTypeErr> {
<super::handler::Uuid as std::convert::TryFrom<_>>::try_from(
<std::string::String as sea_orm::sea_query::ValueType>::try_from(v)?.as_str(),
)
.map_err(|_| sea_orm::sea_query::ValueTypeErr {})
}
fn type_name() -> String {
"Uuid".to_owned()
}
fn array_type() -> sea_orm::sea_query::ArrayType {
sea_orm::sea_query::ArrayType::String
}
fn column_type() -> sea_orm::sea_query::ColumnType {
sea_orm::sea_query::ColumnType::String(Some(36))
}
}
impl From<SchemaVersion> for Value { impl From<SchemaVersion> for Value {
fn from(version: SchemaVersion) -> Self { fn from(version: SchemaVersion) -> Self {
version.0.into() version.0.into()
} }
} }
#[derive(Iden, PartialEq, Eq, Debug, Serialize, Deserialize, Clone)] pub async fn init_table(pool: &DbConnection) -> anyhow::Result<()> {
pub enum Users {
Table,
UserId,
Email,
DisplayName,
FirstName,
LastName,
Avatar,
CreationDate,
PasswordHash,
TotpSecret,
MfaType,
Uuid,
}
pub type UserColumn = Users;
#[derive(Iden, PartialEq, Eq, Debug, Serialize, Deserialize, Clone)]
pub enum Groups {
Table,
GroupId,
DisplayName,
CreationDate,
Uuid,
}
pub type GroupColumn = Groups;
#[derive(Iden)]
pub enum Memberships {
Table,
UserId,
GroupId,
}
// Metadata about the SQL DB.
#[derive(Iden)]
pub enum Metadata {
Table,
// Which version of the schema we're at.
Version,
}
pub async fn init_table(pool: &Pool) -> anyhow::Result<()> {
let version = { let version = {
if let Some(version) = get_schema_version(pool).await { if let Some(version) = get_schema_version(pool).await {
version version
@ -111,33 +215,55 @@ pub async fn init_table(pool: &Pool) -> anyhow::Result<()> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::domain::sql_migrations;
use super::*; use super::*;
use chrono::prelude::*; use chrono::prelude::*;
use sqlx::{Column, Row}; use sea_orm::{ConnectionTrait, Database, DbBackend, FromQueryResult};
async fn get_in_memory_db() -> DbConnection {
let mut sql_opt = sea_orm::ConnectOptions::new("sqlite::memory:".to_owned());
sql_opt.max_connections(1).sqlx_logging(false);
Database::connect(sql_opt).await.unwrap()
}
fn raw_statement(sql: &str) -> sea_orm::Statement {
sea_orm::Statement::from_string(DbBackend::Sqlite, sql.to_owned())
}
#[tokio::test] #[tokio::test]
async fn test_init_table() { async fn test_init_table() {
let sql_pool = PoolOptions::new().connect("sqlite::memory:").await.unwrap(); let sql_pool = get_in_memory_db().await;
init_table(&sql_pool).await.unwrap(); init_table(&sql_pool).await.unwrap();
sqlx::query(r#"INSERT INTO users sql_pool.execute(raw_statement(
r#"INSERT INTO users
(user_id, email, display_name, first_name, last_name, creation_date, password_hash, uuid) (user_id, email, display_name, first_name, last_name, creation_date, password_hash, uuid)
VALUES ("bôb", "böb@bob.bob", "Bob Bobbersön", "Bob", "Bobberson", "1970-01-01 00:00:00", "bob00", "abc")"#).execute(&sql_pool).await.unwrap(); VALUES ("bôb", "böb@bob.bob", "Bob Bobbersön", "Bob", "Bobberson", "1970-01-01 00:00:00", "bob00", "abc")"#)).await.unwrap();
let row = #[derive(FromQueryResult, PartialEq, Eq, Debug)]
sqlx::query(r#"SELECT display_name, creation_date FROM users WHERE user_id = "bôb""#) struct ShortUserDetails {
.fetch_one(&sql_pool) display_name: String,
creation_date: chrono::DateTime<chrono::Utc>,
}
let result = ShortUserDetails::find_by_statement(raw_statement(
r#"SELECT display_name, creation_date FROM users WHERE user_id = "bôb""#,
))
.one(&sql_pool)
.await .await
.unwrap()
.unwrap(); .unwrap();
assert_eq!(row.column(0).name(), "display_name");
assert_eq!(row.get::<String, _>("display_name"), "Bob Bobbersön");
assert_eq!( assert_eq!(
row.get::<DateTime<Utc>, _>("creation_date"), result,
Utc.timestamp(0, 0), ShortUserDetails {
display_name: "Bob Bobbersön".to_owned(),
creation_date: Utc.timestamp_opt(0, 0).unwrap()
}
); );
} }
#[tokio::test] #[tokio::test]
async fn test_already_init_table() { async fn test_already_init_table() {
let sql_pool = PoolOptions::new().connect("sqlite::memory:").await.unwrap(); crate::infra::logging::init_for_tests();
let sql_pool = get_in_memory_db().await;
init_table(&sql_pool).await.unwrap(); init_table(&sql_pool).await.unwrap();
init_table(&sql_pool).await.unwrap(); init_table(&sql_pool).await.unwrap();
} }
@ -145,87 +271,109 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_migrate_tables() { async fn test_migrate_tables() {
// Test that we add the column creation_date to groups and uuid to users and groups. // Test that we add the column creation_date to groups and uuid to users and groups.
let sql_pool = PoolOptions::new().connect("sqlite::memory:").await.unwrap(); let sql_pool = get_in_memory_db().await;
sqlx::query(r#"CREATE TABLE users ( user_id TEXT , creation_date TEXT);"#) sql_pool
.execute(&sql_pool) .execute(raw_statement(
r#"CREATE TABLE users ( user_id TEXT , creation_date TEXT);"#,
))
.await .await
.unwrap(); .unwrap();
sqlx::query( sql_pool
.execute(raw_statement(
r#"INSERT INTO users (user_id, creation_date) r#"INSERT INTO users (user_id, creation_date)
VALUES ("bôb", "1970-01-01 00:00:00")"#, VALUES ("bôb", "1970-01-01 00:00:00")"#,
) ))
.execute(&sql_pool)
.await .await
.unwrap(); .unwrap();
sqlx::query(r#"CREATE TABLE groups ( group_id INTEGER PRIMARY KEY, display_name TEXT );"#) sql_pool
.execute(&sql_pool) .execute(raw_statement(
r#"CREATE TABLE groups ( group_id INTEGER PRIMARY KEY, display_name TEXT );"#,
))
.await .await
.unwrap(); .unwrap();
sqlx::query( sql_pool
.execute(raw_statement(
r#"INSERT INTO groups (display_name) r#"INSERT INTO groups (display_name)
VALUES ("lldap_admin"), ("lldap_readonly")"#, VALUES ("lldap_admin"), ("lldap_readonly")"#,
) ))
.execute(&sql_pool)
.await .await
.unwrap(); .unwrap();
init_table(&sql_pool).await.unwrap(); init_table(&sql_pool).await.unwrap();
sqlx::query( sql_pool
.execute(raw_statement(
r#"INSERT INTO groups (display_name, creation_date, uuid) r#"INSERT INTO groups (display_name, creation_date, uuid)
VALUES ("test", "1970-01-01 00:00:00", "abc")"#, VALUES ("test", "1970-01-01 00:00:00", "abc")"#,
) ))
.execute(&sql_pool)
.await .await
.unwrap(); .unwrap();
#[derive(FromQueryResult, PartialEq, Eq, Debug)]
struct JustUuid {
uuid: Uuid,
}
assert_eq!( assert_eq!(
sqlx::query(r#"SELECT uuid FROM users"#) JustUuid::find_by_statement(raw_statement(r#"SELECT uuid FROM users"#))
.fetch_all(&sql_pool) .all(&sql_pool)
.await .await
.unwrap() .unwrap(),
.into_iter() vec![JustUuid {
.map(|row| row.get::<Uuid, _>("uuid")) uuid: crate::uuid!("a02eaf13-48a7-30f6-a3d4-040ff7c52b04")
.collect::<Vec<_>>(), }]
vec![crate::uuid!("a02eaf13-48a7-30f6-a3d4-040ff7c52b04")]
); );
#[derive(FromQueryResult, PartialEq, Eq, Debug)]
struct ShortGroupDetails {
group_id: GroupId,
display_name: String,
}
assert_eq!( assert_eq!(
sqlx::query(r#"SELECT group_id, display_name FROM groups"#) ShortGroupDetails::find_by_statement(raw_statement(
.fetch_all(&sql_pool) r#"SELECT group_id, display_name, creation_date FROM groups"#
.await
.unwrap()
.into_iter()
.map(|row| (
row.get::<GroupId, _>("group_id"),
row.get::<String, _>("display_name")
)) ))
.collect::<Vec<_>>(), .all(&sql_pool)
.await
.unwrap(),
vec![ vec![
(GroupId(1), "lldap_admin".to_string()), ShortGroupDetails {
(GroupId(2), "lldap_password_manager".to_string()), group_id: GroupId(1),
(GroupId(3), "lldap_strict_readonly".to_string()), display_name: "lldap_admin".to_string()
(GroupId(4), "test".to_string()) },
ShortGroupDetails {
group_id: GroupId(2),
display_name: "lldap_password_manager".to_string()
},
ShortGroupDetails {
group_id: GroupId(3),
display_name: "test".to_string()
}
] ]
); );
assert_eq!( assert_eq!(
sqlx::query(r#"SELECT version FROM metadata"#) sql_migrations::JustSchemaVersion::find_by_statement(raw_statement(
.map(|row: DbRow| row.get::<SchemaVersion, _>("version")) r#"SELECT version FROM metadata"#
.fetch_one(&sql_pool) ))
.one(&sql_pool)
.await .await
.unwrap()
.unwrap(), .unwrap(),
SchemaVersion(1) sql_migrations::JustSchemaVersion {
version: SchemaVersion(1)
}
); );
} }
#[tokio::test] #[tokio::test]
async fn test_too_high_version() { async fn test_too_high_version() {
let sql_pool = PoolOptions::new().connect("sqlite::memory:").await.unwrap(); let sql_pool = get_in_memory_db().await;
sqlx::query(r#"CREATE TABLE metadata ( version INTEGER);"#) sql_pool
.execute(&sql_pool) .execute(raw_statement(
r#"CREATE TABLE metadata ( version INTEGER);"#,
))
.await .await
.unwrap(); .unwrap();
sqlx::query( sql_pool
.execute(raw_statement(
r#"INSERT INTO metadata (version) r#"INSERT INTO metadata (version)
VALUES (127)"#, VALUES (127)"#,
) ))
.execute(&sql_pool)
.await .await
.unwrap(); .unwrap();
assert!(init_table(&sql_pool).await.is_err()); assert!(init_table(&sql_pool).await.is_err());

View File

@ -1,136 +1,68 @@
use super::{ use super::{
error::Result, error::{DomainError, Result},
handler::{ handler::{
CreateUserRequest, GroupDetails, GroupId, UpdateUserRequest, User, UserAndGroups, CreateUserRequest, GroupDetails, GroupId, UpdateUserRequest, User, UserAndGroups,
UserBackendHandler, UserId, UserRequestFilter, Uuid, UserBackendHandler, UserId, UserRequestFilter, Uuid,
}, },
model::{self, GroupColumn, UserColumn},
sql_backend_handler::SqlBackendHandler, sql_backend_handler::SqlBackendHandler,
sql_tables::{DbQueryBuilder, Groups, Memberships, Users},
}; };
use async_trait::async_trait; use async_trait::async_trait;
use sea_query::{Alias, Cond, Expr, Iden, Order, Query, SimpleExpr}; use sea_orm::{
use sea_query_binder::{SqlxBinder, SqlxValues}; entity::IntoActiveValue,
use sqlx::{query_as_with, query_with, FromRow, Row}; sea_query::{Cond, Expr, IntoCondition, SimpleExpr},
ActiveModelTrait, ActiveValue, ColumnTrait, EntityTrait, ModelTrait, QueryFilter, QueryOrder,
QuerySelect, QueryTrait, Set,
};
use sea_query::{Alias, IntoColumnRef};
use std::collections::HashSet; use std::collections::HashSet;
use tracing::{debug, instrument}; use tracing::{debug, instrument};
struct RequiresGroup(bool); fn get_user_filter_expr(filter: UserRequestFilter) -> Cond {
// Returns the condition for the SQL query, and whether it requires joining with the groups table.
fn get_user_filter_expr(filter: UserRequestFilter) -> (RequiresGroup, Cond) {
use sea_query::IntoCondition;
use UserRequestFilter::*; use UserRequestFilter::*;
let group_table = Alias::new("r1");
fn get_repeated_filter( fn get_repeated_filter(
fs: Vec<UserRequestFilter>, fs: Vec<UserRequestFilter>,
condition: Cond, condition: Cond,
default_value: bool, default_value: bool,
) -> (RequiresGroup, Cond) { ) -> Cond {
if fs.is_empty() { if fs.is_empty() {
return ( SimpleExpr::Value(default_value.into()).into_condition()
RequiresGroup(false), } else {
SimpleExpr::Value(default_value.into()).into_condition(), fs.into_iter()
); .map(get_user_filter_expr)
.fold(condition, Cond::add)
} }
let mut requires_group = false;
let filter = fs.into_iter().fold(condition, |c, f| {
let (group, filters) = get_user_filter_expr(f);
requires_group |= group.0;
c.add(filters)
});
(RequiresGroup(requires_group), filter)
} }
match filter { match filter {
And(fs) => get_repeated_filter(fs, Cond::all(), true), And(fs) => get_repeated_filter(fs, Cond::all(), true),
Or(fs) => get_repeated_filter(fs, Cond::any(), false), Or(fs) => get_repeated_filter(fs, Cond::any(), false),
Not(f) => { Not(f) => get_user_filter_expr(*f).not(),
let (requires_group, filters) = get_user_filter_expr(*f); UserId(user_id) => ColumnTrait::eq(&UserColumn::UserId, user_id).into_condition(),
(requires_group, filters.not()) Equality(s1, s2) => {
} if s1 == UserColumn::UserId {
UserId(user_id) => (
RequiresGroup(false),
Expr::col((Users::Table, Users::UserId))
.eq(user_id)
.into_condition(),
),
Equality(s1, s2) => (
RequiresGroup(false),
if s1 == Users::UserId {
panic!("User id should be wrapped") panic!("User id should be wrapped")
} else { } else {
Expr::col((Users::Table, s1)).eq(s2).into_condition() ColumnTrait::eq(&s1, s2).into_condition()
}, }
), }
MemberOf(group) => ( MemberOf(group) => Expr::col((group_table, GroupColumn::DisplayName))
RequiresGroup(true),
Expr::col((Groups::Table, Groups::DisplayName))
.eq(group) .eq(group)
.into_condition(), .into_condition(),
), MemberOfId(group_id) => Expr::col((group_table, GroupColumn::GroupId))
MemberOfId(group_id) => (
RequiresGroup(true),
Expr::col((Groups::Table, Groups::GroupId))
.eq(group_id) .eq(group_id)
.into_condition(), .into_condition(),
),
} }
} }
fn to_value(opt_name: &Option<String>) -> ActiveValue<Option<String>> {
fn get_list_users_query( match opt_name {
filters: Option<UserRequestFilter>, None => ActiveValue::NotSet,
get_groups: bool, Some(name) => ActiveValue::Set(if name.is_empty() {
) -> (String, SqlxValues) { None
let mut query_builder = Query::select() } else {
.column((Users::Table, Users::UserId)) Some(name.to_owned())
.column(Users::Email) }),
.column((Users::Table, Users::DisplayName))
.column(Users::FirstName)
.column(Users::LastName)
.column(Users::Avatar)
.column((Users::Table, Users::CreationDate))
.column((Users::Table, Users::Uuid))
.from(Users::Table)
.order_by((Users::Table, Users::UserId), Order::Asc)
.to_owned();
let add_join_group_tables = |builder: &mut sea_query::SelectStatement| {
builder
.left_join(
Memberships::Table,
Expr::tbl(Users::Table, Users::UserId)
.equals(Memberships::Table, Memberships::UserId),
)
.left_join(
Groups::Table,
Expr::tbl(Memberships::Table, Memberships::GroupId)
.equals(Groups::Table, Groups::GroupId),
);
};
if get_groups {
add_join_group_tables(&mut query_builder);
query_builder
.column((Groups::Table, Groups::GroupId))
.expr_as(
Expr::col((Groups::Table, Groups::DisplayName)),
Alias::new("group_display_name"),
)
.expr_as(
Expr::col((Groups::Table, Groups::CreationDate)),
sea_query::Alias::new("group_creation_date"),
)
.expr_as(
Expr::col((Groups::Table, Groups::Uuid)),
sea_query::Alias::new("group_uuid"),
)
.order_by(Alias::new("group_display_name"), Order::Asc);
} }
if let Some(filter) = filters {
let (RequiresGroup(requires_group), condition) = get_user_filter_expr(filter);
query_builder.cond_where(condition);
if requires_group && !get_groups {
add_join_group_tables(&mut query_builder);
}
}
query_builder.build_sqlx(DbQueryBuilder {})
} }
#[async_trait] #[async_trait]
@ -141,95 +73,86 @@ impl UserBackendHandler for SqlBackendHandler {
filters: Option<UserRequestFilter>, filters: Option<UserRequestFilter>,
get_groups: bool, get_groups: bool,
) -> Result<Vec<UserAndGroups>> { ) -> Result<Vec<UserAndGroups>> {
debug!(?filters, get_groups); debug!(?filters);
let (query, values) = get_list_users_query(filters, get_groups); let query = model::User::find()
.filter(
debug!(%query); filters
.map(|f| {
// For group_by. UserColumn::UserId
use itertools::Itertools; .in_subquery(
let mut users = Vec::new(); model::User::find()
// The rows are returned sorted by user_id. We group them by .find_also_linked(model::memberships::UserToGroup)
// this key which gives us one element (`rows`) per group. .select_only()
for (_, rows) in &query_with(&query, values) .column(UserColumn::UserId)
.fetch_all(&self.sql_pool) .filter(get_user_filter_expr(f))
.into_query(),
)
.into_condition()
})
.unwrap_or_else(|| SimpleExpr::Value(true.into()).into_condition()),
)
.order_by_asc(UserColumn::UserId);
if !get_groups {
Ok(query
.into_model::<User>()
.all(&self.sql_pool)
.await? .await?
.into_iter() .into_iter()
.group_by(|row| row.get::<UserId, _>(&*Users::UserId.to_string())) .map(|u| UserAndGroups {
{ user: u,
let mut rows = rows.peekable(); groups: None,
users.push(UserAndGroups {
user: User::from_row(rows.peek().unwrap()).unwrap(),
groups: if get_groups {
Some(
rows.filter_map(|row| {
let display_name = row.get::<String, _>("group_display_name");
if display_name.is_empty() {
None
} else {
Some(GroupDetails {
group_id: row.get::<GroupId, _>(&*Groups::GroupId.to_string()),
display_name,
creation_date: row.get::<chrono::DateTime<chrono::Utc>, _>(
"group_creation_date",
),
uuid: row.get::<Uuid, _>("group_uuid"),
}) })
.collect())
} else {
let results = query
//find_with_linked?
.find_also_linked(model::memberships::UserToGroup)
.order_by_asc(SimpleExpr::Column(
(Alias::new("r1"), GroupColumn::GroupId).into_column_ref(),
))
.all(&self.sql_pool)
.await?;
use itertools::Itertools;
Ok(results
.iter()
.group_by(|(u, _)| u)
.into_iter()
.map(|(user, groups)| {
let groups: Vec<_> = groups
.into_iter()
.flat_map(|(_, g)| g)
.map(|g| GroupDetails::from(g.clone()))
.collect();
UserAndGroups {
user: user.clone().into(),
groups: Some(groups),
} }
}) })
.collect(), .collect())
)
} else {
None
},
});
} }
Ok(users)
} }
#[instrument(skip_all, level = "debug", ret)] #[instrument(skip_all, level = "debug", ret)]
async fn get_user_details(&self, user_id: &UserId) -> Result<User> { async fn get_user_details(&self, user_id: &UserId) -> Result<User> {
debug!(?user_id); debug!(?user_id);
let (query, values) = Query::select() model::User::find_by_id(user_id.to_owned())
.column(Users::UserId) .into_model::<User>()
.column(Users::Email) .one(&self.sql_pool)
.column(Users::DisplayName) .await?
.column(Users::FirstName) .ok_or_else(|| DomainError::EntityNotFound(user_id.to_string()))
.column(Users::LastName)
.column(Users::Avatar)
.column(Users::CreationDate)
.column(Users::Uuid)
.from(Users::Table)
.cond_where(Expr::col(Users::UserId).eq(user_id))
.build_sqlx(DbQueryBuilder {});
debug!(%query);
Ok(query_as_with::<_, User, _>(query.as_str(), values)
.fetch_one(&self.sql_pool)
.await?)
} }
#[instrument(skip_all, level = "debug", ret, err)] #[instrument(skip_all, level = "debug", ret, err)]
async fn get_user_groups(&self, user_id: &UserId) -> Result<HashSet<GroupDetails>> { async fn get_user_groups(&self, user_id: &UserId) -> Result<HashSet<GroupDetails>> {
debug!(?user_id); debug!(?user_id);
let (query, values) = Query::select() let user = model::User::find_by_id(user_id.to_owned())
.column((Groups::Table, Groups::GroupId)) .one(&self.sql_pool)
.column(Groups::DisplayName) .await?
.column(Groups::CreationDate) .ok_or_else(|| DomainError::EntityNotFound(user_id.to_string()))?;
.column(Groups::Uuid)
.from(Groups::Table)
.inner_join(
Memberships::Table,
Expr::tbl(Groups::Table, Groups::GroupId)
.equals(Memberships::Table, Memberships::GroupId),
)
.cond_where(Expr::col(Memberships::UserId).eq(user_id))
.build_sqlx(DbQueryBuilder {});
debug!(%query);
Ok(HashSet::from_iter( Ok(HashSet::from_iter(
query_as_with::<_, GroupDetails, _>(&query, values) user.find_linked(model::memberships::UserToGroup)
.fetch_all(&self.sql_pool) .into_model::<GroupDetails>()
.all(&self.sql_pool)
.await?, .await?,
)) ))
} }
@ -237,70 +160,41 @@ impl UserBackendHandler for SqlBackendHandler {
#[instrument(skip_all, level = "debug", err)] #[instrument(skip_all, level = "debug", err)]
async fn create_user(&self, request: CreateUserRequest) -> Result<()> { async fn create_user(&self, request: CreateUserRequest) -> Result<()> {
debug!(user_id = ?request.user_id); debug!(user_id = ?request.user_id);
let columns = vec![
Users::UserId,
Users::Email,
Users::DisplayName,
Users::FirstName,
Users::LastName,
Users::Avatar,
Users::CreationDate,
Users::Uuid,
];
let now = chrono::Utc::now(); let now = chrono::Utc::now();
let uuid = Uuid::from_name_and_date(request.user_id.as_str(), &now); let uuid = Uuid::from_name_and_date(request.user_id.as_str(), &now);
let values = vec![ let new_user = model::users::ActiveModel {
request.user_id.into(), user_id: Set(request.user_id),
request.email.into(), email: Set(request.email),
request.display_name.unwrap_or_default().into(), display_name: to_value(&request.display_name),
request.first_name.unwrap_or_default().into(), first_name: to_value(&request.first_name),
request.last_name.unwrap_or_default().into(), last_name: to_value(&request.last_name),
request.avatar.unwrap_or_default().into(), avatar: request.avatar.into_active_value(),
now.naive_utc().into(), creation_date: ActiveValue::Set(now),
uuid.into(), uuid: ActiveValue::Set(uuid),
]; ..Default::default()
let (query, values) = Query::insert() };
.into_table(Users::Table) new_user.insert(&self.sql_pool).await?;
.columns(columns)
.values_panic(values)
.build_sqlx(DbQueryBuilder {});
debug!(%query);
query_with(query.as_str(), values)
.execute(&self.sql_pool)
.await?;
Ok(()) Ok(())
} }
#[instrument(skip_all, level = "debug", err)] #[instrument(skip_all, level = "debug", err)]
async fn update_user(&self, request: UpdateUserRequest) -> Result<()> { async fn update_user(&self, request: UpdateUserRequest) -> Result<()> {
debug!(user_id = ?request.user_id); debug!(user_id = ?request.user_id);
let mut values = Vec::new(); let update_user = model::users::ActiveModel {
if let Some(email) = request.email { email: request.email.map(ActiveValue::Set).unwrap_or_default(),
values.push((Users::Email, email.into())); display_name: to_value(&request.display_name),
} first_name: to_value(&request.first_name),
if let Some(display_name) = request.display_name { last_name: to_value(&request.last_name),
values.push((Users::DisplayName, display_name.into())); avatar: request.avatar.into_active_value(),
} ..Default::default()
if let Some(first_name) = request.first_name { };
values.push((Users::FirstName, first_name.into())); model::User::update_many()
} .set(update_user)
if let Some(last_name) = request.last_name { .filter(sea_orm::ColumnTrait::eq(
values.push((Users::LastName, last_name.into())); &UserColumn::UserId,
} request.user_id,
if let Some(avatar) = request.avatar { ))
values.push((Users::Avatar, avatar.into())); .exec(&self.sql_pool)
}
if values.is_empty() {
return Ok(());
}
let (query, values) = Query::update()
.table(Users::Table)
.values(values)
.cond_where(Expr::col(Users::UserId).eq(request.user_id))
.build_sqlx(DbQueryBuilder {});
debug!(%query);
query_with(query.as_str(), values)
.execute(&self.sql_pool)
.await?; .await?;
Ok(()) Ok(())
} }
@ -308,47 +202,41 @@ impl UserBackendHandler for SqlBackendHandler {
#[instrument(skip_all, level = "debug", err)] #[instrument(skip_all, level = "debug", err)]
async fn delete_user(&self, user_id: &UserId) -> Result<()> { async fn delete_user(&self, user_id: &UserId) -> Result<()> {
debug!(?user_id); debug!(?user_id);
let (query, values) = Query::delete() let res = model::User::delete_by_id(user_id.clone())
.from_table(Users::Table) .exec(&self.sql_pool)
.cond_where(Expr::col(Users::UserId).eq(user_id))
.build_sqlx(DbQueryBuilder {});
debug!(%query);
query_with(query.as_str(), values)
.execute(&self.sql_pool)
.await?; .await?;
if res.rows_affected == 0 {
return Err(DomainError::EntityNotFound(format!(
"No such user: '{}'",
user_id
)));
}
Ok(()) Ok(())
} }
#[instrument(skip_all, level = "debug", err)] #[instrument(skip_all, level = "debug", err)]
async fn add_user_to_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()> { async fn add_user_to_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()> {
debug!(?user_id, ?group_id); debug!(?user_id, ?group_id);
let (query, values) = Query::insert() let new_membership = model::memberships::ActiveModel {
.into_table(Memberships::Table) user_id: ActiveValue::Set(user_id.clone()),
.columns(vec![Memberships::UserId, Memberships::GroupId]) group_id: ActiveValue::Set(group_id),
.values_panic(vec![user_id.into(), group_id.into()]) };
.build_sqlx(DbQueryBuilder {}); new_membership.insert(&self.sql_pool).await?;
debug!(%query);
query_with(query.as_str(), values)
.execute(&self.sql_pool)
.await?;
Ok(()) Ok(())
} }
#[instrument(skip_all, level = "debug", err)] #[instrument(skip_all, level = "debug", err)]
async fn remove_user_from_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()> { async fn remove_user_from_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()> {
debug!(?user_id, ?group_id); debug!(?user_id, ?group_id);
let (query, values) = Query::delete() let res = model::Membership::delete_by_id((user_id.clone(), group_id))
.from_table(Memberships::Table) .exec(&self.sql_pool)
.cond_where(
Cond::all()
.add(Expr::col(Memberships::GroupId).eq(group_id))
.add(Expr::col(Memberships::UserId).eq(user_id)),
)
.build_sqlx(DbQueryBuilder {});
debug!(%query);
query_with(query.as_str(), values)
.execute(&self.sql_pool)
.await?; .await?;
if res.rows_affected == 0 {
return Err(DomainError::EntityNotFound(format!(
"No such membership: '{}' -> {:?}",
user_id, group_id
)));
}
Ok(()) Ok(())
} }
} }
@ -357,7 +245,8 @@ impl UserBackendHandler for SqlBackendHandler {
mod tests { mod tests {
use super::*; use super::*;
use crate::domain::{ use crate::domain::{
handler::JpegPhoto, sql_backend_handler::tests::*, sql_tables::UserColumn, handler::{JpegPhoto, UserColumn},
sql_backend_handler::tests::*,
}; };
#[tokio::test] #[tokio::test]
@ -526,9 +415,13 @@ mod tests {
.map(|u| { .map(|u| {
( (
u.user.user_id.to_string(), u.user.user_id.to_string(),
u.user.display_name.to_string(), u.user
.display_name
.as_deref()
.unwrap_or("<unknown>")
.to_owned(),
u.groups u.groups
.unwrap() .unwrap_or_default()
.into_iter() .into_iter()
.map(|g| g.group_id) .map(|g| g.group_id)
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
@ -571,7 +464,7 @@ mod tests {
( (
u.user.creation_date, u.user.creation_date,
u.groups u.groups
.unwrap() .unwrap_or_default()
.into_iter() .into_iter()
.map(|g| g.creation_date) .map(|g| g.creation_date)
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
@ -685,7 +578,7 @@ mod tests {
display_name: Some("display_name".to_string()), display_name: Some("display_name".to_string()),
first_name: Some("first_name".to_string()), first_name: Some("first_name".to_string()),
last_name: Some("last_name".to_string()), last_name: Some("last_name".to_string()),
avatar: Some(JpegPhoto::default()), avatar: Some(JpegPhoto::for_tests()),
}) })
.await .await
.unwrap(); .unwrap();
@ -696,10 +589,10 @@ mod tests {
.await .await
.unwrap(); .unwrap();
assert_eq!(user.email, "email"); assert_eq!(user.email, "email");
assert_eq!(user.display_name, "display_name"); assert_eq!(user.display_name.unwrap(), "display_name");
assert_eq!(user.first_name, "first_name"); assert_eq!(user.first_name.unwrap(), "first_name");
assert_eq!(user.last_name, "last_name"); assert_eq!(user.last_name.unwrap(), "last_name");
assert_eq!(user.avatar, JpegPhoto::default()); assert_eq!(user.avatar, Some(JpegPhoto::for_tests()));
} }
#[tokio::test] #[tokio::test]
@ -722,9 +615,10 @@ mod tests {
.get_user_details(&UserId::new("bob")) .get_user_details(&UserId::new("bob"))
.await .await
.unwrap(); .unwrap();
assert_eq!(user.display_name, "display bob"); assert_eq!(user.display_name.unwrap(), "display bob");
assert_eq!(user.first_name, "first_name"); assert_eq!(user.first_name.unwrap(), "first_name");
assert_eq!(user.last_name, ""); assert_eq!(user.last_name, None);
assert_eq!(user.avatar, None);
} }
#[tokio::test] #[tokio::test]

View File

@ -26,7 +26,7 @@ use crate::domain::handler::UserRequestFilter;
use crate::{ use crate::{
domain::{ domain::{
error::DomainError, error::DomainError,
handler::{BackendHandler, BindRequest, GroupDetails, LoginHandler, UserId}, handler::{BackendHandler, BindRequest, GroupDetails, LoginHandler, UserColumn, UserId},
opaque_handler::OpaqueHandler, opaque_handler::OpaqueHandler,
}, },
infra::{ infra::{
@ -149,10 +149,7 @@ where
.list_users( .list_users(
Some(UserRequestFilter::Or(vec![ Some(UserRequestFilter::Or(vec![
UserRequestFilter::UserId(UserId::new(user_string)), UserRequestFilter::UserId(UserId::new(user_string)),
UserRequestFilter::Equality( UserRequestFilter::Equality(UserColumn::Email, user_string.to_owned()),
crate::domain::sql_tables::UserColumn::Email,
user_string.to_owned(),
),
])), ])),
false, false,
) )
@ -174,7 +171,9 @@ where
Some(token) => token, Some(token) => token,
}; };
if let Err(e) = super::mail::send_password_reset_email( if let Err(e) = super::mail::send_password_reset_email(
&user.display_name, user.display_name
.as_deref()
.unwrap_or_else(|| user.user_id.as_str()),
&user.email, &user.email,
&token, &token,
&data.server_url, &data.server_url,

View File

@ -1,18 +1,17 @@
use crate::{ use crate::domain::{
domain::sql_tables::{DbQueryBuilder, Pool}, model::{self, JwtRefreshStorageColumn, JwtStorageColumn, PasswordResetTokensColumn},
infra::jwt_sql_tables::{JwtRefreshStorage, JwtStorage}, sql_tables::DbConnection,
}; };
use actix::prelude::*; use actix::prelude::{Actor, AsyncContext, Context};
use chrono::Local;
use cron::Schedule; use cron::Schedule;
use sea_query::{Expr, Query}; use sea_orm::{ColumnTrait, EntityTrait, QueryFilter};
use std::{str::FromStr, time::Duration}; use std::{str::FromStr, time::Duration};
use tracing::{debug, error, info, instrument}; use tracing::{error, info, instrument};
// Define actor // Define actor
pub struct Scheduler { pub struct Scheduler {
schedule: Schedule, schedule: Schedule,
sql_pool: Pool, sql_pool: DbConnection,
} }
// Provide Actor implementation for our actor // Provide Actor implementation for our actor
@ -33,7 +32,7 @@ impl Actor for Scheduler {
} }
impl Scheduler { impl Scheduler {
pub fn new(cron_expression: &str, sql_pool: Pool) -> Self { pub fn new(cron_expression: &str, sql_pool: DbConnection) -> Self {
let schedule = Schedule::from_str(cron_expression).unwrap(); let schedule = Schedule::from_str(cron_expression).unwrap();
Self { schedule, sql_pool } Self { schedule, sql_pool }
} }
@ -48,33 +47,35 @@ impl Scheduler {
} }
#[instrument(skip_all)] #[instrument(skip_all)]
async fn cleanup_db(sql_pool: Pool) { async fn cleanup_db(sql_pool: DbConnection) {
info!("Cleaning DB"); info!("Cleaning DB");
let query = Query::delete() if let Err(e) = model::JwtRefreshStorage::delete_many()
.from_table(JwtRefreshStorage::Table) .filter(JwtRefreshStorageColumn::ExpiryDate.lt(chrono::Utc::now().naive_utc()))
.and_where(Expr::col(JwtRefreshStorage::ExpiryDate).lt(Local::now().naive_utc())) .exec(&sql_pool)
.to_string(DbQueryBuilder {}); .await
debug!(%query); {
if let Err(e) = sqlx::query(&query).execute(&sql_pool).await {
error!("DB error while cleaning up JWT refresh tokens: {}", e); error!("DB error while cleaning up JWT refresh tokens: {}", e);
}; }
if let Err(e) = sqlx::query( if let Err(e) = model::JwtStorage::delete_many()
&Query::delete() .filter(JwtStorageColumn::ExpiryDate.lt(chrono::Utc::now().naive_utc()))
.from_table(JwtStorage::Table) .exec(&sql_pool)
.and_where(Expr::col(JwtStorage::ExpiryDate).lt(Local::now().naive_utc()))
.to_string(DbQueryBuilder {}),
)
.execute(&sql_pool)
.await .await
{ {
error!("DB error while cleaning up JWT storage: {}", e); error!("DB error while cleaning up JWT storage: {}", e);
}; };
if let Err(e) = model::PasswordResetTokens::delete_many()
.filter(PasswordResetTokensColumn::ExpiryDate.lt(chrono::Utc::now().naive_utc()))
.exec(&sql_pool)
.await
{
error!("DB error while cleaning up password reset tokens: {}", e);
};
info!("DB cleaned!"); info!("DB cleaned!");
} }
fn duration_until_next(&self) -> Duration { fn duration_until_next(&self) -> Duration {
let now = Local::now(); let now = chrono::Utc::now();
let next = self.schedule.upcoming(Local).next().unwrap(); let next = self.schedule.upcoming(chrono::Utc).next().unwrap();
let duration_until = next.signed_duration_since(now); let duration_until = next.signed_duration_since(now);
duration_until.to_std().unwrap() duration_until.to_std().unwrap()
} }

View File

@ -1,7 +1,6 @@
use crate::domain::{ use crate::domain::{
handler::{BackendHandler, GroupDetails, GroupId, UserId}, handler::{BackendHandler, GroupDetails, GroupId, UserColumn, UserId},
ldap::utils::map_user_field, ldap::utils::map_user_field,
sql_tables::UserColumn,
}; };
use juniper::{graphql_object, FieldResult, GraphQLInputObject}; use juniper::{graphql_object, FieldResult, GraphQLInputObject};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -214,19 +213,19 @@ impl<Handler: BackendHandler + Sync> User<Handler> {
} }
fn display_name(&self) -> &str { fn display_name(&self) -> &str {
&self.user.display_name self.user.display_name.as_deref().unwrap_or("")
} }
fn first_name(&self) -> &str { fn first_name(&self) -> &str {
&self.user.first_name self.user.first_name.as_deref().unwrap_or("")
} }
fn last_name(&self) -> &str { fn last_name(&self) -> &str {
&self.user.last_name self.user.last_name.as_deref().unwrap_or("")
} }
fn avatar(&self) -> String { fn avatar(&self) -> Option<String> {
(&self.user.avatar).into() self.user.avatar.as_ref().map(String::from)
} }
fn creation_date(&self) -> chrono::DateTime<chrono::Utc> { fn creation_date(&self) -> chrono::DateTime<chrono::Utc> {
@ -392,7 +391,7 @@ mod tests {
Ok(DomainUser { Ok(DomainUser {
user_id: UserId::new("bob"), user_id: UserId::new("bob"),
email: "bob@bobbers.on".to_string(), email: "bob@bobbers.on".to_string(),
creation_date: chrono::Utc.timestamp_millis(42), creation_date: chrono::Utc.timestamp_millis_opt(42).unwrap(),
uuid: crate::uuid!("b1a2a3a4b1b2c1c2d1d2d3d4d5d6d7d8"), uuid: crate::uuid!("b1a2a3a4b1b2c1c2d1d2d3d4d5d6d7d8"),
..Default::default() ..Default::default()
}) })

View File

@ -1,6 +1,7 @@
use sea_query::*; use sea_orm::ConnectionTrait;
use sea_query::{ColumnDef, ForeignKey, ForeignKeyAction, Iden, Table};
pub use crate::domain::sql_tables::*; pub use crate::domain::{sql_migrations::Users, sql_tables::DbConnection};
/// Contains the refresh tokens for a given user. /// Contains the refresh tokens for a given user.
#[derive(Iden)] #[derive(Iden)]
@ -31,9 +32,12 @@ pub enum PasswordResetTokens {
} }
/// This needs to be initialized after the domain tables are. /// This needs to be initialized after the domain tables are.
pub async fn init_table(pool: &Pool) -> sqlx::Result<()> { pub async fn init_table(pool: &DbConnection) -> std::result::Result<(), sea_orm::DbErr> {
sqlx::query( let builder = pool.get_database_backend();
&Table::create()
pool.execute(
builder.build(
Table::create()
.table(JwtRefreshStorage::Table) .table(JwtRefreshStorage::Table)
.if_not_exists() .if_not_exists()
.col( .col(
@ -59,14 +63,14 @@ pub async fn init_table(pool: &Pool) -> sqlx::Result<()> {
.to(Users::Table, Users::UserId) .to(Users::Table, Users::UserId)
.on_delete(ForeignKeyAction::Cascade) .on_delete(ForeignKeyAction::Cascade)
.on_update(ForeignKeyAction::Cascade), .on_update(ForeignKeyAction::Cascade),
),
),
) )
.to_string(DbQueryBuilder {}),
)
.execute(pool)
.await?; .await?;
sqlx::query( pool.execute(
&Table::create() builder.build(
Table::create()
.table(JwtStorage::Table) .table(JwtStorage::Table)
.if_not_exists() .if_not_exists()
.col( .col(
@ -98,14 +102,14 @@ pub async fn init_table(pool: &Pool) -> sqlx::Result<()> {
.to(Users::Table, Users::UserId) .to(Users::Table, Users::UserId)
.on_delete(ForeignKeyAction::Cascade) .on_delete(ForeignKeyAction::Cascade)
.on_update(ForeignKeyAction::Cascade), .on_update(ForeignKeyAction::Cascade),
),
),
) )
.to_string(DbQueryBuilder {}),
)
.execute(pool)
.await?; .await?;
sqlx::query( pool.execute(
&Table::create() builder.build(
Table::create()
.table(PasswordResetTokens::Table) .table(PasswordResetTokens::Table)
.if_not_exists() .if_not_exists()
.col( .col(
@ -131,10 +135,9 @@ pub async fn init_table(pool: &Pool) -> sqlx::Result<()> {
.to(Users::Table, Users::UserId) .to(Users::Table, Users::UserId)
.on_delete(ForeignKeyAction::Cascade) .on_delete(ForeignKeyAction::Cascade)
.on_update(ForeignKeyAction::Cascade), .on_update(ForeignKeyAction::Cascade),
),
),
) )
.to_string(DbQueryBuilder {}),
)
.execute(pool)
.await?; .await?;
Ok(()) Ok(())

View File

@ -569,7 +569,7 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
mod tests { mod tests {
use super::*; use super::*;
use crate::{ use crate::{
domain::{error::Result, handler::*, opaque_handler::*, sql_tables::UserColumn}, domain::{error::Result, handler::*, opaque_handler::*},
uuid, uuid,
}; };
use async_trait::async_trait; use async_trait::async_trait;
@ -669,7 +669,7 @@ mod tests {
set.insert(GroupDetails { set.insert(GroupDetails {
group_id: GroupId(42), group_id: GroupId(42),
display_name: group, display_name: group,
creation_date: chrono::Utc.timestamp(42, 42), creation_date: chrono::Utc.timestamp_opt(42, 42).unwrap(),
uuid: uuid!("a1a2a3a4b1b2c1c2d1d2d3d4d5d6d7d8"), uuid: uuid!("a1a2a3a4b1b2c1c2d1d2d3d4d5d6d7d8"),
}); });
Ok(set) Ok(set)
@ -756,7 +756,7 @@ mod tests {
set.insert(GroupDetails { set.insert(GroupDetails {
group_id: GroupId(42), group_id: GroupId(42),
display_name: "lldap_admin".to_string(), display_name: "lldap_admin".to_string(),
creation_date: chrono::Utc.timestamp(42, 42), creation_date: chrono::Utc.timestamp_opt(42, 42).unwrap(),
uuid: uuid!("a1a2a3a4b1b2c1c2d1d2d3d4d5d6d7d8"), uuid: uuid!("a1a2a3a4b1b2c1c2d1d2d3d4d5d6d7d8"),
}); });
Ok(set) Ok(set)
@ -843,7 +843,7 @@ mod tests {
groups: Some(vec![GroupDetails { groups: Some(vec![GroupDetails {
group_id: GroupId(42), group_id: GroupId(42),
display_name: "rockstars".to_string(), display_name: "rockstars".to_string(),
creation_date: chrono::Utc.timestamp(42, 42), creation_date: chrono::Utc.timestamp_opt(42, 42).unwrap(),
uuid: uuid!("a1a2a3a4b1b2c1c2d1d2d3d4d5d6d7d8"), uuid: uuid!("a1a2a3a4b1b2c1c2d1d2d3d4d5d6d7d8"),
}]), }]),
}]) }])
@ -991,9 +991,9 @@ mod tests {
user: User { user: User {
user_id: UserId::new("bob_1"), user_id: UserId::new("bob_1"),
email: "bob@bobmail.bob".to_string(), email: "bob@bobmail.bob".to_string(),
display_name: "Bôb Böbberson".to_string(), display_name: Some("Bôb Böbberson".to_string()),
first_name: "Bôb".to_string(), first_name: Some("Bôb".to_string()),
last_name: "Böbberson".to_string(), last_name: Some("Böbberson".to_string()),
uuid: uuid!("698e1d5f-7a40-3151-8745-b9b8a37839da"), uuid: uuid!("698e1d5f-7a40-3151-8745-b9b8a37839da"),
..Default::default() ..Default::default()
}, },
@ -1003,12 +1003,12 @@ mod tests {
user: User { user: User {
user_id: UserId::new("jim"), user_id: UserId::new("jim"),
email: "jim@cricket.jim".to_string(), email: "jim@cricket.jim".to_string(),
display_name: "Jimminy Cricket".to_string(), display_name: Some("Jimminy Cricket".to_string()),
first_name: "Jim".to_string(), first_name: Some("Jim".to_string()),
last_name: "Cricket".to_string(), last_name: Some("Cricket".to_string()),
avatar: JpegPhoto::for_tests(), avatar: Some(JpegPhoto::for_tests()),
uuid: uuid!("04ac75e0-2900-3e21-926c-2f732c26b3fc"), uuid: uuid!("04ac75e0-2900-3e21-926c-2f732c26b3fc"),
creation_date: Utc.ymd(2014, 7, 8).and_hms(9, 10, 11), creation_date: Utc.with_ymd_and_hms(2014, 7, 8, 9, 10, 11).unwrap(),
}, },
groups: None, groups: None,
}, },
@ -1137,14 +1137,14 @@ mod tests {
Group { Group {
id: GroupId(1), id: GroupId(1),
display_name: "group_1".to_string(), display_name: "group_1".to_string(),
creation_date: chrono::Utc.timestamp(42, 42), creation_date: chrono::Utc.timestamp_opt(42, 42).unwrap(),
users: vec![UserId::new("bob"), UserId::new("john")], users: vec![UserId::new("bob"), UserId::new("john")],
uuid: uuid!("04ac75e0-2900-3e21-926c-2f732c26b3fc"), uuid: uuid!("04ac75e0-2900-3e21-926c-2f732c26b3fc"),
}, },
Group { Group {
id: GroupId(3), id: GroupId(3),
display_name: "BestGroup".to_string(), display_name: "BestGroup".to_string(),
creation_date: chrono::Utc.timestamp(42, 42), creation_date: chrono::Utc.timestamp_opt(42, 42).unwrap(),
users: vec![UserId::new("john")], users: vec![UserId::new("john")],
uuid: uuid!("04ac75e0-2900-3e21-926c-2f732c26b3fc"), uuid: uuid!("04ac75e0-2900-3e21-926c-2f732c26b3fc"),
}, },
@ -1230,7 +1230,7 @@ mod tests {
Ok(vec![Group { Ok(vec![Group {
display_name: "group_1".to_string(), display_name: "group_1".to_string(),
id: GroupId(1), id: GroupId(1),
creation_date: chrono::Utc.timestamp(42, 42), creation_date: chrono::Utc.timestamp_opt(42, 42).unwrap(),
users: vec![], users: vec![],
uuid: uuid!("04ac75e0-2900-3e21-926c-2f732c26b3fc"), uuid: uuid!("04ac75e0-2900-3e21-926c-2f732c26b3fc"),
}]) }])
@ -1281,7 +1281,7 @@ mod tests {
Ok(vec![Group { Ok(vec![Group {
display_name: "group_1".to_string(), display_name: "group_1".to_string(),
id: GroupId(1), id: GroupId(1),
creation_date: chrono::Utc.timestamp(42, 42), creation_date: chrono::Utc.timestamp_opt(42, 42).unwrap(),
users: vec![], users: vec![],
uuid: uuid!("04ac75e0-2900-3e21-926c-2f732c26b3fc"), uuid: uuid!("04ac75e0-2900-3e21-926c-2f732c26b3fc"),
}]) }])
@ -1542,9 +1542,9 @@ mod tests {
user: User { user: User {
user_id: UserId::new("bob_1"), user_id: UserId::new("bob_1"),
email: "bob@bobmail.bob".to_string(), email: "bob@bobmail.bob".to_string(),
display_name: "Bôb Böbberson".to_string(), display_name: Some("Bôb Böbberson".to_string()),
first_name: "Bôb".to_string(), first_name: Some("Bôb".to_string()),
last_name: "Böbberson".to_string(), last_name: Some("Böbberson".to_string()),
..Default::default() ..Default::default()
}, },
groups: None, groups: None,
@ -1557,7 +1557,7 @@ mod tests {
Ok(vec![Group { Ok(vec![Group {
id: GroupId(1), id: GroupId(1),
display_name: "group_1".to_string(), display_name: "group_1".to_string(),
creation_date: chrono::Utc.timestamp(42, 42), creation_date: chrono::Utc.timestamp_opt(42, 42).unwrap(),
users: vec![UserId::new("bob"), UserId::new("john")], users: vec![UserId::new("bob"), UserId::new("john")],
uuid: uuid!("04ac75e0-2900-3e21-926c-2f732c26b3fc"), uuid: uuid!("04ac75e0-2900-3e21-926c-2f732c26b3fc"),
}]) }])
@ -1616,9 +1616,9 @@ mod tests {
user: User { user: User {
user_id: UserId::new("bob_1"), user_id: UserId::new("bob_1"),
email: "bob@bobmail.bob".to_string(), email: "bob@bobmail.bob".to_string(),
display_name: "Bôb Böbberson".to_string(), display_name: Some("Bôb Böbberson".to_string()),
last_name: "Böbberson".to_string(), last_name: Some("Böbberson".to_string()),
avatar: JpegPhoto::for_tests(), avatar: Some(JpegPhoto::for_tests()),
uuid: uuid!("b4ac75e0-2900-3e21-926c-2f732c26b3fc"), uuid: uuid!("b4ac75e0-2900-3e21-926c-2f732c26b3fc"),
..Default::default() ..Default::default()
}, },
@ -1631,7 +1631,7 @@ mod tests {
Ok(vec![Group { Ok(vec![Group {
id: GroupId(1), id: GroupId(1),
display_name: "group_1".to_string(), display_name: "group_1".to_string(),
creation_date: chrono::Utc.timestamp(42, 42), creation_date: chrono::Utc.timestamp_opt(42, 42).unwrap(),
users: vec![UserId::new("bob"), UserId::new("john")], users: vec![UserId::new("bob"), UserId::new("john")],
uuid: uuid!("04ac75e0-2900-3e21-926c-2f732c26b3fc"), uuid: uuid!("04ac75e0-2900-3e21-926c-2f732c26b3fc"),
}]) }])
@ -1680,7 +1680,11 @@ mod tests {
}, },
LdapPartialAttribute { LdapPartialAttribute {
atype: "createtimestamp".to_string(), atype: "createtimestamp".to_string(),
vals: vec![chrono::Utc.timestamp(0, 0).to_rfc3339().into_bytes()], vals: vec![chrono::Utc
.timestamp_opt(0, 0)
.unwrap()
.to_rfc3339()
.into_bytes()],
}, },
LdapPartialAttribute { LdapPartialAttribute {
atype: "entryuuid".to_string(), atype: "entryuuid".to_string(),
@ -1960,7 +1964,7 @@ mod tests {
groups.insert(GroupDetails { groups.insert(GroupDetails {
group_id: GroupId(0), group_id: GroupId(0),
display_name: "lldap_admin".to_string(), display_name: "lldap_admin".to_string(),
creation_date: chrono::Utc.timestamp(42, 42), creation_date: chrono::Utc.timestamp_opt(42, 42).unwrap(),
uuid: uuid!("a1a2a3a4b1b2c1c2d1d2d3d4d5d6d7d8"), uuid: uuid!("a1a2a3a4b1b2c1c2d1d2d3d4d5d6d7d8"),
}); });
mock.expect_get_user_groups() mock.expect_get_user_groups()

View File

@ -48,3 +48,14 @@ pub fn init(config: &Configuration) -> anyhow::Result<()> {
.init(); .init();
Ok(()) Ok(())
} }
#[cfg(test)]
pub fn init_for_tests() {
if let Err(e) = tracing_subscriber::FmtSubscriber::builder()
.with_max_level(tracing::Level::DEBUG)
.with_test_writer()
.try_init()
{
log::warn!("Could not set up test logging: {:#}", e);
}
}

View File

@ -1,10 +1,16 @@
use super::{jwt_sql_tables::*, tcp_backend_handler::*}; use super::tcp_backend_handler::TcpBackendHandler;
use crate::domain::{error::*, handler::UserId, sql_backend_handler::SqlBackendHandler}; use crate::domain::{
error::*,
handler::UserId,
model::{self, JwtRefreshStorageColumn, JwtStorageColumn, PasswordResetTokensColumn},
sql_backend_handler::SqlBackendHandler,
};
use async_trait::async_trait; use async_trait::async_trait;
use futures_util::StreamExt; use sea_orm::{
use sea_query::{Expr, Iden, Query, SimpleExpr}; sea_query::Cond, ActiveModelTrait, ColumnTrait, EntityTrait, FromQueryResult, IntoActiveModel,
use sea_query_binder::SqlxBinder; QueryFilter, QuerySelect,
use sqlx::{query_as_with, query_with, Row}; };
use sea_query::Expr;
use std::collections::HashSet; use std::collections::HashSet;
use tracing::{debug, instrument}; use tracing::{debug, instrument};
@ -18,126 +24,102 @@ fn gen_random_string(len: usize) -> String {
.collect() .collect()
} }
#[derive(FromQueryResult)]
struct OnlyJwtHash {
jwt_hash: i64,
}
#[async_trait] #[async_trait]
impl TcpBackendHandler for SqlBackendHandler { impl TcpBackendHandler for SqlBackendHandler {
#[instrument(skip_all, level = "debug")] #[instrument(skip_all, level = "debug")]
async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>> { async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>> {
let (query, values) = Query::select() Ok(model::JwtStorage::find()
.column(JwtStorage::JwtHash) .select_only()
.from(JwtStorage::Table) .column(JwtStorageColumn::JwtHash)
.build_sqlx(DbQueryBuilder {}); .filter(JwtStorageColumn::Blacklisted.eq(true))
.into_model::<OnlyJwtHash>()
debug!(%query); .all(&self.sql_pool)
query_with(&query, values) .await?
.map(|row: DbRow| row.get::<i64, _>(&*JwtStorage::JwtHash.to_string()) as u64)
.fetch(&self.sql_pool)
.collect::<Vec<sqlx::Result<u64>>>()
.await
.into_iter() .into_iter()
.collect::<sqlx::Result<HashSet<u64>>>() .map(|m| m.jwt_hash as u64)
.map_err(|e| anyhow::anyhow!(e)) .collect::<HashSet<u64>>())
} }
#[instrument(skip_all, level = "debug")] #[instrument(skip_all, level = "debug")]
async fn create_refresh_token(&self, user: &UserId) -> Result<(String, chrono::Duration)> { async fn create_refresh_token(&self, user: &UserId) -> Result<(String, chrono::Duration)> {
debug!(?user); debug!(?user);
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
// TODO: Initialize the rng only once. Maybe Arc<Cell>? // TODO: Initialize the rng only once. Maybe Arc<Cell>?
let refresh_token = gen_random_string(100); let refresh_token = gen_random_string(100);
let refresh_token_hash = { let refresh_token_hash = {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut s = DefaultHasher::new(); let mut s = DefaultHasher::new();
refresh_token.hash(&mut s); refresh_token.hash(&mut s);
s.finish() s.finish()
}; };
let duration = chrono::Duration::days(30); let duration = chrono::Duration::days(30);
let (query, values) = Query::insert() let new_token = model::jwt_refresh_storage::Model {
.into_table(JwtRefreshStorage::Table) refresh_token_hash: refresh_token_hash as i64,
.columns(vec![ user_id: user.clone(),
JwtRefreshStorage::RefreshTokenHash, expiry_date: chrono::Utc::now() + duration,
JwtRefreshStorage::UserId, }
JwtRefreshStorage::ExpiryDate, .into_active_model();
]) new_token.insert(&self.sql_pool).await?;
.values_panic(vec![
(refresh_token_hash as i64).into(),
user.into(),
(chrono::Utc::now() + duration).naive_utc().into(),
])
.build_sqlx(DbQueryBuilder {});
debug!(%query);
query_with(&query, values).execute(&self.sql_pool).await?;
Ok((refresh_token, duration)) Ok((refresh_token, duration))
} }
#[instrument(skip_all, level = "debug")] #[instrument(skip_all, level = "debug")]
async fn check_token(&self, refresh_token_hash: u64, user: &UserId) -> Result<bool> { async fn check_token(&self, refresh_token_hash: u64, user: &UserId) -> Result<bool> {
debug!(?user); debug!(?user);
let (query, values) = Query::select() Ok(
.expr(SimpleExpr::Value(1.into())) model::JwtRefreshStorage::find_by_id(refresh_token_hash as i64)
.from(JwtRefreshStorage::Table) .filter(JwtRefreshStorageColumn::UserId.eq(user))
.and_where(Expr::col(JwtRefreshStorage::RefreshTokenHash).eq(refresh_token_hash as i64)) .one(&self.sql_pool)
.and_where(Expr::col(JwtRefreshStorage::UserId).eq(user))
.build_sqlx(DbQueryBuilder {});
debug!(%query);
Ok(query_with(&query, values)
.fetch_optional(&self.sql_pool)
.await? .await?
.is_some()) .is_some(),
)
} }
#[instrument(skip_all, level = "debug")] #[instrument(skip_all, level = "debug")]
async fn blacklist_jwts(&self, user: &UserId) -> Result<HashSet<u64>> { async fn blacklist_jwts(&self, user: &UserId) -> Result<HashSet<u64>> {
debug!(?user); debug!(?user);
use sqlx::Result; let valid_tokens = model::JwtStorage::find()
let (query, values) = Query::select() .select_only()
.column(JwtStorage::JwtHash) .column(JwtStorageColumn::JwtHash)
.from(JwtStorage::Table) .filter(
.and_where(Expr::col(JwtStorage::UserId).eq(user)) Cond::all()
.and_where(Expr::col(JwtStorage::Blacklisted).eq(true)) .add(JwtStorageColumn::UserId.eq(user))
.build_sqlx(DbQueryBuilder {}); .add(JwtStorageColumn::Blacklisted.eq(false)),
let result = query_with(&query, values) )
.map(|row: DbRow| row.get::<i64, _>(&*JwtStorage::JwtHash.to_string()) as u64) .into_model::<OnlyJwtHash>()
.fetch(&self.sql_pool) .all(&self.sql_pool)
.collect::<Vec<sqlx::Result<u64>>>() .await?
.await
.into_iter() .into_iter()
.collect::<Result<HashSet<u64>>>(); .map(|t| t.jwt_hash as u64)
let (query, values) = Query::update() .collect::<HashSet<u64>>();
.table(JwtStorage::Table) model::JwtStorage::update_many()
.values(vec![(JwtStorage::Blacklisted, true.into())]) .col_expr(JwtStorageColumn::Blacklisted, Expr::value(true))
.and_where(Expr::col(JwtStorage::UserId).eq(user)) .filter(JwtStorageColumn::UserId.eq(user))
.build_sqlx(DbQueryBuilder {}); .exec(&self.sql_pool)
debug!(%query); .await?;
query_with(&query, values).execute(&self.sql_pool).await?; Ok(valid_tokens)
Ok(result?)
} }
#[instrument(skip_all, level = "debug")] #[instrument(skip_all, level = "debug")]
async fn delete_refresh_token(&self, refresh_token_hash: u64) -> Result<()> { async fn delete_refresh_token(&self, refresh_token_hash: u64) -> Result<()> {
let (query, values) = Query::delete() model::JwtRefreshStorage::delete_by_id(refresh_token_hash as i64)
.from_table(JwtRefreshStorage::Table) .exec(&self.sql_pool)
.and_where(Expr::col(JwtRefreshStorage::RefreshTokenHash).eq(refresh_token_hash as i64)) .await?;
.build_sqlx(DbQueryBuilder {});
debug!(%query);
query_with(&query, values).execute(&self.sql_pool).await?;
Ok(()) Ok(())
} }
#[instrument(skip_all, level = "debug")] #[instrument(skip_all, level = "debug")]
async fn start_password_reset(&self, user: &UserId) -> Result<Option<String>> { async fn start_password_reset(&self, user: &UserId) -> Result<Option<String>> {
debug!(?user); debug!(?user);
let (query, values) = Query::select() if model::User::find_by_id(user.clone())
.column(Users::UserId) .one(&self.sql_pool)
.from(Users::Table) .await?
.and_where(Expr::col(Users::UserId).eq(user)) .is_none()
.build_sqlx(DbQueryBuilder {});
debug!(%query);
// Check that the user exists.
if query_with(&query, values)
.fetch_one(&self.sql_pool)
.await
.is_err()
{ {
debug!("User not found"); debug!("User not found");
return Ok(None); return Ok(None);
@ -146,50 +128,37 @@ impl TcpBackendHandler for SqlBackendHandler {
let token = gen_random_string(100); let token = gen_random_string(100);
let duration = chrono::Duration::minutes(10); let duration = chrono::Duration::minutes(10);
let (query, values) = Query::insert() let new_token = model::password_reset_tokens::Model {
.into_table(PasswordResetTokens::Table) token: token.clone(),
.columns(vec![ user_id: user.clone(),
PasswordResetTokens::Token, expiry_date: chrono::Utc::now() + duration,
PasswordResetTokens::UserId, }
PasswordResetTokens::ExpiryDate, .into_active_model();
]) new_token.insert(&self.sql_pool).await?;
.values_panic(vec![
token.clone().into(),
user.into(),
(chrono::Utc::now() + duration).naive_utc().into(),
])
.build_sqlx(DbQueryBuilder {});
debug!(%query);
query_with(&query, values).execute(&self.sql_pool).await?;
Ok(Some(token)) Ok(Some(token))
} }
#[instrument(skip_all, level = "debug", ret)] #[instrument(skip_all, level = "debug", ret)]
async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result<UserId> { async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result<UserId> {
let (query, values) = Query::select() Ok(model::PasswordResetTokens::find_by_id(token.to_owned())
.column(PasswordResetTokens::UserId) .filter(PasswordResetTokensColumn::ExpiryDate.gt(chrono::Utc::now().naive_utc()))
.from(PasswordResetTokens::Table) .one(&self.sql_pool)
.and_where(Expr::col(PasswordResetTokens::Token).eq(token)) .await?
.and_where( .ok_or_else(|| DomainError::EntityNotFound("Invalid reset token".to_owned()))?
Expr::col(PasswordResetTokens::ExpiryDate).gt(chrono::Utc::now().naive_utc()), .user_id)
)
.build_sqlx(DbQueryBuilder {});
debug!(%query);
let (user_id,) = query_as_with(&query, values)
.fetch_one(&self.sql_pool)
.await?;
Ok(user_id)
} }
#[instrument(skip_all, level = "debug")] #[instrument(skip_all, level = "debug")]
async fn delete_password_reset_token(&self, token: &str) -> Result<()> { async fn delete_password_reset_token(&self, token: &str) -> Result<()> {
let (query, values) = Query::delete() let result = model::PasswordResetTokens::delete_by_id(token.to_owned())
.from_table(PasswordResetTokens::Table) .exec(&self.sql_pool)
.and_where(Expr::col(PasswordResetTokens::Token).eq(token)) .await?;
.build_sqlx(DbQueryBuilder {}); if result.rows_affected == 0 {
debug!(%query); return Err(DomainError::EntityNotFound(format!(
query_with(&query, values).execute(&self.sql_pool).await?; "No such password reset token: '{}'",
token
)));
}
Ok(()) Ok(())
} }
} }

View File

@ -52,9 +52,9 @@ pub(crate) fn error_to_http_response(error: TcpError) -> HttpResponse {
DomainError::DatabaseError(_) DomainError::DatabaseError(_)
| DomainError::InternalError(_) | DomainError::InternalError(_)
| DomainError::UnknownCryptoError(_) => HttpResponse::InternalServerError(), | DomainError::UnknownCryptoError(_) => HttpResponse::InternalServerError(),
DomainError::Base64DecodeError(_) | DomainError::BinarySerializationError(_) => { DomainError::Base64DecodeError(_)
HttpResponse::BadRequest() | DomainError::BinarySerializationError(_)
} | DomainError::EntityNotFound(_) => HttpResponse::BadRequest(),
}, },
TcpError::BadRequest(_) => HttpResponse::BadRequest(), TcpError::BadRequest(_) => HttpResponse::BadRequest(),
TcpError::InternalServerError(_) => HttpResponse::InternalServerError(), TcpError::InternalServerError(_) => HttpResponse::InternalServerError(),

View File

@ -9,7 +9,6 @@ use crate::{
handler::{CreateUserRequest, GroupBackendHandler, GroupRequestFilter, UserBackendHandler}, handler::{CreateUserRequest, GroupBackendHandler, GroupRequestFilter, UserBackendHandler},
sql_backend_handler::SqlBackendHandler, sql_backend_handler::SqlBackendHandler,
sql_opaque_handler::register_password, sql_opaque_handler::register_password,
sql_tables::PoolOptions,
}, },
infra::{cli::*, configuration::Configuration, db_cleaner::Scheduler, healthcheck, mail}, infra::{cli::*, configuration::Configuration, db_cleaner::Scheduler, healthcheck, mail},
}; };
@ -17,6 +16,7 @@ use actix::Actor;
use actix_server::ServerBuilder; use actix_server::ServerBuilder;
use anyhow::{anyhow, Context, Result}; use anyhow::{anyhow, Context, Result};
use futures_util::TryFutureExt; use futures_util::TryFutureExt;
use sea_orm::Database;
use tracing::*; use tracing::*;
mod domain; mod domain;
@ -39,29 +39,52 @@ async fn create_admin_user(handler: &SqlBackendHandler, config: &Configuration)
.and_then(|_| register_password(handler, &config.ldap_user_dn, &config.ldap_user_pass)) .and_then(|_| register_password(handler, &config.ldap_user_dn, &config.ldap_user_pass))
.await .await
.context("Error creating admin user")?; .context("Error creating admin user")?;
let admin_group_id = handler let groups = handler
.create_group("lldap_admin") .list_groups(Some(GroupRequestFilter::DisplayName(
.await "lldap_admin".to_owned(),
.context("Error creating admin group")?; )))
.await?;
assert_eq!(groups.len(), 1);
handler handler
.add_user_to_group(&config.ldap_user_dn, admin_group_id) .add_user_to_group(&config.ldap_user_dn, groups[0].id)
.await .await
.context("Error adding admin user to group") .context("Error adding admin user to group")
} }
async fn ensure_group_exists(handler: &SqlBackendHandler, group_name: &str) -> Result<()> {
if handler
.list_groups(Some(GroupRequestFilter::DisplayName(group_name.to_owned())))
.await?
.is_empty()
{
warn!("Could not find {} group, trying to create it", group_name);
handler
.create_group(group_name)
.await
.context(format!("while creating {} group", group_name))?;
}
Ok(())
}
#[instrument(skip_all)] #[instrument(skip_all)]
async fn set_up_server(config: Configuration) -> Result<ServerBuilder> { async fn set_up_server(config: Configuration) -> Result<ServerBuilder> {
info!("Starting LLDAP version {}", env!("CARGO_PKG_VERSION")); info!("Starting LLDAP version {}", env!("CARGO_PKG_VERSION"));
let sql_pool = PoolOptions::new() let sql_pool = {
let mut sql_opt = sea_orm::ConnectOptions::new(config.database_url.clone());
sql_opt
.max_connections(5) .max_connections(5)
.connect(&config.database_url) .sqlx_logging(true)
.await .sqlx_logging_level(log::LevelFilter::Debug);
.context("while connecting to the DB")?; Database::connect(sql_opt).await?
};
domain::sql_tables::init_table(&sql_pool) domain::sql_tables::init_table(&sql_pool)
.await .await
.context("while creating the tables")?; .context("while creating the tables")?;
let backend_handler = SqlBackendHandler::new(config.clone(), sql_pool.clone()); let backend_handler = SqlBackendHandler::new(config.clone(), sql_pool.clone());
ensure_group_exists(&backend_handler, "lldap_admin").await?;
ensure_group_exists(&backend_handler, "lldap_password_manager").await?;
ensure_group_exists(&backend_handler, "lldap_strict_readonly").await?;
if let Err(e) = backend_handler.get_user_details(&config.ldap_user_dn).await { if let Err(e) = backend_handler.get_user_details(&config.ldap_user_dn).await {
warn!("Could not get admin user, trying to create it: {:#}", e); warn!("Could not get admin user, trying to create it: {:#}", e);
create_admin_user(&backend_handler, &config) create_admin_user(&backend_handler, &config)
@ -69,23 +92,6 @@ async fn set_up_server(config: Configuration) -> Result<ServerBuilder> {
.map_err(|e| anyhow!("Error setting up admin login/account: {:#}", e)) .map_err(|e| anyhow!("Error setting up admin login/account: {:#}", e))
.context("while creating the admin user")?; .context("while creating the admin user")?;
} }
if backend_handler
.list_groups(Some(GroupRequestFilter::DisplayName(
"lldap_password_manager".to_string(),
)))
.await?
.is_empty()
{
warn!("Could not find password_manager group, trying to create it");
backend_handler
.create_group("lldap_password_manager")
.await
.context("while creating password_manager group")?;
backend_handler
.create_group("lldap_strict_readonly")
.await
.context("while creating readonly group")?;
}
let server_builder = infra::ldap_server::build_ldap_server( let server_builder = infra::ldap_server::build_ldap_server(
&config, &config,
backend_handler.clone(), backend_handler.clone(),