Local Database Sync Script
This document provides an example sync_local_schema.py tool. It enables a fast, local-only workflow by synchronizing your local database schema with your SQLAlchemy models without managing Alembic revision history.
This is intentionally best-effort and non-destructive by default:
- creates missing tables
- adds missing columns
- does not automatically drop tables/columns or perform renames safely
Usage
# Sync schema (create missing tables, add missing columns)
PYTHONPATH=. python migrations/sync_local_schema.py
# Check for differences only (dry run)
PYTHONPATH=. python migrations/sync_local_schema.py --check
# Create a backup before syncing
PYTHONPATH=. python migrations/sync_local_schema.py --backup
# Apply without an interactive prompt
PYTHONPATH=. python migrations/sync_local_schema.py --force
Script Implementation
Save this file as migrations/sync_local_schema.py.
#!/usr/bin/env python3
"""
Local Database Schema Sync Tool
This script keeps your local database schema in sync with the latest SQLAlchemy models.
Usage:
python sync_local_schema.py # Full sync
python sync_local_schema.py --check # Check for differences only
python sync_local_schema.py --backup # Create backup before sync
"""
import os
import sys
import argparse
from datetime import datetime
from dotenv import load_dotenv
from sqlalchemy import create_engine, text, inspect
from sqlalchemy import Enum as SAEnum
import logging
# Load environment variables
load_dotenv()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("migrations.sync_local_schema")
# Set local environment and read DB URL.
os.environ["ENVIRONMENT"] = "local"
DATABASE_URL = os.getenv("DATABASE_URL_LOCAL")
if not DATABASE_URL:
logger.error("DATABASE_URL_LOCAL is not set. Cannot continue.")
sys.exit(1)
def import_all_models():
"""
Import all SQLAlchemy models to register them with Base.metadata.
Adjust imports based on your project structure. The key requirement is that
`Base.metadata` contains all tables for your application.
"""
try:
from database.database import Base
# Import your entity model modules here so SQLAlchemy registers them:
# from database.users import models as users_models
# from database.items import models as items_models
logger.info("Imported models for schema sync")
return Base
except ImportError as e:
logger.error("Error importing models: %s", e)
sys.exit(1)
def get_current_schema_info(engine):
"""Get information about current database schema"""
inspector = inspect(engine)
schema_info = {
'tables': {},
'total_tables': 0
}
try:
table_names = inspector.get_table_names()
schema_info['total_tables'] = len(table_names)
for table_name in table_names:
columns = inspector.get_columns(table_name)
indexes = inspector.get_indexes(table_name)
schema_info['tables'][table_name] = {
'columns': {col['name']: col['type'].__class__.__name__ for col in columns},
'column_count': len(columns),
'indexes': [idx['name'] for idx in indexes]
}
except Exception as e:
logger.warning(f"⚠️ Could not inspect current schema: {e}")
return schema_info
def get_model_schema_info(Base):
"""Get information about SQLAlchemy models schema"""
schema_info = {
'tables': {},
'total_tables': 0
}
tables = Base.metadata.tables
schema_info['total_tables'] = len(tables)
for table_name, table in tables.items():
schema_info['tables'][table_name] = {
'columns': {col.name: col.type.__class__.__name__ for col in table.columns},
'column_count': len(table.columns),
'indexes': [idx.name for idx in table.indexes if idx.name]
}
return schema_info
def compare_schemas(current_schema, model_schema):
"""Compare current database schema with model schema"""
differences = {
'missing_tables': [],
'extra_tables': [],
'table_differences': {}
}
current_tables = set(current_schema['tables'].keys())
model_tables = set(model_schema['tables'].keys())
# Find missing and extra tables
differences['missing_tables'] = list(model_tables - current_tables)
differences['extra_tables'] = list(current_tables - model_tables)
# Compare common tables
common_tables = current_tables & model_tables
for table_name in common_tables:
current_cols = set(current_schema['tables'][table_name]['columns'].keys())
model_cols = set(model_schema['tables'][table_name]['columns'].keys())
missing_cols = model_cols - current_cols
extra_cols = current_cols - model_cols
if missing_cols or extra_cols:
differences['table_differences'][table_name] = {
'missing_columns': list(missing_cols),
'extra_columns': list(extra_cols)
}
return differences
def create_backup(engine):
"""Create a backup of the current database using pg_dump"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_file = f"local_db_backup_{timestamp}.sql"
try:
db_url_parts = DATABASE_URL.replace("postgresql://", "").split("/")
db_name = db_url_parts[1]
host_user = db_url_parts[0].split("@")
if len(host_user) == 2:
user, host_port = host_user
host = host_port.split(":")[0]
else:
user = os.getenv("USER", "postgres")
host = "localhost"
backup_cmd = f"pg_dump -h {host} -U {user} -d {db_name} > {backup_file}"
os.system(backup_cmd)
logger.info("Backup created: %s", backup_file)
return backup_file
except Exception as e:
logger.warning("Could not create backup: %s", e)
return None
def _add_missing_columns(engine, Base, table_name, missing_columns):
"""Add missing columns to an existing table"""
try:
logger.info("Adding missing columns to %s: %s", table_name, ", ".join(missing_columns))
table = Base.metadata.tables[table_name]
with engine.connect() as conn:
for column_name in missing_columns:
column = table.columns[column_name]
column_type = column.type.compile(engine.dialect)
nullable = "NULL" if column.nullable else "NOT NULL"
# Handle Enums
if isinstance(column.type, SAEnum):
enum_type_name = column.type.name or column_type
exists_sql = text(
"SELECT 1 FROM pg_type t JOIN pg_namespace n ON n.oid = t.typnamespace WHERE t.typname = :enum_name"
)
result = conn.execute(exists_sql, {"enum_name": enum_type_name}).fetchone()
if not result:
labels = [f"'{lbl}'" for lbl in column.type.enums]
conn.execute(text(f"CREATE TYPE {enum_type_name} AS ENUM ({', '.join(labels)})"))
conn.commit()
# Handle default values
default_clause = ""
if column.default is not None:
if hasattr(column.default, 'arg'):
if isinstance(column.default.arg, bool):
default_clause = f" DEFAULT {str(column.default.arg).lower()}"
elif isinstance(column.default.arg, str):
default_clause = f" DEFAULT '{column.default.arg}'"
else:
default_clause = f" DEFAULT {column.default.arg}"
alter_sql = f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_type}{default_clause} {nullable}"
conn.execute(text(alter_sql))
conn.commit()
logger.info("Successfully added columns to %s", table_name)
except Exception as e:
logger.error("Error adding columns to %s: %s", table_name, e)
raise
def sync_schema(engine, Base, force=False):
"""Sync the database schema with SQLAlchemy models"""
try:
logger.info("Starting schema sync...")
current_schema = get_current_schema_info(engine)
model_schema = get_model_schema_info(Base)
differences = compare_schemas(current_schema, model_schema)
if not any([differences["missing_tables"], differences["extra_tables"], differences["table_differences"]]):
logger.info("Local database schema is already up to date")
return True
# Log differences...
if differences["missing_tables"]:
logger.info("Tables to create: %s", ", ".join(differences["missing_tables"]))
if differences["extra_tables"]:
logger.warning("Extra tables exist (not removed automatically): %s", ", ".join(differences["extra_tables"]))
if not force:
response = input("\nApply these changes? (y/N): ").strip().lower()
if response != "y":
return False
logger.info("Applying schema changes...")
Base.metadata.create_all(engine, checkfirst=True)
if differences["table_differences"]:
for table_name, diffs in differences["table_differences"].items():
if diffs["missing_columns"]:
_add_missing_columns(engine, Base, table_name, diffs["missing_columns"])
return True
except Exception as e:
logger.error("Error during schema sync: %s", e)
return False
def main():
parser = argparse.ArgumentParser(description="Sync local database schema")
parser.add_argument("--check", action="store_true", help="Check for differences only")
parser.add_argument("--backup", action="store_true", help="Create backup before sync")
parser.add_argument("--force", action="store_true", help="Apply without confirmation")
args = parser.parse_args()
Base = import_all_models()
engine = create_engine(DATABASE_URL)
if args.backup:
create_backup(engine)
if args.check:
current_schema = get_current_schema_info(engine)
model_schema = get_model_schema_info(Base)
diffs = compare_schemas(current_schema, model_schema)
has_diffs = any([diffs["missing_tables"], diffs["extra_tables"], diffs["table_differences"]])
if has_diffs:
logger.warning("Schema differences detected: %s", diffs)
sys.exit(2)
logger.info("No schema differences detected")
return
sync_schema(engine, Base, force=args.force)
if __name__ == "__main__":
main()