diff --git a/leggen/services/database_service.py b/leggen/services/database_service.py index d1c68c6..4d1a4e6 100644 --- a/leggen/services/database_service.py +++ b/leggen/services/database_service.py @@ -1,6 +1,8 @@ import json import sqlite3 +from contextlib import contextmanager from datetime import datetime +from functools import wraps from typing import Any, Dict, List, Optional from loguru import logger @@ -14,6 +16,25 @@ from leggen.utils.config import config from leggen.utils.paths import path_manager +def require_sqlite(func): + """Decorator to check if SQLite is enabled before executing method""" + + @wraps(func) + async def wrapper(self, *args, **kwargs): + if not self.sqlite_enabled: + logger.warning(f"SQLite database disabled, skipping {func.__name__}") + # Return appropriate default based on return type hints + return_type = func.__annotations__.get("return") + if return_type is int: + return 0 + elif return_type in (list, List[Dict[str, Any]]): + return [] + return None + return await func(self, *args, **kwargs) + + return wrapper + + class DatabaseService: def __init__(self): self.db_config = config.database_config @@ -24,24 +45,33 @@ class DatabaseService: self.balance_transformer = BalanceTransformer() self.analytics_processor = AnalyticsProcessor() + @contextmanager + def _get_db_connection(self, row_factory: bool = False): + """Context manager for database connections with proper cleanup""" + db_path = path_manager.get_database_path() + conn = sqlite3.connect(str(db_path)) + if row_factory: + conn.row_factory = sqlite3.Row + try: + yield conn + except Exception as e: + conn.rollback() + raise e + finally: + conn.close() + + @require_sqlite async def persist_balance( self, account_id: str, balance_data: Dict[str, Any] ) -> None: """Persist account balance data""" - if not self.sqlite_enabled: - logger.warning("SQLite database disabled, skipping balance persistence") - return - await self._persist_balance_sqlite(account_id, balance_data) + @require_sqlite async def persist_transactions( self, account_id: str, transactions: List[Dict[str, Any]] ) -> List[Dict[str, Any]]: """Persist transactions and return new transactions""" - if not self.sqlite_enabled: - logger.warning("SQLite database disabled, skipping transaction persistence") - return transactions - return await self._persist_transactions_sqlite(account_id, transactions) def process_transactions( @@ -55,6 +85,7 @@ class DatabaseService: account_id, account_info, transaction_data ) + @require_sqlite async def get_transactions_from_db( self, account_id: Optional[str] = None, @@ -67,10 +98,6 @@ class DatabaseService: search: Optional[str] = None, ) -> List[Dict[str, Any]]: """Get transactions from SQLite database""" - if not self.sqlite_enabled: - logger.warning("SQLite database disabled, cannot read transactions") - return [] - try: transactions = self._get_transactions( account_id=account_id, @@ -88,6 +115,7 @@ class DatabaseService: logger.error(f"Failed to get transactions from database: {e}") return [] + @require_sqlite async def get_transaction_count_from_db( self, account_id: Optional[str] = None, @@ -98,9 +126,6 @@ class DatabaseService: search: Optional[str] = None, ) -> int: """Get total count of transactions from SQLite database""" - if not self.sqlite_enabled: - return 0 - try: filters = { "date_from": date_from, @@ -119,14 +144,11 @@ class DatabaseService: logger.error(f"Failed to get transaction count from database: {e}") return 0 + @require_sqlite async def get_balances_from_db( self, account_id: Optional[str] = None ) -> List[Dict[str, Any]]: """Get balances from SQLite database""" - if not self.sqlite_enabled: - logger.warning("SQLite database disabled, cannot read balances") - return [] - try: balances = self._get_balances(account_id=account_id) logger.debug(f"Retrieved {len(balances)} balances from database") @@ -135,14 +157,11 @@ class DatabaseService: logger.error(f"Failed to get balances from database: {e}") return [] + @require_sqlite async def get_historical_balances_from_db( self, account_id: Optional[str] = None, days: int = 365 ) -> List[Dict[str, Any]]: """Get historical balance progression from SQLite database""" - if not self.sqlite_enabled: - logger.warning("SQLite database disabled, cannot read historical balances") - return [] - try: db_path = path_manager.get_database_path() balances = self.analytics_processor.calculate_historical_balances( @@ -156,13 +175,11 @@ class DatabaseService: logger.error(f"Failed to get historical balances from database: {e}") return [] + @require_sqlite async def get_account_summary_from_db( self, account_id: str ) -> Optional[Dict[str, Any]]: """Get basic account info from SQLite database (avoids GoCardless call)""" - if not self.sqlite_enabled: - return None - try: summary = self._get_account_summary(account_id) if summary: @@ -174,22 +191,16 @@ class DatabaseService: logger.error(f"Failed to get account summary from database: {e}") return None + @require_sqlite async def persist_account_details(self, account_data: Dict[str, Any]) -> None: """Persist account details to database""" - if not self.sqlite_enabled: - logger.warning("SQLite database disabled, skipping account persistence") - return - await self._persist_account_details_sqlite(account_data) + @require_sqlite async def get_accounts_from_db( self, account_ids: Optional[List[str]] = None ) -> List[Dict[str, Any]]: """Get account details from database""" - if not self.sqlite_enabled: - logger.warning("SQLite database disabled, cannot read accounts") - return [] - try: accounts = self._get_accounts(account_ids=account_ids) logger.debug(f"Retrieved {len(accounts)} accounts from database") @@ -198,14 +209,11 @@ class DatabaseService: logger.error(f"Failed to get accounts from database: {e}") return [] + @require_sqlite async def get_account_details_from_db( self, account_id: str ) -> Optional[Dict[str, Any]]: """Get specific account details from database""" - if not self.sqlite_enabled: - logger.warning("SQLite database disabled, cannot read account") - return None - try: account = self._get_account(account_id) if account: @@ -729,66 +737,62 @@ class DatabaseService: ) -> None: """Persist balance to SQLite""" try: - import sqlite3 + with self._get_db_connection() as conn: + cursor = conn.cursor() - db_path = path_manager.get_database_path() - conn = sqlite3.connect(str(db_path)) - cursor = conn.cursor() + # Create the balances table if it doesn't exist + cursor.execute( + """CREATE TABLE IF NOT EXISTS balances ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + account_id TEXT, + bank TEXT, + status TEXT, + iban TEXT, + amount REAL, + currency TEXT, + type TEXT, + timestamp DATETIME + )""" + ) - # Create the balances table if it doesn't exist - cursor.execute( - """CREATE TABLE IF NOT EXISTS balances ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - account_id TEXT, - bank TEXT, - status TEXT, - iban TEXT, - amount REAL, - currency TEXT, - type TEXT, - timestamp DATETIME - )""" - ) + # Create indexes for better performance + cursor.execute( + """CREATE INDEX IF NOT EXISTS idx_balances_account_id + ON balances(account_id)""" + ) + cursor.execute( + """CREATE INDEX IF NOT EXISTS idx_balances_timestamp + ON balances(timestamp)""" + ) + cursor.execute( + """CREATE INDEX IF NOT EXISTS idx_balances_account_type_timestamp + ON balances(account_id, type, timestamp)""" + ) - # Create indexes for better performance - cursor.execute( - """CREATE INDEX IF NOT EXISTS idx_balances_account_id - ON balances(account_id)""" - ) - cursor.execute( - """CREATE INDEX IF NOT EXISTS idx_balances_timestamp - ON balances(timestamp)""" - ) - cursor.execute( - """CREATE INDEX IF NOT EXISTS idx_balances_account_type_timestamp - ON balances(account_id, type, timestamp)""" - ) + # Transform and persist balances + balance_rows = self.balance_transformer.transform_to_database_format( + account_id, balance_data + ) - # Transform and persist balances - balance_rows = self.balance_transformer.transform_to_database_format( - account_id, balance_data - ) + for row in balance_rows: + try: + cursor.execute( + """INSERT INTO balances ( + account_id, + bank, + status, + iban, + amount, + currency, + type, + timestamp + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", + row, + ) + except sqlite3.IntegrityError: + logger.warning(f"Skipped duplicate balance for {account_id}") - for row in balance_rows: - try: - cursor.execute( - """INSERT INTO balances ( - account_id, - bank, - status, - iban, - amount, - currency, - type, - timestamp - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", - row, - ) - except sqlite3.IntegrityError: - logger.warning(f"Skipped duplicate balance for {account_id}") - - conn.commit() - conn.close() + conn.commit() logger.info(f"Persisted balances to SQLite for account {account_id}") except Exception as e: @@ -800,106 +804,101 @@ class DatabaseService: ) -> List[Dict[str, Any]]: """Persist transactions to SQLite""" try: - import json - import sqlite3 + with self._get_db_connection() as conn: + cursor = conn.cursor() - db_path = path_manager.get_database_path() - conn = sqlite3.connect(str(db_path)) - cursor = conn.cursor() + # The table should already exist with the new schema from migration + # If it doesn't exist, create it (for new installations) + cursor.execute( + """CREATE TABLE IF NOT EXISTS transactions ( + accountId TEXT NOT NULL, + transactionId TEXT NOT NULL, + internalTransactionId TEXT, + institutionId TEXT, + iban TEXT, + transactionDate DATETIME, + description TEXT, + transactionValue REAL, + transactionCurrency TEXT, + transactionStatus TEXT, + rawTransaction JSON, + PRIMARY KEY (accountId, transactionId) + )""" + ) - # The table should already exist with the new schema from migration - # If it doesn't exist, create it (for new installations) - cursor.execute( - """CREATE TABLE IF NOT EXISTS transactions ( - accountId TEXT NOT NULL, - transactionId TEXT NOT NULL, - internalTransactionId TEXT, - institutionId TEXT, - iban TEXT, - transactionDate DATETIME, - description TEXT, - transactionValue REAL, - transactionCurrency TEXT, - transactionStatus TEXT, - rawTransaction JSON, - PRIMARY KEY (accountId, transactionId) - )""" - ) + # Create indexes for better performance (if they don't exist) + cursor.execute( + """CREATE INDEX IF NOT EXISTS idx_transactions_internal_id + ON transactions(internalTransactionId)""" + ) + cursor.execute( + """CREATE INDEX IF NOT EXISTS idx_transactions_date + ON transactions(transactionDate)""" + ) + cursor.execute( + """CREATE INDEX IF NOT EXISTS idx_transactions_account_date + ON transactions(accountId, transactionDate)""" + ) + cursor.execute( + """CREATE INDEX IF NOT EXISTS idx_transactions_amount + ON transactions(transactionValue)""" + ) - # Create indexes for better performance (if they don't exist) - cursor.execute( - """CREATE INDEX IF NOT EXISTS idx_transactions_internal_id - ON transactions(internalTransactionId)""" - ) - cursor.execute( - """CREATE INDEX IF NOT EXISTS idx_transactions_date - ON transactions(transactionDate)""" - ) - cursor.execute( - """CREATE INDEX IF NOT EXISTS idx_transactions_account_date - ON transactions(accountId, transactionDate)""" - ) - cursor.execute( - """CREATE INDEX IF NOT EXISTS idx_transactions_amount - ON transactions(transactionValue)""" - ) + # Prepare an SQL statement for inserting/replacing data + insert_sql = """INSERT OR REPLACE INTO transactions ( + accountId, + transactionId, + internalTransactionId, + institutionId, + iban, + transactionDate, + description, + transactionValue, + transactionCurrency, + transactionStatus, + rawTransaction + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""" - # Prepare an SQL statement for inserting/replacing data - insert_sql = """INSERT OR REPLACE INTO transactions ( - accountId, - transactionId, - internalTransactionId, - institutionId, - iban, - transactionDate, - description, - transactionValue, - transactionCurrency, - transactionStatus, - rawTransaction - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""" + new_transactions = [] - new_transactions = [] + for transaction in transactions: + try: + # Check if transaction already exists before insertion + cursor.execute( + """SELECT COUNT(*) FROM transactions + WHERE accountId = ? AND transactionId = ?""", + (transaction["accountId"], transaction["transactionId"]), + ) + exists = cursor.fetchone()[0] > 0 - for transaction in transactions: - try: - # Check if transaction already exists before insertion - cursor.execute( - """SELECT COUNT(*) FROM transactions - WHERE accountId = ? AND transactionId = ?""", - (transaction["accountId"], transaction["transactionId"]), - ) - exists = cursor.fetchone()[0] > 0 + cursor.execute( + insert_sql, + ( + transaction["accountId"], + transaction["transactionId"], + transaction.get("internalTransactionId"), + transaction["institutionId"], + transaction["iban"], + transaction["transactionDate"], + transaction["description"], + transaction["transactionValue"], + transaction["transactionCurrency"], + transaction["transactionStatus"], + json.dumps(transaction["rawTransaction"]), + ), + ) - cursor.execute( - insert_sql, - ( - transaction["accountId"], - transaction["transactionId"], - transaction.get("internalTransactionId"), - transaction["institutionId"], - transaction["iban"], - transaction["transactionDate"], - transaction["description"], - transaction["transactionValue"], - transaction["transactionCurrency"], - transaction["transactionStatus"], - json.dumps(transaction["rawTransaction"]), - ), - ) + # Only add to new_transactions if it didn't exist before + if not exists: + new_transactions.append(transaction) - # Only add to new_transactions if it didn't exist before - if not exists: - new_transactions.append(transaction) + except sqlite3.IntegrityError as e: + logger.warning( + f"Failed to insert transaction {transaction.get('transactionId')}: {e}" + ) + continue - except sqlite3.IntegrityError as e: - logger.warning( - f"Failed to insert transaction {transaction.get('transactionId')}: {e}" - ) - continue - - conn.commit() - conn.close() + conn.commit() logger.info( f"Persisted {len(new_transactions)} new transactions to SQLite for account {account_id}" @@ -939,50 +938,49 @@ class DatabaseService: db_path = path_manager.get_database_path() if not db_path.exists(): return [] - conn = sqlite3.connect(str(db_path)) - conn.row_factory = sqlite3.Row # Enable dict-like access - cursor = conn.cursor() - # Build query with filters - query = "SELECT * FROM transactions WHERE 1=1" - params = [] + with self._get_db_connection(row_factory=True) as conn: + cursor = conn.cursor() - if account_id: - query += " AND accountId = ?" - params.append(account_id) + # Build query with filters + query = "SELECT * FROM transactions WHERE 1=1" + params = [] - if date_from: - query += " AND transactionDate >= ?" - params.append(date_from) + if account_id: + query += " AND accountId = ?" + params.append(account_id) - if date_to: - query += " AND transactionDate <= ?" - params.append(date_to) + if date_from: + query += " AND transactionDate >= ?" + params.append(date_from) - if min_amount is not None: - query += " AND transactionValue >= ?" - params.append(min_amount) + if date_to: + query += " AND transactionDate <= ?" + params.append(date_to) - if max_amount is not None: - query += " AND transactionValue <= ?" - params.append(max_amount) + if min_amount is not None: + query += " AND transactionValue >= ?" + params.append(min_amount) - if search: - query += " AND description LIKE ?" - params.append(f"%{search}%") + if max_amount is not None: + query += " AND transactionValue <= ?" + params.append(max_amount) - # Add ordering and pagination - query += " ORDER BY transactionDate DESC" + if search: + query += " AND description LIKE ?" + params.append(f"%{search}%") - if limit: - query += " LIMIT ?" - params.append(limit) + # Add ordering and pagination + query += " ORDER BY transactionDate DESC" - if offset: - query += " OFFSET ?" - params.append(offset) + if limit: + query += " LIMIT ?" + params.append(limit) + + if offset: + query += " OFFSET ?" + params.append(offset) - try: cursor.execute(query, params) rows = cursor.fetchall() @@ -996,61 +994,48 @@ class DatabaseService: ) transactions.append(transaction) - conn.close() return transactions - except Exception as e: - conn.close() - raise e - def _get_balances(self, account_id=None): """Get latest balances from SQLite database""" db_path = path_manager.get_database_path() if not db_path.exists(): return [] - conn = sqlite3.connect(str(db_path)) - conn.row_factory = sqlite3.Row - cursor = conn.cursor() - # Get latest balance for each account_id and type combination - query = """ - SELECT * FROM balances b1 - WHERE b1.timestamp = ( - SELECT MAX(b2.timestamp) - FROM balances b2 - WHERE b2.account_id = b1.account_id AND b2.type = b1.type - ) - """ - params = [] + with self._get_db_connection(row_factory=True) as conn: + cursor = conn.cursor() - if account_id: - query += " AND b1.account_id = ?" - params.append(account_id) + # Get latest balance for each account_id and type combination + query = """ + SELECT * FROM balances b1 + WHERE b1.timestamp = ( + SELECT MAX(b2.timestamp) + FROM balances b2 + WHERE b2.account_id = b1.account_id AND b2.type = b1.type + ) + """ + params = [] - query += " ORDER BY b1.account_id, b1.type" + if account_id: + query += " AND b1.account_id = ?" + params.append(account_id) + + query += " ORDER BY b1.account_id, b1.type" - try: cursor.execute(query, params) rows = cursor.fetchall() - balances = [dict(row) for row in rows] - conn.close() - return balances - - except Exception as e: - conn.close() - raise e + return [dict(row) for row in rows] def _get_account_summary(self, account_id): """Get basic account info from transactions table (avoids GoCardless API call)""" db_path = path_manager.get_database_path() if not db_path.exists(): return None - conn = sqlite3.connect(str(db_path)) - conn.row_factory = sqlite3.Row - cursor = conn.cursor() - try: + with self._get_db_connection(row_factory=True) as conn: + cursor = conn.cursor() + # Get account info from most recent transaction cursor.execute( """ @@ -1064,96 +1049,82 @@ class DatabaseService: ) row = cursor.fetchone() - conn.close() - if row: return dict(row) return None - except Exception as e: - conn.close() - raise e - def _get_transaction_count(self, account_id=None, **filters): """Get total count of transactions matching filters""" db_path = path_manager.get_database_path() if not db_path.exists(): return 0 - conn = sqlite3.connect(str(db_path)) - cursor = conn.cursor() - query = "SELECT COUNT(*) FROM transactions WHERE 1=1" - params = [] + with self._get_db_connection() as conn: + cursor = conn.cursor() - if account_id: - query += " AND accountId = ?" - params.append(account_id) + query = "SELECT COUNT(*) FROM transactions WHERE 1=1" + params = [] - # Add same filters as get_transactions - if filters.get("date_from"): - query += " AND transactionDate >= ?" - params.append(filters["date_from"]) + if account_id: + query += " AND accountId = ?" + params.append(account_id) - if filters.get("date_to"): - query += " AND transactionDate <= ?" - params.append(filters["date_to"]) + # Add same filters as get_transactions + if filters.get("date_from"): + query += " AND transactionDate >= ?" + params.append(filters["date_from"]) - if filters.get("min_amount") is not None: - query += " AND transactionValue >= ?" - params.append(filters["min_amount"]) + if filters.get("date_to"): + query += " AND transactionDate <= ?" + params.append(filters["date_to"]) - if filters.get("max_amount") is not None: - query += " AND transactionValue <= ?" - params.append(filters["max_amount"]) + if filters.get("min_amount") is not None: + query += " AND transactionValue >= ?" + params.append(filters["min_amount"]) - if filters.get("search"): - query += " AND description LIKE ?" - params.append(f"%{filters['search']}%") + if filters.get("max_amount") is not None: + query += " AND transactionValue <= ?" + params.append(filters["max_amount"]) + + if filters.get("search"): + query += " AND description LIKE ?" + params.append(f"%{filters['search']}%") - try: cursor.execute(query, params) - count = cursor.fetchone()[0] - conn.close() - return count - - except Exception as e: - conn.close() - raise e + return cursor.fetchone()[0] def _persist_account(self, account_data: dict): """Persist account details to SQLite database""" - db_path = path_manager.get_database_path() - conn = sqlite3.connect(str(db_path)) - cursor = conn.cursor() + with self._get_db_connection() as conn: + cursor = conn.cursor() - # Create the accounts table if it doesn't exist - cursor.execute( - """CREATE TABLE IF NOT EXISTS accounts ( - id TEXT PRIMARY KEY, - institution_id TEXT, - status TEXT, - iban TEXT, - name TEXT, - currency TEXT, - created DATETIME, - last_accessed DATETIME, - last_updated DATETIME, - display_name TEXT, - logo TEXT - )""" - ) + # Create the accounts table if it doesn't exist + cursor.execute( + """CREATE TABLE IF NOT EXISTS accounts ( + id TEXT PRIMARY KEY, + institution_id TEXT, + status TEXT, + iban TEXT, + name TEXT, + currency TEXT, + created DATETIME, + last_accessed DATETIME, + last_updated DATETIME, + display_name TEXT, + logo TEXT + )""" + ) - # Create indexes for accounts table - cursor.execute( - """CREATE INDEX IF NOT EXISTS idx_accounts_institution_id - ON accounts(institution_id)""" - ) - cursor.execute( - """CREATE INDEX IF NOT EXISTS idx_accounts_status - ON accounts(status)""" - ) + # Create indexes for accounts table + cursor.execute( + """CREATE INDEX IF NOT EXISTS idx_accounts_institution_id + ON accounts(institution_id)""" + ) + cursor.execute( + """CREATE INDEX IF NOT EXISTS idx_accounts_status + ON accounts(status)""" + ) - try: # First, check if account exists and preserve display_name cursor.execute( "SELECT display_name FROM accounts WHERE id = ?", (account_data["id"],) @@ -1194,67 +1165,50 @@ class DatabaseService: ), ) conn.commit() - conn.close() return account_data - except Exception as e: - conn.close() - raise e - def _get_accounts(self, account_ids=None): """Get account details from SQLite database""" db_path = path_manager.get_database_path() if not db_path.exists(): return [] - conn = sqlite3.connect(str(db_path)) - conn.row_factory = sqlite3.Row - cursor = conn.cursor() - query = "SELECT * FROM accounts" - params = [] + with self._get_db_connection(row_factory=True) as conn: + cursor = conn.cursor() - if account_ids: - placeholders = ",".join("?" * len(account_ids)) - query += f" WHERE id IN ({placeholders})" - params.extend(account_ids) + query = "SELECT * FROM accounts" + params = [] - query += " ORDER BY created DESC" + if account_ids: + placeholders = ",".join("?" * len(account_ids)) + query += f" WHERE id IN ({placeholders})" + params.extend(account_ids) + + query += " ORDER BY created DESC" - try: cursor.execute(query, params) rows = cursor.fetchall() - accounts = [dict(row) for row in rows] - conn.close() - return accounts - - except Exception as e: - conn.close() - raise e + return [dict(row) for row in rows] def _get_account(self, account_id: str): """Get specific account details from SQLite database""" db_path = path_manager.get_database_path() if not db_path.exists(): return None - conn = sqlite3.connect(str(db_path)) - conn.row_factory = sqlite3.Row - cursor = conn.cursor() - try: + with self._get_db_connection(row_factory=True) as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM accounts WHERE id = ?", (account_id,)) row = cursor.fetchone() - conn.close() if row: return dict(row) return None - except Exception as e: - conn.close() - raise e - + @require_sqlite async def get_monthly_transaction_stats_from_db( self, account_id: Optional[str] = None, @@ -1262,10 +1216,6 @@ class DatabaseService: date_to: Optional[str] = None, ) -> List[Dict[str, Any]]: """Get monthly transaction statistics aggregated by the database""" - if not self.sqlite_enabled: - logger.warning("SQLite database disabled, cannot read monthly stats") - return [] - try: db_path = path_manager.get_database_path() monthly_stats = self.analytics_processor.calculate_monthly_stats(