mirror of
				https://github.com/nitnelave/lldap.git
				synced 2023-04-12 14:25:13 +00:00 
			
		
		
		
	
							parent
							
								
									648848c816
								
							
						
					
					
						commit
						bfce7361df
					
				@ -2,10 +2,12 @@ use crate::domain::{
 | 
			
		||||
    sql_tables::{DbConnection, SchemaVersion},
 | 
			
		||||
    types::{GroupId, UserId, Uuid},
 | 
			
		||||
};
 | 
			
		||||
use sea_orm::{ConnectionTrait, FromQueryResult, Statement};
 | 
			
		||||
use sea_orm::{ConnectionTrait, FromQueryResult, Statement, TransactionTrait};
 | 
			
		||||
use sea_query::{ColumnDef, Expr, ForeignKey, ForeignKeyAction, Iden, Query, Table, Value};
 | 
			
		||||
use serde::{Deserialize, Serialize};
 | 
			
		||||
use tracing::{instrument, warn};
 | 
			
		||||
use tracing::{info, instrument, warn};
 | 
			
		||||
 | 
			
		||||
use super::sql_tables::LAST_SCHEMA_VERSION;
 | 
			
		||||
 | 
			
		||||
#[derive(Iden, PartialEq, Eq, Debug, Serialize, Deserialize, Clone)]
 | 
			
		||||
pub enum Users {
 | 
			
