refactor(api): Improve database connection management and reduce boilerplate.

- Add context manager for database connections with proper cleanup
- Add @require_sqlite decorator to eliminate duplicate checks
- Refactor 9 core CRUD methods to use managed connections
- Reduce code by 50 lines while improving resource management
- All 114 tests passing
This commit is contained in:
Elisiário Couto
2025-12-08 22:54:57 +00:00
parent 7007043521
commit 267db8ac63

View File

@@ -1,6 +1,8 @@
import json
import sqlite3
from contextlib import contextmanager
from datetime import datetime
from functools import wraps
from typing import Any, Dict, List, Optional
from loguru import logger
@@ -14,6 +16,25 @@ from leggen.utils.config import config
from leggen.utils.paths import path_manager
def require_sqlite(func):
"""Decorator to check if SQLite is enabled before executing method"""
@wraps(func)
async def wrapper(self, *args, **kwargs):
if not self.sqlite_enabled:
logger.warning(f"SQLite database disabled, skipping {func.__name__}")
# Return appropriate default based on return type hints
return_type = func.__annotations__.get("return")
if return_type is int:
return 0
elif return_type in (list, List[Dict[str, Any]]):
return []
return None
return await func(self, *args, **kwargs)
return wrapper
class DatabaseService:
def __init__(self):
self.db_config = config.database_config
@@ -24,24 +45,33 @@ class DatabaseService:
self.balance_transformer = BalanceTransformer()
self.analytics_processor = AnalyticsProcessor()
@contextmanager
def _get_db_connection(self, row_factory: bool = False):
"""Context manager for database connections with proper cleanup"""
db_path = path_manager.get_database_path()
conn = sqlite3.connect(str(db_path))
if row_factory:
conn.row_factory = sqlite3.Row
try:
yield conn
except Exception as e:
conn.rollback()
raise e
finally:
conn.close()
@require_sqlite
async def persist_balance(
self, account_id: str, balance_data: Dict[str, Any]
) -> None:
"""Persist account balance data"""
if not self.sqlite_enabled:
logger.warning("SQLite database disabled, skipping balance persistence")
return
await self._persist_balance_sqlite(account_id, balance_data)
@require_sqlite
async def persist_transactions(
self, account_id: str, transactions: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Persist transactions and return new transactions"""
if not self.sqlite_enabled:
logger.warning("SQLite database disabled, skipping transaction persistence")
return transactions
return await self._persist_transactions_sqlite(account_id, transactions)
def process_transactions(
@@ -55,6 +85,7 @@ class DatabaseService:
account_id, account_info, transaction_data
)
@require_sqlite
async def get_transactions_from_db(
self,
account_id: Optional[str] = None,
@@ -67,10 +98,6 @@ class DatabaseService:
search: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""Get transactions from SQLite database"""
if not self.sqlite_enabled:
logger.warning("SQLite database disabled, cannot read transactions")
return []
try:
transactions = self._get_transactions(
account_id=account_id,
@@ -88,6 +115,7 @@ class DatabaseService:
logger.error(f"Failed to get transactions from database: {e}")
return []
@require_sqlite
async def get_transaction_count_from_db(
self,
account_id: Optional[str] = None,
@@ -98,9 +126,6 @@ class DatabaseService:
search: Optional[str] = None,
) -> int:
"""Get total count of transactions from SQLite database"""
if not self.sqlite_enabled:
return 0
try:
filters = {
"date_from": date_from,
@@ -119,14 +144,11 @@ class DatabaseService:
logger.error(f"Failed to get transaction count from database: {e}")
return 0
@require_sqlite
async def get_balances_from_db(
self, account_id: Optional[str] = None
) -> List[Dict[str, Any]]:
"""Get balances from SQLite database"""
if not self.sqlite_enabled:
logger.warning("SQLite database disabled, cannot read balances")
return []
try:
balances = self._get_balances(account_id=account_id)
logger.debug(f"Retrieved {len(balances)} balances from database")
@@ -135,14 +157,11 @@ class DatabaseService:
logger.error(f"Failed to get balances from database: {e}")
return []
@require_sqlite
async def get_historical_balances_from_db(
self, account_id: Optional[str] = None, days: int = 365
) -> List[Dict[str, Any]]:
"""Get historical balance progression from SQLite database"""
if not self.sqlite_enabled:
logger.warning("SQLite database disabled, cannot read historical balances")
return []
try:
db_path = path_manager.get_database_path()
balances = self.analytics_processor.calculate_historical_balances(
@@ -156,13 +175,11 @@ class DatabaseService:
logger.error(f"Failed to get historical balances from database: {e}")
return []
@require_sqlite
async def get_account_summary_from_db(
self, account_id: str
) -> Optional[Dict[str, Any]]:
"""Get basic account info from SQLite database (avoids GoCardless call)"""
if not self.sqlite_enabled:
return None
try:
summary = self._get_account_summary(account_id)
if summary:
@@ -174,22 +191,16 @@ class DatabaseService:
logger.error(f"Failed to get account summary from database: {e}")
return None
@require_sqlite
async def persist_account_details(self, account_data: Dict[str, Any]) -> None:
"""Persist account details to database"""
if not self.sqlite_enabled:
logger.warning("SQLite database disabled, skipping account persistence")
return
await self._persist_account_details_sqlite(account_data)
@require_sqlite
async def get_accounts_from_db(
self, account_ids: Optional[List[str]] = None
) -> List[Dict[str, Any]]:
"""Get account details from database"""
if not self.sqlite_enabled:
logger.warning("SQLite database disabled, cannot read accounts")
return []
try:
accounts = self._get_accounts(account_ids=account_ids)
logger.debug(f"Retrieved {len(accounts)} accounts from database")
@@ -198,14 +209,11 @@ class DatabaseService:
logger.error(f"Failed to get accounts from database: {e}")
return []
@require_sqlite
async def get_account_details_from_db(
self, account_id: str
) -> Optional[Dict[str, Any]]:
"""Get specific account details from database"""
if not self.sqlite_enabled:
logger.warning("SQLite database disabled, cannot read account")
return None
try:
account = self._get_account(account_id)
if account:
@@ -729,10 +737,7 @@ class DatabaseService:
) -> None:
"""Persist balance to SQLite"""
try:
import sqlite3
db_path = path_manager.get_database_path()
conn = sqlite3.connect(str(db_path))
with self._get_db_connection() as conn:
cursor = conn.cursor()
# Create the balances table if it doesn't exist
@@ -788,7 +793,6 @@ class DatabaseService:
logger.warning(f"Skipped duplicate balance for {account_id}")
conn.commit()
conn.close()
logger.info(f"Persisted balances to SQLite for account {account_id}")
except Exception as e:
@@ -800,11 +804,7 @@ class DatabaseService:
) -> List[Dict[str, Any]]:
"""Persist transactions to SQLite"""
try:
import json
import sqlite3
db_path = path_manager.get_database_path()
conn = sqlite3.connect(str(db_path))
with self._get_db_connection() as conn:
cursor = conn.cursor()
# The table should already exist with the new schema from migration
@@ -899,7 +899,6 @@ class DatabaseService:
continue
conn.commit()
conn.close()
logger.info(
f"Persisted {len(new_transactions)} new transactions to SQLite for account {account_id}"
@@ -939,8 +938,8 @@ class DatabaseService:
db_path = path_manager.get_database_path()
if not db_path.exists():
return []
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row # Enable dict-like access
with self._get_db_connection(row_factory=True) as conn:
cursor = conn.cursor()
# Build query with filters
@@ -982,7 +981,6 @@ class DatabaseService:
query += " OFFSET ?"
params.append(offset)
try:
cursor.execute(query, params)
rows = cursor.fetchall()
@@ -996,20 +994,15 @@ class DatabaseService:
)
transactions.append(transaction)
conn.close()
return transactions
except Exception as e:
conn.close()
raise e
def _get_balances(self, account_id=None):
"""Get latest balances from SQLite database"""
db_path = path_manager.get_database_path()
if not db_path.exists():
return []
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
with self._get_db_connection(row_factory=True) as conn:
cursor = conn.cursor()
# Get latest balance for each account_id and type combination
@@ -1029,28 +1022,20 @@ class DatabaseService:
query += " ORDER BY b1.account_id, b1.type"
try:
cursor.execute(query, params)
rows = cursor.fetchall()
balances = [dict(row) for row in rows]
conn.close()
return balances
except Exception as e:
conn.close()
raise e
return [dict(row) for row in rows]
def _get_account_summary(self, account_id):
"""Get basic account info from transactions table (avoids GoCardless API call)"""
db_path = path_manager.get_database_path()
if not db_path.exists():
return None
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
with self._get_db_connection(row_factory=True) as conn:
cursor = conn.cursor()
try:
# Get account info from most recent transaction
cursor.execute(
"""
@@ -1064,22 +1049,17 @@ class DatabaseService:
)
row = cursor.fetchone()
conn.close()
if row:
return dict(row)
return None
except Exception as e:
conn.close()
raise e
def _get_transaction_count(self, account_id=None, **filters):
"""Get total count of transactions matching filters"""
db_path = path_manager.get_database_path()
if not db_path.exists():
return 0
conn = sqlite3.connect(str(db_path))
with self._get_db_connection() as conn:
cursor = conn.cursor()
query = "SELECT COUNT(*) FROM transactions WHERE 1=1"
@@ -1110,20 +1090,12 @@ class DatabaseService:
query += " AND description LIKE ?"
params.append(f"%{filters['search']}%")
try:
cursor.execute(query, params)
count = cursor.fetchone()[0]
conn.close()
return count
except Exception as e:
conn.close()
raise e
return cursor.fetchone()[0]
def _persist_account(self, account_data: dict):
"""Persist account details to SQLite database"""
db_path = path_manager.get_database_path()
conn = sqlite3.connect(str(db_path))
with self._get_db_connection() as conn:
cursor = conn.cursor()
# Create the accounts table if it doesn't exist
@@ -1153,7 +1125,6 @@ class DatabaseService:
ON accounts(status)"""
)
try:
# First, check if account exists and preserve display_name
cursor.execute(
"SELECT display_name FROM accounts WHERE id = ?", (account_data["id"],)
@@ -1194,21 +1165,16 @@ class DatabaseService:
),
)
conn.commit()
conn.close()
return account_data
except Exception as e:
conn.close()
raise e
def _get_accounts(self, account_ids=None):
"""Get account details from SQLite database"""
db_path = path_manager.get_database_path()
if not db_path.exists():
return []
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
with self._get_db_connection(row_factory=True) as conn:
cursor = conn.cursor()
query = "SELECT * FROM accounts"
@@ -1221,40 +1187,28 @@ class DatabaseService:
query += " ORDER BY created DESC"
try:
cursor.execute(query, params)
rows = cursor.fetchall()
accounts = [dict(row) for row in rows]
conn.close()
return accounts
except Exception as e:
conn.close()
raise e
return [dict(row) for row in rows]
def _get_account(self, account_id: str):
"""Get specific account details from SQLite database"""
db_path = path_manager.get_database_path()
if not db_path.exists():
return None
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
with self._get_db_connection(row_factory=True) as conn:
cursor = conn.cursor()
try:
cursor.execute("SELECT * FROM accounts WHERE id = ?", (account_id,))
row = cursor.fetchone()
conn.close()
if row:
return dict(row)
return None
except Exception as e:
conn.close()
raise e
@require_sqlite
async def get_monthly_transaction_stats_from_db(
self,
account_id: Optional[str] = None,
@@ -1262,10 +1216,6 @@ class DatabaseService:
date_to: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""Get monthly transaction statistics aggregated by the database"""
if not self.sqlite_enabled:
logger.warning("SQLite database disabled, cannot read monthly stats")
return []
try:
db_path = path_manager.get_database_path()
monthly_stats = self.analytics_processor.calculate_monthly_stats(