Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
style: enforce E501 line lengths in utility scripts
  • Loading branch information
ParticularlyPythonicBS committed Apr 10, 2026
commit 087ed158e11d4995d828a4e33e99afcfdf15494d
25 changes: 14 additions & 11 deletions temoa/utilities/master_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os
import re
import sqlite3

import tempfile
from pathlib import Path
from typing import Any
Expand Down Expand Up @@ -251,7 +250,8 @@ def _migrate_loan_lifetime(con_old: sqlite3.Connection, con_new: sqlite3.Connect
for v in vints:
new_data.append((row[0], row[1], v, row[2], row[3]))
con_new.executemany(
'INSERT OR REPLACE INTO loan_lifetime_process (region, tech, vintage, lifetime, notes) VALUES (?,?,?,?,?)',
'INSERT OR REPLACE INTO loan_lifetime_process '
'(region, tech, vintage, lifetime, notes) VALUES (?,?,?,?,?)',
new_data,
)
print(f'Migrated {len(new_data)} rows: LoanLifetimeTech -> loan_lifetime_process')
Expand All @@ -270,7 +270,8 @@ def _migrate_time_tables(con_old: sqlite3.Connection, con_new: sqlite3.Connectio
cols = [c[1] for c in get_table_info(con_old, 'TimeSegmentFraction')]
if 'period' in cols:
old_data = con_old.execute(
'SELECT season, SUM(segfrac) / COUNT(DISTINCT period) FROM TimeSegmentFraction GROUP BY season'
'SELECT season, SUM(segfrac) / COUNT(DISTINCT period) '
'FROM TimeSegmentFraction GROUP BY season'
).fetchall()
else:
old_data = con_old.execute(
Expand Down Expand Up @@ -332,7 +333,8 @@ def _migrate_time_tables(con_old: sqlite3.Connection, con_new: sqlite3.Connectio
).fetchone()[0]
if first_period:
old_data = con_old.execute(
'SELECT seas_seq, season, (num_days / 365.25) FROM TimeSeasonSequential WHERE period = ?',
'SELECT seas_seq, season, (num_days / 365.25) '
'FROM TimeSeasonSequential WHERE period = ?',
(first_period,),
).fetchall()
else:
Expand Down Expand Up @@ -362,16 +364,17 @@ def _migrate_capacity_factor(con_old: sqlite3.Connection, con_new: sqlite3.Conne
if cols:
if 'period' in cols:
old_data = con_old.execute(
'SELECT region, season, tod, tech, vintage, AVG(factor) FROM CapacityFactorProcess '
'GROUP BY region, season, tod, tech, vintage'
'SELECT region, season, tod, tech, vintage, AVG(factor) '
'FROM CapacityFactorProcess GROUP BY region, season, tod, tech, vintage'
).fetchall()
else:
old_data = con_old.execute(
'SELECT region, season, tod, tech, vintage, factor FROM CapacityFactorProcess'
).fetchall()
if old_data:
con_new.executemany(
'INSERT OR REPLACE INTO capacity_factor_process (region, season, tod, tech, vintage, factor) VALUES (?,?,?,?,?,?)',
'INSERT OR REPLACE INTO capacity_factor_process '
'(region, season, tod, tech, vintage, factor) VALUES (?,?,?,?,?,?)',
old_data,
)
print(
Expand Down Expand Up @@ -403,9 +406,9 @@ def execute_v3_to_v4_migration(con_old: sqlite3.Connection, con_new: sqlite3.Con

def migrate_database(source_path: Path, schema_path: Path, output_path: Path) -> None:
if not source_path.is_file():
raise FileNotFoundError(f"Input database not found: {source_path}")
raise FileNotFoundError(f'Input database not found: {source_path}')
if not schema_path.is_file():
raise FileNotFoundError(f"Schema file not found: {schema_path}")
raise FileNotFoundError(f'Schema file not found: {schema_path}')

fd, temp_path_str = tempfile.mkstemp(
suffix='.sqlite', prefix='temp_migration_', dir=output_path.parent
Expand Down Expand Up @@ -441,9 +444,9 @@ def migrate_database(source_path: Path, schema_path: Path, output_path: Path) ->

def migrate_sql_dump(source_path: Path, schema_path: Path, output_path: Path) -> None:
if not source_path.is_file():
raise FileNotFoundError(f"Input SQL dump not found: {source_path}")
raise FileNotFoundError(f'Input SQL dump not found: {source_path}')
if not schema_path.is_file():
raise FileNotFoundError(f"Schema file not found: {schema_path}")
raise FileNotFoundError(f'Schema file not found: {schema_path}')

con_old_in_memory = sqlite3.connect(':memory:')
with open(source_path, encoding='utf-8') as f:
Expand Down
17 changes: 13 additions & 4 deletions temoa/utilities/run_all_v4_migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ def run_command(


def run_migrations(
input_dir: Path, migration_script: Path, schema_path: Path, dry_run: bool = False, silent: bool = False
input_dir: Path,
migration_script: Path,
schema_path: Path,
dry_run: bool = False,
silent: bool = False,
) -> None:
if not input_dir.is_dir():
raise FileNotFoundError(f'Error: Input directory not found at {input_dir}')
Expand Down Expand Up @@ -115,13 +119,17 @@ def run_migrations(
else:
# 4. On failure, restore original file
if not silent:
print(f'FAILED: Migration for {target_file.name} failed. Restoring original file.')
print(
f'FAILED: Migration for {target_file.name} failed. Restoring original file.'
)
shutil.copy2(original_backup_file, target_file)
failed_files.append(target_file.name)

except Exception as e:
if not silent:
print(f'CRITICAL ERROR processing {target_file.name}: {e}. Restoring original file.')
print(
f'CRITICAL ERROR processing {target_file.name}: {e}. Restoring original file.'
)
if original_backup_file.exists():
shutil.copy2(original_backup_file, target_file)
failed_files.append(target_file.name)
Expand All @@ -143,7 +151,8 @@ def run_migrations(

def main() -> None:
parser = argparse.ArgumentParser(
description='Run script migration on all .sql/.sqlite/.db files in a directory, overwriting originals.'
description='Run script migration on all .sql/.sqlite/.db files in a directory, '
'overwriting originals.'
)
parser.add_argument(
'--input_dir',
Expand Down