flowCreate.solutions

Local Database Sync Script

This document provides an example sync_local_schema.py tool. It enables a fast, local-only workflow by synchronizing your local database schema with your SQLAlchemy models without managing Alembic revision history.

This is intentionally best-effort and non-destructive by default:

  • creates missing tables
  • adds missing columns
  • does not automatically drop tables/columns or perform renames safely

Usage

# Sync schema (create missing tables, add missing columns)
PYTHONPATH=. python migrations/sync_local_schema.py

# Check for differences only (dry run)
PYTHONPATH=. python migrations/sync_local_schema.py --check

# Create a backup before syncing
PYTHONPATH=. python migrations/sync_local_schema.py --backup

# Apply without an interactive prompt
PYTHONPATH=. python migrations/sync_local_schema.py --force

Script Implementation

Save this file as migrations/sync_local_schema.py.

#!/usr/bin/env python3
"""
Local Database Schema Sync Tool
This script keeps your local database schema in sync with the latest SQLAlchemy models.

Usage:
    python sync_local_schema.py          # Full sync
    python sync_local_schema.py --check  # Check for differences only
    python sync_local_schema.py --backup # Create backup before sync
"""

import os
import sys
import argparse
from datetime import datetime
from dotenv import load_dotenv
from sqlalchemy import create_engine, text, inspect
from sqlalchemy import Enum as SAEnum
import logging

# Load environment variables
load_dotenv()

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("migrations.sync_local_schema")

# Set local environment and read DB URL.
os.environ["ENVIRONMENT"] = "local"
DATABASE_URL = os.getenv("DATABASE_URL_LOCAL")
if not DATABASE_URL:
    logger.error("DATABASE_URL_LOCAL is not set. Cannot continue.")
    sys.exit(1)

def import_all_models():
    """
    Import all SQLAlchemy models to register them with Base.metadata.
    Adjust imports based on your project structure. The key requirement is that
    `Base.metadata` contains all tables for your application.
    """
    try:
        from database.database import Base
        # Import your entity model modules here so SQLAlchemy registers them:
        # from database.users import models as users_models
        # from database.items import models as items_models

        logger.info("Imported models for schema sync")
        return Base
        
    except ImportError as e:
        logger.error("Error importing models: %s", e)
        sys.exit(1)

def get_current_schema_info(engine):
    """Get information about current database schema"""
    inspector = inspect(engine)
    
    schema_info = {
        'tables': {},
        'total_tables': 0
    }
    
    try:
        table_names = inspector.get_table_names()
        schema_info['total_tables'] = len(table_names)
        
        for table_name in table_names:
            columns = inspector.get_columns(table_name)
            indexes = inspector.get_indexes(table_name)
            
            schema_info['tables'][table_name] = {
                'columns': {col['name']: col['type'].__class__.__name__ for col in columns},
                'column_count': len(columns),
                'indexes': [idx['name'] for idx in indexes]
            }
            
    except Exception as e:
        logger.warning(f"⚠️ Could not inspect current schema: {e}")
        
    return schema_info

def get_model_schema_info(Base):
    """Get information about SQLAlchemy models schema"""
    schema_info = {
        'tables': {},
        'total_tables': 0
    }
    
    tables = Base.metadata.tables
    schema_info['total_tables'] = len(tables)
    
    for table_name, table in tables.items():
        schema_info['tables'][table_name] = {
            'columns': {col.name: col.type.__class__.__name__ for col in table.columns},
            'column_count': len(table.columns),
            'indexes': [idx.name for idx in table.indexes if idx.name]
        }
        
    return schema_info

def compare_schemas(current_schema, model_schema):
    """Compare current database schema with model schema"""
    differences = {
        'missing_tables': [],
        'extra_tables': [],
        'table_differences': {}
    }
    
    current_tables = set(current_schema['tables'].keys())
    model_tables = set(model_schema['tables'].keys())
    
    # Find missing and extra tables
    differences['missing_tables'] = list(model_tables - current_tables)
    differences['extra_tables'] = list(current_tables - model_tables)
    
    # Compare common tables
    common_tables = current_tables & model_tables
    for table_name in common_tables:
        current_cols = set(current_schema['tables'][table_name]['columns'].keys())
        model_cols = set(model_schema['tables'][table_name]['columns'].keys())
        
        missing_cols = model_cols - current_cols
        extra_cols = current_cols - model_cols
        
        if missing_cols or extra_cols:
            differences['table_differences'][table_name] = {
                'missing_columns': list(missing_cols),
                'extra_columns': list(extra_cols)
            }
    
    return differences