		||||
@ -331,11 +333,87 @@ pub async fn upgrade_to_v1(pool: &DbConnection) -> std::result::Result<(), sea_o
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub async fn migrate_from_version(
 | 
			
		||||
    _pool: &DbConnection,
 | 
			
		||||
    pool: &DbConnection,
 | 
			
		||||
    version: SchemaVersion,
 | 
			
		||||
) -> anyhow::Result<()> {
 | 
			
		||||
    if version.0 > 1 {
 | 
			
		||||
    if version > LAST_SCHEMA_VERSION {
 | 
			
		||||
        anyhow::bail!("DB version downgrading is not supported");
 | 
			
		||||
    } else if version == LAST_SCHEMA_VERSION {
 | 
			
		||||
        return Ok(());
 | 
			
		||||
    }
 | 
			
		||||
    info!(
 | 
			
		||||
        "Upgrading DB schema from {} to {}",
 | 
			
		||||
        version.0, LAST_SCHEMA_VERSION.0
 | 
			
		||||
    );
 | 
			
		||||
    let builder = pool.get_database_backend();
 | 
			
		||||
    if version < SchemaVersion(2) {
 | 
			
		||||
        // Drop the not_null constraint on display_name. Due to Sqlite, this is more complicated:
 | 
			
		||||
        //  - rename the display_name column to a temporary name
 | 
			
		||||
        //  - create the display_name column without the constraint
 | 
			
		||||
        //  - copy the data from the temp column to the new one
 | 
			
		||||
        //  - update the new one to replace empty strings with null
 | 
			
		||||
        //  - drop the old one
 | 
			
		||||
        pool.transaction::<_, (), sea_orm::DbErr>(|transaction| {
 | 
			
		||||
            Box::pin(async move {
 | 
			
		||||
                #[derive(Iden)]
 | 
			
		||||
                enum TempUsers {
 | 
			
		||||
                    TempDisplayName,
 | 
			
		||||
                }
 | 
			
		||||
                transaction
 | 
			
		||||
                    .execute(
 | 
			
		||||
                        builder.build(
 | 
			
		||||
                            Table::alter()
 | 
			
		||||
                                .table(Users::Table)
 | 
			
		||||
                                .rename_column(Users::DisplayName, TempUsers::TempDisplayName),
 | 
			
		||||
                        ),
 | 
			
		||||
                    )
 | 
			
		||||
                    .await?;
 | 
			
		||||
                transaction
 | 
			
		||||
                    .execute(
 | 
			
		||||
                        builder.build(
 | 
			
		||||
                            Table::alter()
 | 
			
		||||
                                .table(Users::Table)
 | 
			
		||||
                                .add_column(ColumnDef::new(Users::DisplayName).string_len(255)),
 | 
			
		||||
                        ),
 | 
			
		||||
                    )
 | 
			
		||||
                    .await?;
 | 
			
		||||
                transaction
 | 
			
		||||
                    .execute(builder.build(Query::update().table(Users::Table).value(
 | 
			
		||||
                        Users::DisplayName,
 | 
			
		||||
                        Expr::col((Users::Table, TempUsers::TempDisplayName)),
 | 
			
		||||
                    )))
 | 
			
		||||
                    .await?;
 | 
			
		||||
                transaction
 | 
			
		||||
                    .execute(
 | 
			
		||||
                        builder.build(
 | 
			
		||||
                            Query::update()
 | 
			
		||||
                                .table(Users::Table)
 | 
			
		||||
                                .value(Users::DisplayName, Option::<String>::None)
 | 
			
		||||
                                .cond_where(Expr::col(Users::DisplayName).eq("")),
 | 
			
		||||
                        ),
 | 
			
		||||
                    )
 | 
			
		||||
                    .await?;
 | 
			
		||||
                transaction
 | 
			
		||||
                    .execute(
 | 
			
		||||
                        builder.build(
 | 
			
		||||
                            Table::alter()
 | 
			
		||||
                                .table(Users::Table)
 | 
			
		||||
                                .drop_column(TempUsers::TempDisplayName),
 | 
			
		||||
                        ),
 | 
			
		||||
                    )
 | 
			
		||||
                    .await?;
 | 
			
		||||
                Ok(())
 | 
			
		||||
            })
 | 
			
		||||
        })
 | 
			
		||||
        .await?;
 | 
			
		||||
    }
 | 
			
		||||
    pool.execute(
 | 
			
		||||
        builder.build(
 | 
			
		||||
            Query::update()
 | 
			
		||||
                .table(Metadata::Table)
 | 
			
		||||
                .value(Metadata::Version, Value::from(LAST_SCHEMA_VERSION)),
 | 
			
		||||
        ),
 | 
			
		||||
    )
 | 
			
		||||
    .await?;
 | 
			
		||||
    Ok(())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -3,7 +3,7 @@ use sea_orm::Value;
 | 
			
		||||
 | 
			
		||||
pub type DbConnection = sea_orm::DatabaseConnection;
 | 
			
		||||
 | 
			
		||||
#[derive(Copy, PartialEq, Eq, Debug, Clone)]
 | 
			
		||||
#[derive(Copy, PartialEq, Eq, Debug, Clone, PartialOrd, Ord)]
 | 
			
		||||
pub struct SchemaVersion(pub i16);
 | 
			
		||||
 | 
			
		||||
impl sea_orm::TryGetable for SchemaVersion {
 | 
			
		||||
@ -22,6 +22,8 @@ impl From<SchemaVersion> for Value {
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub const LAST_SCHEMA_VERSION: SchemaVersion = SchemaVersion(2);
 | 
			
		||||
 | 
			
		||||
pub async fn init_table(pool: &DbConnection) -> anyhow::Result<()> {
 | 
			
		||||
    let version = {
 | 
			
		||||
        if let Some(version) = get_schema_version(pool).await {
 | 
			
		||||
@ -99,14 +101,21 @@ mod tests {
 | 
			
		||||
        let sql_pool = get_in_memory_db().await;
 | 
			
		||||
        sql_pool
 | 
			
		||||
            .execute(raw_statement(
 | 
			
		||||
                r#"CREATE TABLE users ( user_id TEXT , creation_date TEXT);"#,
 | 
			
		||||
                r#"CREATE TABLE users ( user_id TEXT, display_name TEXT, creation_date TEXT);"#,
 | 
			
		||||
            ))
 | 
			
		||||
            .await
 | 
			
		||||
            .unwrap();
 | 
			
		||||
        sql_pool
 | 
			
		||||
            .execute(raw_statement(
 | 
			
		||||
                r#"INSERT INTO users (user_id, creation_date)
 | 
			
		||||
                       VALUES ("bôb", "1970-01-01 00:00:00")"#,
 | 
			
		||||
                r#"INSERT INTO users (user_id, display_name, creation_date)
 | 
			
		||||
                       VALUES ("bôb", "", "1970-01-01 00:00:00")"#,
 | 
			
		||||
            ))
 | 
			
		||||
            .await
 | 
			
		||||
            .unwrap();
 | 
			
		||||
        sql_pool
 | 
			
		||||
            .execute(raw_statement(
 | 
			
		||||
                r#"INSERT INTO users (user_id, display_name, creation_date)
 | 
			
		||||
                       VALUES ("john", "John Doe", "1971-01-01 00:00:00")"#,
 | 
			
		||||
            ))
 | 
			
		||||
            .await
 | 
			
		||||
            .unwrap();
 | 
			
		||||
@ -132,17 +141,27 @@ mod tests {
 | 
			
		||||
            .await
 | 
			
		||||
            .unwrap();
 | 
			
		||||
        #[derive(FromQueryResult, PartialEq, Eq, Debug)]
 | 
			
		||||
        struct JustUuid {
 | 
			
		||||
        struct SimpleUser {
 | 
			
		||||
            display_name: Option<String>,
 | 
			
		||||
            uuid: Uuid,
 | 
			
		||||
        }
 | 
			
		||||
        assert_eq!(
 | 
			
		||||
            JustUuid::find_by_statement(raw_statement(r#"SELECT uuid FROM users"#))
 | 
			
		||||
                .all(&sql_pool)
 | 
			
		||||
                .await
 | 
			
		||||
                .unwrap(),
 | 
			
		||||
            vec![JustUuid {
 | 
			
		||||
                uuid: crate::uuid!("a02eaf13-48a7-30f6-a3d4-040ff7c52b04")
 | 
			
		||||
            }]
 | 
			
		||||
            SimpleUser::find_by_statement(raw_statement(
 | 
			
		||||
                r#"SELECT display_name, uuid FROM users ORDER BY display_name"#
 | 
			
		||||
            ))
 | 
			
		||||
            .all(&sql_pool)
 | 
			
		||||
            .await
 | 
			
		||||
            .unwrap(),
 | 
			
		||||
            vec![
 | 
			
		||||
                SimpleUser {
 | 
			
		||||
                    display_name: None,
 | 
			
		||||
                    uuid: crate::uuid!("a02eaf13-48a7-30f6-a3d4-040ff7c52b04")
 | 
			
		||||
                },
 | 
			
		||||
                SimpleUser {
 | 
			
		||||
                    display_name: Some("John Doe".to_owned()),
 | 
			
		||||
                    uuid: crate::uuid!("986765a5-3f03-389e-b47b-536b2d6e1bec")
 | 
			
		||||
                }
 | 
			
		||||
            ]
 | 
			
		||||
        );
 | 
			
		||||
        #[derive(FromQueryResult, PartialEq, Eq, Debug)]
 | 
			
		||||
        struct ShortGroupDetails {
 | 
			
		||||
@ -180,7 +199,7 @@ mod tests {
 | 
			
		||||
            .unwrap()
 | 
			
		||||
            .unwrap(),
 | 
			
		||||
            sql_migrations::JustSchemaVersion {
 | 
			
		||||
                version: SchemaVersion(1)
 | 
			
		||||
                version: LAST_SCHEMA_VERSION
 | 
			
		||||
            }
 | 
			
		||||
        );
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user