server,app: migrate to sea-orm

This commit is contained in:
Valentin Tolmer 2022-11-21 09:13:25 +01:00 committed by nitnelave
parent a3a27f0049
commit e89b1538af
40 changed files with 2125 additions and 1390 deletions

586
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -12,3 +12,7 @@ default-members = ["server"]
[patch.crates-io.ldap3_proto]
git = 'https://github.com/nitnelave/ldap3_server/'
rev = '7b50b2b82c383f5f70e02e11072bb916629ed2bc'
[patch.crates-io.opaque-ke]
git = 'https://github.com/nitnelave/opaque-ke/'
branch = 'zeroize_1.5'

View File

@ -3,6 +3,7 @@ name = "lldap_app"
version = "0.4.2-alpha"
authors = ["Valentin Tolmer <valentin@tolmer.fr>"]
edition = "2021"
include = ["src/**/*", "queries/**/*", "Cargo.toml", "../schema.graphql"]
[dependencies]
anyhow = "1"

View File

@ -38,7 +38,6 @@ pub struct CreateUserModel {
username: String,
#[validate(email(message = "A valid email is required"))]
email: String,
#[validate(length(min = 1, message = "Display name is required"))]
display_name: String,
first_name: String,
last_name: String,
@ -244,9 +243,7 @@ impl Component for CreateUserForm {
<div class="form-group row mb-3">
<label for="display-name"
class="form-label col-4 col-form-label">
{"Display name"}
<span class="text-danger">{"*"}</span>
{":"}
{"Display name:"}
</label>
<div class="col-8">
<Field

View File

@ -89,7 +89,7 @@ impl GroupDetails {
{"Creation date: "}
</label>
<div class="col-8">
<span id="creationDate" class="form-constrol-static">{g.creation_date.date().naive_local()}</span>
<span id="creationDate" class="form-constrol-static">{g.creation_date.naive_local().date()}</span>
</div>
</div>
<div class="form-group row mb-3">

View File

@ -124,7 +124,7 @@ impl GroupTable {
</Link>
</td>
<td>
{&group.creation_date.date().naive_local()}
{&group.creation_date.naive_local().date()}
</td>
<td>
<DeleteGroup

View File

@ -43,7 +43,6 @@ impl FromStr for JsFile {
pub struct UserModel {
#[validate(email)]
email: String,
#[validate(length(min = 1, message = "Display name is required"))]
display_name: String,
first_name: String,
last_name: String,
@ -176,7 +175,10 @@ impl Component for UserDetailsForm {
type Field = yew_form::Field<UserModel>;
let avatar_base64 = maybe_to_base64(&self.avatar).unwrap_or_default();
let avatar_string = avatar_base64.as_ref().unwrap_or(&self.common.user.avatar);
let avatar_string = avatar_base64
.as_deref()
.or(self.common.user.avatar.as_deref())
.unwrap_or("");
html! {
<div class="py-3">
<form class="form">
@ -195,7 +197,7 @@ impl Component for UserDetailsForm {
{"Creation date: "}
</label>
<div class="col-8">
<span id="creationDate" class="form-control-static">{&self.common.user.creation_date.date().naive_local()}</span>
<span id="creationDate" class="form-control-static">{&self.common.user.creation_date.naive_local().date()}</span>
</div>
</div>
<div class="form-group row mb-3">
@ -231,9 +233,7 @@ impl Component for UserDetailsForm {
<div class="form-group row mb-3">
<label for="display_name"
class="form-label col-4 col-form-label">
{"Display Name"}
<span class="text-danger">{"*"}</span>
{":"}
{"Display Name: "}
</label>
<div class="col-8">
<Field
@ -402,7 +402,7 @@ impl UserDetailsForm {
self.common.user.first_name = model.first_name;
self.common.user.last_name = model.last_name;
if let Some(avatar) = maybe_to_base64(&self.avatar)? {
self.common.user.avatar = avatar;
self.common.user.avatar = Some(avatar);
}
self.just_updated = true;
}

View File

@ -133,7 +133,7 @@ impl UserTable {
<td>{&user.display_name}</td>
<td>{&user.first_name}</td>
<td>{&user.last_name}</td>
<td>{&user.creation_date.date().naive_local()}</td>
<td>{&user.creation_date.naive_local().date()}</td>
<td>
<DeleteUser
username=user.id.clone()

View File

@ -53,7 +53,11 @@ pub fn get_cookie(cookie_name: &str) -> Result<Option<String>> {
pub fn delete_cookie(cookie_name: &str) -> Result<()> {
if get_cookie(cookie_name)?.is_some() {
set_cookie(cookie_name, "", &Utc.ymd(1970, 1, 1).and_hms(0, 0, 0))
set_cookie(
cookie_name,
"",
&Utc.with_ymd_and_hms(1970, 1, 1, 0, 0, 0).unwrap(),
)
} else {
Ok(())
}

View File

@ -69,7 +69,7 @@ type User {
displayName: String!
firstName: String!
lastName: String!
avatar: String!
avatar: String
creationDate: DateTimeUtc!
uuid: String!
"The groups to which this user belongs."

View File

@ -35,7 +35,6 @@ rustls = "0.20"
serde = "*"
serde_json = "1"
sha2 = "0.9"
sqlx-core = "0.5.11"
thiserror = "*"
time = "0.2"
tokio-rustls = "0.23"
@ -70,28 +69,12 @@ features = ["builder", "serde", "smtp-transport", "tokio1-rustls-tls"]
default-features = false
version = "0.10.0-rc.3"
[dependencies.sqlx]
version = "0.5.11"
features = [
"any",
"chrono",
"macros",
"mysql",
"postgres",
"runtime-actix-rustls",
"sqlite",
]
[dependencies.lldap_auth]
path = "../auth"
[dependencies.sea-query]
version = "^0.25"
features = ["with-chrono", "sqlx-sqlite"]
[dependencies.sea-query-binder]
version = "0.1"
features = ["with-chrono", "sqlx-sqlite", "sqlx-any"]
version = "*"
features = ["with-chrono"]
[dependencies.opaque-ke]
version = "0.6"
@ -125,6 +108,11 @@ features = ["jpeg"]
default-features = false
version = "0.24"
[dependencies.sea-orm]
version= "0.10.3"
default-features = false
features = ["macros", "with-chrono", "with-uuid", "sqlx-all", "runtime-actix-rustls"]
[dependencies.reqwest]
version = "0.11"
default-features = false

View File

@ -6,7 +6,7 @@ pub enum DomainError {
#[error("Authentication error: `{0}`")]
AuthenticationError(String),
#[error("Database error: `{0}`")]
DatabaseError(#[from] sqlx::Error),
DatabaseError(#[from] sea_orm::DbErr),
#[error("Authentication protocol error for `{0}`")]
AuthenticationProtocolError(#[from] lldap_auth::opaque::AuthenticationError),
#[error("Unknown crypto error: `{0}`")]
@ -15,6 +15,8 @@ pub enum DomainError {
BinarySerializationError(#[from] bincode::Error),
#[error("Invalid base64: `{0}`")]
Base64DecodeError(#[from] base64::DecodeError),
#[error("Entity not found: `{0}`")]
EntityNotFound(String),
#[error("Internal error: `{0}`")]
InternalError(String),
}

View File

@ -1,13 +1,12 @@
use super::{error::*, sql_tables::UserColumn};
use super::error::*;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
#[derive(
PartialEq, Hash, Eq, Clone, Debug, Default, Serialize, Deserialize, sqlx::FromRow, sqlx::Type,
)]
pub use super::model::{GroupColumn, UserColumn};
#[derive(PartialEq, Hash, Eq, Clone, Debug, Default, Serialize, Deserialize)]
#[serde(try_from = "&str")]
#[sqlx(transparent)]
pub struct Uuid(String);
impl Uuid {
@ -43,17 +42,26 @@ impl std::string::ToString for Uuid {
}
}
impl sea_orm::TryGetable for Uuid {
fn try_get(
res: &sea_orm::QueryResult,
pre: &str,
col: &str,
) -> std::result::Result<Self, sea_orm::TryGetError> {
Ok(Uuid(String::try_get(res, pre, col)?))
}
}
#[cfg(test)]
#[macro_export]
macro_rules! uuid {
($s:literal) => {
$crate::domain::handler::Uuid::try_from($s).unwrap()
<$crate::domain::handler::Uuid as std::convert::TryFrom<_>>::try_from($s).unwrap()
};
}
#[derive(PartialEq, Eq, Clone, Debug, Default, Serialize, Deserialize, sqlx::Type)]
#[derive(PartialEq, Eq, Clone, Debug, Default, Serialize, Deserialize)]
#[serde(from = "String")]
#[sqlx(transparent)]
pub struct UserId(String);
impl UserId {
@ -82,17 +90,22 @@ impl From<String> for UserId {
}
}
#[derive(PartialEq, Eq, Clone, Debug, Default, Serialize, Deserialize, sqlx::Type)]
#[sqlx(transparent)]
#[derive(PartialEq, Eq, Clone, Debug, Serialize, Deserialize)]
pub struct JpegPhoto(#[serde(with = "serde_bytes")] Vec<u8>);
impl From<JpegPhoto> for sea_query::Value {
impl JpegPhoto {
pub fn null() -> Self {
Self(vec![])
}
}
impl From<JpegPhoto> for sea_orm::Value {
fn from(photo: JpegPhoto) -> Self {
photo.0.into()
}
}
impl From<&JpegPhoto> for sea_query::Value {
impl From<&JpegPhoto> for sea_orm::Value {
fn from(photo: &JpegPhoto) -> Self {
photo.0.as_slice().into()
}
@ -101,6 +114,9 @@ impl From<&JpegPhoto> for sea_query::Value {
impl TryFrom<&[u8]> for JpegPhoto {
type Error = anyhow::Error;
fn try_from(bytes: &[u8]) -> anyhow::Result<Self> {
if bytes.is_empty() {
return Ok(JpegPhoto::null());
}
// Confirm that it's a valid Jpeg, then store only the bytes.
image::io::Reader::with_format(std::io::Cursor::new(bytes), image::ImageFormat::Jpeg)
.decode()?;
@ -111,6 +127,9 @@ impl TryFrom<&[u8]> for JpegPhoto {
impl TryFrom<Vec<u8>> for JpegPhoto {
type Error = anyhow::Error;
fn try_from(bytes: Vec<u8>) -> anyhow::Result<Self> {
if bytes.is_empty() {
return Ok(JpegPhoto::null());
}
// Confirm that it's a valid Jpeg, then store only the bytes.
image::io::Reader::with_format(
std::io::Cursor::new(bytes.as_slice()),
@ -160,14 +179,14 @@ impl JpegPhoto {
}
}
#[derive(PartialEq, Eq, Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
#[derive(PartialEq, Eq, Debug, Clone, Serialize, Deserialize, sea_orm::FromQueryResult)]
pub struct User {
pub user_id: UserId,
pub email: String,
pub display_name: String,
pub first_name: String,
pub last_name: String,
pub avatar: JpegPhoto,
pub display_name: Option<String>,
pub first_name: Option<String>,
pub last_name: Option<String>,
pub avatar: Option<JpegPhoto>,
pub creation_date: chrono::DateTime<chrono::Utc>,
pub uuid: Uuid,
}
@ -176,14 +195,14 @@ pub struct User {
impl Default for User {
fn default() -> Self {
use chrono::TimeZone;
let epoch = chrono::Utc.timestamp(0, 0);
let epoch = chrono::Utc.timestamp_opt(0, 0).unwrap();
User {
user_id: UserId::default(),
email: String::new(),
display_name: String::new(),
first_name: String::new(),
last_name: String::new(),
avatar: JpegPhoto::default(),
display_name: None,
first_name: None,
last_name: None,
avatar: None,
creation_date: epoch,
uuid: Uuid::from_name_and_date("", &epoch),
}
@ -263,11 +282,10 @@ pub trait LoginHandler: Clone + Send {
async fn bind(&self, request: BindRequest) -> Result<()>;
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)]
#[sqlx(transparent)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct GroupId(pub i32);
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::FromRow)]
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, sea_orm::FromQueryResult)]
pub struct GroupDetails {
pub group_id: GroupId,
pub display_name: String,
@ -349,8 +367,8 @@ mod tests {
fn test_uuid_time() {
use chrono::prelude::*;
let user_id = "bob";
let date1 = Utc.ymd(2014, 7, 8).and_hms(9, 10, 11);
let date2 = Utc.ymd(2014, 7, 8).and_hms(9, 10, 12);
let date1 = Utc.with_ymd_and_hms(2014, 7, 8, 9, 10, 11).unwrap();
let date2 = Utc.with_ymd_and_hms(2014, 7, 8, 9, 10, 12).unwrap();
assert_ne!(
Uuid::from_name_and_date(user_id, &date1),
Uuid::from_name_and_date(user_id, &date2)

View File

@ -4,9 +4,8 @@ use ldap3_proto::{
use tracing::{debug, info, instrument, warn};
use crate::domain::{
handler::{BackendHandler, Group, GroupRequestFilter, UserId, Uuid},
handler::{BackendHandler, Group, GroupColumn, GroupRequestFilter, UserId, Uuid},
ldap::error::LdapError,
sql_tables::GroupColumn,
};
use super::{

View File

@ -4,9 +4,8 @@ use ldap3_proto::{
use tracing::{debug, info, instrument, warn};
use crate::domain::{
handler::{BackendHandler, GroupDetails, User, UserId, UserRequestFilter},
handler::{BackendHandler, GroupDetails, User, UserColumn, UserId, UserRequestFilter},
ldap::{error::LdapError, utils::expand_attribute_wildcards},
sql_tables::UserColumn,
};
use super::{
@ -34,9 +33,9 @@ fn get_user_attribute(
"uid" => vec![user.user_id.to_string().into_bytes()],
"entryuuid" => vec![user.uuid.to_string().into_bytes()],
"mail" => vec![user.email.clone().into_bytes()],
"givenname" => vec![user.first_name.clone().into_bytes()],
"sn" => vec![user.last_name.clone().into_bytes()],
"jpegphoto" => vec![user.avatar.clone().into_bytes()],
"givenname" => vec![user.first_name.clone()?.into_bytes()],
"sn" => vec![user.last_name.clone()?.into_bytes()],
"jpegphoto" => vec![user.avatar.clone()?.into_bytes()],
"memberof" => groups
.into_iter()
.flatten()
@ -48,7 +47,7 @@ fn get_user_attribute(
.into_bytes()
})
.collect(),
"cn" | "displayname" => vec![user.display_name.clone().into_bytes()],
"cn" | "displayname" => vec![user.display_name.clone()?.into_bytes()],
"createtimestamp" | "modifytimestamp" => vec![user.creation_date.to_rfc3339().into_bytes()],
"1.1" => return None,
// We ignore the operational attribute wildcard.

View File

@ -2,10 +2,7 @@ use itertools::Itertools;
use ldap3_proto::LdapResultCode;
use tracing::{debug, instrument, warn};
use crate::domain::{
handler::UserId,
sql_tables::{GroupColumn, UserColumn},
};
use crate::domain::handler::{GroupColumn, UserColumn, UserId};
use super::error::{LdapError, LdapResult};

View File

@ -1,6 +1,7 @@
pub mod error;
pub mod handler;
pub mod ldap;
pub mod model;
pub mod opaque_handler;
pub mod sql_backend_handler;
pub mod sql_group_backend_handler;

View File

@ -0,0 +1,53 @@
//! `SeaORM` Entity. Generated by sea-orm-codegen 0.10.3
use sea_orm::entity::prelude::*;
use serde::{Deserialize, Serialize};
use crate::domain::handler::{GroupId, Uuid};
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)]
#[sea_orm(table_name = "groups")]
pub struct Model {
#[sea_orm(primary_key, auto_increment = false)]
pub group_id: GroupId,
pub display_name: String,
pub creation_date: chrono::DateTime<chrono::Utc>,
pub uuid: Uuid,
}
impl From<Model> for crate::domain::handler::Group {
fn from(group: Model) -> Self {
Self {
id: group.group_id,
display_name: group.display_name,
creation_date: group.creation_date,
uuid: group.uuid,
users: vec![],
}
}
}
impl From<Model> for crate::domain::handler::GroupDetails {
fn from(group: Model) -> Self {
Self {
group_id: group.group_id,
display_name: group.display_name,
creation_date: group.creation_date,
uuid: group.uuid,
}
}
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(has_many = "super::memberships::Entity")]
Memberships,
}
impl Related<super::memberships::Entity> for Entity {
fn to() -> RelationDef {
Relation::Memberships.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@ -0,0 +1,35 @@
//! `SeaORM` Entity. Generated by sea-orm-codegen 0.10.3
use sea_orm::entity::prelude::*;
use serde::{Deserialize, Serialize};
use crate::domain::handler::UserId;
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)]
#[sea_orm(table_name = "jwt_refresh_storage")]
pub struct Model {
#[sea_orm(primary_key, auto_increment = false)]
pub refresh_token_hash: i64,
pub user_id: UserId,
pub expiry_date: chrono::DateTime<chrono::Utc>,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::users::Entity",
from = "Column::UserId",
to = "super::users::Column::UserId",
on_update = "Cascade",
on_delete = "Cascade"
)]
Users,
}
impl Related<super::users::Entity> for Entity {
fn to() -> RelationDef {
Relation::Users.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@ -0,0 +1,36 @@
//! `SeaORM` Entity. Generated by sea-orm-codegen 0.10.3
use sea_orm::entity::prelude::*;
use serde::{Deserialize, Serialize};
use crate::domain::handler::UserId;
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)]
#[sea_orm(table_name = "jwt_storage")]
pub struct Model {
#[sea_orm(primary_key, auto_increment = false)]
pub jwt_hash: i64,
pub user_id: UserId,
pub expiry_date: chrono::DateTime<chrono::Utc>,
pub blacklisted: bool,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::users::Entity",
from = "Column::UserId",
to = "super::users::Column::UserId",
on_update = "Cascade",
on_delete = "Cascade"
)]
Users,
}
impl Related<super::users::Entity> for Entity {
fn to() -> RelationDef {
Relation::Users.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@ -0,0 +1,73 @@
//! `SeaORM` Entity. Generated by sea-orm-codegen 0.10.3
use sea_orm::entity::prelude::*;
use serde::{Deserialize, Serialize};
use crate::domain::handler::{GroupId, UserId};
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)]
#[sea_orm(table_name = "memberships")]
pub struct Model {
#[sea_orm(primary_key)]
pub user_id: UserId,
#[sea_orm(primary_key)]
pub group_id: GroupId,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::groups::Entity",
from = "Column::GroupId",
to = "super::groups::Column::GroupId",
on_update = "Cascade",
on_delete = "Cascade"
)]
Groups,
#[sea_orm(
belongs_to = "super::users::Entity",
from = "Column::UserId",
to = "super::users::Column::UserId",
on_update = "Cascade",
on_delete = "Cascade"
)]
Users,
}
impl Related<super::groups::Entity> for Entity {
fn to() -> RelationDef {
Relation::Groups.def()
}
}
impl Related<super::users::Entity> for Entity {
fn to() -> RelationDef {
Relation::Users.def()
}
}
#[derive(Debug)]
pub struct UserToGroup;
impl Linked for UserToGroup {
type FromEntity = super::User;
type ToEntity = super::Group;
fn link(&self) -> Vec<RelationDef> {
vec![Relation::Users.def().rev(), Relation::Groups.def()]
}
}
#[derive(Debug)]
pub struct GroupToUser;
impl Linked for GroupToUser {
type FromEntity = super::Group;
type ToEntity = super::User;
fn link(&self) -> Vec<RelationDef> {
vec![Relation::Groups.def().rev(), Relation::Users.def()]
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@ -0,0 +1,12 @@
//! `SeaORM` Entity. Generated by sea-orm-codegen 0.10.3
pub mod prelude;
pub mod groups;
pub mod jwt_refresh_storage;
pub mod jwt_storage;
pub mod memberships;
pub mod password_reset_tokens;
pub mod users;
pub use prelude::*;

View File

@ -0,0 +1,35 @@
//! `SeaORM` Entity. Generated by sea-orm-codegen 0.10.3
use sea_orm::entity::prelude::*;
use serde::{Deserialize, Serialize};
use crate::domain::handler::UserId;
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)]
#[sea_orm(table_name = "password_reset_tokens")]
pub struct Model {
#[sea_orm(primary_key, auto_increment = false)]
pub token: String,
pub user_id: UserId,
pub expiry_date: chrono::DateTime<chrono::Utc>,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::users::Entity",
from = "Column::UserId",
to = "super::users::Column::UserId",
on_update = "Cascade",
on_delete = "Cascade"
)]
Users,
}
impl Related<super::users::Entity> for Entity {
fn to() -> RelationDef {
Relation::Users.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@ -0,0 +1,14 @@
//! `SeaORM` Entity. Generated by sea-orm-codegen 0.10.3
pub use super::groups::Column as GroupColumn;
pub use super::groups::Entity as Group;
pub use super::jwt_refresh_storage::Column as JwtRefreshStorageColumn;
pub use super::jwt_refresh_storage::Entity as JwtRefreshStorage;
pub use super::jwt_storage::Column as JwtStorageColumn;
pub use super::jwt_storage::Entity as JwtStorage;
pub use super::memberships::Column as MembershipColumn;
pub use super::memberships::Entity as Membership;
pub use super::password_reset_tokens::Column as PasswordResetTokensColumn;
pub use super::password_reset_tokens::Entity as PasswordResetTokens;
pub use super::users::Column as UserColumn;
pub use super::users::Entity as User;

View File

@ -0,0 +1,134 @@
//! `SeaORM` Entity. Generated by sea-orm-codegen 0.10.3
use sea_orm::entity::prelude::*;
use serde::{Deserialize, Serialize};
use crate::domain::handler::{JpegPhoto, UserId, Uuid};
#[derive(Copy, Clone, Default, Debug, DeriveEntity)]
pub struct Entity;
#[derive(Clone, Debug, PartialEq, DeriveModel, Eq, Serialize, Deserialize, DeriveActiveModel)]
#[sea_orm(table_name = "users")]
pub struct Model {
#[sea_orm(primary_key, auto_increment = false)]
pub user_id: UserId,
pub email: String,
pub display_name: Option<String>,
pub first_name: Option<String>,
pub last_name: Option<String>,
pub avatar: Option<JpegPhoto>,
pub creation_date: chrono::DateTime<chrono::Utc>,
pub password_hash: Option<Vec<u8>>,
pub totp_secret: Option<String>,
pub mfa_type: Option<String>,
pub uuid: Uuid,
}
impl EntityName for Entity {
fn table_name(&self) -> &str {
"users"
}
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn, PartialEq, Eq, Serialize, Deserialize)]
pub enum Column {
UserId,
Email,
DisplayName,
FirstName,
LastName,
Avatar,
CreationDate,
PasswordHash,
TotpSecret,
MfaType,
Uuid,
}
impl ColumnTrait for Column {
type EntityName = Entity;
fn def(&self) -> ColumnDef {
match self {
Column::UserId => ColumnType::String(Some(255)),
Column::Email => ColumnType::String(Some(255)),
Column::DisplayName => ColumnType::String(Some(255)),
Column::FirstName => ColumnType::String(Some(255)),
Column::LastName => ColumnType::String(Some(255)),
Column::Avatar => ColumnType::Binary,
Column::CreationDate => ColumnType::DateTime,
Column::PasswordHash => ColumnType::Binary,
Column::TotpSecret => ColumnType::String(Some(64)),
Column::MfaType => ColumnType::String(Some(64)),
Column::Uuid => ColumnType::String(Some(36)),
}
.def()
}
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(has_many = "super::memberships::Entity")]
Memberships,
#[sea_orm(has_many = "super::jwt_refresh_storage::Entity")]
JwtRefreshStorage,
#[sea_orm(has_many = "super::jwt_storage::Entity")]
JwtStorage,
#[sea_orm(has_many = "super::password_reset_tokens::Entity")]
PasswordResetTokens,
}
#[derive(Copy, Clone, Debug, EnumIter, DerivePrimaryKey)]
pub enum PrimaryKey {
UserId,
}
impl PrimaryKeyTrait for PrimaryKey {
type ValueType = UserId;
fn auto_increment() -> bool {
false
}
}
impl Related<super::memberships::Entity> for Entity {
fn to() -> RelationDef {
Relation::Memberships.def()
}
}
impl Related<super::jwt_refresh_storage::Entity> for Entity {
fn to() -> RelationDef {
Relation::JwtRefreshStorage.def()
}
}
impl Related<super::jwt_storage::Entity> for Entity {
fn to() -> RelationDef {
Relation::JwtStorage.def()
}
}
impl Related<super::password_reset_tokens::Entity> for Entity {
fn to() -> RelationDef {
Relation::PasswordResetTokens.def()
}
}
impl ActiveModelBehavior for ActiveModel {}
impl From<Model> for crate::domain::handler::User {
fn from(user: Model) -> Self {
Self {
user_id: user.user_id,
email: user.email,
display_name: user.display_name,
first_name: user.first_name,
last_name: user.last_name,
creation_date: user.creation_date,
uuid: user.uuid,
avatar: user.avatar,
}
}
}

View File

@ -5,11 +5,11 @@ use async_trait::async_trait;
#[derive(Clone)]
pub struct SqlBackendHandler {
pub(crate) config: Configuration,
pub(crate) sql_pool: Pool,
pub(crate) sql_pool: DbConnection,
}
impl SqlBackendHandler {
pub fn new(config: Configuration, sql_pool: Pool) -> Self {
pub fn new(config: Configuration, sql_pool: DbConnection) -> Self {
SqlBackendHandler { config, sql_pool }
}
}
@ -23,16 +23,23 @@ pub mod tests {
use crate::domain::sql_tables::init_table;
use crate::infra::configuration::ConfigurationBuilder;
use lldap_auth::{opaque, registration};
use sea_orm::Database;
pub fn get_default_config() -> Configuration {
ConfigurationBuilder::for_tests()
}
pub async fn get_in_memory_db() -> Pool {
PoolOptions::new().connect("sqlite::memory:").await.unwrap()
pub async fn get_in_memory_db() -> DbConnection {
crate::infra::logging::init_for_tests();
let mut sql_opt = sea_orm::ConnectOptions::new("sqlite::memory:".to_owned());
sql_opt
.max_connections(1)
.sqlx_logging(true)
.sqlx_logging_level(log::LevelFilter::Debug);
Database::connect(sql_opt).await.unwrap()
}
pub async fn get_initialized_db() -> Pool {
pub async fn get_initialized_db() -> DbConnection {
let sql_pool = get_in_memory_db().await;
init_table(&sql_pool).await.unwrap();
sql_pool

View File

@ -1,21 +1,22 @@
use crate::domain::handler::Uuid;
use super::{
error::Result,
error::{DomainError, Result},
handler::{
Group, GroupBackendHandler, GroupDetails, GroupId, GroupRequestFilter, UpdateGroupRequest,
UserId,
},
model::{self, GroupColumn, MembershipColumn},
sql_backend_handler::SqlBackendHandler,
sql_tables::{DbQueryBuilder, Groups, Memberships},
};
use async_trait::async_trait;
use sea_query::{Cond, Expr, Iden, Order, Query, SimpleExpr};
use sea_query_binder::SqlxBinder;
use sqlx::{query_as_with, query_with, FromRow, Row};
use sea_orm::{
ActiveModelTrait, ActiveValue, ColumnTrait, EntityTrait, QueryFilter, QueryOrder, QuerySelect,
QueryTrait,
};
use sea_query::{Cond, IntoCondition, SimpleExpr};
use tracing::{debug, instrument};
// Returns the condition for the SQL query, and whether it requires joining with the groups table.
fn get_group_filter_expr(filter: GroupRequestFilter) -> Cond {
use sea_query::IntoCondition;
use GroupRequestFilter::*;
match filter {
And(fs) => {
@ -35,23 +36,17 @@ fn get_group_filter_expr(filter: GroupRequestFilter) -> Cond {
}
}
Not(f) => get_group_filter_expr(*f).not(),
DisplayName(name) => Expr::col((Groups::Table, Groups::DisplayName))
.eq(name)
.into_condition(),
GroupId(id) => Expr::col((Groups::Table, Groups::GroupId))
.eq(id.0)
.into_condition(),
Uuid(uuid) => Expr::col((Groups::Table, Groups::Uuid))
.eq(uuid.to_string())
.into_condition(),
DisplayName(name) => GroupColumn::DisplayName.eq(name).into_condition(),
GroupId(id) => GroupColumn::GroupId.eq(id.0).into_condition(),
Uuid(uuid) => GroupColumn::Uuid.eq(uuid.to_string()).into_condition(),
// WHERE (group_id in (SELECT group_id FROM memberships WHERE user_id = user))
Member(user) => Expr::col((Memberships::Table, Memberships::GroupId))
Member(user) => GroupColumn::GroupId
.in_subquery(
Query::select()
.column(Memberships::GroupId)
.from(Memberships::Table)
.cond_where(Expr::col(Memberships::UserId).eq(user))
.take(),
model::Membership::find()
.select_only()
.column(MembershipColumn::GroupId)
.filter(MembershipColumn::UserId.eq(user))
.into_query(),
)
.into_condition(),
}
@ -62,94 +57,67 @@ impl GroupBackendHandler for SqlBackendHandler {
#[instrument(skip_all, level = "debug", ret, err)]
async fn list_groups(&self, filters: Option<GroupRequestFilter>) -> Result<Vec<Group>> {
debug!(?filters);
let (query, values) = {
let mut query_builder = Query::select()
.column((Groups::Table, Groups::GroupId))
.column(Groups::DisplayName)
.column(Groups::CreationDate)
.column(Groups::Uuid)
.column(Memberships::UserId)
.from(Groups::Table)
.left_join(
Memberships::Table,
Expr::tbl(Groups::Table, Groups::GroupId)
.equals(Memberships::Table, Memberships::GroupId),
let results = model::Group::find()
// The order_by must be before find_with_related otherwise the primary order is by group_id.
.order_by_asc(GroupColumn::DisplayName)
.find_with_related(model::Membership)
.filter(
filters
.map(|f| {
GroupColumn::GroupId
.in_subquery(
model::Group::find()
.find_also_linked(model::memberships::GroupToUser)
.select_only()
.column(GroupColumn::GroupId)
.filter(get_group_filter_expr(f))
.into_query(),
)
.order_by(Groups::DisplayName, Order::Asc)
.order_by(Memberships::UserId, Order::Asc)
.to_owned();
if let Some(filter) = filters {
query_builder.cond_where(get_group_filter_expr(filter));
}
query_builder.build_sqlx(DbQueryBuilder {})
};
debug!(%query);
// For group_by.
use itertools::Itertools;
let mut groups = Vec::new();
// The rows are returned sorted by display_name, equivalent to group_id. We group them by
// this key which gives us one element (`rows`) per group.
for (group_details, rows) in &query_with(&query, values)
.fetch_all(&self.sql_pool)
.await?
.into_condition()
})
.unwrap_or_else(|| SimpleExpr::Value(true.into()).into_condition()),
)
.all(&self.sql_pool)
.await?;
Ok(results
.into_iter()
.group_by(|row| GroupDetails::from_row(row).unwrap())
{
groups.push(Group {
id: group_details.group_id,
display_name: group_details.display_name,
creation_date: group_details.creation_date,
uuid: group_details.uuid,
users: rows
.map(|row| row.get::<UserId, _>(&*Memberships::UserId.to_string()))
// If a group has no users, an empty string is returned because of the left
// join.
.filter(|s| !s.as_str().is_empty())
.collect(),
});
.map(|(group, users)| {
let users: Vec<_> = users.into_iter().map(|u| u.user_id).collect();
Group {
users,
..group.into()
}
Ok(groups)
})
.collect())
}
#[instrument(skip_all, level = "debug", ret, err)]
async fn get_group_details(&self, group_id: GroupId) -> Result<GroupDetails> {
debug!(?group_id);
let (query, values) = Query::select()
.column(Groups::GroupId)
.column(Groups::DisplayName)
.column(Groups::CreationDate)
.column(Groups::Uuid)
.from(Groups::Table)
.cond_where(Expr::col(Groups::GroupId).eq(group_id))
.build_sqlx(DbQueryBuilder {});
debug!(%query);
Ok(query_as_with::<_, GroupDetails, _>(&query, values)
.fetch_one(&self.sql_pool)
.await?)
model::Group::find_by_id(group_id)
.into_model::<GroupDetails>()
.one(&self.sql_pool)
.await?
.ok_or_else(|| DomainError::EntityNotFound(format!("{:?}", group_id)))
}
#[instrument(skip_all, level = "debug", err)]
async fn update_group(&self, request: UpdateGroupRequest) -> Result<()> {
debug!(?request.group_id);
let mut values = Vec::new();
if let Some(display_name) = request.display_name {
values.push((Groups::DisplayName, display_name.into()));
}
if values.is_empty() {
return Ok(());
}
let (query, values) = Query::update()
.table(Groups::Table)
.values(values)
.cond_where(Expr::col(Groups::GroupId).eq(request.group_id))
.build_sqlx(DbQueryBuilder {});
debug!(%query);
query_with(query.as_str(), values)
.execute(&self.sql_pool)
let update_group = model::groups::ActiveModel {
display_name: request
.display_name
.map(ActiveValue::Set)
.unwrap_or_default(),
..Default::default()
};
model::Group::update_many()
.set(update_group)
.filter(sea_orm::ColumnTrait::eq(
&GroupColumn::GroupId,
request.group_id,
))
.exec(&self.sql_pool)
.await?;
Ok(())
}
@ -157,30 +125,29 @@ impl GroupBackendHandler for SqlBackendHandler {
#[instrument(skip_all, level = "debug", ret, err)]
async fn create_group(&self, group_name: &str) -> Result<GroupId> {
debug!(?group_name);
crate::domain::sql_tables::create_group(group_name, &self.sql_pool).await?;
let (query, values) = Query::select()
.column(Groups::GroupId)
.from(Groups::Table)
.cond_where(Expr::col(Groups::DisplayName).eq(group_name))
.build_sqlx(DbQueryBuilder {});
debug!(%query);
let row = query_with(query.as_str(), values)
.fetch_one(&self.sql_pool)
.await?;
Ok(GroupId(row.get::<i32, _>(&*Groups::GroupId.to_string())))
let now = chrono::Utc::now();
let uuid = Uuid::from_name_and_date(group_name, &now);
let new_group = model::groups::ActiveModel {
display_name: ActiveValue::Set(group_name.to_owned()),
creation_date: ActiveValue::Set(now),
uuid: ActiveValue::Set(uuid),
..Default::default()
};
Ok(new_group.insert(&self.sql_pool).await?.group_id)
}
#[instrument(skip_all, level = "debug", err)]
async fn delete_group(&self, group_id: GroupId) -> Result<()> {
debug!(?group_id);
let (query, values) = Query::delete()
.from_table(Groups::Table)
.cond_where(Expr::col(Groups::GroupId).eq(group_id))
.build_sqlx(DbQueryBuilder {});
debug!(%query);
query_with(query.as_str(), values)
.execute(&self.sql_pool)
let res = model::Group::delete_by_id(group_id)
.exec(&self.sql_pool)
.await?;
if res.rows_affected == 0 {
return Err(DomainError::EntityNotFound(format!(
"No such group: '{:?}'",
group_id
)));
}
Ok(())
}
}
@ -188,7 +155,7 @@ impl GroupBackendHandler for SqlBackendHandler {
#[cfg(test)]
mod tests {
use super::*;
use crate::domain::sql_backend_handler::tests::*;
use crate::domain::{handler::UserId, sql_backend_handler::tests::*};
async fn get_group_ids(
handler: &SqlBackendHandler,
@ -203,12 +170,29 @@ mod tests {
.collect::<Vec<_>>()
}
async fn get_group_names(
handler: &SqlBackendHandler,
filters: Option<GroupRequestFilter>,
) -> Vec<String> {
handler
.list_groups(filters)
.await
.unwrap()
.into_iter()
.map(|g| g.display_name)
.collect::<Vec<_>>()
}
#[tokio::test]
async fn test_list_groups_no_filter() {
let fixture = TestFixture::new().await;
assert_eq!(
get_group_ids(&fixture.handler, None).await,
vec![fixture.groups[0], fixture.groups[2], fixture.groups[1]]
get_group_names(&fixture.handler, None).await,
vec![
"Best Group".to_owned(),
"Empty Group".to_owned(),
"Worst Group".to_owned()
]
);
}
@ -216,15 +200,15 @@ mod tests {
async fn test_list_groups_simple_filter() {
let fixture = TestFixture::new().await;
assert_eq!(
get_group_ids(
get_group_names(
&fixture.handler,
Some(GroupRequestFilter::Or(vec![
GroupRequestFilter::DisplayName("Empty Group".to_string()),
GroupRequestFilter::DisplayName("Empty Group".to_owned()),
GroupRequestFilter::Member(UserId::new("bob")),
]))
)
.await,
vec![fixture.groups[0], fixture.groups[2]]
vec!["Best Group".to_owned(), "Empty Group".to_owned()]
);
}
@ -236,7 +220,7 @@ mod tests {
&fixture.handler,
Some(GroupRequestFilter::And(vec![
GroupRequestFilter::Not(Box::new(GroupRequestFilter::DisplayName(
"value".to_string()
"value".to_owned()
))),
GroupRequestFilter::GroupId(fixture.groups[0]),
]))
@ -273,7 +257,7 @@ mod tests {
.handler
.update_group(UpdateGroupRequest {
group_id: fixture.groups[0],
display_name: Some("Awesomest Group".to_string()),
display_name: Some("Awesomest Group".to_owned()),
})
.await
.unwrap();
@ -288,6 +272,10 @@ mod tests {
#[tokio::test]
async fn test_delete_group() {
let fixture = TestFixture::new().await;
assert_eq!(
get_group_ids(&fixture.handler, None).await,
vec![fixture.groups[0], fixture.groups[2], fixture.groups[1]]
);
fixture
.handler
.delete_group(fixture.groups[0])

View File

@ -1,55 +1,87 @@
use super::{
handler::{GroupId, UserId, Uuid},
sql_tables::{
DbQueryBuilder, DbRow, Groups, Memberships, Metadata, Pool, SchemaVersion, Users,
},
sql_tables::{DbConnection, SchemaVersion},
};
use sea_query::*;
use sea_query_binder::SqlxBinder;
use sqlx::Row;
use tracing::{debug, warn};
use sea_orm::{ConnectionTrait, FromQueryResult, Statement};
use sea_query::{ColumnDef, Expr, ForeignKey, ForeignKeyAction, Iden, Query, Table, Value};
use serde::{Deserialize, Serialize};
use tracing::{instrument, warn};
pub async fn create_group(group_name: &str, pool: &Pool) -> sqlx::Result<()> {
let now = chrono::Utc::now();
let (query, values) = Query::insert()
.into_table(Groups::Table)
.columns(vec![
Groups::DisplayName,
Groups::CreationDate,
Groups::Uuid,
])
.values_panic(vec![
group_name.into(),
now.naive_utc().into(),
Uuid::from_name_and_date(group_name, &now).into(),
])
.build_sqlx(DbQueryBuilder {});
debug!(%query);
sqlx::query_with(query.as_str(), values)
.execute(pool)
.await
.map(|_| ())
#[derive(Iden, PartialEq, Eq, Debug, Serialize, Deserialize, Clone)]
pub enum Users {
Table,
UserId,
Email,
DisplayName,
FirstName,
LastName,
Avatar,
CreationDate,
PasswordHash,
TotpSecret,
MfaType,
Uuid,
}
pub async fn get_schema_version(pool: &Pool) -> Option<SchemaVersion> {
sqlx::query(
&Query::select()
#[derive(Iden, PartialEq, Eq, Debug, Serialize, Deserialize, Clone)]
pub enum Groups {
Table,
GroupId,
DisplayName,
CreationDate,
Uuid,
}
#[derive(Iden)]
pub enum Memberships {
Table,
UserId,
GroupId,
}
// Metadata about the SQL DB.
#[derive(Iden)]
pub enum Metadata {
Table,
// Which version of the schema we're at.
Version,
}
#[derive(FromQueryResult, PartialEq, Eq, Debug)]
pub struct JustSchemaVersion {
pub version: SchemaVersion,
}
#[instrument(skip_all, level = "debug", ret)]
pub async fn get_schema_version(pool: &DbConnection) -> Option<SchemaVersion> {
JustSchemaVersion::find_by_statement(
pool.get_database_backend().build(
Query::select()
.from(Metadata::Table)
.column(Metadata::Version)
.to_string(DbQueryBuilder {}),
.column(Metadata::Version),
),
)
.map(|row: DbRow| row.get::<SchemaVersion, _>(&*Metadata::Version.to_string()))
.fetch_one(pool)
.one(pool)
.await
.ok()
.flatten()
.map(|j| j.version)
}
pub async fn upgrade_to_v1(pool: &Pool) -> sqlx::Result<()> {
pub async fn upgrade_to_v1(pool: &DbConnection) -> std::result::Result<(), sea_orm::DbErr> {
let builder = pool.get_database_backend();
// SQLite needs this pragma to be turned on. Other DB might not understand this, so ignore the
// error.
let _ = sqlx::query("PRAGMA foreign_keys = ON").execute(pool).await;
sqlx::query(
&Table::create()
let _ = pool
.execute(Statement::from_string(
builder,
"PRAGMA foreign_keys = ON".to_owned(),
))
.await;
pool.execute(
builder.build(
Table::create()
.table(Users::Table)
.if_not_exists()
.col(
@ -64,21 +96,21 @@ pub async fn upgrade_to_v1(pool: &Pool) -> sqlx::Result<()> {
.string_len(255)
.not_null(),
)
.col(ColumnDef::new(Users::FirstName).string_len(255).not_null())
.col(ColumnDef::new(Users::LastName).string_len(255).not_null())
.col(ColumnDef::new(Users::FirstName).string_len(255))
.col(ColumnDef::new(Users::LastName).string_len(255))
.col(ColumnDef::new(Users::Avatar).binary())
.col(ColumnDef::new(Users::CreationDate).date_time().not_null())
.col(ColumnDef::new(Users::PasswordHash).binary())
.col(ColumnDef::new(Users::TotpSecret).string_len(64))
.col(ColumnDef::new(Users::MfaType).string_len(64))
.col(ColumnDef::new(Users::Uuid).string_len(36).not_null())
.to_string(DbQueryBuilder {}),
.col(ColumnDef::new(Users::Uuid).string_len(36).not_null()),
),
)
.execute(pool)
.await?;
sqlx::query(
&Table::create()
pool.execute(
builder.build(
Table::create()
.table(Groups::Table)
.if_not_exists()
.col(
@ -94,25 +126,23 @@ pub async fn upgrade_to_v1(pool: &Pool) -> sqlx::Result<()> {
.not_null(),
)
.col(ColumnDef::new(Users::CreationDate).date_time().not_null())
.col(ColumnDef::new(Users::Uuid).string_len(36).not_null())
.to_string(DbQueryBuilder {}),
.col(ColumnDef::new(Users::Uuid).string_len(36).not_null()),
),
)
.execute(pool)
.await?;
// If the creation_date column doesn't exist, add it.
if sqlx::query(
&Table::alter()
.table(Groups::Table)
.add_column(
if pool
.execute(
builder.build(
Table::alter().table(Groups::Table).add_column(
ColumnDef::new(Groups::CreationDate)
.date_time()
.not_null()
.default(chrono::Utc::now().naive_utc()),
),
),
)
.to_string(DbQueryBuilder {}),
)
.execute(pool)
.await
.is_ok()
{
@ -120,107 +150,109 @@ pub async fn upgrade_to_v1(pool: &Pool) -> sqlx::Result<()> {
}
// If the uuid column doesn't exist, add it.
if sqlx::query(
&Table::alter()
.table(Groups::Table)
.add_column(
if pool
.execute(
builder.build(
Table::alter().table(Groups::Table).add_column(
ColumnDef::new(Groups::Uuid)
.string_len(36)
.not_null()
.default(""),
),
),
)
.to_string(DbQueryBuilder {}),
)
.execute(pool)
.await
.is_ok()
{
warn!("`uuid` column not found in `groups`, creating it");
for row in sqlx::query(
&Query::select()
#[derive(FromQueryResult)]
struct ShortGroupDetails {
group_id: GroupId,
display_name: String,
creation_date: chrono::DateTime<chrono::Utc>,
}
for result in ShortGroupDetails::find_by_statement(
builder.build(
Query::select()
.from(Groups::Table)
.column(Groups::GroupId)
.column(Groups::DisplayName)
.column(Groups::CreationDate)
.to_string(DbQueryBuilder {}),
.column(Groups::CreationDate),
),
)
.fetch_all(pool)
.all(pool)
.await?
{
sqlx::query(
&Query::update()
pool.execute(
builder.build(
Query::update()
.table(Groups::Table)
.value(
Groups::Uuid,
Uuid::from_name_and_date(
&row.get::<String, _>(&*Groups::DisplayName.to_string()),
&row.get::<chrono::DateTime<chrono::Utc>, _>(
&*Groups::CreationDate.to_string(),
Value::from(Uuid::from_name_and_date(
&result.display_name,
&result.creation_date,
)),
)
.and_where(Expr::col(Groups::GroupId).eq(result.group_id)),
),
)
.into(),
)
.and_where(
Expr::col(Groups::GroupId)
.eq(row.get::<GroupId, _>(&*Groups::GroupId.to_string())),
)
.to_string(DbQueryBuilder {}),
)
.execute(pool)
.await?;
}
}
if sqlx::query(
&Table::alter()
.table(Users::Table)
.add_column(
if pool
.execute(
builder.build(
Table::alter().table(Users::Table).add_column(
ColumnDef::new(Users::Uuid)
.string_len(36)
.not_null()
.default(""),
),
),
)
.to_string(DbQueryBuilder {}),
)
.execute(pool)
.await
.is_ok()
{
warn!("`uuid` column not found in `users`, creating it");
for row in sqlx::query(
&Query::select()
#[derive(FromQueryResult)]
struct ShortUserDetails {
user_id: UserId,
creation_date: chrono::DateTime<chrono::Utc>,
}
for result in ShortUserDetails::find_by_statement(
builder.build(
Query::select()
.from(Users::Table)
.column(Users::UserId)
.column(Users::CreationDate)
.to_string(DbQueryBuilder {}),
.column(Users::CreationDate),
),
)
.fetch_all(pool)
.all(pool)
.await?
{
let user_id = row.get::<UserId, _>(&*Users::UserId.to_string());
sqlx::query(
&Query::update()
pool.execute(
builder.build(
Query::update()
.table(Users::Table)
.value(
Users::Uuid,
Uuid::from_name_and_date(
user_id.as_str(),
&row.get::<chrono::DateTime<chrono::Utc>, _>(
&*Users::CreationDate.to_string(),
Value::from(Uuid::from_name_and_date(
result.user_id.as_str(),
&result.creation_date,
)),
)
.and_where(Expr::col(Users::UserId).eq(result.user_id)),
),
)
.into(),
)
.and_where(Expr::col(Users::UserId).eq(user_id))
.to_string(DbQueryBuilder {}),
)
.execute(pool)
.await?;
}
}
sqlx::query(
&Table::create()
pool.execute(
builder.build(
Table::create()
.table(Memberships::Table)
.if_not_exists()
.col(
@ -244,59 +276,63 @@ pub async fn upgrade_to_v1(pool: &Pool) -> sqlx::Result<()> {
.to(Groups::Table, Groups::GroupId)
.on_delete(ForeignKeyAction::Cascade)
.on_update(ForeignKeyAction::Cascade),
),
),
)
.to_string(DbQueryBuilder {}),
)
.execute(pool)
.await?;
if sqlx::query(
&Query::select()
if pool
.query_one(
builder.build(
Query::select()
.from(Groups::Table)
.column(Groups::DisplayName)
.cond_where(Expr::col(Groups::DisplayName).eq("lldap_readonly"))
.to_string(DbQueryBuilder {}),
.cond_where(Expr::col(Groups::DisplayName).eq("lldap_readonly")),
),
)
.fetch_one(pool)
.await
.is_ok()
{
sqlx::query(
&Query::update()
pool.execute(
builder.build(
Query::update()
.table(Groups::Table)
.values(vec![(Groups::DisplayName, "lldap_password_manager".into())])
.cond_where(Expr::col(Groups::DisplayName).eq("lldap_readonly"))
.to_string(DbQueryBuilder {}),
.cond_where(Expr::col(Groups::DisplayName).eq("lldap_readonly")),
),
)
.execute(pool)
.await?;
create_group("lldap_strict_readonly", pool).await?
}
sqlx::query(
&Table::create()
pool.execute(
builder.build(
Table::create()
.table(Metadata::Table)
.if_not_exists()
.col(ColumnDef::new(Metadata::Version).tiny_integer().not_null())
.to_string(DbQueryBuilder {}),
.col(ColumnDef::new(Metadata::Version).tiny_integer()),
),
)
.execute(pool)
.await?;
sqlx::query(
&Query::insert()
pool.execute(
builder.build(
Query::insert()
.into_table(Metadata::Table)
.columns(vec![Metadata::Version])
.values_panic(vec![SchemaVersion(1).into()])
.to_string(DbQueryBuilder {}),
.values_panic(vec![SchemaVersion(1).into()]),
),
)
.execute(pool)
.await?;
assert_eq!(get_schema_version(pool).await.unwrap().0, 1);
Ok(())
}
pub async fn migrate_from_version(_pool: &Pool, version: SchemaVersion) -> anyhow::Result<()> {
pub async fn migrate_from_version(
_pool: &DbConnection,
version: SchemaVersion,
) -> anyhow::Result<()> {
if version.0 > 1 {
anyhow::bail!("DB version downgrading is not supported");
}

View File

@ -1,16 +1,14 @@
use super::{
error::{DomainError, Result},
handler::{BindRequest, LoginHandler, UserId},
model::{self, UserColumn},
opaque_handler::{login, registration, OpaqueHandler},
sql_backend_handler::SqlBackendHandler,
sql_tables::{DbQueryBuilder, Users},
};
use async_trait::async_trait;
use lldap_auth::opaque;
use sea_query::{Expr, Iden, Query};
use sea_query_binder::SqlxBinder;
use sea_orm::{ActiveValue, EntityTrait, FromQueryResult, QuerySelect};
use secstr::SecUtf8;
use sqlx::Row;
use tracing::{debug, instrument};
type SqlOpaqueHandler = SqlBackendHandler;
@ -50,39 +48,19 @@ impl SqlBackendHandler {
}
#[instrument(skip_all, level = "debug", err)]
async fn get_password_file_for_user(
&self,
username: &str,
) -> Result<Option<opaque::server::ServerRegistration>> {
async fn get_password_file_for_user(&self, user_id: UserId) -> Result<Option<Vec<u8>>> {
#[derive(FromQueryResult)]
struct OnlyPasswordHash {
password_hash: Option<Vec<u8>>,
}
// Fetch the previously registered password file from the DB.
let password_file_bytes = {
let (query, values) = Query::select()
.column(Users::PasswordHash)
.from(Users::Table)
.cond_where(Expr::col(Users::UserId).eq(username))
.build_sqlx(DbQueryBuilder {});
if let Some(row) = sqlx::query_with(query.as_str(), values)
.fetch_optional(&self.sql_pool)
Ok(model::User::find_by_id(user_id)
.select_only()
.column(UserColumn::PasswordHash)
.into_model::<OnlyPasswordHash>()
.one(&self.sql_pool)
.await?
{
if let Some(bytes) =
row.get::<Option<Vec<u8>>, _>(&*Users::PasswordHash.to_string())
{
bytes
} else {
// No password set.
return Ok(None);
}
} else {
// No such user.
return Ok(None);
}
};
opaque::server::ServerRegistration::deserialize(&password_file_bytes)
.map(Option::Some)
.map_err(|_| {
DomainError::InternalError(format!("Corrupted password file for {}", username))
})
.and_then(|u| u.password_hash))
}
}
@ -90,17 +68,9 @@ impl SqlBackendHandler {
impl LoginHandler for SqlBackendHandler {
#[instrument(skip_all, level = "debug", err)]
async fn bind(&self, request: BindRequest) -> Result<()> {
let (query, values) = Query::select()
.column(Users::PasswordHash)
.from(Users::Table)
.cond_where(Expr::col(Users::UserId).eq(&request.name))
.build_sqlx(DbQueryBuilder {});
if let Ok(row) = sqlx::query_with(&query, values)
.fetch_one(&self.sql_pool)
.await
{
if let Some(password_hash) =
row.get::<Option<Vec<u8>>, _>(&*Users::PasswordHash.to_string())
if let Some(password_hash) = self
.get_password_file_for_user(request.name.clone())
.await?
{
if let Err(e) = passwords_match(
&password_hash,
@ -113,10 +83,10 @@ impl LoginHandler for SqlBackendHandler {
return Ok(());
}
} else {
debug!(r#"User "{}" has no password"#, &request.name);
}
} else {
debug!(r#"No user found for "{}""#, &request.name);
debug!(
r#"User "{}" doesn't exist or has no password"#,
&request.name
);
}
Err(DomainError::AuthenticationError(format!(
" for user '{}'",
@ -132,7 +102,18 @@ impl OpaqueHandler for SqlOpaqueHandler {
&self,
request: login::ClientLoginStartRequest,
) -> Result<login::ServerLoginStartResponse> {
let maybe_password_file = self.get_password_file_for_user(&request.username).await?;
let maybe_password_file = self
.get_password_file_for_user(UserId::new(&request.username))
.await?
.map(|bytes| {
opaque::server::ServerRegistration::deserialize(&bytes).map_err(|_| {
DomainError::InternalError(format!(
"Corrupted password file for {}",
&request.username
))
})
})
.transpose()?;
let mut rng = rand::rngs::OsRng;
// Get the CredentialResponse for the user, or a dummy one if no user/no password.
@ -210,17 +191,16 @@ impl OpaqueHandler for SqlOpaqueHandler {
let password_file =
opaque::server::registration::get_password_file(request.registration_upload);
{
// Set the user password to the new password.
let (update_query, values) = Query::update()
.table(Users::Table)
.value(Users::PasswordHash, password_file.serialize().into())
.cond_where(Expr::col(Users::UserId).eq(username))
.build_sqlx(DbQueryBuilder {});
sqlx::query_with(update_query.as_str(), values)
.execute(&self.sql_pool)
let user_update = model::users::ActiveModel {
user_id: ActiveValue::Set(UserId::new(&username)),
password_hash: ActiveValue::Set(Some(password_file.serialize())),
..Default::default()
};
model::User::update_many()
.set(user_update)
.exec(&self.sql_pool)
.await?;
}
Ok(())
}
}

View File

@ -1,39 +1,116 @@
use super::{
handler::{GroupId, UserId, Uuid},
handler::{GroupId, JpegPhoto, UserId, Uuid},
sql_migrations::{get_schema_version, migrate_from_version, upgrade_to_v1},
};
use sea_query::*;
use serde::{Deserialize, Serialize};
use sea_orm::{DbErr, Value};
pub use super::sql_migrations::create_group;
pub type DbConnection = sea_orm::DatabaseConnection;
pub type Pool = sqlx::sqlite::SqlitePool;
pub type PoolOptions = sqlx::sqlite::SqlitePoolOptions;
pub type DbRow = sqlx::sqlite::SqliteRow;
pub type DbQueryBuilder = SqliteQueryBuilder;
#[derive(Copy, PartialEq, Eq, Debug, Clone, sqlx::FromRow, sqlx::Type)]
#[sqlx(transparent)]
#[derive(Copy, PartialEq, Eq, Debug, Clone)]
pub struct SchemaVersion(pub u8);
impl From<GroupId> for Value {
impl sea_orm::TryGetable for SchemaVersion {
fn try_get(
res: &sea_orm::QueryResult,
pre: &str,
col: &str,
) -> Result<Self, sea_orm::TryGetError> {
Ok(SchemaVersion(u8::try_get(res, pre, col)?))
}
}
impl From<GroupId> for sea_orm::Value {
fn from(group_id: GroupId) -> Self {
group_id.0.into()
}
}
impl From<UserId> for sea_query::Value {
impl sea_orm::TryGetable for GroupId {
fn try_get(
res: &sea_orm::QueryResult,
pre: &str,
col: &str,
) -> Result<Self, sea_orm::TryGetError> {
Ok(GroupId(i32::try_get(res, pre, col)?))
}
}
impl sea_orm::sea_query::value::ValueType for GroupId {
fn try_from(v: sea_orm::Value) -> Result<Self, sea_orm::sea_query::ValueTypeErr> {
Ok(GroupId(<i32 as sea_orm::sea_query::ValueType>::try_from(
v,
)?))
}
fn type_name() -> String {
"GroupId".to_owned()
}
fn array_type() -> sea_orm::sea_query::ArrayType {
sea_orm::sea_query::ArrayType::Int
}
fn column_type() -> sea_orm::sea_query::ColumnType {
sea_orm::sea_query::ColumnType::Integer(None)
}
}
impl sea_orm::TryFromU64 for GroupId {
fn try_from_u64(n: u64) -> Result<Self, sea_orm::DbErr> {
Ok(GroupId(i32::try_from_u64(n)?))
}
}
impl From<UserId> for sea_orm::Value {
fn from(user_id: UserId) -> Self {
user_id.into_string().into()
}
}
impl From<&UserId> for sea_query::Value {
impl From<&UserId> for sea_orm::Value {
fn from(user_id: &UserId) -> Self {
user_id.as_str().into()
}
}
impl sea_orm::TryGetable for UserId {
fn try_get(
res: &sea_orm::QueryResult,
pre: &str,
col: &str,
) -> Result<Self, sea_orm::TryGetError> {
Ok(UserId::new(&String::try_get(res, pre, col)?))
}
}
impl sea_orm::TryFromU64 for UserId {
fn try_from_u64(_n: u64) -> Result<Self, sea_orm::DbErr> {
Err(sea_orm::DbErr::ConvertFromU64(
"UserId cannot be constructed from u64",
))
}
}
impl sea_orm::sea_query::value::ValueType for UserId {
fn try_from(v: sea_orm::Value) -> Result<Self, sea_orm::sea_query::ValueTypeErr> {
Ok(UserId::new(
<String as sea_orm::sea_query::ValueType>::try_from(v)?.as_str(),
))
}
fn type_name() -> String {
"UserId".to_owned()
}
fn array_type() -> sea_orm::sea_query::ArrayType {
sea_orm::sea_query::ArrayType::String
}
fn column_type() -> sea_orm::sea_query::ColumnType {
sea_orm::sea_query::ColumnType::String(Some(255))
}
}
impl From<Uuid> for sea_query::Value {
fn from(uuid: Uuid) -> Self {
uuid.as_str().into()
@ -46,57 +123,84 @@ impl From<&Uuid> for sea_query::Value {
}
}
impl sea_orm::TryGetable for JpegPhoto {
fn try_get(
res: &sea_orm::QueryResult,
pre: &str,
col: &str,
) -> Result<Self, sea_orm::TryGetError> {
<JpegPhoto as std::convert::TryFrom<Vec<_>>>::try_from(Vec::<u8>::try_get(res, pre, col)?)
.map_err(|e| {
sea_orm::TryGetError::DbErr(DbErr::TryIntoErr {
from: "[u8]",
into: "JpegPhoto",
source: e.into(),
})
})
}
}
impl sea_orm::sea_query::value::ValueType for JpegPhoto {
fn try_from(v: sea_orm::Value) -> Result<Self, sea_orm::sea_query::ValueTypeErr> {
<JpegPhoto as std::convert::TryFrom<_>>::try_from(
<Vec<u8> as sea_orm::sea_query::ValueType>::try_from(v)?.as_slice(),
)
.map_err(|_| sea_orm::sea_query::ValueTypeErr {})
}
fn type_name() -> String {
"JpegPhoto".to_owned()
}
fn array_type() -> sea_orm::sea_query::ArrayType {
sea_orm::sea_query::ArrayType::Bytes
}
fn column_type() -> sea_orm::sea_query::ColumnType {
sea_orm::sea_query::ColumnType::Binary(sea_orm::sea_query::BlobSize::Long)
}
}
impl sea_orm::sea_query::Nullable for JpegPhoto {
fn null() -> sea_orm::Value {
JpegPhoto::null().into()
}
}
impl sea_orm::entity::IntoActiveValue<JpegPhoto> for JpegPhoto {
fn into_active_value(self) -> sea_orm::ActiveValue<JpegPhoto> {
sea_orm::ActiveValue::Set(self)
}
}
impl sea_orm::sea_query::value::ValueType for Uuid {
fn try_from(v: sea_orm::Value) -> Result<Self, sea_orm::sea_query::ValueTypeErr> {
<super::handler::Uuid as std::convert::TryFrom<_>>::try_from(
<std::string::String as sea_orm::sea_query::ValueType>::try_from(v)?.as_str(),
)
.map_err(|_| sea_orm::sea_query::ValueTypeErr {})
}
fn type_name() -> String {
"Uuid".to_owned()
}
fn array_type() -> sea_orm::sea_query::ArrayType {
sea_orm::sea_query::ArrayType::String
}
fn column_type() -> sea_orm::sea_query::ColumnType {
sea_orm::sea_query::ColumnType::String(Some(36))
}
}
impl From<SchemaVersion> for Value {
fn from(version: SchemaVersion) -> Self {
version.0.into()
}
}
#[derive(Iden, PartialEq, Eq, Debug, Serialize, Deserialize, Clone)]
pub enum Users {
Table,
UserId,
Email,
DisplayName,
FirstName,
LastName,
Avatar,
CreationDate,
PasswordHash,
TotpSecret,
MfaType,
Uuid,
}
pub type UserColumn = Users;
#[derive(Iden, PartialEq, Eq, Debug, Serialize, Deserialize, Clone)]
pub enum Groups {
Table,
GroupId,
DisplayName,
CreationDate,
Uuid,
}
pub type GroupColumn = Groups;
#[derive(Iden)]
pub enum Memberships {
Table,
UserId,
GroupId,
}
// Metadata about the SQL DB.
#[derive(Iden)]
pub enum Metadata {
Table,
// Which version of the schema we're at.
Version,
}
pub async fn init_table(pool: &Pool) -> anyhow::Result<()> {
pub async fn init_table(pool: &DbConnection) -> anyhow::Result<()> {
let version = {
if let Some(version) = get_schema_version(pool).await {
version
@ -111,33 +215,55 @@ pub async fn init_table(pool: &Pool) -> anyhow::Result<()> {
#[cfg(test)]
mod tests {
use crate::domain::sql_migrations;
use super::*;
use chrono::prelude::*;
use sqlx::{Column, Row};
use sea_orm::{ConnectionTrait, Database, DbBackend, FromQueryResult};
async fn get_in_memory_db() -> DbConnection {
let mut sql_opt = sea_orm::ConnectOptions::new("sqlite::memory:".to_owned());
sql_opt.max_connections(1).sqlx_logging(false);
Database::connect(sql_opt).await.unwrap()
}
fn raw_statement(sql: &str) -> sea_orm::Statement {
sea_orm::Statement::from_string(DbBackend::Sqlite, sql.to_owned())
}
#[tokio::test]
async fn test_init_table() {
let sql_pool = PoolOptions::new().connect("sqlite::memory:").await.unwrap();
let sql_pool = get_in_memory_db().await;
init_table(&sql_pool).await.unwrap();
sqlx::query(r#"INSERT INTO users
sql_pool.execute(raw_statement(
r#"INSERT INTO users
(user_id, email, display_name, first_name, last_name, creation_date, password_hash, uuid)
VALUES ("bôb", "böb@bob.bob", "Bob Bobbersön", "Bob", "Bobberson", "1970-01-01 00:00:00", "bob00", "abc")"#).execute(&sql_pool).await.unwrap();
let row =
sqlx::query(r#"SELECT display_name, creation_date FROM users WHERE user_id = "bôb""#)
.fetch_one(&sql_pool)
VALUES ("bôb", "böb@bob.bob", "Bob Bobbersön", "Bob", "Bobberson", "1970-01-01 00:00:00", "bob00", "abc")"#)).await.unwrap();
#[derive(FromQueryResult, PartialEq, Eq, Debug)]
struct ShortUserDetails {
display_name: String,
creation_date: chrono::DateTime<chrono::Utc>,
}
let result = ShortUserDetails::find_by_statement(raw_statement(
r#"SELECT display_name, creation_date FROM users WHERE user_id = "bôb""#,
))
.one(&sql_pool)
.await
.unwrap()
.unwrap();
assert_eq!(row.column(0).name(), "display_name");
assert_eq!(row.get::<String, _>("display_name"), "Bob Bobbersön");
assert_eq!(
row.get::<DateTime<Utc>, _>("creation_date"),
Utc.timestamp(0, 0),
result,
ShortUserDetails {
display_name: "Bob Bobbersön".to_owned(),
creation_date: Utc.timestamp_opt(0, 0).unwrap()
}
);
}
#[tokio::test]
async fn test_already_init_table() {
let sql_pool = PoolOptions::new().connect("sqlite::memory:").await.unwrap();
crate::infra::logging::init_for_tests();
let sql_pool = get_in_memory_db().await;
init_table(&sql_pool).await.unwrap();
init_table(&sql_pool).await.unwrap();
}
@ -145,87 +271,109 @@ mod tests {
#[tokio::test]
async fn test_migrate_tables() {
// Test that we add the column creation_date to groups and uuid to users and groups.
let sql_pool = PoolOptions::new().connect("sqlite::memory:").await.unwrap();
sqlx::query(r#"CREATE TABLE users ( user_id TEXT , creation_date TEXT);"#)
.execute(&sql_pool)
let sql_pool = get_in_memory_db().await;
sql_pool
.execute(raw_statement(
r#"CREATE TABLE users ( user_id TEXT , creation_date TEXT);"#,
))
.await
.unwrap();
sqlx::query(
sql_pool
.execute(raw_statement(
r#"INSERT INTO users (user_id, creation_date)
VALUES ("bôb", "1970-01-01 00:00:00")"#,
)
.execute(&sql_pool)
))
.await
.unwrap();
sqlx::query(r#"CREATE TABLE groups ( group_id INTEGER PRIMARY KEY, display_name TEXT );"#)
.execute(&sql_pool)
sql_pool
.execute(raw_statement(
r#"CREATE TABLE groups ( group_id INTEGER PRIMARY KEY, display_name TEXT );"#,
))
.await
.unwrap();
sqlx::query(
sql_pool
.execute(raw_statement(
r#"INSERT INTO groups (display_name)
VALUES ("lldap_admin"), ("lldap_readonly")"#,
)
.execute(&sql_pool)
))
.await
.unwrap();
init_table(&sql_pool).await.unwrap();
sqlx::query(
sql_pool
.execute(raw_statement(
r#"INSERT INTO groups (display_name, creation_date, uuid)
VALUES ("test", "1970-01-01 00:00:00", "abc")"#,
)
.execute(&sql_pool)
))
.await
.unwrap();
#[derive(FromQueryResult, PartialEq, Eq, Debug)]
struct JustUuid {
uuid: Uuid,
}
assert_eq!(
sqlx::query(r#"SELECT uuid FROM users"#)
.fetch_all(&sql_pool)
JustUuid::find_by_statement(raw_statement(r#"SELECT uuid FROM users"#))
.all(&sql_pool)
.await
.unwrap()
.into_iter()
.map(|row| row.get::<Uuid, _>("uuid"))
.collect::<Vec<_>>(),
vec![crate::uuid!("a02eaf13-48a7-30f6-a3d4-040ff7c52b04")]
.unwrap(),
vec![JustUuid {
uuid: crate::uuid!("a02eaf13-48a7-30f6-a3d4-040ff7c52b04")
}]
);
#[derive(FromQueryResult, PartialEq, Eq, Debug)]
struct ShortGroupDetails {
group_id: GroupId,
display_name: String,
}
assert_eq!(
sqlx::query(r#"SELECT group_id, display_name FROM groups"#)
.fetch_all(&sql_pool)
.await
.unwrap()
.into_iter()
.map(|row| (
row.get::<GroupId, _>("group_id"),
row.get::<String, _>("display_name")
ShortGroupDetails::find_by_statement(raw_statement(
r#"SELECT group_id, display_name, creation_date FROM groups"#
))
.collect::<Vec<_>>(),
.all(&sql_pool)
.await
.unwrap(),
vec![
(GroupId(1), "lldap_admin".to_string()),
(GroupId(2), "lldap_password_manager".to_string()),
(GroupId(3), "lldap_strict_readonly".to_string()),
(GroupId(4), "test".to_string())
ShortGroupDetails {
group_id: GroupId(1),
display_name: "lldap_admin".to_string()
},
ShortGroupDetails {
group_id: GroupId(2),
display_name: "lldap_password_manager".to_string()
},
ShortGroupDetails {
group_id: GroupId(3),
display_name: "test".to_string()
}
]
);
assert_eq!(
sqlx::query(r#"SELECT version FROM metadata"#)
.map(|row: DbRow| row.get::<SchemaVersion, _>("version"))
.fetch_one(&sql_pool)
sql_migrations::JustSchemaVersion::find_by_statement(raw_statement(
r#"SELECT version FROM metadata"#
))
.one(&sql_pool)
.await
.unwrap()
.unwrap(),
SchemaVersion(1)
sql_migrations::JustSchemaVersion {
version: SchemaVersion(1)
}
);
}
#[tokio::test]
async fn test_too_high_version() {
let sql_pool = PoolOptions::new().connect("sqlite::memory:").await.unwrap();
sqlx::query(r#"CREATE TABLE metadata ( version INTEGER);"#)
.execute(&sql_pool)
let sql_pool = get_in_memory_db().await;
sql_pool
.execute(raw_statement(
r#"CREATE TABLE metadata ( version INTEGER);"#,
))
.await
.unwrap();
sqlx::query(
sql_pool
.execute(raw_statement(
r#"INSERT INTO metadata (version)
VALUES (127)"#,
)
.execute(&sql_pool)
))
.await
.unwrap();
assert!(init_table(&sql_pool).await.is_err());

View File

@ -1,136 +1,68 @@
use super::{
error::Result,
error::{DomainError, Result},
handler::{
CreateUserRequest, GroupDetails, GroupId, UpdateUserRequest, User, UserAndGroups,
UserBackendHandler, UserId, UserRequestFilter, Uuid,
},
model::{self, GroupColumn, UserColumn},
sql_backend_handler::SqlBackendHandler,
sql_tables::{DbQueryBuilder, Groups, Memberships, Users},
};
use async_trait::async_trait;
use sea_query::{Alias, Cond, Expr, Iden, Order, Query, SimpleExpr};
use sea_query_binder::{SqlxBinder, SqlxValues};
use sqlx::{query_as_with, query_with, FromRow, Row};
use sea_orm::{
entity::IntoActiveValue,
sea_query::{Cond, Expr, IntoCondition, SimpleExpr},
ActiveModelTrait, ActiveValue, ColumnTrait, EntityTrait, ModelTrait, QueryFilter, QueryOrder,
QuerySelect, QueryTrait, Set,
};
use sea_query::{Alias, IntoColumnRef};
use std::collections::HashSet;
use tracing::{debug, instrument};
struct RequiresGroup(bool);
// Returns the condition for the SQL query, and whether it requires joining with the groups table.
fn get_user_filter_expr(filter: UserRequestFilter) -> (RequiresGroup, Cond) {
use sea_query::IntoCondition;
fn get_user_filter_expr(filter: UserRequestFilter) -> Cond {
use UserRequestFilter::*;
let group_table = Alias::new("r1");
fn get_repeated_filter(
fs: Vec<UserRequestFilter>,
condition: Cond,
default_value: bool,
) -> (RequiresGroup, Cond) {
) -> Cond {
if fs.is_empty() {
return (
RequiresGroup(false),
SimpleExpr::Value(default_value.into()).into_condition(),
);
SimpleExpr::Value(default_value.into()).into_condition()
} else {
fs.into_iter()
.map(get_user_filter_expr)
.fold(condition, Cond::add)
}
let mut requires_group = false;
let filter = fs.into_iter().fold(condition, |c, f| {
let (group, filters) = get_user_filter_expr(f);
requires_group |= group.0;
c.add(filters)
});
(RequiresGroup(requires_group), filter)
}
match filter {
And(fs) => get_repeated_filter(fs, Cond::all(), true),
Or(fs) => get_repeated_filter(fs, Cond::any(), false),
Not(f) => {
let (requires_group, filters) = get_user_filter_expr(*f);
(requires_group, filters.not())
}
UserId(user_id) => (
RequiresGroup(false),
Expr::col((Users::Table, Users::UserId))
.eq(user_id)
.into_condition(),
),
Equality(s1, s2) => (
RequiresGroup(false),
if s1 == Users::UserId {
Not(f) => get_user_filter_expr(*f).not(),
UserId(user_id) => ColumnTrait::eq(&UserColumn::UserId, user_id).into_condition(),
Equality(s1, s2) => {
if s1 == UserColumn::UserId {
panic!("User id should be wrapped")
} else {
Expr::col((Users::Table, s1)).eq(s2).into_condition()
},
),
MemberOf(group) => (
RequiresGroup(true),
Expr::col((Groups::Table, Groups::DisplayName))
ColumnTrait::eq(&s1, s2).into_condition()
}
}
MemberOf(group) => Expr::col((group_table, GroupColumn::DisplayName))
.eq(group)
.into_condition(),
),
MemberOfId(group_id) => (
RequiresGroup(true),
Expr::col((Groups::Table, Groups::GroupId))
MemberOfId(group_id) => Expr::col((group_table, GroupColumn::GroupId))
.eq(group_id)
.into_condition(),
),
}
}
fn get_list_users_query(
filters: Option<UserRequestFilter>,
get_groups: bool,
) -> (String, SqlxValues) {
let mut query_builder = Query::select()
.column((Users::Table, Users::UserId))
.column(Users::Email)
.column((Users::Table, Users::DisplayName))
.column(Users::FirstName)
.column(Users::LastName)
.column(Users::Avatar)
.column((Users::Table, Users::CreationDate))
.column((Users::Table, Users::Uuid))
.from(Users::Table)
.order_by((Users::Table, Users::UserId), Order::Asc)
.to_owned();
let add_join_group_tables = |builder: &mut sea_query::SelectStatement| {
builder
.left_join(
Memberships::Table,
Expr::tbl(Users::Table, Users::UserId)
.equals(Memberships::Table, Memberships::UserId),
)
.left_join(
Groups::Table,
Expr::tbl(Memberships::Table, Memberships::GroupId)
.equals(Groups::Table, Groups::GroupId),
);
};
if get_groups {
add_join_group_tables(&mut query_builder);
query_builder
.column((Groups::Table, Groups::GroupId))
.expr_as(
Expr::col((Groups::Table, Groups::DisplayName)),
Alias::new("group_display_name"),
)
.expr_as(
Expr::col((Groups::Table, Groups::CreationDate)),
sea_query::Alias::new("group_creation_date"),
)
.expr_as(
Expr::col((Groups::Table, Groups::Uuid)),
sea_query::Alias::new("group_uuid"),
)
.order_by(Alias::new("group_display_name"), Order::Asc);
fn to_value(opt_name: &Option<String>) -> ActiveValue<Option<String>> {
match opt_name {
None => ActiveValue::NotSet,
Some(name) => ActiveValue::Set(if name.is_empty() {
None
} else {
Some(name.to_owned())
}),
}
if let Some(filter) = filters {
let (RequiresGroup(requires_group), condition) = get_user_filter_expr(filter);
query_builder.cond_where(condition);
if requires_group && !get_groups {
add_join_group_tables(&mut query_builder);
}
}
query_builder.build_sqlx(DbQueryBuilder {})
}
#[async_trait]
@ -141,95 +73,86 @@ impl UserBackendHandler for SqlBackendHandler {
filters: Option<UserRequestFilter>,
get_groups: bool,
) -> Result<Vec<UserAndGroups>> {
debug!(?filters, get_groups);
let (query, values) = get_list_users_query(filters, get_groups);
debug!(%query);
// For group_by.
use itertools::Itertools;
let mut users = Vec::new();
// The rows are returned sorted by user_id. We group them by
// this key which gives us one element (`rows`) per group.
for (_, rows) in &query_with(&query, values)
.fetch_all(&self.sql_pool)
debug!(?filters);
let query = model::User::find()
.filter(
filters
.map(|f| {
UserColumn::UserId
.in_subquery(
model::User::find()
.find_also_linked(model::memberships::UserToGroup)
.select_only()
.column(UserColumn::UserId)
.filter(get_user_filter_expr(f))
.into_query(),
)
.into_condition()
})
.unwrap_or_else(|| SimpleExpr::Value(true.into()).into_condition()),
)
.order_by_asc(UserColumn::UserId);
if !get_groups {
Ok(query
.into_model::<User>()
.all(&self.sql_pool)
.await?
.into_iter()
.group_by(|row| row.get::<UserId, _>(&*Users::UserId.to_string()))
{
let mut rows = rows.peekable();
users.push(UserAndGroups {
user: User::from_row(rows.peek().unwrap()).unwrap(),
groups: if get_groups {
Some(
rows.filter_map(|row| {
let display_name = row.get::<String, _>("group_display_name");
if display_name.is_empty() {
None
} else {
Some(GroupDetails {
group_id: row.get::<GroupId, _>(&*Groups::GroupId.to_string()),
display_name,
creation_date: row.get::<chrono::DateTime<chrono::Utc>, _>(
"group_creation_date",
),
uuid: row.get::<Uuid, _>("group_uuid"),
.map(|u| UserAndGroups {
user: u,
groups: None,
})
.collect())
} else {
let results = query
//find_with_linked?
.find_also_linked(model::memberships::UserToGroup)
.order_by_asc(SimpleExpr::Column(
(Alias::new("r1"), GroupColumn::GroupId).into_column_ref(),
))
.all(&self.sql_pool)
.await?;
use itertools::Itertools;
Ok(results
.iter()
.group_by(|(u, _)| u)
.into_iter()
.map(|(user, groups)| {
let groups: Vec<_> = groups
.into_iter()
.flat_map(|(_, g)| g)
.map(|g| GroupDetails::from(g.clone()))
.collect();
UserAndGroups {
user: user.clone().into(),
groups: Some(groups),
}
})
.collect(),
)
} else {
None
},
});
.collect())
}
Ok(users)
}
#[instrument(skip_all, level = "debug", ret)]
async fn get_user_details(&self, user_id: &UserId) -> Result<User> {
debug!(?user_id);
let (query, values) = Query::select()
.column(Users::UserId)
.column(Users::Email)
.column(Users::DisplayName)
.column(Users::FirstName)
.column(Users::LastName)
.column(Users::Avatar)
.column(Users::CreationDate)
.column(Users::Uuid)
.from(Users::Table)
.cond_where(Expr::col(Users::UserId).eq(user_id))
.build_sqlx(DbQueryBuilder {});
debug!(%query);
Ok(query_as_with::<_, User, _>(query.as_str(), values)
.fetch_one(&self.sql_pool)
.await?)
model::User::find_by_id(user_id.to_owned())
.into_model::<User>()
.one(&self.sql_pool)
.await?
.ok_or_else(|| DomainError::EntityNotFound(user_id.to_string()))
}
#[instrument(skip_all, level = "debug", ret, err)]
async fn get_user_groups(&self, user_id: &UserId) -> Result<HashSet<GroupDetails>> {
debug!(?user_id);
let (query, values) = Query::select()
.column((Groups::Table, Groups::GroupId))
.column(Groups::DisplayName)
.column(Groups::CreationDate)
.column(Groups::Uuid)
.from(Groups::Table)
.inner_join(
Memberships::Table,
Expr::tbl(Groups::Table, Groups::GroupId)
.equals(Memberships::Table, Memberships::GroupId),
)
.cond_where(Expr::col(Memberships::UserId).eq(user_id))
.build_sqlx(DbQueryBuilder {});
debug!(%query);
let user = model::User::find_by_id(user_id.to_owned())
.one(&self.sql_pool)
.await?
.ok_or_else(|| DomainError::EntityNotFound(user_id.to_string()))?;
Ok(HashSet::from_iter(
query_as_with::<_, GroupDetails, _>(&query, values)
.fetch_all(&self.sql_pool)
user.find_linked(model::memberships::UserToGroup)
.into_model::<GroupDetails>()
.all(&self.sql_pool)
.await?,
))
}
@ -237,70 +160,41 @@ impl UserBackendHandler for SqlBackendHandler {
#[instrument(skip_all, level = "debug", err)]
async fn create_user(&self, request: CreateUserRequest) -> Result<()> {
debug!(user_id = ?request.user_id);
let columns = vec![
Users::UserId,
Users::Email,
Users::DisplayName,
Users::FirstName,
Users::LastName,
Users::Avatar,
Users::CreationDate,
Users::Uuid,
];
let now = chrono::Utc::now();
let uuid = Uuid::from_name_and_date(request.user_id.as_str(), &now);
let values = vec![
request.user_id.into(),
request.email.into(),
request.display_name.unwrap_or_default().into(),
request.first_name.unwrap_or_default().into(),
request.last_name.unwrap_or_default().into(),
request.avatar.unwrap_or_default().into(),
now.naive_utc().into(),
uuid.into(),
];
let (query, values) = Query::insert()
.into_table(Users::Table)
.columns(columns)
.values_panic(values)
.build_sqlx(DbQueryBuilder {});
debug!(%query);
query_with(query.as_str(), values)
.execute(&self.sql_pool)
.await?;
let new_user = model::users::ActiveModel {
user_id: Set(request.user_id),
email: Set(request.email),
display_name: to_value(&request.display_name),
first_name: to_value(&request.first_name),
last_name: to_value(&request.last_name),
avatar: request.avatar.into_active_value(),
creation_date: ActiveValue::Set(now),
uuid: ActiveValue::Set(uuid),
..Default::default()
};
new_user.insert(&self.sql_pool).await?;
Ok(())
}
#[instrument(skip_all, level = "debug", err)]
async fn update_user(&self, request: UpdateUserRequest) -> Result<()> {
debug!(user_id = ?request.user_id);
let mut values = Vec::new();
if let Some(email) = request.email {
values.push((Users::Email, email.into()));
}
if let Some(display_name) = request.display_name {
values.push((Users::DisplayName, display_name.into()));
}
if let Some(first_name) = request.first_name {
values.push((Users::FirstName, first_name.into()));
}
if let Some(last_name) = request.last_name {
values.push((Users::LastName, last_name.into()));
}
if let Some(avatar) = request.avatar {
values.push((Users::Avatar, avatar.into()));
}
if values.is_empty() {
return Ok(());
}
let (query, values) = Query::update()
.table(Users::Table)
.values(values)
.cond_where(Expr::col(Users::UserId).eq(request.user_id))
.build_sqlx(DbQueryBuilder {});
debug!(%query);
query_with(query.as_str(), values)
.execute(&self.sql_pool)
let update_user = model::users::ActiveModel {
email: request.email.map(ActiveValue::Set).unwrap_or_default(),
display_name: to_value(&request.display_name),
first_name: to_value(&request.first_name),
last_name: to_value(&request.last_name),
avatar: request.avatar.into_active_value(),
..Default::default()
};
model::User::update_many()
.set(update_user)
.filter(sea_orm::ColumnTrait::eq(
&UserColumn::UserId,
request.user_id,
))
.exec(&self.sql_pool)
.await?;
Ok(())
}
@ -308,47 +202,41 @@ impl UserBackendHandler for SqlBackendHandler {
#[instrument(skip_all, level = "debug", err)]
async fn delete_user(&self, user_id: &UserId) -> Result<()> {
debug!(?user_id);
let (query, values) = Query::delete()
.from_table(Users::Table)
.cond_where(Expr::col(Users::UserId).eq(user_id))
.build_sqlx(DbQueryBuilder {});
debug!(%query);
query_with(query.as_str(), values)
.execute(&self.sql_pool)
let res = model::User::delete_by_id(user_id.clone())
.exec(&self.sql_pool)
.await?;
if res.rows_affected == 0 {
return Err(DomainError::EntityNotFound(format!(
"No such user: '{}'",
user_id
)));
}
Ok(())
}
#[instrument(skip_all, level = "debug", err)]
async fn add_user_to_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()> {
debug!(?user_id, ?group_id);
let (query, values) = Query::insert()
.into_table(Memberships::Table)
.columns(vec![Memberships::UserId, Memberships::GroupId])
.values_panic(vec![user_id.into(), group_id.into()])
.build_sqlx(DbQueryBuilder {});
debug!(%query);
query_with(query.as_str(), values)
.execute(&self.sql_pool)
.await?;
let new_membership = model::memberships::ActiveModel {
user_id: ActiveValue::Set(user_id.clone()),
group_id: ActiveValue::Set(group_id),
};
new_membership.insert(&self.sql_pool).await?;
Ok(())
}
#[instrument(skip_all, level = "debug", err)]
async fn remove_user_from_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()> {
debug!(?user_id, ?group_id);
let (query, values) = Query::delete()
.from_table(Memberships::Table)
.cond_where(
Cond::all()
.add(Expr::col(Memberships::GroupId).eq(group_id))
.add(Expr::col(Memberships::UserId).eq(user_id)),
)
.build_sqlx(DbQueryBuilder {});
debug!(%query);
query_with(query.as_str(), values)
.execute(&self.sql_pool)
let res = model::Membership::delete_by_id((user_id.clone(), group_id))
.exec(&self.sql_pool)
.await?;
if res.rows_affected == 0 {
return Err(DomainError::EntityNotFound(format!(
"No such membership: '{}' -> {:?}",
user_id, group_id
)));
}
Ok(())
}
}
@ -357,7 +245,8 @@ impl UserBackendHandler for SqlBackendHandler {
mod tests {
use super::*;
use crate::domain::{
handler::JpegPhoto, sql_backend_handler::tests::*, sql_tables::UserColumn,
handler::{JpegPhoto, UserColumn},
sql_backend_handler::tests::*,
};
#[tokio::test]
@ -526,9 +415,13 @@ mod tests {
.map(|u| {
(
u.user.user_id.to_string(),
u.user.display_name.to_string(),
u.user
.display_name
.as_deref()
.unwrap_or("<unknown>")
.to_owned(),
u.groups
.unwrap()
.unwrap_or_default()
.into_iter()
.map(|g| g.group_id)
.collect::<Vec<_>>(),
@ -571,7 +464,7 @@ mod tests {
(
u.user.creation_date,
u.groups
.unwrap()
.unwrap_or_default()
.into_iter()
.map(|g| g.creation_date)
.collect::<Vec<_>>(),
@ -685,7 +578,7 @@ mod tests {
display_name: Some("display_name".to_string()),
first_name: Some("first_name".to_string()),
last_name: Some("last_name".to_string()),
avatar: Some(JpegPhoto::default()),
avatar: Some(JpegPhoto::for_tests()),
})
.await
.unwrap();
@ -696,10 +589,10 @@ mod tests {
.await
.unwrap();
assert_eq!(user.email, "email");
assert_eq!(user.display_name, "display_name");
assert_eq!(user.first_name, "first_name");
assert_eq!(user.last_name, "last_name");
assert_eq!(user.avatar, JpegPhoto::default());
assert_eq!(user.display_name.unwrap(), "display_name");
assert_eq!(user.first_name.unwrap(), "first_name");
assert_eq!(user.last_name.unwrap(), "last_name");
assert_eq!(user.avatar, Some(JpegPhoto::for_tests()));
}
#[tokio::test]
@ -722,9 +615,10 @@ mod tests {
.get_user_details(&UserId::new("bob"))
.await
.unwrap();
assert_eq!(user.display_name, "display bob");
assert_eq!(user.first_name, "first_name");
assert_eq!(user.last_name, "");
assert_eq!(user.display_name.unwrap(), "display bob");
assert_eq!(user.first_name.unwrap(), "first_name");
assert_eq!(user.last_name, None);
assert_eq!(user.avatar, None);
}
#[tokio::test]

View File

@ -26,7 +26,7 @@ use crate::domain::handler::UserRequestFilter;
use crate::{
domain::{
error::DomainError,
handler::{BackendHandler, BindRequest, GroupDetails, LoginHandler, UserId},
handler::{BackendHandler, BindRequest, GroupDetails, LoginHandler, UserColumn, UserId},
opaque_handler::OpaqueHandler,
},
infra::{
@ -149,10 +149,7 @@ where
.list_users(
Some(UserRequestFilter::Or(vec![
UserRequestFilter::UserId(UserId::new(user_string)),
UserRequestFilter::Equality(
crate::domain::sql_tables::UserColumn::Email,
user_string.to_owned(),
),
UserRequestFilter::Equality(UserColumn::Email, user_string.to_owned()),
])),
false,
)
@ -174,7 +171,9 @@ where
Some(token) => token,
};
if let Err(e) = super::mail::send_password_reset_email(
&user.display_name,
user.display_name
.as_deref()
.unwrap_or_else(|| user.user_id.as_str()),
&user.email,
&token,
&data.server_url,

View File

@ -1,18 +1,17 @@
use crate::{
domain::sql_tables::{DbQueryBuilder, Pool},
infra::jwt_sql_tables::{JwtRefreshStorage, JwtStorage},
use crate::domain::{
model::{self, JwtRefreshStorageColumn, JwtStorageColumn, PasswordResetTokensColumn},
sql_tables::DbConnection,
};
use actix::prelude::*;
use chrono::Local;
use actix::prelude::{Actor, AsyncContext, Context};
use cron::Schedule;
use sea_query::{Expr, Query};
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter};
use std::{str::FromStr, time::Duration};
use tracing::{debug, error, info, instrument};
use tracing::{error, info, instrument};
// Define actor
pub struct Scheduler {
schedule: Schedule,
sql_pool: Pool,
sql_pool: DbConnection,
}
// Provide Actor implementation for our actor
@ -33,7 +32,7 @@ impl Actor for Scheduler {
}
impl Scheduler {
pub fn new(cron_expression: &str, sql_pool: Pool) -> Self {
pub fn new(cron_expression: &str, sql_pool: DbConnection) -> Self {
let schedule = Schedule::from_str(cron_expression).unwrap();
Self { schedule, sql_pool }
}
@ -48,33 +47,35 @@ impl Scheduler {
}
#[instrument(skip_all)]
async fn cleanup_db(sql_pool: Pool) {
async fn cleanup_db(sql_pool: DbConnection) {
info!("Cleaning DB");
let query = Query::delete()
.from_table(JwtRefreshStorage::Table)
.and_where(Expr::col(JwtRefreshStorage::ExpiryDate).lt(Local::now().naive_utc()))
.to_string(DbQueryBuilder {});
debug!(%query);
if let Err(e) = sqlx::query(&query).execute(&sql_pool).await {
if let Err(e) = model::JwtRefreshStorage::delete_many()
.filter(JwtRefreshStorageColumn::ExpiryDate.lt(chrono::Utc::now().naive_utc()))
.exec(&sql_pool)
.await
{
error!("DB error while cleaning up JWT refresh tokens: {}", e);
};
if let Err(e) = sqlx::query(
&Query::delete()
.from_table(JwtStorage::Table)
.and_where(Expr::col(JwtStorage::ExpiryDate).lt(Local::now().naive_utc()))
.to_string(DbQueryBuilder {}),
)
.execute(&sql_pool)
}
if let Err(e) = model::JwtStorage::delete_many()
.filter(JwtStorageColumn::ExpiryDate.lt(chrono::Utc::now().naive_utc()))
.exec(&sql_pool)
.await
{
error!("DB error while cleaning up JWT storage: {}", e);
};
if let Err(e) = model::PasswordResetTokens::delete_many()
.filter(PasswordResetTokensColumn::ExpiryDate.lt(chrono::Utc::now().naive_utc()))
.exec(&sql_pool)
.await
{
error!("DB error while cleaning up password reset tokens: {}", e);
};
info!("DB cleaned!");
}
fn duration_until_next(&self) -> Duration {
let now = Local::now();
let next = self.schedule.upcoming(Local).next().unwrap();
let now = chrono::Utc::now();
let next = self.schedule.upcoming(chrono::Utc).next().unwrap();
let duration_until = next.signed_duration_since(now);
duration_until.to_std().unwrap()
}

View File

@ -1,7 +1,6 @@
use crate::domain::{
handler::{BackendHandler, GroupDetails, GroupId, UserId},
handler::{BackendHandler, GroupDetails, GroupId, UserColumn, UserId},
ldap::utils::map_user_field,
sql_tables::UserColumn,
};
use juniper::{graphql_object, FieldResult, GraphQLInputObject};
use serde::{Deserialize, Serialize};
@ -214,19 +213,19 @@ impl<Handler: BackendHandler + Sync> User<Handler> {
}
fn display_name(&self) -> &str {
&self.user.display_name
self.user.display_name.as_deref().unwrap_or("")
}
fn first_name(&self) -> &str {
&self.user.first_name
self.user.first_name.as_deref().unwrap_or("")
}
fn last_name(&self) -> &str {
&self.user.last_name
self.user.last_name.as_deref().unwrap_or("")
}
fn avatar(&self) -> String {
(&self.user.avatar).into()
fn avatar(&self) -> Option<String> {
self.user.avatar.as_ref().map(String::from)
}
fn creation_date(&self) -> chrono::DateTime<chrono::Utc> {
@ -392,7 +391,7 @@ mod tests {
Ok(DomainUser {
user_id: UserId::new("bob"),
email: "bob@bobbers.on".to_string(),
creation_date: chrono::Utc.timestamp_millis(42),
creation_date: chrono::Utc.timestamp_millis_opt(42).unwrap(),
uuid: crate::uuid!("b1a2a3a4b1b2c1c2d1d2d3d4d5d6d7d8"),
..Default::default()
})

View File

@ -1,6 +1,7 @@
use sea_query::*;
use sea_orm::ConnectionTrait;
use sea_query::{ColumnDef, ForeignKey, ForeignKeyAction, Iden, Table};
pub use crate::domain::sql_tables::*;
pub use crate::domain::{sql_migrations::Users, sql_tables::DbConnection};
/// Contains the refresh tokens for a given user.
#[derive(Iden)]
@ -31,9 +32,12 @@ pub enum PasswordResetTokens {
}
/// This needs to be initialized after the domain tables are.
pub async fn init_table(pool: &Pool) -> sqlx::Result<()> {
sqlx::query(
&Table::create()
pub async fn init_table(pool: &DbConnection) -> std::result::Result<(), sea_orm::DbErr> {
let builder = pool.get_database_backend();
pool.execute(
builder.build(
Table::create()
.table(JwtRefreshStorage::Table)
.if_not_exists()
.col(
@ -59,14 +63,14 @@ pub async fn init_table(pool: &Pool) -> sqlx::Result<()> {
.to(Users::Table, Users::UserId)
.on_delete(ForeignKeyAction::Cascade)
.on_update(ForeignKeyAction::Cascade),
),
),
)
.to_string(DbQueryBuilder {}),
)
.execute(pool)
.await?;
sqlx::query(
&Table::create()
pool.execute(
builder.build(
Table::create()
.table(JwtStorage::Table)
.if_not_exists()
.col(
@ -98,14 +102,14 @@ pub async fn init_table(pool: &Pool) -> sqlx::Result<()> {
.to(Users::Table, Users::UserId)
.on_delete(ForeignKeyAction::Cascade)
.on_update(ForeignKeyAction::Cascade),
),
),
)
.to_string(DbQueryBuilder {}),
)
.execute(pool)
.await?;
sqlx::query(
&Table::create()
pool.execute(
builder.build(
Table::create()
.table(PasswordResetTokens::Table)
.if_not_exists()
.col(
@ -131,10 +135,9 @@ pub async fn init_table(pool: &Pool) -> sqlx::Result<()> {
.to(Users::Table, Users::UserId)
.on_delete(ForeignKeyAction::Cascade)
.on_update(ForeignKeyAction::Cascade),
),
),
)
.to_string(DbQueryBuilder {}),
)
.execute(pool)
.await?;
Ok(())

View File

@ -569,7 +569,7 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
mod tests {
use super::*;
use crate::{
domain::{error::Result, handler::*, opaque_handler::*, sql_tables::UserColumn},
domain::{error::Result, handler::*, opaque_handler::*},
uuid,
};
use async_trait::async_trait;
@ -669,7 +669,7 @@ mod tests {
set.insert(GroupDetails {
group_id: GroupId(42),
display_name: group,
creation_date: chrono::Utc.timestamp(42, 42),
creation_date: chrono::Utc.timestamp_opt(42, 42).unwrap(),
uuid: uuid!("a1a2a3a4b1b2c1c2d1d2d3d4d5d6d7d8"),
});
Ok(set)
@ -756,7 +756,7 @@ mod tests {
set.insert(GroupDetails {
group_id: GroupId(42),
display_name: "lldap_admin".to_string(),
creation_date: chrono::Utc.timestamp(42, 42),
creation_date: chrono::Utc.timestamp_opt(42, 42).unwrap(),
uuid: uuid!("a1a2a3a4b1b2c1c2d1d2d3d4d5d6d7d8"),
});
Ok(set)
@ -843,7 +843,7 @@ mod tests {
groups: Some(vec![GroupDetails {
group_id: GroupId(42),
display_name: "rockstars".to_string(),
creation_date: chrono::Utc.timestamp(42, 42),
creation_date: chrono::Utc.timestamp_opt(42, 42).unwrap(),
uuid: uuid!("a1a2a3a4b1b2c1c2d1d2d3d4d5d6d7d8"),
}]),
}])
@ -991,9 +991,9 @@ mod tests {
user: User {
user_id: UserId::new("bob_1"),
email: "bob@bobmail.bob".to_string(),
display_name: "Bôb Böbberson".to_string(),
first_name: "Bôb".to_string(),
last_name: "Böbberson".to_string(),
display_name: Some("Bôb Böbberson".to_string()),
first_name: Some("Bôb".to_string()),
last_name: Some("Böbberson".to_string()),
uuid: uuid!("698e1d5f-7a40-3151-8745-b9b8a37839da"),
..Default::default()
},
@ -1003,12 +1003,12 @@ mod tests {
user: User {
user_id: UserId::new("jim"),
email: "jim@cricket.jim".to_string(),
display_name: "Jimminy Cricket".to_string(),
first_name: "Jim".to_string(),
last_name: "Cricket".to_string(),
avatar: JpegPhoto::for_tests(),
display_name: Some("Jimminy Cricket".to_string()),
first_name: Some("Jim".to_string()),
last_name: Some("Cricket".to_string()),
avatar: Some(JpegPhoto::for_tests()),
uuid: uuid!("04ac75e0-2900-3e21-926c-2f732c26b3fc"),
creation_date: Utc.ymd(2014, 7, 8).and_hms(9, 10, 11),
creation_date: Utc.with_ymd_and_hms(2014, 7, 8, 9, 10, 11).unwrap(),
},
groups: None,
},
@ -1137,14 +1137,14 @@ mod tests {
Group {
id: GroupId(1),
display_name: "group_1".to_string(),
creation_date: chrono::Utc.timestamp(42, 42),
creation_date: chrono::Utc.timestamp_opt(42, 42).unwrap(),
users: vec![UserId::new("bob"), UserId::new("john")],
uuid: uuid!("04ac75e0-2900-3e21-926c-2f732c26b3fc"),
},
Group {
id: GroupId(3),
display_name: "BestGroup".to_string(),
creation_date: chrono::Utc.timestamp(42, 42),
creation_date: chrono::Utc.timestamp_opt(42, 42).unwrap(),
users: vec![UserId::new("john")],
uuid: uuid!("04ac75e0-2900-3e21-926c-2f732c26b3fc"),
},
@ -1230,7 +1230,7 @@ mod tests {
Ok(vec![Group {
display_name: "group_1".to_string(),
id: GroupId(1),
creation_date: chrono::Utc.timestamp(42, 42),
creation_date: chrono::Utc.timestamp_opt(42, 42).unwrap(),
users: vec![],
uuid: uuid!("04ac75e0-2900-3e21-926c-2f732c26b3fc"),
}])
@ -1281,7 +1281,7 @@ mod tests {
Ok(vec![Group {
display_name: "group_1".to_string(),
id: GroupId(1),
creation_date: chrono::Utc.timestamp(42, 42),
creation_date: chrono::Utc.timestamp_opt(42, 42).unwrap(),
users: vec![],
uuid: uuid!("04ac75e0-2900-3e21-926c-2f732c26b3fc"),
}])
@ -1542,9 +1542,9 @@ mod tests {
user: User {
user_id: UserId::new("bob_1"),
email: "bob@bobmail.bob".to_string(),
display_name: "Bôb Böbberson".to_string(),
first_name: "Bôb".to_string(),
last_name: "Böbberson".to_string(),
display_name: Some("Bôb Böbberson".to_string()),
first_name: Some("Bôb".to_string()),
last_name: Some("Böbberson".to_string()),
..Default::default()
},
groups: None,
@ -1557,7 +1557,7 @@ mod tests {
Ok(vec![Group {
id: GroupId(1),
display_name: "group_1".to_string(),
creation_date: chrono::Utc.timestamp(42, 42),
creation_date: chrono::Utc.timestamp_opt(42, 42).unwrap(),
users: vec![UserId::new("bob"), UserId::new("john")],
uuid: uuid!("04ac75e0-2900-3e21-926c-2f732c26b3fc"),
}])
@ -1616,9 +1616,9 @@ mod tests {
user: User {
user_id: UserId::new("bob_1"),
email: "bob@bobmail.bob".to_string(),
display_name: "Bôb Böbberson".to_string(),
last_name: "Böbberson".to_string(),
avatar: JpegPhoto::for_tests(),
display_name: Some("Bôb Böbberson".to_string()),
last_name: Some("Böbberson".to_string()),
avatar: Some(JpegPhoto::for_tests()),
uuid: uuid!("b4ac75e0-2900-3e21-926c-2f732c26b3fc"),
..Default::default()
},
@ -1631,7 +1631,7 @@ mod tests {
Ok(vec![Group {
id: GroupId(1),
display_name: "group_1".to_string(),
creation_date: chrono::Utc.timestamp(42, 42),
creation_date: chrono::Utc.timestamp_opt(42, 42).unwrap(),
users: vec![UserId::new("bob"), UserId::new("john")],
uuid: uuid!("04ac75e0-2900-3e21-926c-2f732c26b3fc"),
}])
@ -1680,7 +1680,11 @@ mod tests {
},
LdapPartialAttribute {
atype: "createtimestamp".to_string(),
vals: vec![chrono::Utc.timestamp(0, 0).to_rfc3339().into_bytes()],
vals: vec![chrono::Utc
.timestamp_opt(0, 0)
.unwrap()
.to_rfc3339()
.into_bytes()],
},
LdapPartialAttribute {
atype: "entryuuid".to_string(),
@ -1960,7 +1964,7 @@ mod tests {
groups.insert(GroupDetails {
group_id: GroupId(0),
display_name: "lldap_admin".to_string(),
creation_date: chrono::Utc.timestamp(42, 42),
creation_date: chrono::Utc.timestamp_opt(42, 42).unwrap(),
uuid: uuid!("a1a2a3a4b1b2c1c2d1d2d3d4d5d6d7d8"),
});
mock.expect_get_user_groups()

View File

@ -48,3 +48,14 @@ pub fn init(config: &Configuration) -> anyhow::Result<()> {
.init();
Ok(())
}
#[cfg(test)]
pub fn init_for_tests() {
if let Err(e) = tracing_subscriber::FmtSubscriber::builder()
.with_max_level(tracing::Level::DEBUG)
.with_test_writer()
.try_init()
{
log::warn!("Could not set up test logging: {:#}", e);
}
}

View File

@ -1,10 +1,16 @@
use super::{jwt_sql_tables::*, tcp_backend_handler::*};
use crate::domain::{error::*, handler::UserId, sql_backend_handler::SqlBackendHandler};
use super::tcp_backend_handler::TcpBackendHandler;
use crate::domain::{
error::*,
handler::UserId,
model::{self, JwtRefreshStorageColumn, JwtStorageColumn, PasswordResetTokensColumn},
sql_backend_handler::SqlBackendHandler,
};
use async_trait::async_trait;
use futures_util::StreamExt;
use sea_query::{Expr, Iden, Query, SimpleExpr};
use sea_query_binder::SqlxBinder;
use sqlx::{query_as_with, query_with, Row};
use sea_orm::{
sea_query::Cond, ActiveModelTrait, ColumnTrait, EntityTrait, FromQueryResult, IntoActiveModel,
QueryFilter, QuerySelect,
};
use sea_query::Expr;
use std::collections::HashSet;
use tracing::{debug, instrument};
@ -18,126 +24,102 @@ fn gen_random_string(len: usize) -> String {
.collect()
}
#[derive(FromQueryResult)]
struct OnlyJwtHash {
jwt_hash: i64,
}
#[async_trait]
impl TcpBackendHandler for SqlBackendHandler {
#[instrument(skip_all, level = "debug")]
async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>> {
let (query, values) = Query::select()
.column(JwtStorage::JwtHash)
.from(JwtStorage::Table)
.build_sqlx(DbQueryBuilder {});
debug!(%query);
query_with(&query, values)
.map(|row: DbRow| row.get::<i64, _>(&*JwtStorage::JwtHash.to_string()) as u64)
.fetch(&self.sql_pool)
.collect::<Vec<sqlx::Result<u64>>>()
.await
Ok(model::JwtStorage::find()
.select_only()
.column(JwtStorageColumn::JwtHash)
.filter(JwtStorageColumn::Blacklisted.eq(true))
.into_model::<OnlyJwtHash>()
.all(&self.sql_pool)
.await?
.into_iter()
.collect::<sqlx::Result<HashSet<u64>>>()
.map_err(|e| anyhow::anyhow!(e))
.map(|m| m.jwt_hash as u64)
.collect::<HashSet<u64>>())
}
#[instrument(skip_all, level = "debug")]
async fn create_refresh_token(&self, user: &UserId) -> Result<(String, chrono::Duration)> {
debug!(?user);
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
// TODO: Initialize the rng only once. Maybe Arc<Cell>?
let refresh_token = gen_random_string(100);
let refresh_token_hash = {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut s = DefaultHasher::new();
refresh_token.hash(&mut s);
s.finish()
};
let duration = chrono::Duration::days(30);
let (query, values) = Query::insert()
.into_table(JwtRefreshStorage::Table)
.columns(vec![
JwtRefreshStorage::RefreshTokenHash,
JwtRefreshStorage::UserId,
JwtRefreshStorage::ExpiryDate,
])
.values_panic(vec![
(refresh_token_hash as i64).into(),
user.into(),
(chrono::Utc::now() + duration).naive_utc().into(),
])
.build_sqlx(DbQueryBuilder {});
debug!(%query);
query_with(&query, values).execute(&self.sql_pool).await?;
let new_token = model::jwt_refresh_storage::Model {
refresh_token_hash: refresh_token_hash as i64,
user_id: user.clone(),
expiry_date: chrono::Utc::now() + duration,
}
.into_active_model();
new_token.insert(&self.sql_pool).await?;
Ok((refresh_token, duration))
}
#[instrument(skip_all, level = "debug")]
async fn check_token(&self, refresh_token_hash: u64, user: &UserId) -> Result<bool> {
debug!(?user);
let (query, values) = Query::select()
.expr(SimpleExpr::Value(1.into()))
.from(JwtRefreshStorage::Table)
.and_where(Expr::col(JwtRefreshStorage::RefreshTokenHash).eq(refresh_token_hash as i64))
.and_where(Expr::col(JwtRefreshStorage::UserId).eq(user))
.build_sqlx(DbQueryBuilder {});
debug!(%query);
Ok(query_with(&query, values)
.fetch_optional(&self.sql_pool)
Ok(
model::JwtRefreshStorage::find_by_id(refresh_token_hash as i64)
.filter(JwtRefreshStorageColumn::UserId.eq(user))
.one(&self.sql_pool)
.await?
.is_some())
.is_some(),
)
}
#[instrument(skip_all, level = "debug")]
async fn blacklist_jwts(&self, user: &UserId) -> Result<HashSet<u64>> {
debug!(?user);
use sqlx::Result;
let (query, values) = Query::select()
.column(JwtStorage::JwtHash)
.from(JwtStorage::Table)
.and_where(Expr::col(JwtStorage::UserId).eq(user))
.and_where(Expr::col(JwtStorage::Blacklisted).eq(true))
.build_sqlx(DbQueryBuilder {});
let result = query_with(&query, values)
.map(|row: DbRow| row.get::<i64, _>(&*JwtStorage::JwtHash.to_string()) as u64)
.fetch(&self.sql_pool)
.collect::<Vec<sqlx::Result<u64>>>()
.await
let valid_tokens = model::JwtStorage::find()
.select_only()
.column(JwtStorageColumn::JwtHash)
.filter(
Cond::all()
.add(JwtStorageColumn::UserId.eq(user))
.add(JwtStorageColumn::Blacklisted.eq(false)),
)
.into_model::<OnlyJwtHash>()
.all(&self.sql_pool)
.await?
.into_iter()
.collect::<Result<HashSet<u64>>>();
let (query, values) = Query::update()
.table(JwtStorage::Table)
.values(vec![(JwtStorage::Blacklisted, true.into())])
.and_where(Expr::col(JwtStorage::UserId).eq(user))
.build_sqlx(DbQueryBuilder {});
debug!(%query);
query_with(&query, values).execute(&self.sql_pool).await?;
Ok(result?)
.map(|t| t.jwt_hash as u64)
.collect::<HashSet<u64>>();
model::JwtStorage::update_many()
.col_expr(JwtStorageColumn::Blacklisted, Expr::value(true))
.filter(JwtStorageColumn::UserId.eq(user))
.exec(&self.sql_pool)
.await?;
Ok(valid_tokens)
}
#[instrument(skip_all, level = "debug")]
async fn delete_refresh_token(&self, refresh_token_hash: u64) -> Result<()> {
let (query, values) = Query::delete()
.from_table(JwtRefreshStorage::Table)
.and_where(Expr::col(JwtRefreshStorage::RefreshTokenHash).eq(refresh_token_hash as i64))
.build_sqlx(DbQueryBuilder {});
debug!(%query);
query_with(&query, values).execute(&self.sql_pool).await?;
model::JwtRefreshStorage::delete_by_id(refresh_token_hash as i64)
.exec(&self.sql_pool)
.await?;
Ok(())
}
#[instrument(skip_all, level = "debug")]
async fn start_password_reset(&self, user: &UserId) -> Result<Option<String>> {
debug!(?user);
let (query, values) = Query::select()
.column(Users::UserId)
.from(Users::Table)
.and_where(Expr::col(Users::UserId).eq(user))
.build_sqlx(DbQueryBuilder {});
debug!(%query);
// Check that the user exists.
if query_with(&query, values)
.fetch_one(&self.sql_pool)
.await
.is_err()
if model::User::find_by_id(user.clone())
.one(&self.sql_pool)
.await?
.is_none()
{
debug!("User not found");
return Ok(None);
@ -146,50 +128,37 @@ impl TcpBackendHandler for SqlBackendHandler {
let token = gen_random_string(100);
let duration = chrono::Duration::minutes(10);
let (query, values) = Query::insert()
.into_table(PasswordResetTokens::Table)
.columns(vec![
PasswordResetTokens::Token,
PasswordResetTokens::UserId,
PasswordResetTokens::ExpiryDate,
])
.values_panic(vec![
token.clone().into(),
user.into(),
(chrono::Utc::now() + duration).naive_utc().into(),
])
.build_sqlx(DbQueryBuilder {});
debug!(%query);
query_with(&query, values).execute(&self.sql_pool).await?;
let new_token = model::password_reset_tokens::Model {
token: token.clone(),
user_id: user.clone(),
expiry_date: chrono::Utc::now() + duration,
}
.into_active_model();
new_token.insert(&self.sql_pool).await?;
Ok(Some(token))
}
#[instrument(skip_all, level = "debug", ret)]
async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result<UserId> {
let (query, values) = Query::select()
.column(PasswordResetTokens::UserId)
.from(PasswordResetTokens::Table)
.and_where(Expr::col(PasswordResetTokens::Token).eq(token))
.and_where(
Expr::col(PasswordResetTokens::ExpiryDate).gt(chrono::Utc::now().naive_utc()),
)
.build_sqlx(DbQueryBuilder {});
debug!(%query);
let (user_id,) = query_as_with(&query, values)
.fetch_one(&self.sql_pool)
.await?;
Ok(user_id)
Ok(model::PasswordResetTokens::find_by_id(token.to_owned())
.filter(PasswordResetTokensColumn::ExpiryDate.gt(chrono::Utc::now().naive_utc()))
.one(&self.sql_pool)
.await?
.ok_or_else(|| DomainError::EntityNotFound("Invalid reset token".to_owned()))?
.user_id)
}
#[instrument(skip_all, level = "debug")]
async fn delete_password_reset_token(&self, token: &str) -> Result<()> {
let (query, values) = Query::delete()
.from_table(PasswordResetTokens::Table)
.and_where(Expr::col(PasswordResetTokens::Token).eq(token))
.build_sqlx(DbQueryBuilder {});
debug!(%query);
query_with(&query, values).execute(&self.sql_pool).await?;
let result = model::PasswordResetTokens::delete_by_id(token.to_owned())
.exec(&self.sql_pool)
.await?;
if result.rows_affected == 0 {
return Err(DomainError::EntityNotFound(format!(
"No such password reset token: '{}'",
token
)));
}
Ok(())
}
}

View File

@ -52,9 +52,9 @@ pub(crate) fn error_to_http_response(error: TcpError) -> HttpResponse {
DomainError::DatabaseError(_)
| DomainError::InternalError(_)
| DomainError::UnknownCryptoError(_) => HttpResponse::InternalServerError(),
DomainError::Base64DecodeError(_) | DomainError::BinarySerializationError(_) => {
HttpResponse::BadRequest()
}
DomainError::Base64DecodeError(_)
| DomainError::BinarySerializationError(_)
| DomainError::EntityNotFound(_) => HttpResponse::BadRequest(),
},
TcpError::BadRequest(_) => HttpResponse::BadRequest(),
TcpError::InternalServerError(_) => HttpResponse::InternalServerError(),

View File

@ -9,7 +9,6 @@ use crate::{
handler::{CreateUserRequest, GroupBackendHandler, GroupRequestFilter, UserBackendHandler},
sql_backend_handler::SqlBackendHandler,
sql_opaque_handler::register_password,
sql_tables::PoolOptions,
},
infra::{cli::*, configuration::Configuration, db_cleaner::Scheduler, healthcheck, mail},
};
@ -17,6 +16,7 @@ use actix::Actor;
use actix_server::ServerBuilder;
use anyhow::{anyhow, Context, Result};
use futures_util::TryFutureExt;
use sea_orm::Database;
use tracing::*;
mod domain;
@ -39,29 +39,52 @@ async fn create_admin_user(handler: &SqlBackendHandler, config: &Configuration)
.and_then(|_| register_password(handler, &config.ldap_user_dn, &config.ldap_user_pass))
.await
.context("Error creating admin user")?;
let admin_group_id = handler
.create_group("lldap_admin")
.await
.context("Error creating admin group")?;
let groups = handler
.list_groups(Some(GroupRequestFilter::DisplayName(
"lldap_admin".to_owned(),
)))
.await?;
assert_eq!(groups.len(), 1);
handler
.add_user_to_group(&config.ldap_user_dn, admin_group_id)
.add_user_to_group(&config.ldap_user_dn, groups[0].id)
.await
.context("Error adding admin user to group")
}
async fn ensure_group_exists(handler: &SqlBackendHandler, group_name: &str) -> Result<()> {
if handler
.list_groups(Some(GroupRequestFilter::DisplayName(group_name.to_owned())))
.await?
.is_empty()
{
warn!("Could not find {} group, trying to create it", group_name);
handler
.create_group(group_name)
.await
.context(format!("while creating {} group", group_name))?;
}
Ok(())
}
#[instrument(skip_all)]
async fn set_up_server(config: Configuration) -> Result<ServerBuilder> {
info!("Starting LLDAP version {}", env!("CARGO_PKG_VERSION"));
let sql_pool = PoolOptions::new()
let sql_pool = {
let mut sql_opt = sea_orm::ConnectOptions::new(config.database_url.clone());
sql_opt
.max_connections(5)
.connect(&config.database_url)
.await
.context("while connecting to the DB")?;
.sqlx_logging(true)
.sqlx_logging_level(log::LevelFilter::Debug);
Database::connect(sql_opt).await?
};
domain::sql_tables::init_table(&sql_pool)
.await
.context("while creating the tables")?;
let backend_handler = SqlBackendHandler::new(config.clone(), sql_pool.clone());
ensure_group_exists(&backend_handler, "lldap_admin").await?;
ensure_group_exists(&backend_handler, "lldap_password_manager").await?;
ensure_group_exists(&backend_handler, "lldap_strict_readonly").await?;
if let Err(e) = backend_handler.get_user_details(&config.ldap_user_dn).await {
warn!("Could not get admin user, trying to create it: {:#}", e);
create_admin_user(&backend_handler, &config)
@ -69,23 +92,6 @@ async fn set_up_server(config: Configuration) -> Result<ServerBuilder> {
.map_err(|e| anyhow!("Error setting up admin login/account: {:#}", e))
.context("while creating the admin user")?;
}
if backend_handler
.list_groups(Some(GroupRequestFilter::DisplayName(
"lldap_password_manager".to_string(),
)))
.await?
.is_empty()
{
warn!("Could not find password_manager group, trying to create it");
backend_handler
.create_group("lldap_password_manager")
.await
.context("while creating password_manager group")?;
backend_handler
.create_group("lldap_strict_readonly")
.await
.context("while creating readonly group")?;
}
let server_builder = infra::ldap_server::build_ldap_server(
&config,
backend_handler.clone(),