def create_backup(engine):
    """Create a backup of the current database using pg_dump"""
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    backup_file = f"local_db_backup_{timestamp}.sql"
    
    try:
        db_url_parts = DATABASE_URL.replace("postgresql://", "").split("/")
        db_name = db_url_parts[1]
        host_user = db_url_parts[0].split("@")
        
        if len(host_user) == 2:
            user, host_port = host_user
            host = host_port.split(":")[0]
        else:
            user = os.getenv("USER", "postgres")
            host = "localhost"
        
        backup_cmd = f"pg_dump -h {host} -U {user} -d {db_name} > {backup_file}"
        os.system(backup_cmd)
        
        logger.info("Backup created: %s", backup_file)
        return backup_file
        
    except Exception as e:
        logger.warning("Could not create backup: %s", e)
        return None

def _add_missing_columns(engine, Base, table_name, missing_columns):
    """Add missing columns to an existing table"""
    try:
        logger.info("Adding missing columns to %s: %s", table_name, ", ".join(missing_columns))
        
        table = Base.metadata.tables[table_name]
        
        with engine.connect() as conn:
            for column_name in missing_columns:
                column = table.columns[column_name]
                column_type = column.type.compile(engine.dialect)
                nullable = "NULL" if column.nullable else "NOT NULL"

                # Handle Enums
                if isinstance(column.type, SAEnum):
                    enum_type_name = column.type.name or column_type
                    exists_sql = text(
                        "SELECT 1 FROM pg_type t JOIN pg_namespace n ON n.oid = t.typnamespace WHERE t.typname = :enum_name"
                    )
                    result = conn.execute(exists_sql, {"enum_name": enum_type_name}).fetchone()
                    if not result:
                        labels = [f"'{lbl}'" for lbl in column.type.enums]
                        conn.execute(text(f"CREATE TYPE {enum_type_name} AS ENUM ({', '.join(labels)})"))
                        conn.commit()
                
                # Handle default values
                default_clause = ""
                if column.default is not None:
                    if hasattr(column.default, 'arg'):
                        if isinstance(column.default.arg, bool):
                            default_clause = f" DEFAULT {str(column.default.arg).lower()}"
                        elif isinstance(column.default.arg, str):
                            default_clause = f" DEFAULT '{column.default.arg}'"
                        else:
                            default_clause = f" DEFAULT {column.default.arg}"
                
                alter_sql = f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_type}{default_clause} {nullable}"
                conn.execute(text(alter_sql))
                conn.commit()
                
        logger.info("Successfully added columns to %s", table_name)
        
    except Exception as e:
        logger.error("Error adding columns to %s: %s", table_name, e)
        raise

def sync_schema(engine, Base, force=False):
    """Sync the database schema with SQLAlchemy models"""
    try:
        logger.info("Starting schema sync...")
        
        current_schema = get_current_schema_info(engine)
        model_schema = get_model_schema_info(Base)
        
        differences = compare_schemas(current_schema, model_schema)
        
        if not any([differences["missing_tables"], differences["extra_tables"], differences["table_differences"]]):
            logger.info("Local database schema is already up to date")
            return True
        
        # Log differences...
        if differences["missing_tables"]:
            logger.info("Tables to create: %s", ", ".join(differences["missing_tables"]))
        if differences["extra_tables"]:
            logger.warning("Extra tables exist (not removed automatically): %s", ", ".join(differences["extra_tables"]))
            
        if not force:
            response = input("\nApply these changes? (y/N): ").strip().lower()
            if response != "y":
                return False
        
        logger.info("Applying schema changes...")
        Base.metadata.create_all(engine, checkfirst=True)
        
        if differences["table_differences"]:
            for table_name, diffs in differences["table_differences"].items():
                if diffs["missing_columns"]:
                    _add_missing_columns(engine, Base, table_name, diffs["missing_columns"])
        
        return True
        
    except Exception as e:
        logger.error("Error during schema sync: %s", e)
        return False

def main():
    parser = argparse.ArgumentParser(description="Sync local database schema")
    parser.add_argument("--check", action="store_true", help="Check for differences only")
    parser.add_argument("--backup", action="store_true", help="Create backup before sync")
    parser.add_argument("--force", action="store_true", help="Apply without confirmation")
    
    args = parser.parse_args()
    
    Base = import_all_models()
    engine = create_engine(DATABASE_URL)
    
    if args.backup:
        create_backup(engine)

    if args.check:
        current_schema = get_current_schema_info(engine)
        model_schema = get_model_schema_info(Base)
        diffs = compare_schemas(current_schema, model_schema)
        has_diffs = any([diffs["missing_tables"], diffs["extra_tables"], diffs["table_differences"]])
        if has_diffs:
            logger.warning("Schema differences detected: %s", diffs)
            sys.exit(2)
        logger.info("No schema differences detected")
        return

    sync_schema(engine, Base, force=args.force)

if __name__ == "__main__":
    main()