From e89b1538af51001624f62dede34d15112c1b18f3 Mon Sep 17 00:00:00 2001 From: Valentin Tolmer Date: Mon, 21 Nov 2022 09:13:25 +0100 Subject: [PATCH] server,app: migrate to sea-orm --- Cargo.lock | 586 +++++++++++++----- Cargo.toml | 4 + app/Cargo.toml | 1 + app/src/components/create_user.rs | 5 +- app/src/components/group_details.rs | 2 +- app/src/components/group_table.rs | 2 +- app/src/components/user_details_form.rs | 14 +- app/src/components/user_table.rs | 2 +- app/src/infra/cookies.rs | 6 +- schema.graphql | 2 +- server/Cargo.toml | 26 +- server/src/domain/error.rs | 4 +- server/src/domain/handler.rs | 72 ++- server/src/domain/ldap/group.rs | 3 +- server/src/domain/ldap/user.rs | 11 +- server/src/domain/ldap/utils.rs | 5 +- server/src/domain/mod.rs | 1 + server/src/domain/model/groups.rs | 53 ++ .../src/domain/model/jwt_refresh_storage.rs | 35 ++ server/src/domain/model/jwt_storage.rs | 36 ++ server/src/domain/model/memberships.rs | 73 +++ server/src/domain/model/mod.rs | 12 + .../src/domain/model/password_reset_tokens.rs | 35 ++ server/src/domain/model/prelude.rs | 14 + server/src/domain/model/users.rs | 134 ++++ server/src/domain/sql_backend_handler.rs | 17 +- .../src/domain/sql_group_backend_handler.rs | 240 ++++--- server/src/domain/sql_migrations.rs | 508 ++++++++------- server/src/domain/sql_opaque_handler.rs | 122 ++-- server/src/domain/sql_tables.rs | 412 ++++++++---- server/src/domain/sql_user_backend_handler.rs | 450 +++++--------- server/src/infra/auth_service.rs | 11 +- server/src/infra/db_cleaner.rs | 55 +- server/src/infra/graphql/query.rs | 15 +- server/src/infra/jwt_sql_tables.rs | 201 +++--- server/src/infra/ldap_handler.rs | 56 +- server/src/infra/logging.rs | 11 + server/src/infra/sql_backend_handler.rs | 211 +++---- server/src/infra/tcp_server.rs | 6 +- server/src/main.rs | 62 +- 40 files changed, 2125 insertions(+), 1390 deletions(-) create mode 100644 server/src/domain/model/groups.rs create mode 100644 server/src/domain/model/jwt_refresh_storage.rs create mode 100644 server/src/domain/model/jwt_storage.rs create mode 100644 server/src/domain/model/memberships.rs create mode 100644 server/src/domain/model/mod.rs create mode 100644 server/src/domain/model/password_reset_tokens.rs create mode 100644 server/src/domain/model/prelude.rs create mode 100644 server/src/domain/model/users.rs diff --git a/Cargo.lock b/Cargo.lock index f8e8d6c..e8ace93 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,12 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "Inflector" +version = "0.11.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe438c63458706e03479442743baae6c88256498e6431708f6dfc520a26515d3" + [[package]] name = "actix" version = "0.12.0" @@ -342,6 +348,21 @@ dependencies = [ "memchr", ] +[[package]] +name = "aliasable" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "250f629c0161ad8107cf89319e990051fae62832fd343083bea452d93e2205fd" + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + [[package]] name = "ansi_term" version = "0.12.1" @@ -426,6 +447,27 @@ dependencies = [ "syn", ] +[[package]] +name = "async-stream" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dad5c83079eae9969be7fadefe640a1c566901f05ff91ab221de4b6f68d9507e" +dependencies = [ + "async-stream-impl", + "futures-core", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10f203db73a71dfa2fb6dd22763990fa26f3d2625a6da2da900d23b87d26be27" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "async-trait" version = "0.1.56" @@ -439,9 +481,9 @@ dependencies = [ [[package]] name = "atoi" -version = "0.4.0" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "616896e05fc0e2649463a93a15183c6a16bf03413a7af88ef1285ddedfa9cda5" +checksum = "d7c57d12312ff59c811c0643f4d80830505833c9ffaebd193d819392b265be8e" dependencies = [ "num-traits", ] @@ -452,7 +494,7 @@ version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b88d82667eca772c4aa12f0f1348b3ae643424c8876448f3f7bd5787032e234c" dependencies = [ - "autocfg 1.1.0", + "autocfg", ] [[package]] @@ -466,15 +508,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "autocfg" -version = "0.1.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0dde43e75fd43e8a1bf86103336bc699aa8d17ad1be60c76c0bdfd4828e19b78" -dependencies = [ - "autocfg 1.1.0", -] - [[package]] name = "autocfg" version = "1.1.0" @@ -496,6 +529,19 @@ dependencies = [ "rustc-demangle", ] +[[package]] +name = "bae" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33b8de67cc41132507eeece2584804efcb15f85ba516e34c944b7667f480397a" +dependencies = [ + "heck 0.3.3", + "proc-macro-error", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "base-x" version = "0.2.11" @@ -549,6 +595,15 @@ dependencies = [ "generic-array", ] +[[package]] +name = "block-buffer" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cce20737498f97b993470a6e536b8523f0af7892a4f928cceb1ac5e52ebe7e" +dependencies = [ + "generic-array", +] + [[package]] name = "boolinator" version = "2.4.0" @@ -589,7 +644,7 @@ dependencies = [ "rand 0.7.3", "serde", "serde_json", - "uuid", + "uuid 0.8.2", ] [[package]] @@ -648,12 +703,12 @@ checksum = "8100e46ff92eb85bf6dc2930c73f2a4f7176393c84a9446b3d501e1b354e7b34" [[package]] name = "chrono" -version = "0.4.19" +version = "0.4.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "670ad68c9088c2a963aaa298cb369688cf3f9465ce5e2d4ca10e6e0098a1ce73" +checksum = "16b0a3d9ed01224b22057780a37bb8c5dbfe1be8ba48678e7bf57ec4b385411f" dependencies = [ + "iana-time-zone", "js-sys", - "libc", "num-integer", "num-traits", "serde", @@ -701,6 +756,16 @@ dependencies = [ "os_str_bytes", ] +[[package]] +name = "codespan-reporting" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3538270d33cc669650c4b093848450d380def10c331d38c768e34cac80576e6e" +dependencies = [ + "termcolor", + "unicode-width", +] + [[package]] name = "color_quant" version = "1.1.0" @@ -732,9 +797,9 @@ dependencies = [ [[package]] name = "const-oid" -version = "0.6.2" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d6f2aa4d0537bcc1c74df8755072bd31c1ef1a3a1b85a68e8404a8c353b7b8b" +checksum = "e4c78c047431fee22c1a7bb92e00ad095a02a983affe4d8a72e2a2c62c1b94f3" [[package]] name = "const_fn" @@ -792,18 +857,18 @@ dependencies = [ [[package]] name = "crc" -version = "2.1.0" +version = "3.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49fc9a695bca7f35f5f4c15cddc84415f66a74ea78eef08e90c5024f2b540e23" +checksum = "53757d12b596c16c78b83458d732a5d1a17ab3f53f2f7412f6fb57cc8a140ab3" dependencies = [ "crc-catalog", ] [[package]] name = "crc-catalog" -version = "1.1.1" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ccaeedb56da03b09f598226e25e80088cb4cd25f316e6e4df7d695f0feeb1403" +checksum = "2d0165d2900ae6778e36e80bbc4da3b5eefccee9ba939761f9c2882a5d9af3ff" [[package]] name = "crc32fast" @@ -882,15 +947,24 @@ dependencies = [ [[package]] name = "crypto-bigint" -version = "0.2.11" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f83bd3bb4314701c568e340cd8cf78c975aa0ca79e03d3f6d1677d5b0c9c0c03" +checksum = "03c6a1d5fa1de37e071642dfa44ec552ca5b299adb128fab16138e24b548fd21" dependencies = [ "generic-array", - "rand_core 0.6.3", "subtle", ] +[[package]] +name = "crypto-common" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +dependencies = [ + "generic-array", + "typenum", +] + [[package]] name = "crypto-mac" version = "0.10.1" @@ -919,18 +993,62 @@ checksum = "f3b7eb4404b8195a9abb6356f4ac07d8ba267045c8d6d220ac4dc992e6cc75df" [[package]] name = "curve25519-dalek" -version = "3.2.1" +version = "3.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90f9d052967f590a76e62eb387bd0bbb1b000182c3cefe5364db6b7211651bc0" +checksum = "0b9fdf9972b2bd6af2d913799d9ebc165ea4d2e65878e329d9c6b372c4491b61" dependencies = [ "byteorder", - "digest", + "digest 0.9.0", "rand_core 0.5.1", "serde", "subtle", "zeroize", ] +[[package]] +name = "cxx" +version = "1.0.82" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4a41a86530d0fe7f5d9ea779916b7cadd2d4f9add748b99c2c029cbbdfaf453" +dependencies = [ + "cc", + "cxxbridge-flags", + "cxxbridge-macro", + "link-cplusplus", +] + +[[package]] +name = "cxx-build" +version = "1.0.82" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06416d667ff3e3ad2df1cd8cd8afae5da26cf9cec4d0825040f88b5ca659a2f0" +dependencies = [ + "cc", + "codespan-reporting", + "once_cell", + "proc-macro2", + "quote", + "scratch", + "syn", +] + +[[package]] +name = "cxxbridge-flags" +version = "1.0.82" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "820a9a2af1669deeef27cb271f476ffd196a2c4b6731336011e0ba63e2c7cf71" + +[[package]] +name = "cxxbridge-macro" +version = "1.0.82" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a08a6e2fcc370a089ad3b4aaf54db3b1b4cee38ddabce5896b33eb693275f470" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "darling" version = "0.12.4" @@ -974,12 +1092,13 @@ checksum = "3ee2393c4a91429dffb4bedf19f4d6abf27d8a732c8ce4980305d782e5426d57" [[package]] name = "der" -version = "0.4.5" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79b71cca7d95d7681a4b3b9cdf63c8dbc3730d0584c2c74e31416d64a90493f4" +checksum = "6919815d73839e7ad218de758883aae3a257ba6759ce7a9992501efbb53d705c" dependencies = [ "const-oid", "crypto-bigint", + "pem-rfc7468", ] [[package]] @@ -991,7 +1110,7 @@ dependencies = [ "asn1-rs", "displaydoc 0.2.3", "nom 7.1.1", - "num-bigint 0.4.3", + "num-bigint", "num-traits", "rusticata-macros", ] @@ -1066,6 +1185,17 @@ dependencies = [ "generic-array", ] +[[package]] +name = "digest" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8168378f4e5023e7218c89c891c0fd8ecdb5e5e4f18cb78f38cf245dd021e76f" +dependencies = [ + "block-buffer 0.10.3", + "crypto-common", + "subtle", +] + [[package]] name = "dirs" version = "4.0.0" @@ -1115,10 +1245,10 @@ dependencies = [ ] [[package]] -name = "dotenv" -version = "0.15.0" +name = "dotenvy" +version = "0.15.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77c90badedccf4105eca100756a0b1289e191f6fcbdadd3cee1d2f614f97da8f" +checksum = "03d8c417d7a8cb362e0c37e5d815f5eb7c37f79ff93707329d5a194e42e54ca0" [[package]] name = "downcast" @@ -1157,6 +1287,12 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "event-listener" +version = "2.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0" + [[package]] name = "failure" version = "0.1.8" @@ -1629,20 +1765,20 @@ checksum = "d7afe4a420e3fe79967a00898cc1f4db7c8a49a9333a29f8a4bd76a253d5cd04" [[package]] name = "hashbrown" -version = "0.11.2" +version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e" +checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" dependencies = [ "ahash", ] [[package]] name = "hashlink" -version = "0.7.0" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7249a3129cbc1ffccd74857f81464a323a152173cdb134e0fd81bc803b29facf" +checksum = "69fe1fcf8b4278d860ad0548329f892a3631fb63f82574df68275f34cdbe0ffa" dependencies = [ - "hashbrown 0.11.2", + "hashbrown 0.12.3", ] [[package]] @@ -1659,6 +1795,9 @@ name = "heck" version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2540771e65fc8cb83cd6e8a237f70c319bd5c29f78ed1084ba5d50eeac86f7f9" +dependencies = [ + "unicode-segmentation", +] [[package]] name = "hermit-abi" @@ -1681,10 +1820,19 @@ version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "01706d578d5c281058480e673ae4086a9f4710d8df1ad80a5b03e39ece5f886b" dependencies = [ - "digest", + "digest 0.9.0", "hmac 0.11.0", ] +[[package]] +name = "hkdf" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "791a029f6b9fc27657f6f188ec6e5e43f6911f6f878e0dc5501396e09809d437" +dependencies = [ + "hmac 0.12.1", +] + [[package]] name = "hmac" version = "0.10.1" @@ -1692,7 +1840,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c1441c6b1e930e2817404b5046f1f989899143a12bf92de603b69f4e0aee1e15" dependencies = [ "crypto-mac 0.10.1", - "digest", + "digest 0.9.0", ] [[package]] @@ -1702,7 +1850,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2a2a2320eb7ec0ebe8da8f744d7812d9fc4cb4d09344ac01898dbcb6a20ae69b" dependencies = [ "crypto-mac 0.11.1", - "digest", + "digest 0.9.0", +] + +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest 0.10.6", ] [[package]] @@ -1782,6 +1939,30 @@ dependencies = [ "tokio-rustls 0.23.4", ] +[[package]] +name = "iana-time-zone" +version = "0.1.53" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64c122667b287044802d6ce17ee2ddf13207ed924c712de9a66a5814d5b64765" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "wasm-bindgen", + "winapi", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0703ae284fc167426161c2e3f1da3ea71d94b21bedbcc9494e92b28e334e3dca" +dependencies = [ + "cxx", + "cxx-build", +] + [[package]] name = "ident_case" version = "1.0.1" @@ -1825,7 +2006,7 @@ version = "1.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "824845a0bf897a9042383849b02c1bc219c2383772efcd5c6f9766fa4b81aef3" dependencies = [ - "autocfg 1.1.0", + "autocfg", "hashbrown 0.9.1", "serde", ] @@ -1915,7 +2096,7 @@ dependencies = [ "smartstring", "static_assertions", "url", - "uuid", + "uuid 0.8.2", ] [[package]] @@ -1957,11 +2138,11 @@ checksum = "86e46349d67dc03bdbdb28da0337a355a53ca1d5156452722c36fe21d0e6389b" dependencies = [ "base64", "crypto-mac 0.10.1", - "digest", + "digest 0.9.0", "hmac 0.10.1", "serde", "serde_json", - "sha2", + "sha2 0.9.9", ] [[package]] @@ -2083,15 +2264,24 @@ checksum = "33a33a362ce288760ec6a508b94caaec573ae7d3bbbd91b87aa0bad4456839db" [[package]] name = "libsqlite3-sys" -version = "0.23.2" +version = "0.24.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2cafc7c74096c336d9d27145f7ebd4f4b6f95ba16aa5a282387267e6925cb58" +checksum = "898745e570c7d0453cc1fbc4a701eb6c662ed54e8fec8b7d14be137ebeeb9d14" dependencies = [ "cc", "pkg-config", "vcpkg", ] +[[package]] +name = "link-cplusplus" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9272ab7b96c9046fbc5bc56c06c117cb639fe2d509df0c421cad82d2915cf369" +dependencies = [ + "cc", +] + [[package]] name = "linked-hash-map" version = "0.5.6" @@ -2141,15 +2331,13 @@ dependencies = [ "reqwest", "rustls 0.20.6", "rustls-pemfile", + "sea-orm", "sea-query", - "sea-query-binder", "secstr", "serde", "serde_bytes", "serde_json", - "sha2", - "sqlx", - "sqlx-core", + "sha2 0.9.9", "thiserror", "time 0.2.27", "tokio", @@ -2162,7 +2350,7 @@ dependencies = [ "tracing-forest", "tracing-log", "tracing-subscriber", - "uuid", + "uuid 1.2.2", "webpki-roots 0.22.4", ] @@ -2200,14 +2388,14 @@ version = "0.3.0-alpha.1" dependencies = [ "chrono", "curve25519-dalek", - "digest", + "digest 0.9.0", "generic-array", "getrandom 0.2.7", "opaque-ke", "rand 0.8.5", "rust-argon2", "serde", - "sha2", + "sha2 0.9.9", "thiserror", ] @@ -2235,7 +2423,7 @@ version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "327fa5b6a6940e4699ec49a9beae1ea4845c6bab9314e4f84ac68742139d8c53" dependencies = [ - "autocfg 1.1.0", + "autocfg", "scopeguard", ] @@ -2265,21 +2453,13 @@ checksum = "a3e378b66a060d48947b590737b30a1be76706c8dd7b8ba0f2fe3989c68a853f" [[package]] name = "md-5" -version = "0.9.1" +version = "0.10.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b5a279bb9607f9f53c22d496eade00d138d1bdcccd07d74650387cf94942a15" +checksum = "6365506850d44bff6e2fbcb5176cf63650e48bd45ef2fe2665ae1570e0f4b9ca" dependencies = [ - "block-buffer", - "digest", - "opaque-debug", + "digest 0.10.6", ] -[[package]] -name = "md5" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771" - [[package]] name = "memchr" version = "2.5.0" @@ -2437,35 +2617,23 @@ dependencies = [ "winapi", ] -[[package]] -name = "num-bigint" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f6f7833f2cbf2360a6cfd58cd41a53aa7a90bd4c202f5b1c7dd2ed73c57b2c3" -dependencies = [ - "autocfg 1.1.0", - "num-integer", - "num-traits", -] - [[package]] name = "num-bigint" version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f93ab6289c7b344a8a9f60f88d80aa20032336fe78da341afc91c8a2341fc75f" dependencies = [ - "autocfg 1.1.0", + "autocfg", "num-integer", "num-traits", ] [[package]] name = "num-bigint-dig" -version = "0.7.0" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4547ee5541c18742396ae2c895d0717d0f886d8823b8399cdaf7b07d63ad0480" +checksum = "2399c9463abc5f909349d8aa9ba080e0b88b3ce2885389b60b993f39b1a56905" dependencies = [ - "autocfg 0.1.8", "byteorder", "lazy_static", "libm", @@ -2483,7 +2651,7 @@ version = "0.1.45" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" dependencies = [ - "autocfg 1.1.0", + "autocfg", "num-traits", ] @@ -2493,7 +2661,7 @@ version = "0.1.43" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7d03e6c028c5dc5cac6e2dec0efda81fc887605bb3d884578bb6d6bf7514e252" dependencies = [ - "autocfg 1.1.0", + "autocfg", "num-integer", "num-traits", ] @@ -2504,7 +2672,7 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0" dependencies = [ - "autocfg 1.1.0", + "autocfg", "num-integer", "num-traits", ] @@ -2515,7 +2683,7 @@ version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" dependencies = [ - "autocfg 1.1.0", + "autocfg", "libm", ] @@ -2571,16 +2739,15 @@ checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" [[package]] name = "opaque-ke" version = "0.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "507fdf0b89eabfde58445f88807f57f253f72236e44960ebf690d897803cd18d" +source = "git+https://github.com/nitnelave/opaque-ke/?branch=zeroize_1.5#308a8dfee7eb855923187d2b63d64a0aaf516304" dependencies = [ "base64", "curve25519-dalek", - "digest", + "digest 0.9.0", "displaydoc 0.1.7", "generic-array", "generic-bytes", - "hkdf", + "hkdf 0.11.0", "hmac 0.11.0", "rand 0.8.5", "serde", @@ -2613,6 +2780,29 @@ version = "6.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "21326818e99cfe6ce1e524c2a805c189a99b5ae555a35d19f9a284b427d86afa" +[[package]] +name = "ouroboros" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfbb50b356159620db6ac971c6d5c9ab788c9cc38a6f49619fca2a27acb062ca" +dependencies = [ + "aliasable", + "ouroboros_macro", +] + +[[package]] +name = "ouroboros_macro" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a0d9d1a6191c4f391f87219d1ea42b23f09ee84d64763cd05ee6ea88d9f384d" +dependencies = [ + "Inflector", + "proc-macro-error", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "parking_lot" version = "0.11.2" @@ -2692,9 +2882,9 @@ dependencies = [ [[package]] name = "pem-rfc7468" -version = "0.2.4" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "84e93a3b1cc0510b03020f33f21e62acdde3dcaef432edc95bea377fbd4c2cd4" +checksum = "01de5d978f34aa4b2296576379fcc416034702fd94117c56ffd8a1a767cefb30" dependencies = [ "base64ct", ] @@ -2739,24 +2929,22 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "pkcs1" -version = "0.2.4" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "116bee8279d783c0cf370efa1a94632f2108e5ef0bb32df31f051647810a4e2c" +checksum = "a78f66c04ccc83dd4486fd46c33896f4e17b24a7a3a6400dedc48ed0ddd72320" dependencies = [ "der", - "pem-rfc7468", + "pkcs8", "zeroize", ] [[package]] name = "pkcs8" -version = "0.7.6" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee3ef9b64d26bad0536099c816c6734379e45bbd5f14798def6809e5cc350447" +checksum = "7cabda3fb821068a9a4fab19a683eac3af12edf0f34b94a8be53c4972b8149d0" dependencies = [ "der", - "pem-rfc7468", - "pkcs1", "spki", "zeroize", ] @@ -3075,20 +3263,20 @@ dependencies = [ [[package]] name = "rsa" -version = "0.5.0" +version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e05c2603e2823634ab331437001b411b9ed11660fbc4066f3908c84a9439260d" +checksum = "4cf22754c49613d2b3b119f0e5d46e34a2c628a937e3024b8762de4e7d8c710b" dependencies = [ "byteorder", - "digest", - "lazy_static", + "digest 0.10.6", "num-bigint-dig", "num-integer", "num-iter", "num-traits", "pkcs1", "pkcs8", - "rand 0.8.5", + "rand_core 0.6.3", + "smallvec", "subtle", "zeroize", ] @@ -3184,6 +3372,12 @@ dependencies = [ "base64", ] +[[package]] +name = "rustversion" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97477e48b4cf8603ad5f7aaf897467cf42ab4218a38ef76fb14c2d6773a6d6a8" + [[package]] name = "ryu" version = "1.0.10" @@ -3206,6 +3400,12 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" +[[package]] +name = "scratch" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8132065adcfd6e02db789d9285a0deb2f3fcb04002865ab67d5fb103533898" + [[package]] name = "sct" version = "0.6.1" @@ -3227,24 +3427,64 @@ dependencies = [ ] [[package]] -name = "sea-query" -version = "0.25.2" +name = "sea-orm" +version = "0.10.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6afb1ec318e6cfd4a7586bc9591afc0667783e335b0a3394661e1b0258fd84e" +checksum = "8744afc95ca462de12c2cea5a56d7e406f3be2b2683d3b05066e1afdba898bc5" +dependencies = [ + "async-stream", + "async-trait", + "chrono", + "futures", + "futures-util", + "log", + "ouroboros", + "sea-orm-macros", + "sea-query", + "sea-query-binder", + "sea-strum", + "serde", + "sqlx", + "thiserror", + "tracing", + "url", + "uuid 1.2.2", +] + +[[package]] +name = "sea-orm-macros" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ca4d01381fdcabc3818b6d39c5f1f0c885900af90da638e4001406907462784" +dependencies = [ + "bae", + "heck 0.3.3", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "sea-query" +version = "0.27.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4f0fc4d8e44e1d51c739a68d336252a18bc59553778075d5e32649be6ec92ed" dependencies = [ "chrono", "sea-query-derive", - "sea-query-driver", + "uuid 1.2.2", ] [[package]] name = "sea-query-binder" -version = "0.1.0" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a290a2d8cdb3b79a2a7f867385eff441e27533727aecf4d505baa58759c49ba9" +checksum = "9c2585b89c985cfacfe0ec9fc9e7bb055b776c1a2581c4e3c6185af2b8bf8865" dependencies = [ + "chrono", "sea-query", "sqlx", + "uuid 1.2.2", ] [[package]] @@ -3261,13 +3501,24 @@ dependencies = [ ] [[package]] -name = "sea-query-driver" -version = "0.1.1" +name = "sea-strum" +version = "0.23.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e3953baee94dcb90f0e19e8b4b91b91e9394867b0fc1886d0221cfc6d0439f5" +checksum = "391d06a6007842cfe79ac6f7f53911b76dfd69fc9a6769f1cf6569d12ce20e1b" dependencies = [ + "sea-strum_macros", +] + +[[package]] +name = "sea-strum_macros" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69b4397b825df6ccf1e98bcdabef3bbcfc47ff5853983467850eeab878384f21" +dependencies = [ + "heck 0.3.3", "proc-macro2", "quote", + "rustversion", "syn", ] @@ -3384,10 +3635,10 @@ version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "99cd6713db3cf16b6c84e06321e049a9b9f699826e16096d23bbcc44d15d51a6" dependencies = [ - "block-buffer", + "block-buffer 0.9.0", "cfg-if", "cpufeatures", - "digest", + "digest 0.9.0", "opaque-debug", ] @@ -3400,6 +3651,17 @@ dependencies = [ "sha1_smol", ] +[[package]] +name = "sha1" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f04293dc80c3993519f2d7f6f511707ee7094fe0c6d3406feb330cdb3540eba3" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest 0.10.6", +] + [[package]] name = "sha1_smol" version = "1.0.0" @@ -3412,13 +3674,24 @@ version = "0.9.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4d58a1e1bf39749807d89cf2d98ac2dfa0ff1cb3faa38fbb64dd88ac8013d800" dependencies = [ - "block-buffer", + "block-buffer 0.9.0", "cfg-if", "cpufeatures", - "digest", + "digest 0.9.0", "opaque-debug", ] +[[package]] +name = "sha2" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82e6b795fe2e3b1e845bafcb27aa35405c4d47cdfc92af5fc8d3002f76cebdc0" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest 0.10.6", +] + [[package]] name = "sharded-slab" version = "0.1.4" @@ -3466,9 +3739,9 @@ checksum = "eb703cfe953bccee95685111adeedb76fabe4e97549a58d16f03ea7b9367bb32" [[package]] name = "smallvec" -version = "1.9.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2fd0db749597d91ff862fd1d55ea87f7855a744a8425a64695b6fca237d1dad1" +checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" [[package]] name = "smartstring" @@ -3512,18 +3785,19 @@ dependencies = [ [[package]] name = "spki" -version = "0.4.1" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c01a0c15da1b0b0e1494112e7af814a678fec9bd157881b49beac661e9b6f32" +checksum = "44d01ac02a6ccf3e07db148d2be087da624fea0221a16152ed01f0496a6b0a27" dependencies = [ + "base64ct", "der", ] [[package]] name = "sqlformat" -version = "0.1.8" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4b7922be017ee70900be125523f38bdd644f4f06a1b16e8fa5a8ee8c34bffd4" +checksum = "f87e292b4291f154971a43c3774364e2cbcaec599d3f5bf6fa9d122885dbc38a" dependencies = [ "itertools", "nom 7.1.1", @@ -3532,9 +3806,9 @@ dependencies = [ [[package]] name = "sqlx" -version = "0.5.11" +version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc15591eb44ffb5816a4a70a7efd5dd87bfd3aa84c4c200401c4396140525826" +checksum = "9249290c05928352f71c077cc44a464d880c63f26f7534728cca008e135c0428" dependencies = [ "sqlx-core", "sqlx-macros", @@ -3542,9 +3816,9 @@ dependencies = [ [[package]] name = "sqlx-core" -version = "0.5.11" +version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "195183bf6ff8328bb82c0511a83faf60aacf75840103388851db61d7a9854ae3" +checksum = "dcbc16ddba161afc99e14d1713a453747a2b07fc097d2009f4c300ec99286105" dependencies = [ "ahash", "atoi", @@ -3555,9 +3829,11 @@ dependencies = [ "chrono", "crc", "crossbeam-queue", - "digest", + "digest 0.10.6", "dirs", + "dotenvy", "either", + "event-listener", "flume", "futures-channel", "futures-core", @@ -3567,7 +3843,8 @@ dependencies = [ "generic-array", "hashlink", "hex", - "hmac 0.11.0", + "hkdf 0.12.3", + "hmac 0.12.1", "indexmap", "itoa 1.0.2", "libc", @@ -3575,17 +3852,18 @@ dependencies = [ "log", "md-5", "memchr", - "num-bigint 0.3.3", + "num-bigint", "once_cell", "paste", "percent-encoding", "rand 0.8.5", "rsa", - "rustls 0.19.1", + "rustls 0.20.6", + "rustls-pemfile", "serde", "serde_json", - "sha-1", - "sha2", + "sha1 0.10.5", + "sha2 0.10.6", "smallvec", "sqlformat", "sqlx-rt", @@ -3593,24 +3871,24 @@ dependencies = [ "thiserror", "tokio-stream", "url", - "webpki 0.21.4", - "webpki-roots 0.21.1", + "uuid 1.2.2", + "webpki-roots 0.22.4", "whoami", ] [[package]] name = "sqlx-macros" -version = "0.5.11" +version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eee35713129561f5e55c554bba1c378e2a7e67f81257b7311183de98c50e6f94" +checksum = "b850fa514dc11f2ee85be9d055c512aa866746adfacd1cb42d867d68e6a5b0d9" dependencies = [ - "dotenv", + "dotenvy", "either", - "heck 0.3.3", + "heck 0.4.0", "once_cell", "proc-macro2", "quote", - "sha2", + "sha2 0.10.6", "sqlx-core", "sqlx-rt", "syn", @@ -3619,14 +3897,13 @@ dependencies = [ [[package]] name = "sqlx-rt" -version = "0.5.13" +version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4db708cd3e459078f85f39f96a00960bd841f66ee2a669e90bf36907f5a79aae" +checksum = "24c5b2d25fa654cc5f841750b8e1cdedbe21189bf9a9382ee90bfa9dd3562396" dependencies = [ - "actix-rt", "once_cell", "tokio", - "tokio-rustls 0.22.0", + "tokio-rustls 0.23.4", ] [[package]] @@ -3683,7 +3960,7 @@ dependencies = [ "serde", "serde_derive", "serde_json", - "sha1", + "sha1 0.6.1", "syn", ] @@ -4003,6 +4280,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a400e31aa60b9d44a52a8ee0343b5b18566b03a8321e0d321f695cf56e940160" dependencies = [ "cfg-if", + "log", "pin-project-lite", "tracing-attributes", "tracing-core", @@ -4018,7 +4296,7 @@ dependencies = [ "futures", "tracing", "tracing-futures", - "uuid", + "uuid 0.8.2", ] [[package]] @@ -4222,7 +4500,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bc5cf98d8186244414c848017f0e2676b3fcb46807f6668a97dfe67359a3c4b7" dependencies = [ "getrandom 0.2.7", - "md5", +] + +[[package]] +name = "uuid" +version = "1.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "422ee0de9031b5b948b97a8fc04e3aa35230001a722ddd27943e0be31564ce4c" +dependencies = [ + "getrandom 0.2.7", + "md-5", + "serde", ] [[package]] @@ -4671,9 +4959,9 @@ dependencies = [ [[package]] name = "zeroize" -version = "1.1.1" +version = "1.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05f33972566adbd2d3588b0491eb94b98b43695c4ef897903470ede4f3f5a28a" +checksum = "c394b5bd0c6f669e7275d9c20aa90ae064cb22e75a1cad54e1b34088034b149f" dependencies = [ "zeroize_derive", ] diff --git a/Cargo.toml b/Cargo.toml index 71c003e..f2283cd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,3 +12,7 @@ default-members = ["server"] [patch.crates-io.ldap3_proto] git = 'https://github.com/nitnelave/ldap3_server/' rev = '7b50b2b82c383f5f70e02e11072bb916629ed2bc' + +[patch.crates-io.opaque-ke] +git = 'https://github.com/nitnelave/opaque-ke/' +branch = 'zeroize_1.5' diff --git a/app/Cargo.toml b/app/Cargo.toml index 1769536..65fa335 100644 --- a/app/Cargo.toml +++ b/app/Cargo.toml @@ -3,6 +3,7 @@ name = "lldap_app" version = "0.4.2-alpha" authors = ["Valentin Tolmer "] edition = "2021" +include = ["src/**/*", "queries/**/*", "Cargo.toml", "../schema.graphql"] [dependencies] anyhow = "1" diff --git a/app/src/components/create_user.rs b/app/src/components/create_user.rs index b167bef..b7ece1c 100644 --- a/app/src/components/create_user.rs +++ b/app/src/components/create_user.rs @@ -38,7 +38,6 @@ pub struct CreateUserModel { username: String, #[validate(email(message = "A valid email is required"))] email: String, - #[validate(length(min = 1, message = "Display name is required"))] display_name: String, first_name: String, last_name: String, @@ -244,9 +243,7 @@ impl Component for CreateUserForm {
- {g.creation_date.date().naive_local()} + {g.creation_date.naive_local().date()}
diff --git a/app/src/components/group_table.rs b/app/src/components/group_table.rs index 72d9145..5087610 100644 --- a/app/src/components/group_table.rs +++ b/app/src/components/group_table.rs @@ -124,7 +124,7 @@ impl GroupTable { - {&group.creation_date.date().naive_local()} + {&group.creation_date.naive_local().date()} ; let avatar_base64 = maybe_to_base64(&self.avatar).unwrap_or_default(); - let avatar_string = avatar_base64.as_ref().unwrap_or(&self.common.user.avatar); + let avatar_string = avatar_base64 + .as_deref() + .or(self.common.user.avatar.as_deref()) + .unwrap_or(""); html! {
@@ -195,7 +197,7 @@ impl Component for UserDetailsForm { {"Creation date: "}
- {&self.common.user.creation_date.date().naive_local()} + {&self.common.user.creation_date.naive_local().date()}
@@ -231,9 +233,7 @@ impl Component for UserDetailsForm {
{&user.display_name} {&user.first_name} {&user.last_name} - {&user.creation_date.date().naive_local()} + {&user.creation_date.naive_local().date()} Result> { pub fn delete_cookie(cookie_name: &str) -> Result<()> { if get_cookie(cookie_name)?.is_some() { - set_cookie(cookie_name, "", &Utc.ymd(1970, 1, 1).and_hms(0, 0, 0)) + set_cookie( + cookie_name, + "", + &Utc.with_ymd_and_hms(1970, 1, 1, 0, 0, 0).unwrap(), + ) } else { Ok(()) } diff --git a/schema.graphql b/schema.graphql index b776344..4c88bd7 100644 --- a/schema.graphql +++ b/schema.graphql @@ -69,7 +69,7 @@ type User { displayName: String! firstName: String! lastName: String! - avatar: String! + avatar: String creationDate: DateTimeUtc! uuid: String! "The groups to which this user belongs." diff --git a/server/Cargo.toml b/server/Cargo.toml index 44a6a32..9c38656 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -35,7 +35,6 @@ rustls = "0.20" serde = "*" serde_json = "1" sha2 = "0.9" -sqlx-core = "0.5.11" thiserror = "*" time = "0.2" tokio-rustls = "0.23" @@ -70,28 +69,12 @@ features = ["builder", "serde", "smtp-transport", "tokio1-rustls-tls"] default-features = false version = "0.10.0-rc.3" -[dependencies.sqlx] -version = "0.5.11" -features = [ - "any", - "chrono", - "macros", - "mysql", - "postgres", - "runtime-actix-rustls", - "sqlite", -] - [dependencies.lldap_auth] path = "../auth" [dependencies.sea-query] -version = "^0.25" -features = ["with-chrono", "sqlx-sqlite"] - -[dependencies.sea-query-binder] -version = "0.1" -features = ["with-chrono", "sqlx-sqlite", "sqlx-any"] +version = "*" +features = ["with-chrono"] [dependencies.opaque-ke] version = "0.6" @@ -125,6 +108,11 @@ features = ["jpeg"] default-features = false version = "0.24" +[dependencies.sea-orm] +version= "0.10.3" +default-features = false +features = ["macros", "with-chrono", "with-uuid", "sqlx-all", "runtime-actix-rustls"] + [dependencies.reqwest] version = "0.11" default-features = false diff --git a/server/src/domain/error.rs b/server/src/domain/error.rs index 5c9e38d..d5ec981 100644 --- a/server/src/domain/error.rs +++ b/server/src/domain/error.rs @@ -6,7 +6,7 @@ pub enum DomainError { #[error("Authentication error: `{0}`")] AuthenticationError(String), #[error("Database error: `{0}`")] - DatabaseError(#[from] sqlx::Error), + DatabaseError(#[from] sea_orm::DbErr), #[error("Authentication protocol error for `{0}`")] AuthenticationProtocolError(#[from] lldap_auth::opaque::AuthenticationError), #[error("Unknown crypto error: `{0}`")] @@ -15,6 +15,8 @@ pub enum DomainError { BinarySerializationError(#[from] bincode::Error), #[error("Invalid base64: `{0}`")] Base64DecodeError(#[from] base64::DecodeError), + #[error("Entity not found: `{0}`")] + EntityNotFound(String), #[error("Internal error: `{0}`")] InternalError(String), } diff --git a/server/src/domain/handler.rs b/server/src/domain/handler.rs index 8b1a622..9eb0d52 100644 --- a/server/src/domain/handler.rs +++ b/server/src/domain/handler.rs @@ -1,13 +1,12 @@ -use super::{error::*, sql_tables::UserColumn}; +use super::error::*; use async_trait::async_trait; use serde::{Deserialize, Serialize}; use std::collections::HashSet; -#[derive( - PartialEq, Hash, Eq, Clone, Debug, Default, Serialize, Deserialize, sqlx::FromRow, sqlx::Type, -)] +pub use super::model::{GroupColumn, UserColumn}; + +#[derive(PartialEq, Hash, Eq, Clone, Debug, Default, Serialize, Deserialize)] #[serde(try_from = "&str")] -#[sqlx(transparent)] pub struct Uuid(String); impl Uuid { @@ -43,17 +42,26 @@ impl std::string::ToString for Uuid { } } +impl sea_orm::TryGetable for Uuid { + fn try_get( + res: &sea_orm::QueryResult, + pre: &str, + col: &str, + ) -> std::result::Result { + Ok(Uuid(String::try_get(res, pre, col)?)) + } +} + #[cfg(test)] #[macro_export] macro_rules! uuid { ($s:literal) => { - $crate::domain::handler::Uuid::try_from($s).unwrap() + <$crate::domain::handler::Uuid as std::convert::TryFrom<_>>::try_from($s).unwrap() }; } -#[derive(PartialEq, Eq, Clone, Debug, Default, Serialize, Deserialize, sqlx::Type)] +#[derive(PartialEq, Eq, Clone, Debug, Default, Serialize, Deserialize)] #[serde(from = "String")] -#[sqlx(transparent)] pub struct UserId(String); impl UserId { @@ -82,17 +90,22 @@ impl From for UserId { } } -#[derive(PartialEq, Eq, Clone, Debug, Default, Serialize, Deserialize, sqlx::Type)] -#[sqlx(transparent)] +#[derive(PartialEq, Eq, Clone, Debug, Serialize, Deserialize)] pub struct JpegPhoto(#[serde(with = "serde_bytes")] Vec); -impl From for sea_query::Value { +impl JpegPhoto { + pub fn null() -> Self { + Self(vec![]) + } +} + +impl From for sea_orm::Value { fn from(photo: JpegPhoto) -> Self { photo.0.into() } } -impl From<&JpegPhoto> for sea_query::Value { +impl From<&JpegPhoto> for sea_orm::Value { fn from(photo: &JpegPhoto) -> Self { photo.0.as_slice().into() } @@ -101,6 +114,9 @@ impl From<&JpegPhoto> for sea_query::Value { impl TryFrom<&[u8]> for JpegPhoto { type Error = anyhow::Error; fn try_from(bytes: &[u8]) -> anyhow::Result { + if bytes.is_empty() { + return Ok(JpegPhoto::null()); + } // Confirm that it's a valid Jpeg, then store only the bytes. image::io::Reader::with_format(std::io::Cursor::new(bytes), image::ImageFormat::Jpeg) .decode()?; @@ -111,6 +127,9 @@ impl TryFrom<&[u8]> for JpegPhoto { impl TryFrom> for JpegPhoto { type Error = anyhow::Error; fn try_from(bytes: Vec) -> anyhow::Result { + if bytes.is_empty() { + return Ok(JpegPhoto::null()); + } // Confirm that it's a valid Jpeg, then store only the bytes. image::io::Reader::with_format( std::io::Cursor::new(bytes.as_slice()), @@ -160,14 +179,14 @@ impl JpegPhoto { } } -#[derive(PartialEq, Eq, Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] +#[derive(PartialEq, Eq, Debug, Clone, Serialize, Deserialize, sea_orm::FromQueryResult)] pub struct User { pub user_id: UserId, pub email: String, - pub display_name: String, - pub first_name: String, - pub last_name: String, - pub avatar: JpegPhoto, + pub display_name: Option, + pub first_name: Option, + pub last_name: Option, + pub avatar: Option, pub creation_date: chrono::DateTime, pub uuid: Uuid, } @@ -176,14 +195,14 @@ pub struct User { impl Default for User { fn default() -> Self { use chrono::TimeZone; - let epoch = chrono::Utc.timestamp(0, 0); + let epoch = chrono::Utc.timestamp_opt(0, 0).unwrap(); User { user_id: UserId::default(), email: String::new(), - display_name: String::new(), - first_name: String::new(), - last_name: String::new(), - avatar: JpegPhoto::default(), + display_name: None, + first_name: None, + last_name: None, + avatar: None, creation_date: epoch, uuid: Uuid::from_name_and_date("", &epoch), } @@ -263,11 +282,10 @@ pub trait LoginHandler: Clone + Send { async fn bind(&self, request: BindRequest) -> Result<()>; } -#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)] -#[sqlx(transparent)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct GroupId(pub i32); -#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::FromRow)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, sea_orm::FromQueryResult)] pub struct GroupDetails { pub group_id: GroupId, pub display_name: String, @@ -349,8 +367,8 @@ mod tests { fn test_uuid_time() { use chrono::prelude::*; let user_id = "bob"; - let date1 = Utc.ymd(2014, 7, 8).and_hms(9, 10, 11); - let date2 = Utc.ymd(2014, 7, 8).and_hms(9, 10, 12); + let date1 = Utc.with_ymd_and_hms(2014, 7, 8, 9, 10, 11).unwrap(); + let date2 = Utc.with_ymd_and_hms(2014, 7, 8, 9, 10, 12).unwrap(); assert_ne!( Uuid::from_name_and_date(user_id, &date1), Uuid::from_name_and_date(user_id, &date2) diff --git a/server/src/domain/ldap/group.rs b/server/src/domain/ldap/group.rs index 922ffca..cd4759a 100644 --- a/server/src/domain/ldap/group.rs +++ b/server/src/domain/ldap/group.rs @@ -4,9 +4,8 @@ use ldap3_proto::{ use tracing::{debug, info, instrument, warn}; use crate::domain::{ - handler::{BackendHandler, Group, GroupRequestFilter, UserId, Uuid}, + handler::{BackendHandler, Group, GroupColumn, GroupRequestFilter, UserId, Uuid}, ldap::error::LdapError, - sql_tables::GroupColumn, }; use super::{ diff --git a/server/src/domain/ldap/user.rs b/server/src/domain/ldap/user.rs index 060d83b..35919e1 100644 --- a/server/src/domain/ldap/user.rs +++ b/server/src/domain/ldap/user.rs @@ -4,9 +4,8 @@ use ldap3_proto::{ use tracing::{debug, info, instrument, warn}; use crate::domain::{ - handler::{BackendHandler, GroupDetails, User, UserId, UserRequestFilter}, + handler::{BackendHandler, GroupDetails, User, UserColumn, UserId, UserRequestFilter}, ldap::{error::LdapError, utils::expand_attribute_wildcards}, - sql_tables::UserColumn, }; use super::{ @@ -34,9 +33,9 @@ fn get_user_attribute( "uid" => vec![user.user_id.to_string().into_bytes()], "entryuuid" => vec![user.uuid.to_string().into_bytes()], "mail" => vec![user.email.clone().into_bytes()], - "givenname" => vec![user.first_name.clone().into_bytes()], - "sn" => vec![user.last_name.clone().into_bytes()], - "jpegphoto" => vec![user.avatar.clone().into_bytes()], + "givenname" => vec![user.first_name.clone()?.into_bytes()], + "sn" => vec![user.last_name.clone()?.into_bytes()], + "jpegphoto" => vec![user.avatar.clone()?.into_bytes()], "memberof" => groups .into_iter() .flatten() @@ -48,7 +47,7 @@ fn get_user_attribute( .into_bytes() }) .collect(), - "cn" | "displayname" => vec![user.display_name.clone().into_bytes()], + "cn" | "displayname" => vec![user.display_name.clone()?.into_bytes()], "createtimestamp" | "modifytimestamp" => vec![user.creation_date.to_rfc3339().into_bytes()], "1.1" => return None, // We ignore the operational attribute wildcard. diff --git a/server/src/domain/ldap/utils.rs b/server/src/domain/ldap/utils.rs index 05fefb4..b1de711 100644 --- a/server/src/domain/ldap/utils.rs +++ b/server/src/domain/ldap/utils.rs @@ -2,10 +2,7 @@ use itertools::Itertools; use ldap3_proto::LdapResultCode; use tracing::{debug, instrument, warn}; -use crate::domain::{ - handler::UserId, - sql_tables::{GroupColumn, UserColumn}, -}; +use crate::domain::handler::{GroupColumn, UserColumn, UserId}; use super::error::{LdapError, LdapResult}; diff --git a/server/src/domain/mod.rs b/server/src/domain/mod.rs index b139c59..8047331 100644 --- a/server/src/domain/mod.rs +++ b/server/src/domain/mod.rs @@ -1,6 +1,7 @@ pub mod error; pub mod handler; pub mod ldap; +pub mod model; pub mod opaque_handler; pub mod sql_backend_handler; pub mod sql_group_backend_handler; diff --git a/server/src/domain/model/groups.rs b/server/src/domain/model/groups.rs new file mode 100644 index 0000000..eb5342e --- /dev/null +++ b/server/src/domain/model/groups.rs @@ -0,0 +1,53 @@ +//! `SeaORM` Entity. Generated by sea-orm-codegen 0.10.3 + +use sea_orm::entity::prelude::*; +use serde::{Deserialize, Serialize}; + +use crate::domain::handler::{GroupId, Uuid}; + +#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)] +#[sea_orm(table_name = "groups")] +pub struct Model { + #[sea_orm(primary_key, auto_increment = false)] + pub group_id: GroupId, + pub display_name: String, + pub creation_date: chrono::DateTime, + pub uuid: Uuid, +} + +impl From for crate::domain::handler::Group { + fn from(group: Model) -> Self { + Self { + id: group.group_id, + display_name: group.display_name, + creation_date: group.creation_date, + uuid: group.uuid, + users: vec![], + } + } +} + +impl From for crate::domain::handler::GroupDetails { + fn from(group: Model) -> Self { + Self { + group_id: group.group_id, + display_name: group.display_name, + creation_date: group.creation_date, + uuid: group.uuid, + } + } +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm(has_many = "super::memberships::Entity")] + Memberships, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Memberships.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/server/src/domain/model/jwt_refresh_storage.rs b/server/src/domain/model/jwt_refresh_storage.rs new file mode 100644 index 0000000..16b35ef --- /dev/null +++ b/server/src/domain/model/jwt_refresh_storage.rs @@ -0,0 +1,35 @@ +//! `SeaORM` Entity. Generated by sea-orm-codegen 0.10.3 + +use sea_orm::entity::prelude::*; +use serde::{Deserialize, Serialize}; + +use crate::domain::handler::UserId; + +#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)] +#[sea_orm(table_name = "jwt_refresh_storage")] +pub struct Model { + #[sea_orm(primary_key, auto_increment = false)] + pub refresh_token_hash: i64, + pub user_id: UserId, + pub expiry_date: chrono::DateTime, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::users::Entity", + from = "Column::UserId", + to = "super::users::Column::UserId", + on_update = "Cascade", + on_delete = "Cascade" + )] + Users, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Users.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/server/src/domain/model/jwt_storage.rs b/server/src/domain/model/jwt_storage.rs new file mode 100644 index 0000000..0df0144 --- /dev/null +++ b/server/src/domain/model/jwt_storage.rs @@ -0,0 +1,36 @@ +//! `SeaORM` Entity. Generated by sea-orm-codegen 0.10.3 + +use sea_orm::entity::prelude::*; +use serde::{Deserialize, Serialize}; + +use crate::domain::handler::UserId; + +#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)] +#[sea_orm(table_name = "jwt_storage")] +pub struct Model { + #[sea_orm(primary_key, auto_increment = false)] + pub jwt_hash: i64, + pub user_id: UserId, + pub expiry_date: chrono::DateTime, + pub blacklisted: bool, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::users::Entity", + from = "Column::UserId", + to = "super::users::Column::UserId", + on_update = "Cascade", + on_delete = "Cascade" + )] + Users, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Users.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/server/src/domain/model/memberships.rs b/server/src/domain/model/memberships.rs new file mode 100644 index 0000000..aff6b3e --- /dev/null +++ b/server/src/domain/model/memberships.rs @@ -0,0 +1,73 @@ +//! `SeaORM` Entity. Generated by sea-orm-codegen 0.10.3 + +use sea_orm::entity::prelude::*; +use serde::{Deserialize, Serialize}; + +use crate::domain::handler::{GroupId, UserId}; + +#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)] +#[sea_orm(table_name = "memberships")] +pub struct Model { + #[sea_orm(primary_key)] + pub user_id: UserId, + #[sea_orm(primary_key)] + pub group_id: GroupId, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::groups::Entity", + from = "Column::GroupId", + to = "super::groups::Column::GroupId", + on_update = "Cascade", + on_delete = "Cascade" + )] + Groups, + #[sea_orm( + belongs_to = "super::users::Entity", + from = "Column::UserId", + to = "super::users::Column::UserId", + on_update = "Cascade", + on_delete = "Cascade" + )] + Users, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Groups.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Users.def() + } +} + +#[derive(Debug)] +pub struct UserToGroup; +impl Linked for UserToGroup { + type FromEntity = super::User; + + type ToEntity = super::Group; + + fn link(&self) -> Vec { + vec![Relation::Users.def().rev(), Relation::Groups.def()] + } +} + +#[derive(Debug)] +pub struct GroupToUser; +impl Linked for GroupToUser { + type FromEntity = super::Group; + + type ToEntity = super::User; + + fn link(&self) -> Vec { + vec![Relation::Groups.def().rev(), Relation::Users.def()] + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/server/src/domain/model/mod.rs b/server/src/domain/model/mod.rs new file mode 100644 index 0000000..36b1060 --- /dev/null +++ b/server/src/domain/model/mod.rs @@ -0,0 +1,12 @@ +//! `SeaORM` Entity. Generated by sea-orm-codegen 0.10.3 + +pub mod prelude; + +pub mod groups; +pub mod jwt_refresh_storage; +pub mod jwt_storage; +pub mod memberships; +pub mod password_reset_tokens; +pub mod users; + +pub use prelude::*; diff --git a/server/src/domain/model/password_reset_tokens.rs b/server/src/domain/model/password_reset_tokens.rs new file mode 100644 index 0000000..03ee09b --- /dev/null +++ b/server/src/domain/model/password_reset_tokens.rs @@ -0,0 +1,35 @@ +//! `SeaORM` Entity. Generated by sea-orm-codegen 0.10.3 + +use sea_orm::entity::prelude::*; +use serde::{Deserialize, Serialize}; + +use crate::domain::handler::UserId; + +#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)] +#[sea_orm(table_name = "password_reset_tokens")] +pub struct Model { + #[sea_orm(primary_key, auto_increment = false)] + pub token: String, + pub user_id: UserId, + pub expiry_date: chrono::DateTime, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::users::Entity", + from = "Column::UserId", + to = "super::users::Column::UserId", + on_update = "Cascade", + on_delete = "Cascade" + )] + Users, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Users.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/server/src/domain/model/prelude.rs b/server/src/domain/model/prelude.rs new file mode 100644 index 0000000..a25ffe6 --- /dev/null +++ b/server/src/domain/model/prelude.rs @@ -0,0 +1,14 @@ +//! `SeaORM` Entity. Generated by sea-orm-codegen 0.10.3 + +pub use super::groups::Column as GroupColumn; +pub use super::groups::Entity as Group; +pub use super::jwt_refresh_storage::Column as JwtRefreshStorageColumn; +pub use super::jwt_refresh_storage::Entity as JwtRefreshStorage; +pub use super::jwt_storage::Column as JwtStorageColumn; +pub use super::jwt_storage::Entity as JwtStorage; +pub use super::memberships::Column as MembershipColumn; +pub use super::memberships::Entity as Membership; +pub use super::password_reset_tokens::Column as PasswordResetTokensColumn; +pub use super::password_reset_tokens::Entity as PasswordResetTokens; +pub use super::users::Column as UserColumn; +pub use super::users::Entity as User; diff --git a/server/src/domain/model/users.rs b/server/src/domain/model/users.rs new file mode 100644 index 0000000..6421084 --- /dev/null +++ b/server/src/domain/model/users.rs @@ -0,0 +1,134 @@ +//! `SeaORM` Entity. Generated by sea-orm-codegen 0.10.3 + +use sea_orm::entity::prelude::*; +use serde::{Deserialize, Serialize}; + +use crate::domain::handler::{JpegPhoto, UserId, Uuid}; + +#[derive(Copy, Clone, Default, Debug, DeriveEntity)] +pub struct Entity; + +#[derive(Clone, Debug, PartialEq, DeriveModel, Eq, Serialize, Deserialize, DeriveActiveModel)] +#[sea_orm(table_name = "users")] +pub struct Model { + #[sea_orm(primary_key, auto_increment = false)] + pub user_id: UserId, + pub email: String, + pub display_name: Option, + pub first_name: Option, + pub last_name: Option, + pub avatar: Option, + pub creation_date: chrono::DateTime, + pub password_hash: Option>, + pub totp_secret: Option, + pub mfa_type: Option, + pub uuid: Uuid, +} + +impl EntityName for Entity { + fn table_name(&self) -> &str { + "users" + } +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn, PartialEq, Eq, Serialize, Deserialize)] +pub enum Column { + UserId, + Email, + DisplayName, + FirstName, + LastName, + Avatar, + CreationDate, + PasswordHash, + TotpSecret, + MfaType, + Uuid, +} + +impl ColumnTrait for Column { + type EntityName = Entity; + + fn def(&self) -> ColumnDef { + match self { + Column::UserId => ColumnType::String(Some(255)), + Column::Email => ColumnType::String(Some(255)), + Column::DisplayName => ColumnType::String(Some(255)), + Column::FirstName => ColumnType::String(Some(255)), + Column::LastName => ColumnType::String(Some(255)), + Column::Avatar => ColumnType::Binary, + Column::CreationDate => ColumnType::DateTime, + Column::PasswordHash => ColumnType::Binary, + Column::TotpSecret => ColumnType::String(Some(64)), + Column::MfaType => ColumnType::String(Some(64)), + Column::Uuid => ColumnType::String(Some(36)), + } + .def() + } +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm(has_many = "super::memberships::Entity")] + Memberships, + #[sea_orm(has_many = "super::jwt_refresh_storage::Entity")] + JwtRefreshStorage, + #[sea_orm(has_many = "super::jwt_storage::Entity")] + JwtStorage, + #[sea_orm(has_many = "super::password_reset_tokens::Entity")] + PasswordResetTokens, +} + +#[derive(Copy, Clone, Debug, EnumIter, DerivePrimaryKey)] +pub enum PrimaryKey { + UserId, +} + +impl PrimaryKeyTrait for PrimaryKey { + type ValueType = UserId; + + fn auto_increment() -> bool { + false + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Memberships.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::JwtRefreshStorage.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::JwtStorage.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::PasswordResetTokens.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} + +impl From for crate::domain::handler::User { + fn from(user: Model) -> Self { + Self { + user_id: user.user_id, + email: user.email, + display_name: user.display_name, + first_name: user.first_name, + last_name: user.last_name, + creation_date: user.creation_date, + uuid: user.uuid, + avatar: user.avatar, + } + } +} diff --git a/server/src/domain/sql_backend_handler.rs b/server/src/domain/sql_backend_handler.rs index 0b8478d..7640510 100644 --- a/server/src/domain/sql_backend_handler.rs +++ b/server/src/domain/sql_backend_handler.rs @@ -5,11 +5,11 @@ use async_trait::async_trait; #[derive(Clone)] pub struct SqlBackendHandler { pub(crate) config: Configuration, - pub(crate) sql_pool: Pool, + pub(crate) sql_pool: DbConnection, } impl SqlBackendHandler { - pub fn new(config: Configuration, sql_pool: Pool) -> Self { + pub fn new(config: Configuration, sql_pool: DbConnection) -> Self { SqlBackendHandler { config, sql_pool } } } @@ -23,16 +23,23 @@ pub mod tests { use crate::domain::sql_tables::init_table; use crate::infra::configuration::ConfigurationBuilder; use lldap_auth::{opaque, registration}; + use sea_orm::Database; pub fn get_default_config() -> Configuration { ConfigurationBuilder::for_tests() } - pub async fn get_in_memory_db() -> Pool { - PoolOptions::new().connect("sqlite::memory:").await.unwrap() + pub async fn get_in_memory_db() -> DbConnection { + crate::infra::logging::init_for_tests(); + let mut sql_opt = sea_orm::ConnectOptions::new("sqlite::memory:".to_owned()); + sql_opt + .max_connections(1) + .sqlx_logging(true) + .sqlx_logging_level(log::LevelFilter::Debug); + Database::connect(sql_opt).await.unwrap() } - pub async fn get_initialized_db() -> Pool { + pub async fn get_initialized_db() -> DbConnection { let sql_pool = get_in_memory_db().await; init_table(&sql_pool).await.unwrap(); sql_pool diff --git a/server/src/domain/sql_group_backend_handler.rs b/server/src/domain/sql_group_backend_handler.rs index cec1b90..44e5863 100644 --- a/server/src/domain/sql_group_backend_handler.rs +++ b/server/src/domain/sql_group_backend_handler.rs @@ -1,21 +1,22 @@ +use crate::domain::handler::Uuid; + use super::{ - error::Result, + error::{DomainError, Result}, handler::{ Group, GroupBackendHandler, GroupDetails, GroupId, GroupRequestFilter, UpdateGroupRequest, - UserId, }, + model::{self, GroupColumn, MembershipColumn}, sql_backend_handler::SqlBackendHandler, - sql_tables::{DbQueryBuilder, Groups, Memberships}, }; use async_trait::async_trait; -use sea_query::{Cond, Expr, Iden, Order, Query, SimpleExpr}; -use sea_query_binder::SqlxBinder; -use sqlx::{query_as_with, query_with, FromRow, Row}; +use sea_orm::{ + ActiveModelTrait, ActiveValue, ColumnTrait, EntityTrait, QueryFilter, QueryOrder, QuerySelect, + QueryTrait, +}; +use sea_query::{Cond, IntoCondition, SimpleExpr}; use tracing::{debug, instrument}; -// Returns the condition for the SQL query, and whether it requires joining with the groups table. fn get_group_filter_expr(filter: GroupRequestFilter) -> Cond { - use sea_query::IntoCondition; use GroupRequestFilter::*; match filter { And(fs) => { @@ -35,23 +36,17 @@ fn get_group_filter_expr(filter: GroupRequestFilter) -> Cond { } } Not(f) => get_group_filter_expr(*f).not(), - DisplayName(name) => Expr::col((Groups::Table, Groups::DisplayName)) - .eq(name) - .into_condition(), - GroupId(id) => Expr::col((Groups::Table, Groups::GroupId)) - .eq(id.0) - .into_condition(), - Uuid(uuid) => Expr::col((Groups::Table, Groups::Uuid)) - .eq(uuid.to_string()) - .into_condition(), + DisplayName(name) => GroupColumn::DisplayName.eq(name).into_condition(), + GroupId(id) => GroupColumn::GroupId.eq(id.0).into_condition(), + Uuid(uuid) => GroupColumn::Uuid.eq(uuid.to_string()).into_condition(), // WHERE (group_id in (SELECT group_id FROM memberships WHERE user_id = user)) - Member(user) => Expr::col((Memberships::Table, Memberships::GroupId)) + Member(user) => GroupColumn::GroupId .in_subquery( - Query::select() - .column(Memberships::GroupId) - .from(Memberships::Table) - .cond_where(Expr::col(Memberships::UserId).eq(user)) - .take(), + model::Membership::find() + .select_only() + .column(MembershipColumn::GroupId) + .filter(MembershipColumn::UserId.eq(user)) + .into_query(), ) .into_condition(), } @@ -62,94 +57,67 @@ impl GroupBackendHandler for SqlBackendHandler { #[instrument(skip_all, level = "debug", ret, err)] async fn list_groups(&self, filters: Option) -> Result> { debug!(?filters); - let (query, values) = { - let mut query_builder = Query::select() - .column((Groups::Table, Groups::GroupId)) - .column(Groups::DisplayName) - .column(Groups::CreationDate) - .column(Groups::Uuid) - .column(Memberships::UserId) - .from(Groups::Table) - .left_join( - Memberships::Table, - Expr::tbl(Groups::Table, Groups::GroupId) - .equals(Memberships::Table, Memberships::GroupId), - ) - .order_by(Groups::DisplayName, Order::Asc) - .order_by(Memberships::UserId, Order::Asc) - .to_owned(); - - if let Some(filter) = filters { - query_builder.cond_where(get_group_filter_expr(filter)); - } - - query_builder.build_sqlx(DbQueryBuilder {}) - }; - debug!(%query); - - // For group_by. - use itertools::Itertools; - let mut groups = Vec::new(); - // The rows are returned sorted by display_name, equivalent to group_id. We group them by - // this key which gives us one element (`rows`) per group. - for (group_details, rows) in &query_with(&query, values) - .fetch_all(&self.sql_pool) - .await? + let results = model::Group::find() + // The order_by must be before find_with_related otherwise the primary order is by group_id. + .order_by_asc(GroupColumn::DisplayName) + .find_with_related(model::Membership) + .filter( + filters + .map(|f| { + GroupColumn::GroupId + .in_subquery( + model::Group::find() + .find_also_linked(model::memberships::GroupToUser) + .select_only() + .column(GroupColumn::GroupId) + .filter(get_group_filter_expr(f)) + .into_query(), + ) + .into_condition() + }) + .unwrap_or_else(|| SimpleExpr::Value(true.into()).into_condition()), + ) + .all(&self.sql_pool) + .await?; + Ok(results .into_iter() - .group_by(|row| GroupDetails::from_row(row).unwrap()) - { - groups.push(Group { - id: group_details.group_id, - display_name: group_details.display_name, - creation_date: group_details.creation_date, - uuid: group_details.uuid, - users: rows - .map(|row| row.get::(&*Memberships::UserId.to_string())) - // If a group has no users, an empty string is returned because of the left - // join. - .filter(|s| !s.as_str().is_empty()) - .collect(), - }); - } - Ok(groups) + .map(|(group, users)| { + let users: Vec<_> = users.into_iter().map(|u| u.user_id).collect(); + Group { + users, + ..group.into() + } + }) + .collect()) } #[instrument(skip_all, level = "debug", ret, err)] async fn get_group_details(&self, group_id: GroupId) -> Result { debug!(?group_id); - let (query, values) = Query::select() - .column(Groups::GroupId) - .column(Groups::DisplayName) - .column(Groups::CreationDate) - .column(Groups::Uuid) - .from(Groups::Table) - .cond_where(Expr::col(Groups::GroupId).eq(group_id)) - .build_sqlx(DbQueryBuilder {}); - debug!(%query); - - Ok(query_as_with::<_, GroupDetails, _>(&query, values) - .fetch_one(&self.sql_pool) - .await?) + model::Group::find_by_id(group_id) + .into_model::() + .one(&self.sql_pool) + .await? + .ok_or_else(|| DomainError::EntityNotFound(format!("{:?}", group_id))) } #[instrument(skip_all, level = "debug", err)] async fn update_group(&self, request: UpdateGroupRequest) -> Result<()> { debug!(?request.group_id); - let mut values = Vec::new(); - if let Some(display_name) = request.display_name { - values.push((Groups::DisplayName, display_name.into())); - } - if values.is_empty() { - return Ok(()); - } - let (query, values) = Query::update() - .table(Groups::Table) - .values(values) - .cond_where(Expr::col(Groups::GroupId).eq(request.group_id)) - .build_sqlx(DbQueryBuilder {}); - debug!(%query); - query_with(query.as_str(), values) - .execute(&self.sql_pool) + let update_group = model::groups::ActiveModel { + display_name: request + .display_name + .map(ActiveValue::Set) + .unwrap_or_default(), + ..Default::default() + }; + model::Group::update_many() + .set(update_group) + .filter(sea_orm::ColumnTrait::eq( + &GroupColumn::GroupId, + request.group_id, + )) + .exec(&self.sql_pool) .await?; Ok(()) } @@ -157,30 +125,29 @@ impl GroupBackendHandler for SqlBackendHandler { #[instrument(skip_all, level = "debug", ret, err)] async fn create_group(&self, group_name: &str) -> Result { debug!(?group_name); - crate::domain::sql_tables::create_group(group_name, &self.sql_pool).await?; - let (query, values) = Query::select() - .column(Groups::GroupId) - .from(Groups::Table) - .cond_where(Expr::col(Groups::DisplayName).eq(group_name)) - .build_sqlx(DbQueryBuilder {}); - debug!(%query); - let row = query_with(query.as_str(), values) - .fetch_one(&self.sql_pool) - .await?; - Ok(GroupId(row.get::(&*Groups::GroupId.to_string()))) + let now = chrono::Utc::now(); + let uuid = Uuid::from_name_and_date(group_name, &now); + let new_group = model::groups::ActiveModel { + display_name: ActiveValue::Set(group_name.to_owned()), + creation_date: ActiveValue::Set(now), + uuid: ActiveValue::Set(uuid), + ..Default::default() + }; + Ok(new_group.insert(&self.sql_pool).await?.group_id) } #[instrument(skip_all, level = "debug", err)] async fn delete_group(&self, group_id: GroupId) -> Result<()> { debug!(?group_id); - let (query, values) = Query::delete() - .from_table(Groups::Table) - .cond_where(Expr::col(Groups::GroupId).eq(group_id)) - .build_sqlx(DbQueryBuilder {}); - debug!(%query); - query_with(query.as_str(), values) - .execute(&self.sql_pool) + let res = model::Group::delete_by_id(group_id) + .exec(&self.sql_pool) .await?; + if res.rows_affected == 0 { + return Err(DomainError::EntityNotFound(format!( + "No such group: '{:?}'", + group_id + ))); + } Ok(()) } } @@ -188,7 +155,7 @@ impl GroupBackendHandler for SqlBackendHandler { #[cfg(test)] mod tests { use super::*; - use crate::domain::sql_backend_handler::tests::*; + use crate::domain::{handler::UserId, sql_backend_handler::tests::*}; async fn get_group_ids( handler: &SqlBackendHandler, @@ -203,12 +170,29 @@ mod tests { .collect::>() } + async fn get_group_names( + handler: &SqlBackendHandler, + filters: Option, + ) -> Vec { + handler + .list_groups(filters) + .await + .unwrap() + .into_iter() + .map(|g| g.display_name) + .collect::>() + } + #[tokio::test] async fn test_list_groups_no_filter() { let fixture = TestFixture::new().await; assert_eq!( - get_group_ids(&fixture.handler, None).await, - vec![fixture.groups[0], fixture.groups[2], fixture.groups[1]] + get_group_names(&fixture.handler, None).await, + vec![ + "Best Group".to_owned(), + "Empty Group".to_owned(), + "Worst Group".to_owned() + ] ); } @@ -216,15 +200,15 @@ mod tests { async fn test_list_groups_simple_filter() { let fixture = TestFixture::new().await; assert_eq!( - get_group_ids( + get_group_names( &fixture.handler, Some(GroupRequestFilter::Or(vec![ - GroupRequestFilter::DisplayName("Empty Group".to_string()), + GroupRequestFilter::DisplayName("Empty Group".to_owned()), GroupRequestFilter::Member(UserId::new("bob")), ])) ) .await, - vec![fixture.groups[0], fixture.groups[2]] + vec!["Best Group".to_owned(), "Empty Group".to_owned()] ); } @@ -236,7 +220,7 @@ mod tests { &fixture.handler, Some(GroupRequestFilter::And(vec![ GroupRequestFilter::Not(Box::new(GroupRequestFilter::DisplayName( - "value".to_string() + "value".to_owned() ))), GroupRequestFilter::GroupId(fixture.groups[0]), ])) @@ -273,7 +257,7 @@ mod tests { .handler .update_group(UpdateGroupRequest { group_id: fixture.groups[0], - display_name: Some("Awesomest Group".to_string()), + display_name: Some("Awesomest Group".to_owned()), }) .await .unwrap(); @@ -288,6 +272,10 @@ mod tests { #[tokio::test] async fn test_delete_group() { let fixture = TestFixture::new().await; + assert_eq!( + get_group_ids(&fixture.handler, None).await, + vec![fixture.groups[0], fixture.groups[2], fixture.groups[1]] + ); fixture .handler .delete_group(fixture.groups[0]) diff --git a/server/src/domain/sql_migrations.rs b/server/src/domain/sql_migrations.rs index 85d2b82..7d252b8 100644 --- a/server/src/domain/sql_migrations.rs +++ b/server/src/domain/sql_migrations.rs @@ -1,302 +1,338 @@ use super::{ handler::{GroupId, UserId, Uuid}, - sql_tables::{ - DbQueryBuilder, DbRow, Groups, Memberships, Metadata, Pool, SchemaVersion, Users, - }, + sql_tables::{DbConnection, SchemaVersion}, }; -use sea_query::*; -use sea_query_binder::SqlxBinder; -use sqlx::Row; -use tracing::{debug, warn}; +use sea_orm::{ConnectionTrait, FromQueryResult, Statement}; +use sea_query::{ColumnDef, Expr, ForeignKey, ForeignKeyAction, Iden, Query, Table, Value}; +use serde::{Deserialize, Serialize}; +use tracing::{instrument, warn}; -pub async fn create_group(group_name: &str, pool: &Pool) -> sqlx::Result<()> { - let now = chrono::Utc::now(); - let (query, values) = Query::insert() - .into_table(Groups::Table) - .columns(vec![ - Groups::DisplayName, - Groups::CreationDate, - Groups::Uuid, - ]) - .values_panic(vec![ - group_name.into(), - now.naive_utc().into(), - Uuid::from_name_and_date(group_name, &now).into(), - ]) - .build_sqlx(DbQueryBuilder {}); - debug!(%query); - sqlx::query_with(query.as_str(), values) - .execute(pool) - .await - .map(|_| ()) +#[derive(Iden, PartialEq, Eq, Debug, Serialize, Deserialize, Clone)] +pub enum Users { + Table, + UserId, + Email, + DisplayName, + FirstName, + LastName, + Avatar, + CreationDate, + PasswordHash, + TotpSecret, + MfaType, + Uuid, } -pub async fn get_schema_version(pool: &Pool) -> Option { - sqlx::query( - &Query::select() - .from(Metadata::Table) - .column(Metadata::Version) - .to_string(DbQueryBuilder {}), +#[derive(Iden, PartialEq, Eq, Debug, Serialize, Deserialize, Clone)] +pub enum Groups { + Table, + GroupId, + DisplayName, + CreationDate, + Uuid, +} + +#[derive(Iden)] +pub enum Memberships { + Table, + UserId, + GroupId, +} + +// Metadata about the SQL DB. +#[derive(Iden)] +pub enum Metadata { + Table, + // Which version of the schema we're at. + Version, +} + +#[derive(FromQueryResult, PartialEq, Eq, Debug)] +pub struct JustSchemaVersion { + pub version: SchemaVersion, +} + +#[instrument(skip_all, level = "debug", ret)] +pub async fn get_schema_version(pool: &DbConnection) -> Option { + JustSchemaVersion::find_by_statement( + pool.get_database_backend().build( + Query::select() + .from(Metadata::Table) + .column(Metadata::Version), + ), ) - .map(|row: DbRow| row.get::(&*Metadata::Version.to_string())) - .fetch_one(pool) + .one(pool) .await .ok() + .flatten() + .map(|j| j.version) } -pub async fn upgrade_to_v1(pool: &Pool) -> sqlx::Result<()> { +pub async fn upgrade_to_v1(pool: &DbConnection) -> std::result::Result<(), sea_orm::DbErr> { + let builder = pool.get_database_backend(); // SQLite needs this pragma to be turned on. Other DB might not understand this, so ignore the // error. - let _ = sqlx::query("PRAGMA foreign_keys = ON").execute(pool).await; - sqlx::query( - &Table::create() - .table(Users::Table) - .if_not_exists() - .col( - ColumnDef::new(Users::UserId) - .string_len(255) - .not_null() - .primary_key(), - ) - .col(ColumnDef::new(Users::Email).string_len(255).not_null()) - .col( - ColumnDef::new(Users::DisplayName) - .string_len(255) - .not_null(), - ) - .col(ColumnDef::new(Users::FirstName).string_len(255).not_null()) - .col(ColumnDef::new(Users::LastName).string_len(255).not_null()) - .col(ColumnDef::new(Users::Avatar).binary()) - .col(ColumnDef::new(Users::CreationDate).date_time().not_null()) - .col(ColumnDef::new(Users::PasswordHash).binary()) - .col(ColumnDef::new(Users::TotpSecret).string_len(64)) - .col(ColumnDef::new(Users::MfaType).string_len(64)) - .col(ColumnDef::new(Users::Uuid).string_len(36).not_null()) - .to_string(DbQueryBuilder {}), + let _ = pool + .execute(Statement::from_string( + builder, + "PRAGMA foreign_keys = ON".to_owned(), + )) + .await; + + pool.execute( + builder.build( + Table::create() + .table(Users::Table) + .if_not_exists() + .col( + ColumnDef::new(Users::UserId) + .string_len(255) + .not_null() + .primary_key(), + ) + .col(ColumnDef::new(Users::Email).string_len(255).not_null()) + .col( + ColumnDef::new(Users::DisplayName) + .string_len(255) + .not_null(), + ) + .col(ColumnDef::new(Users::FirstName).string_len(255)) + .col(ColumnDef::new(Users::LastName).string_len(255)) + .col(ColumnDef::new(Users::Avatar).binary()) + .col(ColumnDef::new(Users::CreationDate).date_time().not_null()) + .col(ColumnDef::new(Users::PasswordHash).binary()) + .col(ColumnDef::new(Users::TotpSecret).string_len(64)) + .col(ColumnDef::new(Users::MfaType).string_len(64)) + .col(ColumnDef::new(Users::Uuid).string_len(36).not_null()), + ), ) - .execute(pool) .await?; - sqlx::query( - &Table::create() - .table(Groups::Table) - .if_not_exists() - .col( - ColumnDef::new(Groups::GroupId) - .integer() - .not_null() - .primary_key(), - ) - .col( - ColumnDef::new(Groups::DisplayName) - .string_len(255) - .unique_key() - .not_null(), - ) - .col(ColumnDef::new(Users::CreationDate).date_time().not_null()) - .col(ColumnDef::new(Users::Uuid).string_len(36).not_null()) - .to_string(DbQueryBuilder {}), + pool.execute( + builder.build( + Table::create() + .table(Groups::Table) + .if_not_exists() + .col( + ColumnDef::new(Groups::GroupId) + .integer() + .not_null() + .primary_key(), + ) + .col( + ColumnDef::new(Groups::DisplayName) + .string_len(255) + .unique_key() + .not_null(), + ) + .col(ColumnDef::new(Users::CreationDate).date_time().not_null()) + .col(ColumnDef::new(Users::Uuid).string_len(36).not_null()), + ), ) - .execute(pool) .await?; // If the creation_date column doesn't exist, add it. - if sqlx::query( - &Table::alter() - .table(Groups::Table) - .add_column( - ColumnDef::new(Groups::CreationDate) - .date_time() - .not_null() - .default(chrono::Utc::now().naive_utc()), - ) - .to_string(DbQueryBuilder {}), - ) - .execute(pool) - .await - .is_ok() + if pool + .execute( + builder.build( + Table::alter().table(Groups::Table).add_column( + ColumnDef::new(Groups::CreationDate) + .date_time() + .not_null() + .default(chrono::Utc::now().naive_utc()), + ), + ), + ) + .await + .is_ok() { warn!("`creation_date` column not found in `groups`, creating it"); } // If the uuid column doesn't exist, add it. - if sqlx::query( - &Table::alter() - .table(Groups::Table) - .add_column( - ColumnDef::new(Groups::Uuid) - .string_len(36) - .not_null() - .default(""), - ) - .to_string(DbQueryBuilder {}), - ) - .execute(pool) - .await - .is_ok() + if pool + .execute( + builder.build( + Table::alter().table(Groups::Table).add_column( + ColumnDef::new(Groups::Uuid) + .string_len(36) + .not_null() + .default(""), + ), + ), + ) + .await + .is_ok() { warn!("`uuid` column not found in `groups`, creating it"); - for row in sqlx::query( - &Query::select() - .from(Groups::Table) - .column(Groups::GroupId) - .column(Groups::DisplayName) - .column(Groups::CreationDate) - .to_string(DbQueryBuilder {}), + #[derive(FromQueryResult)] + struct ShortGroupDetails { + group_id: GroupId, + display_name: String, + creation_date: chrono::DateTime, + } + for result in ShortGroupDetails::find_by_statement( + builder.build( + Query::select() + .from(Groups::Table) + .column(Groups::GroupId) + .column(Groups::DisplayName) + .column(Groups::CreationDate), + ), ) - .fetch_all(pool) + .all(pool) .await? { - sqlx::query( - &Query::update() - .table(Groups::Table) - .value( - Groups::Uuid, - Uuid::from_name_and_date( - &row.get::(&*Groups::DisplayName.to_string()), - &row.get::, _>( - &*Groups::CreationDate.to_string(), - ), + pool.execute( + builder.build( + Query::update() + .table(Groups::Table) + .value( + Groups::Uuid, + Value::from(Uuid::from_name_and_date( + &result.display_name, + &result.creation_date, + )), ) - .into(), - ) - .and_where( - Expr::col(Groups::GroupId) - .eq(row.get::(&*Groups::GroupId.to_string())), - ) - .to_string(DbQueryBuilder {}), + .and_where(Expr::col(Groups::GroupId).eq(result.group_id)), + ), ) - .execute(pool) .await?; } } - if sqlx::query( - &Table::alter() - .table(Users::Table) - .add_column( - ColumnDef::new(Users::Uuid) - .string_len(36) - .not_null() - .default(""), - ) - .to_string(DbQueryBuilder {}), - ) - .execute(pool) - .await - .is_ok() + if pool + .execute( + builder.build( + Table::alter().table(Users::Table).add_column( + ColumnDef::new(Users::Uuid) + .string_len(36) + .not_null() + .default(""), + ), + ), + ) + .await + .is_ok() { warn!("`uuid` column not found in `users`, creating it"); - for row in sqlx::query( - &Query::select() - .from(Users::Table) - .column(Users::UserId) - .column(Users::CreationDate) - .to_string(DbQueryBuilder {}), + #[derive(FromQueryResult)] + struct ShortUserDetails { + user_id: UserId, + creation_date: chrono::DateTime, + } + for result in ShortUserDetails::find_by_statement( + builder.build( + Query::select() + .from(Users::Table) + .column(Users::UserId) + .column(Users::CreationDate), + ), ) - .fetch_all(pool) + .all(pool) .await? { - let user_id = row.get::(&*Users::UserId.to_string()); - sqlx::query( - &Query::update() - .table(Users::Table) - .value( - Users::Uuid, - Uuid::from_name_and_date( - user_id.as_str(), - &row.get::, _>( - &*Users::CreationDate.to_string(), - ), + pool.execute( + builder.build( + Query::update() + .table(Users::Table) + .value( + Users::Uuid, + Value::from(Uuid::from_name_and_date( + result.user_id.as_str(), + &result.creation_date, + )), ) - .into(), - ) - .and_where(Expr::col(Users::UserId).eq(user_id)) - .to_string(DbQueryBuilder {}), + .and_where(Expr::col(Users::UserId).eq(result.user_id)), + ), ) - .execute(pool) .await?; } } - sqlx::query( - &Table::create() - .table(Memberships::Table) - .if_not_exists() - .col( - ColumnDef::new(Memberships::UserId) - .string_len(255) - .not_null(), - ) - .col(ColumnDef::new(Memberships::GroupId).integer().not_null()) - .foreign_key( - ForeignKey::create() - .name("MembershipUserForeignKey") - .from(Memberships::Table, Memberships::UserId) - .to(Users::Table, Users::UserId) - .on_delete(ForeignKeyAction::Cascade) - .on_update(ForeignKeyAction::Cascade), - ) - .foreign_key( - ForeignKey::create() - .name("MembershipGroupForeignKey") - .from(Memberships::Table, Memberships::GroupId) - .to(Groups::Table, Groups::GroupId) - .on_delete(ForeignKeyAction::Cascade) - .on_update(ForeignKeyAction::Cascade), - ) - .to_string(DbQueryBuilder {}), + pool.execute( + builder.build( + Table::create() + .table(Memberships::Table) + .if_not_exists() + .col( + ColumnDef::new(Memberships::UserId) + .string_len(255) + .not_null(), + ) + .col(ColumnDef::new(Memberships::GroupId).integer().not_null()) + .foreign_key( + ForeignKey::create() + .name("MembershipUserForeignKey") + .from(Memberships::Table, Memberships::UserId) + .to(Users::Table, Users::UserId) + .on_delete(ForeignKeyAction::Cascade) + .on_update(ForeignKeyAction::Cascade), + ) + .foreign_key( + ForeignKey::create() + .name("MembershipGroupForeignKey") + .from(Memberships::Table, Memberships::GroupId) + .to(Groups::Table, Groups::GroupId) + .on_delete(ForeignKeyAction::Cascade) + .on_update(ForeignKeyAction::Cascade), + ), + ), ) - .execute(pool) .await?; - if sqlx::query( - &Query::select() - .from(Groups::Table) - .column(Groups::DisplayName) - .cond_where(Expr::col(Groups::DisplayName).eq("lldap_readonly")) - .to_string(DbQueryBuilder {}), - ) - .fetch_one(pool) - .await - .is_ok() - { - sqlx::query( - &Query::update() - .table(Groups::Table) - .values(vec![(Groups::DisplayName, "lldap_password_manager".into())]) - .cond_where(Expr::col(Groups::DisplayName).eq("lldap_readonly")) - .to_string(DbQueryBuilder {}), + if pool + .query_one( + builder.build( + Query::select() + .from(Groups::Table) + .column(Groups::DisplayName) + .cond_where(Expr::col(Groups::DisplayName).eq("lldap_readonly")), + ), + ) + .await + .is_ok() + { + pool.execute( + builder.build( + Query::update() + .table(Groups::Table) + .values(vec![(Groups::DisplayName, "lldap_password_manager".into())]) + .cond_where(Expr::col(Groups::DisplayName).eq("lldap_readonly")), + ), ) - .execute(pool) .await?; - create_group("lldap_strict_readonly", pool).await? } - sqlx::query( - &Table::create() - .table(Metadata::Table) - .if_not_exists() - .col(ColumnDef::new(Metadata::Version).tiny_integer().not_null()) - .to_string(DbQueryBuilder {}), + pool.execute( + builder.build( + Table::create() + .table(Metadata::Table) + .if_not_exists() + .col(ColumnDef::new(Metadata::Version).tiny_integer()), + ), ) - .execute(pool) .await?; - sqlx::query( - &Query::insert() - .into_table(Metadata::Table) - .columns(vec![Metadata::Version]) - .values_panic(vec![SchemaVersion(1).into()]) - .to_string(DbQueryBuilder {}), + pool.execute( + builder.build( + Query::insert() + .into_table(Metadata::Table) + .columns(vec![Metadata::Version]) + .values_panic(vec![SchemaVersion(1).into()]), + ), ) - .execute(pool) .await?; + assert_eq!(get_schema_version(pool).await.unwrap().0, 1); + Ok(()) } -pub async fn migrate_from_version(_pool: &Pool, version: SchemaVersion) -> anyhow::Result<()> { +pub async fn migrate_from_version( + _pool: &DbConnection, + version: SchemaVersion, +) -> anyhow::Result<()> { if version.0 > 1 { anyhow::bail!("DB version downgrading is not supported"); } diff --git a/server/src/domain/sql_opaque_handler.rs b/server/src/domain/sql_opaque_handler.rs index 3df4aed..dda433d 100644 --- a/server/src/domain/sql_opaque_handler.rs +++ b/server/src/domain/sql_opaque_handler.rs @@ -1,16 +1,14 @@ use super::{ error::{DomainError, Result}, handler::{BindRequest, LoginHandler, UserId}, + model::{self, UserColumn}, opaque_handler::{login, registration, OpaqueHandler}, sql_backend_handler::SqlBackendHandler, - sql_tables::{DbQueryBuilder, Users}, }; use async_trait::async_trait; use lldap_auth::opaque; -use sea_query::{Expr, Iden, Query}; -use sea_query_binder::SqlxBinder; +use sea_orm::{ActiveValue, EntityTrait, FromQueryResult, QuerySelect}; use secstr::SecUtf8; -use sqlx::Row; use tracing::{debug, instrument}; type SqlOpaqueHandler = SqlBackendHandler; @@ -50,39 +48,19 @@ impl SqlBackendHandler { } #[instrument(skip_all, level = "debug", err)] - async fn get_password_file_for_user( - &self, - username: &str, - ) -> Result> { + async fn get_password_file_for_user(&self, user_id: UserId) -> Result>> { + #[derive(FromQueryResult)] + struct OnlyPasswordHash { + password_hash: Option>, + } // Fetch the previously registered password file from the DB. - let password_file_bytes = { - let (query, values) = Query::select() - .column(Users::PasswordHash) - .from(Users::Table) - .cond_where(Expr::col(Users::UserId).eq(username)) - .build_sqlx(DbQueryBuilder {}); - if let Some(row) = sqlx::query_with(query.as_str(), values) - .fetch_optional(&self.sql_pool) - .await? - { - if let Some(bytes) = - row.get::>, _>(&*Users::PasswordHash.to_string()) - { - bytes - } else { - // No password set. - return Ok(None); - } - } else { - // No such user. - return Ok(None); - } - }; - opaque::server::ServerRegistration::deserialize(&password_file_bytes) - .map(Option::Some) - .map_err(|_| { - DomainError::InternalError(format!("Corrupted password file for {}", username)) - }) + Ok(model::User::find_by_id(user_id) + .select_only() + .column(UserColumn::PasswordHash) + .into_model::() + .one(&self.sql_pool) + .await? + .and_then(|u| u.password_hash)) } } @@ -90,33 +68,25 @@ impl SqlBackendHandler { impl LoginHandler for SqlBackendHandler { #[instrument(skip_all, level = "debug", err)] async fn bind(&self, request: BindRequest) -> Result<()> { - let (query, values) = Query::select() - .column(Users::PasswordHash) - .from(Users::Table) - .cond_where(Expr::col(Users::UserId).eq(&request.name)) - .build_sqlx(DbQueryBuilder {}); - if let Ok(row) = sqlx::query_with(&query, values) - .fetch_one(&self.sql_pool) - .await + if let Some(password_hash) = self + .get_password_file_for_user(request.name.clone()) + .await? { - if let Some(password_hash) = - row.get::>, _>(&*Users::PasswordHash.to_string()) - { - if let Err(e) = passwords_match( - &password_hash, - &request.password, - self.config.get_server_setup(), - &request.name, - ) { - debug!(r#"Invalid password for "{}": {}"#, &request.name, e); - } else { - return Ok(()); - } + if let Err(e) = passwords_match( + &password_hash, + &request.password, + self.config.get_server_setup(), + &request.name, + ) { + debug!(r#"Invalid password for "{}": {}"#, &request.name, e); } else { - debug!(r#"User "{}" has no password"#, &request.name); + return Ok(()); } } else { - debug!(r#"No user found for "{}""#, &request.name); + debug!( + r#"User "{}" doesn't exist or has no password"#, + &request.name + ); } Err(DomainError::AuthenticationError(format!( " for user '{}'", @@ -132,7 +102,18 @@ impl OpaqueHandler for SqlOpaqueHandler { &self, request: login::ClientLoginStartRequest, ) -> Result { - let maybe_password_file = self.get_password_file_for_user(&request.username).await?; + let maybe_password_file = self + .get_password_file_for_user(UserId::new(&request.username)) + .await? + .map(|bytes| { + opaque::server::ServerRegistration::deserialize(&bytes).map_err(|_| { + DomainError::InternalError(format!( + "Corrupted password file for {}", + &request.username + )) + }) + }) + .transpose()?; let mut rng = rand::rngs::OsRng; // Get the CredentialResponse for the user, or a dummy one if no user/no password. @@ -210,17 +191,16 @@ impl OpaqueHandler for SqlOpaqueHandler { let password_file = opaque::server::registration::get_password_file(request.registration_upload); - { - // Set the user password to the new password. - let (update_query, values) = Query::update() - .table(Users::Table) - .value(Users::PasswordHash, password_file.serialize().into()) - .cond_where(Expr::col(Users::UserId).eq(username)) - .build_sqlx(DbQueryBuilder {}); - sqlx::query_with(update_query.as_str(), values) - .execute(&self.sql_pool) - .await?; - } + // Set the user password to the new password. + let user_update = model::users::ActiveModel { + user_id: ActiveValue::Set(UserId::new(&username)), + password_hash: ActiveValue::Set(Some(password_file.serialize())), + ..Default::default() + }; + model::User::update_many() + .set(user_update) + .exec(&self.sql_pool) + .await?; Ok(()) } } diff --git a/server/src/domain/sql_tables.rs b/server/src/domain/sql_tables.rs index b409a1b..1e09ae3 100644 --- a/server/src/domain/sql_tables.rs +++ b/server/src/domain/sql_tables.rs @@ -1,39 +1,116 @@ use super::{ - handler::{GroupId, UserId, Uuid}, + handler::{GroupId, JpegPhoto, UserId, Uuid}, sql_migrations::{get_schema_version, migrate_from_version, upgrade_to_v1}, }; -use sea_query::*; -use serde::{Deserialize, Serialize}; +use sea_orm::{DbErr, Value}; -pub use super::sql_migrations::create_group; +pub type DbConnection = sea_orm::DatabaseConnection; -pub type Pool = sqlx::sqlite::SqlitePool; -pub type PoolOptions = sqlx::sqlite::SqlitePoolOptions; -pub type DbRow = sqlx::sqlite::SqliteRow; -pub type DbQueryBuilder = SqliteQueryBuilder; - -#[derive(Copy, PartialEq, Eq, Debug, Clone, sqlx::FromRow, sqlx::Type)] -#[sqlx(transparent)] +#[derive(Copy, PartialEq, Eq, Debug, Clone)] pub struct SchemaVersion(pub u8); -impl From for Value { +impl sea_orm::TryGetable for SchemaVersion { + fn try_get( + res: &sea_orm::QueryResult, + pre: &str, + col: &str, + ) -> Result { + Ok(SchemaVersion(u8::try_get(res, pre, col)?)) + } +} + +impl From for sea_orm::Value { fn from(group_id: GroupId) -> Self { group_id.0.into() } } -impl From for sea_query::Value { +impl sea_orm::TryGetable for GroupId { + fn try_get( + res: &sea_orm::QueryResult, + pre: &str, + col: &str, + ) -> Result { + Ok(GroupId(i32::try_get(res, pre, col)?)) + } +} + +impl sea_orm::sea_query::value::ValueType for GroupId { + fn try_from(v: sea_orm::Value) -> Result { + Ok(GroupId(::try_from( + v, + )?)) + } + + fn type_name() -> String { + "GroupId".to_owned() + } + + fn array_type() -> sea_orm::sea_query::ArrayType { + sea_orm::sea_query::ArrayType::Int + } + + fn column_type() -> sea_orm::sea_query::ColumnType { + sea_orm::sea_query::ColumnType::Integer(None) + } +} + +impl sea_orm::TryFromU64 for GroupId { + fn try_from_u64(n: u64) -> Result { + Ok(GroupId(i32::try_from_u64(n)?)) + } +} + +impl From for sea_orm::Value { fn from(user_id: UserId) -> Self { user_id.into_string().into() } } -impl From<&UserId> for sea_query::Value { +impl From<&UserId> for sea_orm::Value { fn from(user_id: &UserId) -> Self { user_id.as_str().into() } } +impl sea_orm::TryGetable for UserId { + fn try_get( + res: &sea_orm::QueryResult, + pre: &str, + col: &str, + ) -> Result { + Ok(UserId::new(&String::try_get(res, pre, col)?)) + } +} + +impl sea_orm::TryFromU64 for UserId { + fn try_from_u64(_n: u64) -> Result { + Err(sea_orm::DbErr::ConvertFromU64( + "UserId cannot be constructed from u64", + )) + } +} + +impl sea_orm::sea_query::value::ValueType for UserId { + fn try_from(v: sea_orm::Value) -> Result { + Ok(UserId::new( + ::try_from(v)?.as_str(), + )) + } + + fn type_name() -> String { + "UserId".to_owned() + } + + fn array_type() -> sea_orm::sea_query::ArrayType { + sea_orm::sea_query::ArrayType::String + } + + fn column_type() -> sea_orm::sea_query::ColumnType { + sea_orm::sea_query::ColumnType::String(Some(255)) + } +} + impl From for sea_query::Value { fn from(uuid: Uuid) -> Self { uuid.as_str().into() @@ -46,57 +123,84 @@ impl From<&Uuid> for sea_query::Value { } } +impl sea_orm::TryGetable for JpegPhoto { + fn try_get( + res: &sea_orm::QueryResult, + pre: &str, + col: &str, + ) -> Result { + >>::try_from(Vec::::try_get(res, pre, col)?) + .map_err(|e| { + sea_orm::TryGetError::DbErr(DbErr::TryIntoErr { + from: "[u8]", + into: "JpegPhoto", + source: e.into(), + }) + }) + } +} + +impl sea_orm::sea_query::value::ValueType for JpegPhoto { + fn try_from(v: sea_orm::Value) -> Result { + >::try_from( + as sea_orm::sea_query::ValueType>::try_from(v)?.as_slice(), + ) + .map_err(|_| sea_orm::sea_query::ValueTypeErr {}) + } + + fn type_name() -> String { + "JpegPhoto".to_owned() + } + + fn array_type() -> sea_orm::sea_query::ArrayType { + sea_orm::sea_query::ArrayType::Bytes + } + + fn column_type() -> sea_orm::sea_query::ColumnType { + sea_orm::sea_query::ColumnType::Binary(sea_orm::sea_query::BlobSize::Long) + } +} + +impl sea_orm::sea_query::Nullable for JpegPhoto { + fn null() -> sea_orm::Value { + JpegPhoto::null().into() + } +} + +impl sea_orm::entity::IntoActiveValue for JpegPhoto { + fn into_active_value(self) -> sea_orm::ActiveValue { + sea_orm::ActiveValue::Set(self) + } +} + +impl sea_orm::sea_query::value::ValueType for Uuid { + fn try_from(v: sea_orm::Value) -> Result { + >::try_from( + ::try_from(v)?.as_str(), + ) + .map_err(|_| sea_orm::sea_query::ValueTypeErr {}) + } + + fn type_name() -> String { + "Uuid".to_owned() + } + + fn array_type() -> sea_orm::sea_query::ArrayType { + sea_orm::sea_query::ArrayType::String + } + + fn column_type() -> sea_orm::sea_query::ColumnType { + sea_orm::sea_query::ColumnType::String(Some(36)) + } +} + impl From for Value { fn from(version: SchemaVersion) -> Self { version.0.into() } } -#[derive(Iden, PartialEq, Eq, Debug, Serialize, Deserialize, Clone)] -pub enum Users { - Table, - UserId, - Email, - DisplayName, - FirstName, - LastName, - Avatar, - CreationDate, - PasswordHash, - TotpSecret, - MfaType, - Uuid, -} - -pub type UserColumn = Users; - -#[derive(Iden, PartialEq, Eq, Debug, Serialize, Deserialize, Clone)] -pub enum Groups { - Table, - GroupId, - DisplayName, - CreationDate, - Uuid, -} - -pub type GroupColumn = Groups; - -#[derive(Iden)] -pub enum Memberships { - Table, - UserId, - GroupId, -} - -// Metadata about the SQL DB. -#[derive(Iden)] -pub enum Metadata { - Table, - // Which version of the schema we're at. - Version, -} - -pub async fn init_table(pool: &Pool) -> anyhow::Result<()> { +pub async fn init_table(pool: &DbConnection) -> anyhow::Result<()> { let version = { if let Some(version) = get_schema_version(pool).await { version @@ -111,33 +215,55 @@ pub async fn init_table(pool: &Pool) -> anyhow::Result<()> { #[cfg(test)] mod tests { + use crate::domain::sql_migrations; + use super::*; use chrono::prelude::*; - use sqlx::{Column, Row}; + use sea_orm::{ConnectionTrait, Database, DbBackend, FromQueryResult}; + + async fn get_in_memory_db() -> DbConnection { + let mut sql_opt = sea_orm::ConnectOptions::new("sqlite::memory:".to_owned()); + sql_opt.max_connections(1).sqlx_logging(false); + Database::connect(sql_opt).await.unwrap() + } + + fn raw_statement(sql: &str) -> sea_orm::Statement { + sea_orm::Statement::from_string(DbBackend::Sqlite, sql.to_owned()) + } #[tokio::test] async fn test_init_table() { - let sql_pool = PoolOptions::new().connect("sqlite::memory:").await.unwrap(); + let sql_pool = get_in_memory_db().await; init_table(&sql_pool).await.unwrap(); - sqlx::query(r#"INSERT INTO users + sql_pool.execute(raw_statement( + r#"INSERT INTO users (user_id, email, display_name, first_name, last_name, creation_date, password_hash, uuid) - VALUES ("bôb", "böb@bob.bob", "Bob Bobbersön", "Bob", "Bobberson", "1970-01-01 00:00:00", "bob00", "abc")"#).execute(&sql_pool).await.unwrap(); - let row = - sqlx::query(r#"SELECT display_name, creation_date FROM users WHERE user_id = "bôb""#) - .fetch_one(&sql_pool) - .await - .unwrap(); - assert_eq!(row.column(0).name(), "display_name"); - assert_eq!(row.get::("display_name"), "Bob Bobbersön"); + VALUES ("bôb", "böb@bob.bob", "Bob Bobbersön", "Bob", "Bobberson", "1970-01-01 00:00:00", "bob00", "abc")"#)).await.unwrap(); + #[derive(FromQueryResult, PartialEq, Eq, Debug)] + struct ShortUserDetails { + display_name: String, + creation_date: chrono::DateTime, + } + let result = ShortUserDetails::find_by_statement(raw_statement( + r#"SELECT display_name, creation_date FROM users WHERE user_id = "bôb""#, + )) + .one(&sql_pool) + .await + .unwrap() + .unwrap(); assert_eq!( - row.get::, _>("creation_date"), - Utc.timestamp(0, 0), + result, + ShortUserDetails { + display_name: "Bob Bobbersön".to_owned(), + creation_date: Utc.timestamp_opt(0, 0).unwrap() + } ); } #[tokio::test] async fn test_already_init_table() { - let sql_pool = PoolOptions::new().connect("sqlite::memory:").await.unwrap(); + crate::infra::logging::init_for_tests(); + let sql_pool = get_in_memory_db().await; init_table(&sql_pool).await.unwrap(); init_table(&sql_pool).await.unwrap(); } @@ -145,89 +271,111 @@ mod tests { #[tokio::test] async fn test_migrate_tables() { // Test that we add the column creation_date to groups and uuid to users and groups. - let sql_pool = PoolOptions::new().connect("sqlite::memory:").await.unwrap(); - sqlx::query(r#"CREATE TABLE users ( user_id TEXT , creation_date TEXT);"#) - .execute(&sql_pool) + let sql_pool = get_in_memory_db().await; + sql_pool + .execute(raw_statement( + r#"CREATE TABLE users ( user_id TEXT , creation_date TEXT);"#, + )) .await .unwrap(); - sqlx::query( - r#"INSERT INTO users (user_id, creation_date) + sql_pool + .execute(raw_statement( + r#"INSERT INTO users (user_id, creation_date) VALUES ("bôb", "1970-01-01 00:00:00")"#, - ) - .execute(&sql_pool) - .await - .unwrap(); - sqlx::query(r#"CREATE TABLE groups ( group_id INTEGER PRIMARY KEY, display_name TEXT );"#) - .execute(&sql_pool) + )) .await .unwrap(); - sqlx::query( - r#"INSERT INTO groups (display_name) + sql_pool + .execute(raw_statement( + r#"CREATE TABLE groups ( group_id INTEGER PRIMARY KEY, display_name TEXT );"#, + )) + .await + .unwrap(); + sql_pool + .execute(raw_statement( + r#"INSERT INTO groups (display_name) VALUES ("lldap_admin"), ("lldap_readonly")"#, - ) - .execute(&sql_pool) - .await - .unwrap(); + )) + .await + .unwrap(); init_table(&sql_pool).await.unwrap(); - sqlx::query( - r#"INSERT INTO groups (display_name, creation_date, uuid) + sql_pool + .execute(raw_statement( + r#"INSERT INTO groups (display_name, creation_date, uuid) VALUES ("test", "1970-01-01 00:00:00", "abc")"#, - ) - .execute(&sql_pool) - .await - .unwrap(); + )) + .await + .unwrap(); + #[derive(FromQueryResult, PartialEq, Eq, Debug)] + struct JustUuid { + uuid: Uuid, + } assert_eq!( - sqlx::query(r#"SELECT uuid FROM users"#) - .fetch_all(&sql_pool) + JustUuid::find_by_statement(raw_statement(r#"SELECT uuid FROM users"#)) + .all(&sql_pool) .await - .unwrap() - .into_iter() - .map(|row| row.get::("uuid")) - .collect::>(), - vec![crate::uuid!("a02eaf13-48a7-30f6-a3d4-040ff7c52b04")] + .unwrap(), + vec![JustUuid { + uuid: crate::uuid!("a02eaf13-48a7-30f6-a3d4-040ff7c52b04") + }] ); + #[derive(FromQueryResult, PartialEq, Eq, Debug)] + struct ShortGroupDetails { + group_id: GroupId, + display_name: String, + } assert_eq!( - sqlx::query(r#"SELECT group_id, display_name FROM groups"#) - .fetch_all(&sql_pool) - .await - .unwrap() - .into_iter() - .map(|row| ( - row.get::("group_id"), - row.get::("display_name") - )) - .collect::>(), + ShortGroupDetails::find_by_statement(raw_statement( + r#"SELECT group_id, display_name, creation_date FROM groups"# + )) + .all(&sql_pool) + .await + .unwrap(), vec![ - (GroupId(1), "lldap_admin".to_string()), - (GroupId(2), "lldap_password_manager".to_string()), - (GroupId(3), "lldap_strict_readonly".to_string()), - (GroupId(4), "test".to_string()) + ShortGroupDetails { + group_id: GroupId(1), + display_name: "lldap_admin".to_string() + }, + ShortGroupDetails { + group_id: GroupId(2), + display_name: "lldap_password_manager".to_string() + }, + ShortGroupDetails { + group_id: GroupId(3), + display_name: "test".to_string() + } ] ); assert_eq!( - sqlx::query(r#"SELECT version FROM metadata"#) - .map(|row: DbRow| row.get::("version")) - .fetch_one(&sql_pool) - .await - .unwrap(), - SchemaVersion(1) + sql_migrations::JustSchemaVersion::find_by_statement(raw_statement( + r#"SELECT version FROM metadata"# + )) + .one(&sql_pool) + .await + .unwrap() + .unwrap(), + sql_migrations::JustSchemaVersion { + version: SchemaVersion(1) + } ); } #[tokio::test] async fn test_too_high_version() { - let sql_pool = PoolOptions::new().connect("sqlite::memory:").await.unwrap(); - sqlx::query(r#"CREATE TABLE metadata ( version INTEGER);"#) - .execute(&sql_pool) + let sql_pool = get_in_memory_db().await; + sql_pool + .execute(raw_statement( + r#"CREATE TABLE metadata ( version INTEGER);"#, + )) .await .unwrap(); - sqlx::query( - r#"INSERT INTO metadata (version) + sql_pool + .execute(raw_statement( + r#"INSERT INTO metadata (version) VALUES (127)"#, - ) - .execute(&sql_pool) - .await - .unwrap(); + )) + .await + .unwrap(); assert!(init_table(&sql_pool).await.is_err()); } } diff --git a/server/src/domain/sql_user_backend_handler.rs b/server/src/domain/sql_user_backend_handler.rs index ef3298f..8b4e65e 100644 --- a/server/src/domain/sql_user_backend_handler.rs +++ b/server/src/domain/sql_user_backend_handler.rs @@ -1,136 +1,68 @@ use super::{ - error::Result, + error::{DomainError, Result}, handler::{ CreateUserRequest, GroupDetails, GroupId, UpdateUserRequest, User, UserAndGroups, UserBackendHandler, UserId, UserRequestFilter, Uuid, }, + model::{self, GroupColumn, UserColumn}, sql_backend_handler::SqlBackendHandler, - sql_tables::{DbQueryBuilder, Groups, Memberships, Users}, }; use async_trait::async_trait; -use sea_query::{Alias, Cond, Expr, Iden, Order, Query, SimpleExpr}; -use sea_query_binder::{SqlxBinder, SqlxValues}; -use sqlx::{query_as_with, query_with, FromRow, Row}; +use sea_orm::{ + entity::IntoActiveValue, + sea_query::{Cond, Expr, IntoCondition, SimpleExpr}, + ActiveModelTrait, ActiveValue, ColumnTrait, EntityTrait, ModelTrait, QueryFilter, QueryOrder, + QuerySelect, QueryTrait, Set, +}; +use sea_query::{Alias, IntoColumnRef}; use std::collections::HashSet; use tracing::{debug, instrument}; -struct RequiresGroup(bool); - -// Returns the condition for the SQL query, and whether it requires joining with the groups table. -fn get_user_filter_expr(filter: UserRequestFilter) -> (RequiresGroup, Cond) { - use sea_query::IntoCondition; +fn get_user_filter_expr(filter: UserRequestFilter) -> Cond { use UserRequestFilter::*; + let group_table = Alias::new("r1"); fn get_repeated_filter( fs: Vec, condition: Cond, default_value: bool, - ) -> (RequiresGroup, Cond) { + ) -> Cond { if fs.is_empty() { - return ( - RequiresGroup(false), - SimpleExpr::Value(default_value.into()).into_condition(), - ); + SimpleExpr::Value(default_value.into()).into_condition() + } else { + fs.into_iter() + .map(get_user_filter_expr) + .fold(condition, Cond::add) } - let mut requires_group = false; - let filter = fs.into_iter().fold(condition, |c, f| { - let (group, filters) = get_user_filter_expr(f); - requires_group |= group.0; - c.add(filters) - }); - (RequiresGroup(requires_group), filter) } match filter { And(fs) => get_repeated_filter(fs, Cond::all(), true), Or(fs) => get_repeated_filter(fs, Cond::any(), false), - Not(f) => { - let (requires_group, filters) = get_user_filter_expr(*f); - (requires_group, filters.not()) - } - UserId(user_id) => ( - RequiresGroup(false), - Expr::col((Users::Table, Users::UserId)) - .eq(user_id) - .into_condition(), - ), - Equality(s1, s2) => ( - RequiresGroup(false), - if s1 == Users::UserId { + Not(f) => get_user_filter_expr(*f).not(), + UserId(user_id) => ColumnTrait::eq(&UserColumn::UserId, user_id).into_condition(), + Equality(s1, s2) => { + if s1 == UserColumn::UserId { panic!("User id should be wrapped") } else { - Expr::col((Users::Table, s1)).eq(s2).into_condition() - }, - ), - MemberOf(group) => ( - RequiresGroup(true), - Expr::col((Groups::Table, Groups::DisplayName)) - .eq(group) - .into_condition(), - ), - MemberOfId(group_id) => ( - RequiresGroup(true), - Expr::col((Groups::Table, Groups::GroupId)) - .eq(group_id) - .into_condition(), - ), + ColumnTrait::eq(&s1, s2).into_condition() + } + } + MemberOf(group) => Expr::col((group_table, GroupColumn::DisplayName)) + .eq(group) + .into_condition(), + MemberOfId(group_id) => Expr::col((group_table, GroupColumn::GroupId)) + .eq(group_id) + .into_condition(), } } - -fn get_list_users_query( - filters: Option, - get_groups: bool, -) -> (String, SqlxValues) { - let mut query_builder = Query::select() - .column((Users::Table, Users::UserId)) - .column(Users::Email) - .column((Users::Table, Users::DisplayName)) - .column(Users::FirstName) - .column(Users::LastName) - .column(Users::Avatar) - .column((Users::Table, Users::CreationDate)) - .column((Users::Table, Users::Uuid)) - .from(Users::Table) - .order_by((Users::Table, Users::UserId), Order::Asc) - .to_owned(); - let add_join_group_tables = |builder: &mut sea_query::SelectStatement| { - builder - .left_join( - Memberships::Table, - Expr::tbl(Users::Table, Users::UserId) - .equals(Memberships::Table, Memberships::UserId), - ) - .left_join( - Groups::Table, - Expr::tbl(Memberships::Table, Memberships::GroupId) - .equals(Groups::Table, Groups::GroupId), - ); - }; - if get_groups { - add_join_group_tables(&mut query_builder); - query_builder - .column((Groups::Table, Groups::GroupId)) - .expr_as( - Expr::col((Groups::Table, Groups::DisplayName)), - Alias::new("group_display_name"), - ) - .expr_as( - Expr::col((Groups::Table, Groups::CreationDate)), - sea_query::Alias::new("group_creation_date"), - ) - .expr_as( - Expr::col((Groups::Table, Groups::Uuid)), - sea_query::Alias::new("group_uuid"), - ) - .order_by(Alias::new("group_display_name"), Order::Asc); +fn to_value(opt_name: &Option) -> ActiveValue> { + match opt_name { + None => ActiveValue::NotSet, + Some(name) => ActiveValue::Set(if name.is_empty() { + None + } else { + Some(name.to_owned()) + }), } - if let Some(filter) = filters { - let (RequiresGroup(requires_group), condition) = get_user_filter_expr(filter); - query_builder.cond_where(condition); - if requires_group && !get_groups { - add_join_group_tables(&mut query_builder); - } - } - - query_builder.build_sqlx(DbQueryBuilder {}) } #[async_trait] @@ -141,95 +73,86 @@ impl UserBackendHandler for SqlBackendHandler { filters: Option, get_groups: bool, ) -> Result> { - debug!(?filters, get_groups); - let (query, values) = get_list_users_query(filters, get_groups); - - debug!(%query); - - // For group_by. - use itertools::Itertools; - let mut users = Vec::new(); - // The rows are returned sorted by user_id. We group them by - // this key which gives us one element (`rows`) per group. - for (_, rows) in &query_with(&query, values) - .fetch_all(&self.sql_pool) - .await? - .into_iter() - .group_by(|row| row.get::(&*Users::UserId.to_string())) - { - let mut rows = rows.peekable(); - users.push(UserAndGroups { - user: User::from_row(rows.peek().unwrap()).unwrap(), - groups: if get_groups { - Some( - rows.filter_map(|row| { - let display_name = row.get::("group_display_name"); - if display_name.is_empty() { - None - } else { - Some(GroupDetails { - group_id: row.get::(&*Groups::GroupId.to_string()), - display_name, - creation_date: row.get::, _>( - "group_creation_date", - ), - uuid: row.get::("group_uuid"), - }) - } - }) - .collect(), - ) - } else { - None - }, - }); + debug!(?filters); + let query = model::User::find() + .filter( + filters + .map(|f| { + UserColumn::UserId + .in_subquery( + model::User::find() + .find_also_linked(model::memberships::UserToGroup) + .select_only() + .column(UserColumn::UserId) + .filter(get_user_filter_expr(f)) + .into_query(), + ) + .into_condition() + }) + .unwrap_or_else(|| SimpleExpr::Value(true.into()).into_condition()), + ) + .order_by_asc(UserColumn::UserId); + if !get_groups { + Ok(query + .into_model::() + .all(&self.sql_pool) + .await? + .into_iter() + .map(|u| UserAndGroups { + user: u, + groups: None, + }) + .collect()) + } else { + let results = query + //find_with_linked? + .find_also_linked(model::memberships::UserToGroup) + .order_by_asc(SimpleExpr::Column( + (Alias::new("r1"), GroupColumn::GroupId).into_column_ref(), + )) + .all(&self.sql_pool) + .await?; + use itertools::Itertools; + Ok(results + .iter() + .group_by(|(u, _)| u) + .into_iter() + .map(|(user, groups)| { + let groups: Vec<_> = groups + .into_iter() + .flat_map(|(_, g)| g) + .map(|g| GroupDetails::from(g.clone())) + .collect(); + UserAndGroups { + user: user.clone().into(), + groups: Some(groups), + } + }) + .collect()) } - Ok(users) } #[instrument(skip_all, level = "debug", ret)] async fn get_user_details(&self, user_id: &UserId) -> Result { debug!(?user_id); - let (query, values) = Query::select() - .column(Users::UserId) - .column(Users::Email) - .column(Users::DisplayName) - .column(Users::FirstName) - .column(Users::LastName) - .column(Users::Avatar) - .column(Users::CreationDate) - .column(Users::Uuid) - .from(Users::Table) - .cond_where(Expr::col(Users::UserId).eq(user_id)) - .build_sqlx(DbQueryBuilder {}); - debug!(%query); - - Ok(query_as_with::<_, User, _>(query.as_str(), values) - .fetch_one(&self.sql_pool) - .await?) + model::User::find_by_id(user_id.to_owned()) + .into_model::() + .one(&self.sql_pool) + .await? + .ok_or_else(|| DomainError::EntityNotFound(user_id.to_string())) } #[instrument(skip_all, level = "debug", ret, err)] async fn get_user_groups(&self, user_id: &UserId) -> Result> { debug!(?user_id); - let (query, values) = Query::select() - .column((Groups::Table, Groups::GroupId)) - .column(Groups::DisplayName) - .column(Groups::CreationDate) - .column(Groups::Uuid) - .from(Groups::Table) - .inner_join( - Memberships::Table, - Expr::tbl(Groups::Table, Groups::GroupId) - .equals(Memberships::Table, Memberships::GroupId), - ) - .cond_where(Expr::col(Memberships::UserId).eq(user_id)) - .build_sqlx(DbQueryBuilder {}); - debug!(%query); - + let user = model::User::find_by_id(user_id.to_owned()) + .one(&self.sql_pool) + .await? + .ok_or_else(|| DomainError::EntityNotFound(user_id.to_string()))?; Ok(HashSet::from_iter( - query_as_with::<_, GroupDetails, _>(&query, values) - .fetch_all(&self.sql_pool) + user.find_linked(model::memberships::UserToGroup) + .into_model::() + .all(&self.sql_pool) .await?, )) } @@ -237,70 +160,41 @@ impl UserBackendHandler for SqlBackendHandler { #[instrument(skip_all, level = "debug", err)] async fn create_user(&self, request: CreateUserRequest) -> Result<()> { debug!(user_id = ?request.user_id); - let columns = vec![ - Users::UserId, - Users::Email, - Users::DisplayName, - Users::FirstName, - Users::LastName, - Users::Avatar, - Users::CreationDate, - Users::Uuid, - ]; let now = chrono::Utc::now(); let uuid = Uuid::from_name_and_date(request.user_id.as_str(), &now); - let values = vec![ - request.user_id.into(), - request.email.into(), - request.display_name.unwrap_or_default().into(), - request.first_name.unwrap_or_default().into(), - request.last_name.unwrap_or_default().into(), - request.avatar.unwrap_or_default().into(), - now.naive_utc().into(), - uuid.into(), - ]; - let (query, values) = Query::insert() - .into_table(Users::Table) - .columns(columns) - .values_panic(values) - .build_sqlx(DbQueryBuilder {}); - debug!(%query); - query_with(query.as_str(), values) - .execute(&self.sql_pool) - .await?; + let new_user = model::users::ActiveModel { + user_id: Set(request.user_id), + email: Set(request.email), + display_name: to_value(&request.display_name), + first_name: to_value(&request.first_name), + last_name: to_value(&request.last_name), + avatar: request.avatar.into_active_value(), + creation_date: ActiveValue::Set(now), + uuid: ActiveValue::Set(uuid), + ..Default::default() + }; + new_user.insert(&self.sql_pool).await?; Ok(()) } #[instrument(skip_all, level = "debug", err)] async fn update_user(&self, request: UpdateUserRequest) -> Result<()> { debug!(user_id = ?request.user_id); - let mut values = Vec::new(); - if let Some(email) = request.email { - values.push((Users::Email, email.into())); - } - if let Some(display_name) = request.display_name { - values.push((Users::DisplayName, display_name.into())); - } - if let Some(first_name) = request.first_name { - values.push((Users::FirstName, first_name.into())); - } - if let Some(last_name) = request.last_name { - values.push((Users::LastName, last_name.into())); - } - if let Some(avatar) = request.avatar { - values.push((Users::Avatar, avatar.into())); - } - if values.is_empty() { - return Ok(()); - } - let (query, values) = Query::update() - .table(Users::Table) - .values(values) - .cond_where(Expr::col(Users::UserId).eq(request.user_id)) - .build_sqlx(DbQueryBuilder {}); - debug!(%query); - query_with(query.as_str(), values) - .execute(&self.sql_pool) + let update_user = model::users::ActiveModel { + email: request.email.map(ActiveValue::Set).unwrap_or_default(), + display_name: to_value(&request.display_name), + first_name: to_value(&request.first_name), + last_name: to_value(&request.last_name), + avatar: request.avatar.into_active_value(), + ..Default::default() + }; + model::User::update_many() + .set(update_user) + .filter(sea_orm::ColumnTrait::eq( + &UserColumn::UserId, + request.user_id, + )) + .exec(&self.sql_pool) .await?; Ok(()) } @@ -308,47 +202,41 @@ impl UserBackendHandler for SqlBackendHandler { #[instrument(skip_all, level = "debug", err)] async fn delete_user(&self, user_id: &UserId) -> Result<()> { debug!(?user_id); - let (query, values) = Query::delete() - .from_table(Users::Table) - .cond_where(Expr::col(Users::UserId).eq(user_id)) - .build_sqlx(DbQueryBuilder {}); - debug!(%query); - query_with(query.as_str(), values) - .execute(&self.sql_pool) + let res = model::User::delete_by_id(user_id.clone()) + .exec(&self.sql_pool) .await?; + if res.rows_affected == 0 { + return Err(DomainError::EntityNotFound(format!( + "No such user: '{}'", + user_id + ))); + } Ok(()) } #[instrument(skip_all, level = "debug", err)] async fn add_user_to_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()> { debug!(?user_id, ?group_id); - let (query, values) = Query::insert() - .into_table(Memberships::Table) - .columns(vec![Memberships::UserId, Memberships::GroupId]) - .values_panic(vec![user_id.into(), group_id.into()]) - .build_sqlx(DbQueryBuilder {}); - debug!(%query); - query_with(query.as_str(), values) - .execute(&self.sql_pool) - .await?; + let new_membership = model::memberships::ActiveModel { + user_id: ActiveValue::Set(user_id.clone()), + group_id: ActiveValue::Set(group_id), + }; + new_membership.insert(&self.sql_pool).await?; Ok(()) } #[instrument(skip_all, level = "debug", err)] async fn remove_user_from_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()> { debug!(?user_id, ?group_id); - let (query, values) = Query::delete() - .from_table(Memberships::Table) - .cond_where( - Cond::all() - .add(Expr::col(Memberships::GroupId).eq(group_id)) - .add(Expr::col(Memberships::UserId).eq(user_id)), - ) - .build_sqlx(DbQueryBuilder {}); - debug!(%query); - query_with(query.as_str(), values) - .execute(&self.sql_pool) + let res = model::Membership::delete_by_id((user_id.clone(), group_id)) + .exec(&self.sql_pool) .await?; + if res.rows_affected == 0 { + return Err(DomainError::EntityNotFound(format!( + "No such membership: '{}' -> {:?}", + user_id, group_id + ))); + } Ok(()) } } @@ -357,7 +245,8 @@ impl UserBackendHandler for SqlBackendHandler { mod tests { use super::*; use crate::domain::{ - handler::JpegPhoto, sql_backend_handler::tests::*, sql_tables::UserColumn, + handler::{JpegPhoto, UserColumn}, + sql_backend_handler::tests::*, }; #[tokio::test] @@ -526,9 +415,13 @@ mod tests { .map(|u| { ( u.user.user_id.to_string(), - u.user.display_name.to_string(), + u.user + .display_name + .as_deref() + .unwrap_or("") + .to_owned(), u.groups - .unwrap() + .unwrap_or_default() .into_iter() .map(|g| g.group_id) .collect::>(), @@ -571,7 +464,7 @@ mod tests { ( u.user.creation_date, u.groups - .unwrap() + .unwrap_or_default() .into_iter() .map(|g| g.creation_date) .collect::>(), @@ -685,7 +578,7 @@ mod tests { display_name: Some("display_name".to_string()), first_name: Some("first_name".to_string()), last_name: Some("last_name".to_string()), - avatar: Some(JpegPhoto::default()), + avatar: Some(JpegPhoto::for_tests()), }) .await .unwrap(); @@ -696,10 +589,10 @@ mod tests { .await .unwrap(); assert_eq!(user.email, "email"); - assert_eq!(user.display_name, "display_name"); - assert_eq!(user.first_name, "first_name"); - assert_eq!(user.last_name, "last_name"); - assert_eq!(user.avatar, JpegPhoto::default()); + assert_eq!(user.display_name.unwrap(), "display_name"); + assert_eq!(user.first_name.unwrap(), "first_name"); + assert_eq!(user.last_name.unwrap(), "last_name"); + assert_eq!(user.avatar, Some(JpegPhoto::for_tests())); } #[tokio::test] @@ -722,9 +615,10 @@ mod tests { .get_user_details(&UserId::new("bob")) .await .unwrap(); - assert_eq!(user.display_name, "display bob"); - assert_eq!(user.first_name, "first_name"); - assert_eq!(user.last_name, ""); + assert_eq!(user.display_name.unwrap(), "display bob"); + assert_eq!(user.first_name.unwrap(), "first_name"); + assert_eq!(user.last_name, None); + assert_eq!(user.avatar, None); } #[tokio::test] diff --git a/server/src/infra/auth_service.rs b/server/src/infra/auth_service.rs index cdbafb5..207d1fb 100644 --- a/server/src/infra/auth_service.rs +++ b/server/src/infra/auth_service.rs @@ -26,7 +26,7 @@ use crate::domain::handler::UserRequestFilter; use crate::{ domain::{ error::DomainError, - handler::{BackendHandler, BindRequest, GroupDetails, LoginHandler, UserId}, + handler::{BackendHandler, BindRequest, GroupDetails, LoginHandler, UserColumn, UserId}, opaque_handler::OpaqueHandler, }, infra::{ @@ -149,10 +149,7 @@ where .list_users( Some(UserRequestFilter::Or(vec![ UserRequestFilter::UserId(UserId::new(user_string)), - UserRequestFilter::Equality( - crate::domain::sql_tables::UserColumn::Email, - user_string.to_owned(), - ), + UserRequestFilter::Equality(UserColumn::Email, user_string.to_owned()), ])), false, ) @@ -174,7 +171,9 @@ where Some(token) => token, }; if let Err(e) = super::mail::send_password_reset_email( - &user.display_name, + user.display_name + .as_deref() + .unwrap_or_else(|| user.user_id.as_str()), &user.email, &token, &data.server_url, diff --git a/server/src/infra/db_cleaner.rs b/server/src/infra/db_cleaner.rs index a6d4bc9..9e87bde 100644 --- a/server/src/infra/db_cleaner.rs +++ b/server/src/infra/db_cleaner.rs @@ -1,18 +1,17 @@ -use crate::{ - domain::sql_tables::{DbQueryBuilder, Pool}, - infra::jwt_sql_tables::{JwtRefreshStorage, JwtStorage}, +use crate::domain::{ + model::{self, JwtRefreshStorageColumn, JwtStorageColumn, PasswordResetTokensColumn}, + sql_tables::DbConnection, }; -use actix::prelude::*; -use chrono::Local; +use actix::prelude::{Actor, AsyncContext, Context}; use cron::Schedule; -use sea_query::{Expr, Query}; +use sea_orm::{ColumnTrait, EntityTrait, QueryFilter}; use std::{str::FromStr, time::Duration}; -use tracing::{debug, error, info, instrument}; +use tracing::{error, info, instrument}; // Define actor pub struct Scheduler { schedule: Schedule, - sql_pool: Pool, + sql_pool: DbConnection, } // Provide Actor implementation for our actor @@ -33,7 +32,7 @@ impl Actor for Scheduler { } impl Scheduler { - pub fn new(cron_expression: &str, sql_pool: Pool) -> Self { + pub fn new(cron_expression: &str, sql_pool: DbConnection) -> Self { let schedule = Schedule::from_str(cron_expression).unwrap(); Self { schedule, sql_pool } } @@ -48,33 +47,35 @@ impl Scheduler { } #[instrument(skip_all)] - async fn cleanup_db(sql_pool: Pool) { + async fn cleanup_db(sql_pool: DbConnection) { info!("Cleaning DB"); - let query = Query::delete() - .from_table(JwtRefreshStorage::Table) - .and_where(Expr::col(JwtRefreshStorage::ExpiryDate).lt(Local::now().naive_utc())) - .to_string(DbQueryBuilder {}); - debug!(%query); - if let Err(e) = sqlx::query(&query).execute(&sql_pool).await { + if let Err(e) = model::JwtRefreshStorage::delete_many() + .filter(JwtRefreshStorageColumn::ExpiryDate.lt(chrono::Utc::now().naive_utc())) + .exec(&sql_pool) + .await + { error!("DB error while cleaning up JWT refresh tokens: {}", e); - }; - if let Err(e) = sqlx::query( - &Query::delete() - .from_table(JwtStorage::Table) - .and_where(Expr::col(JwtStorage::ExpiryDate).lt(Local::now().naive_utc())) - .to_string(DbQueryBuilder {}), - ) - .execute(&sql_pool) - .await + } + if let Err(e) = model::JwtStorage::delete_many() + .filter(JwtStorageColumn::ExpiryDate.lt(chrono::Utc::now().naive_utc())) + .exec(&sql_pool) + .await { error!("DB error while cleaning up JWT storage: {}", e); }; + if let Err(e) = model::PasswordResetTokens::delete_many() + .filter(PasswordResetTokensColumn::ExpiryDate.lt(chrono::Utc::now().naive_utc())) + .exec(&sql_pool) + .await + { + error!("DB error while cleaning up password reset tokens: {}", e); + }; info!("DB cleaned!"); } fn duration_until_next(&self) -> Duration { - let now = Local::now(); - let next = self.schedule.upcoming(Local).next().unwrap(); + let now = chrono::Utc::now(); + let next = self.schedule.upcoming(chrono::Utc).next().unwrap(); let duration_until = next.signed_duration_since(now); duration_until.to_std().unwrap() } diff --git a/server/src/infra/graphql/query.rs b/server/src/infra/graphql/query.rs index d8e221a..4754ee6 100644 --- a/server/src/infra/graphql/query.rs +++ b/server/src/infra/graphql/query.rs @@ -1,7 +1,6 @@ use crate::domain::{ - handler::{BackendHandler, GroupDetails, GroupId, UserId}, + handler::{BackendHandler, GroupDetails, GroupId, UserColumn, UserId}, ldap::utils::map_user_field, - sql_tables::UserColumn, }; use juniper::{graphql_object, FieldResult, GraphQLInputObject}; use serde::{Deserialize, Serialize}; @@ -214,19 +213,19 @@ impl User { } fn display_name(&self) -> &str { - &self.user.display_name + self.user.display_name.as_deref().unwrap_or("") } fn first_name(&self) -> &str { - &self.user.first_name + self.user.first_name.as_deref().unwrap_or("") } fn last_name(&self) -> &str { - &self.user.last_name + self.user.last_name.as_deref().unwrap_or("") } - fn avatar(&self) -> String { - (&self.user.avatar).into() + fn avatar(&self) -> Option { + self.user.avatar.as_ref().map(String::from) } fn creation_date(&self) -> chrono::DateTime { @@ -392,7 +391,7 @@ mod tests { Ok(DomainUser { user_id: UserId::new("bob"), email: "bob@bobbers.on".to_string(), - creation_date: chrono::Utc.timestamp_millis(42), + creation_date: chrono::Utc.timestamp_millis_opt(42).unwrap(), uuid: crate::uuid!("b1a2a3a4b1b2c1c2d1d2d3d4d5d6d7d8"), ..Default::default() }) diff --git a/server/src/infra/jwt_sql_tables.rs b/server/src/infra/jwt_sql_tables.rs index 737b91e..b3443c6 100644 --- a/server/src/infra/jwt_sql_tables.rs +++ b/server/src/infra/jwt_sql_tables.rs @@ -1,6 +1,7 @@ -use sea_query::*; +use sea_orm::ConnectionTrait; +use sea_query::{ColumnDef, ForeignKey, ForeignKeyAction, Iden, Table}; -pub use crate::domain::sql_tables::*; +pub use crate::domain::{sql_migrations::Users, sql_tables::DbConnection}; /// Contains the refresh tokens for a given user. #[derive(Iden)] @@ -31,110 +32,112 @@ pub enum PasswordResetTokens { } /// This needs to be initialized after the domain tables are. -pub async fn init_table(pool: &Pool) -> sqlx::Result<()> { - sqlx::query( - &Table::create() - .table(JwtRefreshStorage::Table) - .if_not_exists() - .col( - ColumnDef::new(JwtRefreshStorage::RefreshTokenHash) - .big_integer() - .not_null() - .primary_key(), - ) - .col( - ColumnDef::new(JwtRefreshStorage::UserId) - .string_len(255) - .not_null(), - ) - .col( - ColumnDef::new(JwtRefreshStorage::ExpiryDate) - .date_time() - .not_null(), - ) - .foreign_key( - ForeignKey::create() - .name("JwtRefreshStorageUserForeignKey") - .from(JwtRefreshStorage::Table, JwtRefreshStorage::UserId) - .to(Users::Table, Users::UserId) - .on_delete(ForeignKeyAction::Cascade) - .on_update(ForeignKeyAction::Cascade), - ) - .to_string(DbQueryBuilder {}), +pub async fn init_table(pool: &DbConnection) -> std::result::Result<(), sea_orm::DbErr> { + let builder = pool.get_database_backend(); + + pool.execute( + builder.build( + Table::create() + .table(JwtRefreshStorage::Table) + .if_not_exists() + .col( + ColumnDef::new(JwtRefreshStorage::RefreshTokenHash) + .big_integer() + .not_null() + .primary_key(), + ) + .col( + ColumnDef::new(JwtRefreshStorage::UserId) + .string_len(255) + .not_null(), + ) + .col( + ColumnDef::new(JwtRefreshStorage::ExpiryDate) + .date_time() + .not_null(), + ) + .foreign_key( + ForeignKey::create() + .name("JwtRefreshStorageUserForeignKey") + .from(JwtRefreshStorage::Table, JwtRefreshStorage::UserId) + .to(Users::Table, Users::UserId) + .on_delete(ForeignKeyAction::Cascade) + .on_update(ForeignKeyAction::Cascade), + ), + ), ) - .execute(pool) .await?; - sqlx::query( - &Table::create() - .table(JwtStorage::Table) - .if_not_exists() - .col( - ColumnDef::new(JwtStorage::JwtHash) - .big_integer() - .not_null() - .primary_key(), - ) - .col( - ColumnDef::new(JwtStorage::UserId) - .string_len(255) - .not_null(), - ) - .col( - ColumnDef::new(JwtStorage::ExpiryDate) - .date_time() - .not_null(), - ) - .col( - ColumnDef::new(JwtStorage::Blacklisted) - .boolean() - .default(false) - .not_null(), - ) - .foreign_key( - ForeignKey::create() - .name("JwtStorageUserForeignKey") - .from(JwtStorage::Table, JwtStorage::UserId) - .to(Users::Table, Users::UserId) - .on_delete(ForeignKeyAction::Cascade) - .on_update(ForeignKeyAction::Cascade), - ) - .to_string(DbQueryBuilder {}), + pool.execute( + builder.build( + Table::create() + .table(JwtStorage::Table) + .if_not_exists() + .col( + ColumnDef::new(JwtStorage::JwtHash) + .big_integer() + .not_null() + .primary_key(), + ) + .col( + ColumnDef::new(JwtStorage::UserId) + .string_len(255) + .not_null(), + ) + .col( + ColumnDef::new(JwtStorage::ExpiryDate) + .date_time() + .not_null(), + ) + .col( + ColumnDef::new(JwtStorage::Blacklisted) + .boolean() + .default(false) + .not_null(), + ) + .foreign_key( + ForeignKey::create() + .name("JwtStorageUserForeignKey") + .from(JwtStorage::Table, JwtStorage::UserId) + .to(Users::Table, Users::UserId) + .on_delete(ForeignKeyAction::Cascade) + .on_update(ForeignKeyAction::Cascade), + ), + ), ) - .execute(pool) .await?; - sqlx::query( - &Table::create() - .table(PasswordResetTokens::Table) - .if_not_exists() - .col( - ColumnDef::new(PasswordResetTokens::Token) - .string_len(255) - .not_null() - .primary_key(), - ) - .col( - ColumnDef::new(PasswordResetTokens::UserId) - .string_len(255) - .not_null(), - ) - .col( - ColumnDef::new(PasswordResetTokens::ExpiryDate) - .date_time() - .not_null(), - ) - .foreign_key( - ForeignKey::create() - .name("PasswordResetTokensUserForeignKey") - .from(PasswordResetTokens::Table, PasswordResetTokens::UserId) - .to(Users::Table, Users::UserId) - .on_delete(ForeignKeyAction::Cascade) - .on_update(ForeignKeyAction::Cascade), - ) - .to_string(DbQueryBuilder {}), + pool.execute( + builder.build( + Table::create() + .table(PasswordResetTokens::Table) + .if_not_exists() + .col( + ColumnDef::new(PasswordResetTokens::Token) + .string_len(255) + .not_null() + .primary_key(), + ) + .col( + ColumnDef::new(PasswordResetTokens::UserId) + .string_len(255) + .not_null(), + ) + .col( + ColumnDef::new(PasswordResetTokens::ExpiryDate) + .date_time() + .not_null(), + ) + .foreign_key( + ForeignKey::create() + .name("PasswordResetTokensUserForeignKey") + .from(PasswordResetTokens::Table, PasswordResetTokens::UserId) + .to(Users::Table, Users::UserId) + .on_delete(ForeignKeyAction::Cascade) + .on_update(ForeignKeyAction::Cascade), + ), + ), ) - .execute(pool) .await?; Ok(()) diff --git a/server/src/infra/ldap_handler.rs b/server/src/infra/ldap_handler.rs index 382e5f1..94b7e93 100644 --- a/server/src/infra/ldap_handler.rs +++ b/server/src/infra/ldap_handler.rs @@ -569,7 +569,7 @@ impl LdapHandler anyhow::Result<()> { .init(); Ok(()) } + +#[cfg(test)] +pub fn init_for_tests() { + if let Err(e) = tracing_subscriber::FmtSubscriber::builder() + .with_max_level(tracing::Level::DEBUG) + .with_test_writer() + .try_init() + { + log::warn!("Could not set up test logging: {:#}", e); + } +} diff --git a/server/src/infra/sql_backend_handler.rs b/server/src/infra/sql_backend_handler.rs index 64c43ba..c87d865 100644 --- a/server/src/infra/sql_backend_handler.rs +++ b/server/src/infra/sql_backend_handler.rs @@ -1,10 +1,16 @@ -use super::{jwt_sql_tables::*, tcp_backend_handler::*}; -use crate::domain::{error::*, handler::UserId, sql_backend_handler::SqlBackendHandler}; +use super::tcp_backend_handler::TcpBackendHandler; +use crate::domain::{ + error::*, + handler::UserId, + model::{self, JwtRefreshStorageColumn, JwtStorageColumn, PasswordResetTokensColumn}, + sql_backend_handler::SqlBackendHandler, +}; use async_trait::async_trait; -use futures_util::StreamExt; -use sea_query::{Expr, Iden, Query, SimpleExpr}; -use sea_query_binder::SqlxBinder; -use sqlx::{query_as_with, query_with, Row}; +use sea_orm::{ + sea_query::Cond, ActiveModelTrait, ColumnTrait, EntityTrait, FromQueryResult, IntoActiveModel, + QueryFilter, QuerySelect, +}; +use sea_query::Expr; use std::collections::HashSet; use tracing::{debug, instrument}; @@ -18,126 +24,102 @@ fn gen_random_string(len: usize) -> String { .collect() } +#[derive(FromQueryResult)] +struct OnlyJwtHash { + jwt_hash: i64, +} + #[async_trait] impl TcpBackendHandler for SqlBackendHandler { #[instrument(skip_all, level = "debug")] async fn get_jwt_blacklist(&self) -> anyhow::Result> { - let (query, values) = Query::select() - .column(JwtStorage::JwtHash) - .from(JwtStorage::Table) - .build_sqlx(DbQueryBuilder {}); - - debug!(%query); - query_with(&query, values) - .map(|row: DbRow| row.get::(&*JwtStorage::JwtHash.to_string()) as u64) - .fetch(&self.sql_pool) - .collect::>>() - .await + Ok(model::JwtStorage::find() + .select_only() + .column(JwtStorageColumn::JwtHash) + .filter(JwtStorageColumn::Blacklisted.eq(true)) + .into_model::() + .all(&self.sql_pool) + .await? .into_iter() - .collect::>>() - .map_err(|e| anyhow::anyhow!(e)) + .map(|m| m.jwt_hash as u64) + .collect::>()) } #[instrument(skip_all, level = "debug")] async fn create_refresh_token(&self, user: &UserId) -> Result<(String, chrono::Duration)> { debug!(?user); - use std::collections::hash_map::DefaultHasher; - use std::hash::{Hash, Hasher}; // TODO: Initialize the rng only once. Maybe Arc? let refresh_token = gen_random_string(100); let refresh_token_hash = { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; let mut s = DefaultHasher::new(); refresh_token.hash(&mut s); s.finish() }; let duration = chrono::Duration::days(30); - let (query, values) = Query::insert() - .into_table(JwtRefreshStorage::Table) - .columns(vec![ - JwtRefreshStorage::RefreshTokenHash, - JwtRefreshStorage::UserId, - JwtRefreshStorage::ExpiryDate, - ]) - .values_panic(vec![ - (refresh_token_hash as i64).into(), - user.into(), - (chrono::Utc::now() + duration).naive_utc().into(), - ]) - .build_sqlx(DbQueryBuilder {}); - debug!(%query); - query_with(&query, values).execute(&self.sql_pool).await?; + let new_token = model::jwt_refresh_storage::Model { + refresh_token_hash: refresh_token_hash as i64, + user_id: user.clone(), + expiry_date: chrono::Utc::now() + duration, + } + .into_active_model(); + new_token.insert(&self.sql_pool).await?; Ok((refresh_token, duration)) } #[instrument(skip_all, level = "debug")] async fn check_token(&self, refresh_token_hash: u64, user: &UserId) -> Result { debug!(?user); - let (query, values) = Query::select() - .expr(SimpleExpr::Value(1.into())) - .from(JwtRefreshStorage::Table) - .and_where(Expr::col(JwtRefreshStorage::RefreshTokenHash).eq(refresh_token_hash as i64)) - .and_where(Expr::col(JwtRefreshStorage::UserId).eq(user)) - .build_sqlx(DbQueryBuilder {}); - debug!(%query); - Ok(query_with(&query, values) - .fetch_optional(&self.sql_pool) - .await? - .is_some()) + Ok( + model::JwtRefreshStorage::find_by_id(refresh_token_hash as i64) + .filter(JwtRefreshStorageColumn::UserId.eq(user)) + .one(&self.sql_pool) + .await? + .is_some(), + ) } #[instrument(skip_all, level = "debug")] async fn blacklist_jwts(&self, user: &UserId) -> Result> { debug!(?user); - use sqlx::Result; - let (query, values) = Query::select() - .column(JwtStorage::JwtHash) - .from(JwtStorage::Table) - .and_where(Expr::col(JwtStorage::UserId).eq(user)) - .and_where(Expr::col(JwtStorage::Blacklisted).eq(true)) - .build_sqlx(DbQueryBuilder {}); - let result = query_with(&query, values) - .map(|row: DbRow| row.get::(&*JwtStorage::JwtHash.to_string()) as u64) - .fetch(&self.sql_pool) - .collect::>>() - .await + let valid_tokens = model::JwtStorage::find() + .select_only() + .column(JwtStorageColumn::JwtHash) + .filter( + Cond::all() + .add(JwtStorageColumn::UserId.eq(user)) + .add(JwtStorageColumn::Blacklisted.eq(false)), + ) + .into_model::() + .all(&self.sql_pool) + .await? .into_iter() - .collect::>>(); - let (query, values) = Query::update() - .table(JwtStorage::Table) - .values(vec![(JwtStorage::Blacklisted, true.into())]) - .and_where(Expr::col(JwtStorage::UserId).eq(user)) - .build_sqlx(DbQueryBuilder {}); - debug!(%query); - query_with(&query, values).execute(&self.sql_pool).await?; - Ok(result?) + .map(|t| t.jwt_hash as u64) + .collect::>(); + model::JwtStorage::update_many() + .col_expr(JwtStorageColumn::Blacklisted, Expr::value(true)) + .filter(JwtStorageColumn::UserId.eq(user)) + .exec(&self.sql_pool) + .await?; + Ok(valid_tokens) } #[instrument(skip_all, level = "debug")] async fn delete_refresh_token(&self, refresh_token_hash: u64) -> Result<()> { - let (query, values) = Query::delete() - .from_table(JwtRefreshStorage::Table) - .and_where(Expr::col(JwtRefreshStorage::RefreshTokenHash).eq(refresh_token_hash as i64)) - .build_sqlx(DbQueryBuilder {}); - debug!(%query); - query_with(&query, values).execute(&self.sql_pool).await?; + model::JwtRefreshStorage::delete_by_id(refresh_token_hash as i64) + .exec(&self.sql_pool) + .await?; Ok(()) } #[instrument(skip_all, level = "debug")] async fn start_password_reset(&self, user: &UserId) -> Result> { debug!(?user); - let (query, values) = Query::select() - .column(Users::UserId) - .from(Users::Table) - .and_where(Expr::col(Users::UserId).eq(user)) - .build_sqlx(DbQueryBuilder {}); - - debug!(%query); - // Check that the user exists. - if query_with(&query, values) - .fetch_one(&self.sql_pool) - .await - .is_err() + if model::User::find_by_id(user.clone()) + .one(&self.sql_pool) + .await? + .is_none() { debug!("User not found"); return Ok(None); @@ -146,50 +128,37 @@ impl TcpBackendHandler for SqlBackendHandler { let token = gen_random_string(100); let duration = chrono::Duration::minutes(10); - let (query, values) = Query::insert() - .into_table(PasswordResetTokens::Table) - .columns(vec![ - PasswordResetTokens::Token, - PasswordResetTokens::UserId, - PasswordResetTokens::ExpiryDate, - ]) - .values_panic(vec![ - token.clone().into(), - user.into(), - (chrono::Utc::now() + duration).naive_utc().into(), - ]) - .build_sqlx(DbQueryBuilder {}); - debug!(%query); - query_with(&query, values).execute(&self.sql_pool).await?; + let new_token = model::password_reset_tokens::Model { + token: token.clone(), + user_id: user.clone(), + expiry_date: chrono::Utc::now() + duration, + } + .into_active_model(); + new_token.insert(&self.sql_pool).await?; Ok(Some(token)) } #[instrument(skip_all, level = "debug", ret)] async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result { - let (query, values) = Query::select() - .column(PasswordResetTokens::UserId) - .from(PasswordResetTokens::Table) - .and_where(Expr::col(PasswordResetTokens::Token).eq(token)) - .and_where( - Expr::col(PasswordResetTokens::ExpiryDate).gt(chrono::Utc::now().naive_utc()), - ) - .build_sqlx(DbQueryBuilder {}); - debug!(%query); - - let (user_id,) = query_as_with(&query, values) - .fetch_one(&self.sql_pool) - .await?; - Ok(user_id) + Ok(model::PasswordResetTokens::find_by_id(token.to_owned()) + .filter(PasswordResetTokensColumn::ExpiryDate.gt(chrono::Utc::now().naive_utc())) + .one(&self.sql_pool) + .await? + .ok_or_else(|| DomainError::EntityNotFound("Invalid reset token".to_owned()))? + .user_id) } #[instrument(skip_all, level = "debug")] async fn delete_password_reset_token(&self, token: &str) -> Result<()> { - let (query, values) = Query::delete() - .from_table(PasswordResetTokens::Table) - .and_where(Expr::col(PasswordResetTokens::Token).eq(token)) - .build_sqlx(DbQueryBuilder {}); - debug!(%query); - query_with(&query, values).execute(&self.sql_pool).await?; + let result = model::PasswordResetTokens::delete_by_id(token.to_owned()) + .exec(&self.sql_pool) + .await?; + if result.rows_affected == 0 { + return Err(DomainError::EntityNotFound(format!( + "No such password reset token: '{}'", + token + ))); + } Ok(()) } } diff --git a/server/src/infra/tcp_server.rs b/server/src/infra/tcp_server.rs index dd808f6..27a751e 100644 --- a/server/src/infra/tcp_server.rs +++ b/server/src/infra/tcp_server.rs @@ -52,9 +52,9 @@ pub(crate) fn error_to_http_response(error: TcpError) -> HttpResponse { DomainError::DatabaseError(_) | DomainError::InternalError(_) | DomainError::UnknownCryptoError(_) => HttpResponse::InternalServerError(), - DomainError::Base64DecodeError(_) | DomainError::BinarySerializationError(_) => { - HttpResponse::BadRequest() - } + DomainError::Base64DecodeError(_) + | DomainError::BinarySerializationError(_) + | DomainError::EntityNotFound(_) => HttpResponse::BadRequest(), }, TcpError::BadRequest(_) => HttpResponse::BadRequest(), TcpError::InternalServerError(_) => HttpResponse::InternalServerError(), diff --git a/server/src/main.rs b/server/src/main.rs index 9e133f4..005e0ce 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -9,7 +9,6 @@ use crate::{ handler::{CreateUserRequest, GroupBackendHandler, GroupRequestFilter, UserBackendHandler}, sql_backend_handler::SqlBackendHandler, sql_opaque_handler::register_password, - sql_tables::PoolOptions, }, infra::{cli::*, configuration::Configuration, db_cleaner::Scheduler, healthcheck, mail}, }; @@ -17,6 +16,7 @@ use actix::Actor; use actix_server::ServerBuilder; use anyhow::{anyhow, Context, Result}; use futures_util::TryFutureExt; +use sea_orm::Database; use tracing::*; mod domain; @@ -39,29 +39,52 @@ async fn create_admin_user(handler: &SqlBackendHandler, config: &Configuration) .and_then(|_| register_password(handler, &config.ldap_user_dn, &config.ldap_user_pass)) .await .context("Error creating admin user")?; - let admin_group_id = handler - .create_group("lldap_admin") - .await - .context("Error creating admin group")?; + let groups = handler + .list_groups(Some(GroupRequestFilter::DisplayName( + "lldap_admin".to_owned(), + ))) + .await?; + assert_eq!(groups.len(), 1); handler - .add_user_to_group(&config.ldap_user_dn, admin_group_id) + .add_user_to_group(&config.ldap_user_dn, groups[0].id) .await .context("Error adding admin user to group") } +async fn ensure_group_exists(handler: &SqlBackendHandler, group_name: &str) -> Result<()> { + if handler + .list_groups(Some(GroupRequestFilter::DisplayName(group_name.to_owned()))) + .await? + .is_empty() + { + warn!("Could not find {} group, trying to create it", group_name); + handler + .create_group(group_name) + .await + .context(format!("while creating {} group", group_name))?; + } + Ok(()) +} + #[instrument(skip_all)] async fn set_up_server(config: Configuration) -> Result { info!("Starting LLDAP version {}", env!("CARGO_PKG_VERSION")); - let sql_pool = PoolOptions::new() - .max_connections(5) - .connect(&config.database_url) - .await - .context("while connecting to the DB")?; + let sql_pool = { + let mut sql_opt = sea_orm::ConnectOptions::new(config.database_url.clone()); + sql_opt + .max_connections(5) + .sqlx_logging(true) + .sqlx_logging_level(log::LevelFilter::Debug); + Database::connect(sql_opt).await? + }; domain::sql_tables::init_table(&sql_pool) .await .context("while creating the tables")?; let backend_handler = SqlBackendHandler::new(config.clone(), sql_pool.clone()); + ensure_group_exists(&backend_handler, "lldap_admin").await?; + ensure_group_exists(&backend_handler, "lldap_password_manager").await?; + ensure_group_exists(&backend_handler, "lldap_strict_readonly").await?; if let Err(e) = backend_handler.get_user_details(&config.ldap_user_dn).await { warn!("Could not get admin user, trying to create it: {:#}", e); create_admin_user(&backend_handler, &config) @@ -69,23 +92,6 @@ async fn set_up_server(config: Configuration) -> Result { .map_err(|e| anyhow!("Error setting up admin login/account: {:#}", e)) .context("while creating the admin user")?; } - if backend_handler - .list_groups(Some(GroupRequestFilter::DisplayName( - "lldap_password_manager".to_string(), - ))) - .await? - .is_empty() - { - warn!("Could not find password_manager group, trying to create it"); - backend_handler - .create_group("lldap_password_manager") - .await - .context("while creating password_manager group")?; - backend_handler - .create_group("lldap_strict_readonly") - .await - .context("while creating readonly group")?; - } let server_builder = infra::ldap_server::build_ldap_server( &config, backend_handler.clone(),