diff --git a/server/src/domain/sql_migrations.rs b/server/src/domain/sql_migrations.rs index 7efb152..e4ed092 100644 --- a/server/src/domain/sql_migrations.rs +++ b/server/src/domain/sql_migrations.rs @@ -2,10 +2,12 @@ use crate::domain::{ sql_tables::{DbConnection, SchemaVersion}, types::{GroupId, UserId, Uuid}, }; -use sea_orm::{ConnectionTrait, FromQueryResult, Statement}; +use sea_orm::{ConnectionTrait, FromQueryResult, Statement, TransactionTrait}; use sea_query::{ColumnDef, Expr, ForeignKey, ForeignKeyAction, Iden, Query, Table, Value}; use serde::{Deserialize, Serialize}; -use tracing::{instrument, warn}; +use tracing::{info, instrument, warn}; + +use super::sql_tables::LAST_SCHEMA_VERSION; #[derive(Iden, PartialEq, Eq, Debug, Serialize, Deserialize, Clone)] pub enum Users { @@ -331,11 +333,87 @@ pub async fn upgrade_to_v1(pool: &DbConnection) -> std::result::Result<(), sea_o } pub async fn migrate_from_version( - _pool: &DbConnection, + pool: &DbConnection, version: SchemaVersion, ) -> anyhow::Result<()> { - if version.0 > 1 { + if version > LAST_SCHEMA_VERSION { anyhow::bail!("DB version downgrading is not supported"); + } else if version == LAST_SCHEMA_VERSION { + return Ok(()); } + info!( + "Upgrading DB schema from {} to {}", + version.0, LAST_SCHEMA_VERSION.0 + ); + let builder = pool.get_database_backend(); + if version < SchemaVersion(2) { + // Drop the not_null constraint on display_name. Due to Sqlite, this is more complicated: + // - rename the display_name column to a temporary name + // - create the display_name column without the constraint + // - copy the data from the temp column to the new one + // - update the new one to replace empty strings with null + // - drop the old one + pool.transaction::<_, (), sea_orm::DbErr>(|transaction| { + Box::pin(async move { + #[derive(Iden)] + enum TempUsers { + TempDisplayName, + } + transaction + .execute( + builder.build( + Table::alter() + .table(Users::Table) + .rename_column(Users::DisplayName, TempUsers::TempDisplayName), + ), + ) + .await?; + transaction + .execute( + builder.build( + Table::alter() + .table(Users::Table) + .add_column(ColumnDef::new(Users::DisplayName).string_len(255)), + ), + ) + .await?; + transaction + .execute(builder.build(Query::update().table(Users::Table).value( + Users::DisplayName, + Expr::col((Users::Table, TempUsers::TempDisplayName)), + ))) + .await?; + transaction + .execute( + builder.build( + Query::update() + .table(Users::Table) + .value(Users::DisplayName, Option::::None) + .cond_where(Expr::col(Users::DisplayName).eq("")), + ), + ) + .await?; + transaction + .execute( + builder.build( + Table::alter() + .table(Users::Table) + .drop_column(TempUsers::TempDisplayName), + ), + ) + .await?; + Ok(()) + }) + }) + .await?; + } + pool.execute( + builder.build( + Query::update() + .table(Metadata::Table) + .value(Metadata::Version, Value::from(LAST_SCHEMA_VERSION)), + ), + ) + .await?; Ok(()) } diff --git a/server/src/domain/sql_tables.rs b/server/src/domain/sql_tables.rs index 0f202b0..0a81363 100644 --- a/server/src/domain/sql_tables.rs +++ b/server/src/domain/sql_tables.rs @@ -3,7 +3,7 @@ use sea_orm::Value; pub type DbConnection = sea_orm::DatabaseConnection; -#[derive(Copy, PartialEq, Eq, Debug, Clone)] +#[derive(Copy, PartialEq, Eq, Debug, Clone, PartialOrd, Ord)] pub struct SchemaVersion(pub i16); impl sea_orm::TryGetable for SchemaVersion { @@ -22,6 +22,8 @@ impl From for Value { } } +pub const LAST_SCHEMA_VERSION: SchemaVersion = SchemaVersion(2); + pub async fn init_table(pool: &DbConnection) -> anyhow::Result<()> { let version = { if let Some(version) = get_schema_version(pool).await { @@ -99,14 +101,21 @@ mod tests { let sql_pool = get_in_memory_db().await; sql_pool .execute(raw_statement( - r#"CREATE TABLE users ( user_id TEXT , creation_date TEXT);"#, + r#"CREATE TABLE users ( user_id TEXT, display_name TEXT, creation_date TEXT);"#, )) .await .unwrap(); sql_pool .execute(raw_statement( - r#"INSERT INTO users (user_id, creation_date) - VALUES ("bôb", "1970-01-01 00:00:00")"#, + r#"INSERT INTO users (user_id, display_name, creation_date) + VALUES ("bôb", "", "1970-01-01 00:00:00")"#, + )) + .await + .unwrap(); + sql_pool + .execute(raw_statement( + r#"INSERT INTO users (user_id, display_name, creation_date) + VALUES ("john", "John Doe", "1971-01-01 00:00:00")"#, )) .await .unwrap(); @@ -132,17 +141,27 @@ mod tests { .await .unwrap(); #[derive(FromQueryResult, PartialEq, Eq, Debug)] - struct JustUuid { + struct SimpleUser { + display_name: Option, uuid: Uuid, } assert_eq!( - JustUuid::find_by_statement(raw_statement(r#"SELECT uuid FROM users"#)) - .all(&sql_pool) - .await - .unwrap(), - vec![JustUuid { - uuid: crate::uuid!("a02eaf13-48a7-30f6-a3d4-040ff7c52b04") - }] + SimpleUser::find_by_statement(raw_statement( + r#"SELECT display_name, uuid FROM users ORDER BY display_name"# + )) + .all(&sql_pool) + .await + .unwrap(), + vec![ + SimpleUser { + display_name: None, + uuid: crate::uuid!("a02eaf13-48a7-30f6-a3d4-040ff7c52b04") + }, + SimpleUser { + display_name: Some("John Doe".to_owned()), + uuid: crate::uuid!("986765a5-3f03-389e-b47b-536b2d6e1bec") + } + ] ); #[derive(FromQueryResult, PartialEq, Eq, Debug)] struct ShortGroupDetails { @@ -180,7 +199,7 @@ mod tests { .unwrap() .unwrap(), sql_migrations::JustSchemaVersion { - version: SchemaVersion(1) + version: LAST_SCHEMA_VERSION } ); }