refactor(api): Improve database connection management and reduce boilerplate.

- Add context manager for database connections with proper cleanup
- Add @require_sqlite decorator to eliminate duplicate checks
- Refactor 9 core CRUD methods to use managed connections
- Reduce code by 50 lines while improving resource management
- All 114 tests passing
This commit is contained in:
Elisiário Couto
2025-12-08 22:54:57 +00:00
parent 7007043521
commit 267db8ac63

View File

@@ -1,6 +1,8 @@
import json
import sqlite3
from contextlib import contextmanager
from datetime import datetime
from functools import wraps
from typing import Any, Dict, List, Optional
from loguru import logger
@@ -14,6 +16,25 @@ from leggen.utils.config import config
from leggen.utils.paths import path_manager
def require_sqlite(func):
"""Decorator to check if SQLite is enabled before executing method"""
@wraps(func)
async def wrapper(self, *args, **kwargs):
if not self.sqlite_enabled:
logger.warning(f"SQLite database disabled, skipping {func.__name__}")
# Return appropriate default based on return type hints
return_type = func.__annotations__.get("return")
if return_type is int:
return 0
elif return_type in (list, List[Dict[str, Any]]):
return []
return None
return await func(self, *args, **kwargs)
return wrapper
class DatabaseService:
def __init__(self):
self.db_config = config.database_config
@@ -24,24 +45,33 @@ class DatabaseService:
self.balance_transformer = BalanceTransformer()
self.analytics_processor = AnalyticsProcessor()
@contextmanager
def _get_db_connection(self, row_factory: bool = False):
"""Context manager for database connections with proper cleanup"""
db_path = path_manager.get_database_path()
conn = sqlite3.connect(str(db_path))
if row_factory:
conn.row_factory = sqlite3.Row
try:
yield conn
except Exception as e:
conn.rollback()
raise e
finally:
conn.close()
@require_sqlite
async def persist_balance(
self, account_id: str, balance_data: Dict[str, Any]
) -> None:
"""Persist account balance data"""
if not self.sqlite_enabled:
logger.warning("SQLite database disabled, skipping balance persistence")
return
await self._persist_balance_sqlite(account_id, balance_data)
@require_sqlite
async def persist_transactions(
self, account_id: str, transactions: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Persist transactions and return new transactions"""
if not self.sqlite_enabled:
logger.warning("SQLite database disabled, skipping transaction persistence")
return transactions
return await self._persist_transactions_sqlite(account_id, transactions)
def process_transactions(
@@ -55,6 +85,7 @@ class DatabaseService:
account_id, account_info, transaction_data
)
@require_sqlite
async def get_transactions_from_db(
self,
account_id: Optional[str] = None,
@@ -67,10 +98,6 @@ class DatabaseService:
search: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""Get transactions from SQLite database"""
if not self.sqlite_enabled:
logger.warning("SQLite database disabled, cannot read transactions")
return []
try:
transactions = self._get_transactions(
account_id=account_id,
@@ -88,6 +115,7 @@ class DatabaseService:
logger.error(f"Failed to get transactions from database: {e}")
return []
@require_sqlite
async def get_transaction_count_from_db(
self,
account_id: Optional[str] = None,
@@ -98,9 +126,6 @@ class DatabaseService:
search: Optional[str] = None,
) -> int:
"""Get total count of transactions from SQLite database"""
if not self.sqlite_enabled:
return 0
try:
filters = {
"date_from": date_from,
@@ -119,14 +144,11 @@ class DatabaseService:
logger.error(f"Failed to get transaction count from database: {e}")
return 0
@require_sqlite
async def get_balances_from_db(
self, account_id: Optional[str] = None
) -> List[Dict[str, Any]]:
"""Get balances from SQLite database"""
if not self.sqlite_enabled:
logger.warning("SQLite database disabled, cannot read balances")
return []
try:
balances = self._get_balances(account_id=account_id)
logger.debug(f"Retrieved {len(balances)} balances from database")
@@ -135,14 +157,11 @@ class DatabaseService:
logger.error(f"Failed to get balances from database: {e}")
return []
@require_sqlite
async def get_historical_balances_from_db(
self, account_id: Optional[str] = None, days: int = 365
) -> List[Dict[str, Any]]:
"""Get historical balance progression from SQLite database"""
if not self.sqlite_enabled:
logger.warning("SQLite database disabled, cannot read historical balances")
return []
try:
db_path = path_manager.get_database_path()
balances = self.analytics_processor.calculate_historical_balances(
@@ -156,13 +175,11 @@ class DatabaseService:
logger.error(f"Failed to get historical balances from database: {e}")
return []
@require_sqlite
async def get_account_summary_from_db(
self, account_id: str
) -> Optional[Dict[str, Any]]:
"""Get basic account info from SQLite database (avoids GoCardless call)"""
if not self.sqlite_enabled:
return None
try:
summary = self._get_account_summary(account_id)
if summary:
@@ -174,22 +191,16 @@ class DatabaseService:
logger.error(f"Failed to get account summary from database: {e}")
return None
@require_sqlite
async def persist_account_details(self, account_data: Dict[str, Any]) -> None:
"""Persist account details to database"""
if not self.sqlite_enabled:
logger.warning("SQLite database disabled, skipping account persistence")
return
await self._persist_account_details_sqlite(account_data)
@require_sqlite
async def get_accounts_from_db(
self, account_ids: Optional[List[str]] = None
) -> List[Dict[str, Any]]:
"""Get account details from database"""
if not self.sqlite_enabled:
logger.warning("SQLite database disabled, cannot read accounts")
return []
try:
accounts = self._get_accounts(account_ids=account_ids)
logger.debug(f"Retrieved {len(accounts)} accounts from database")
@@ -198,14 +209,11 @@ class DatabaseService:
logger.error(f"Failed to get accounts from database: {e}")
return []
@require_sqlite
async def get_account_details_from_db(
self, account_id: str
) -> Optional[Dict[str, Any]]:
"""Get specific account details from database"""
if not self.sqlite_enabled:
logger.warning("SQLite database disabled, cannot read account")
return None
try:
account = self._get_account(account_id)
if account:
@@ -729,66 +737,62 @@ class DatabaseService:
) -> None:
"""Persist balance to SQLite"""
try:
import sqlite3
with self._get_db_connection() as conn:
cursor = conn.cursor()
db_path = path_manager.get_database_path()
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
# Create the balances table if it doesn't exist
cursor.execute(
"""CREATE TABLE IF NOT EXISTS balances (
id INTEGER PRIMARY KEY AUTOINCREMENT,
account_id TEXT,
bank TEXT,
status TEXT,
iban TEXT,
amount REAL,
currency TEXT,
type TEXT,
timestamp DATETIME
)"""
)
# Create the balances table if it doesn't exist
cursor.execute(
"""CREATE TABLE IF NOT EXISTS balances (
id INTEGER PRIMARY KEY AUTOINCREMENT,
account_id TEXT,
bank TEXT,
status TEXT,
iban TEXT,
amount REAL,
currency TEXT,
type TEXT,
timestamp DATETIME
)"""
)
# Create indexes for better performance
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_balances_account_id
ON balances(account_id)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_balances_timestamp
ON balances(timestamp)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_balances_account_type_timestamp
ON balances(account_id, type, timestamp)"""
)
# Create indexes for better performance
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_balances_account_id
ON balances(account_id)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_balances_timestamp
ON balances(timestamp)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_balances_account_type_timestamp
ON balances(account_id, type, timestamp)"""
)
# Transform and persist balances
balance_rows = self.balance_transformer.transform_to_database_format(
account_id, balance_data
)
# Transform and persist balances
balance_rows = self.balance_transformer.transform_to_database_format(
account_id, balance_data
)
for row in balance_rows:
try:
cursor.execute(
"""INSERT INTO balances (
account_id,
bank,
status,
iban,
amount,
currency,
type,
timestamp
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
row,
)
except sqlite3.IntegrityError:
logger.warning(f"Skipped duplicate balance for {account_id}")
for row in balance_rows:
try:
cursor.execute(
"""INSERT INTO balances (
account_id,
bank,
status,
iban,
amount,
currency,
type,
timestamp
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
row,
)
except sqlite3.IntegrityError:
logger.warning(f"Skipped duplicate balance for {account_id}")
conn.commit()
conn.close()
conn.commit()
logger.info(f"Persisted balances to SQLite for account {account_id}")
except Exception as e:
@@ -800,106 +804,101 @@ class DatabaseService:
) -> List[Dict[str, Any]]:
"""Persist transactions to SQLite"""
try:
import json
import sqlite3
with self._get_db_connection() as conn:
cursor = conn.cursor()
db_path = path_manager.get_database_path()
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
# The table should already exist with the new schema from migration
# If it doesn't exist, create it (for new installations)
cursor.execute(
"""CREATE TABLE IF NOT EXISTS transactions (
accountId TEXT NOT NULL,
transactionId TEXT NOT NULL,
internalTransactionId TEXT,
institutionId TEXT,
iban TEXT,
transactionDate DATETIME,
description TEXT,
transactionValue REAL,
transactionCurrency TEXT,
transactionStatus TEXT,
rawTransaction JSON,
PRIMARY KEY (accountId, transactionId)
)"""
)
# The table should already exist with the new schema from migration
# If it doesn't exist, create it (for new installations)
cursor.execute(
"""CREATE TABLE IF NOT EXISTS transactions (
accountId TEXT NOT NULL,
transactionId TEXT NOT NULL,
internalTransactionId TEXT,
institutionId TEXT,
iban TEXT,
transactionDate DATETIME,
description TEXT,
transactionValue REAL,
transactionCurrency TEXT,
transactionStatus TEXT,
rawTransaction JSON,
PRIMARY KEY (accountId, transactionId)
)"""
)
# Create indexes for better performance (if they don't exist)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_transactions_internal_id
ON transactions(internalTransactionId)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_transactions_date
ON transactions(transactionDate)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_transactions_account_date
ON transactions(accountId, transactionDate)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_transactions_amount
ON transactions(transactionValue)"""
)
# Create indexes for better performance (if they don't exist)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_transactions_internal_id
ON transactions(internalTransactionId)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_transactions_date
ON transactions(transactionDate)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_transactions_account_date
ON transactions(accountId, transactionDate)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_transactions_amount
ON transactions(transactionValue)"""
)
# Prepare an SQL statement for inserting/replacing data
insert_sql = """INSERT OR REPLACE INTO transactions (
accountId,
transactionId,
internalTransactionId,
institutionId,
iban,
transactionDate,
description,
transactionValue,
transactionCurrency,
transactionStatus,
rawTransaction
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"""
# Prepare an SQL statement for inserting/replacing data
insert_sql = """INSERT OR REPLACE INTO transactions (
accountId,
transactionId,
internalTransactionId,
institutionId,
iban,
transactionDate,
description,
transactionValue,
transactionCurrency,
transactionStatus,
rawTransaction
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"""
new_transactions = []
new_transactions = []
for transaction in transactions:
try:
# Check if transaction already exists before insertion
cursor.execute(
"""SELECT COUNT(*) FROM transactions
WHERE accountId = ? AND transactionId = ?""",
(transaction["accountId"], transaction["transactionId"]),
)
exists = cursor.fetchone()[0] > 0
for transaction in transactions:
try:
# Check if transaction already exists before insertion
cursor.execute(
"""SELECT COUNT(*) FROM transactions
WHERE accountId = ? AND transactionId = ?""",
(transaction["accountId"], transaction["transactionId"]),
)
exists = cursor.fetchone()[0] > 0
cursor.execute(
insert_sql,
(
transaction["accountId"],
transaction["transactionId"],
transaction.get("internalTransactionId"),
transaction["institutionId"],
transaction["iban"],
transaction["transactionDate"],
transaction["description"],
transaction["transactionValue"],
transaction["transactionCurrency"],
transaction["transactionStatus"],
json.dumps(transaction["rawTransaction"]),
),
)
cursor.execute(
insert_sql,
(
transaction["accountId"],
transaction["transactionId"],
transaction.get("internalTransactionId"),
transaction["institutionId"],
transaction["iban"],
transaction["transactionDate"],
transaction["description"],
transaction["transactionValue"],
transaction["transactionCurrency"],
transaction["transactionStatus"],
json.dumps(transaction["rawTransaction"]),
),
)
# Only add to new_transactions if it didn't exist before
if not exists:
new_transactions.append(transaction)
# Only add to new_transactions if it didn't exist before
if not exists:
new_transactions.append(transaction)
except sqlite3.IntegrityError as e:
logger.warning(
f"Failed to insert transaction {transaction.get('transactionId')}: {e}"
)
continue
except sqlite3.IntegrityError as e:
logger.warning(
f"Failed to insert transaction {transaction.get('transactionId')}: {e}"
)
continue
conn.commit()
conn.close()
conn.commit()
logger.info(
f"Persisted {len(new_transactions)} new transactions to SQLite for account {account_id}"
@@ -939,50 +938,49 @@ class DatabaseService:
db_path = path_manager.get_database_path()
if not db_path.exists():
return []
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row # Enable dict-like access
cursor = conn.cursor()
# Build query with filters
query = "SELECT * FROM transactions WHERE 1=1"
params = []
with self._get_db_connection(row_factory=True) as conn:
cursor = conn.cursor()
if account_id:
query += " AND accountId = ?"
params.append(account_id)
# Build query with filters
query = "SELECT * FROM transactions WHERE 1=1"
params = []
if date_from:
query += " AND transactionDate >= ?"
params.append(date_from)
if account_id:
query += " AND accountId = ?"
params.append(account_id)
if date_to:
query += " AND transactionDate <= ?"
params.append(date_to)
if date_from:
query += " AND transactionDate >= ?"
params.append(date_from)
if min_amount is not None:
query += " AND transactionValue >= ?"
params.append(min_amount)
if date_to:
query += " AND transactionDate <= ?"
params.append(date_to)
if max_amount is not None:
query += " AND transactionValue <= ?"
params.append(max_amount)
if min_amount is not None:
query += " AND transactionValue >= ?"
params.append(min_amount)
if search:
query += " AND description LIKE ?"
params.append(f"%{search}%")
if max_amount is not None:
query += " AND transactionValue <= ?"
params.append(max_amount)
# Add ordering and pagination
query += " ORDER BY transactionDate DESC"
if search:
query += " AND description LIKE ?"
params.append(f"%{search}%")
if limit:
query += " LIMIT ?"
params.append(limit)
# Add ordering and pagination
query += " ORDER BY transactionDate DESC"
if offset:
query += " OFFSET ?"
params.append(offset)
if limit:
query += " LIMIT ?"
params.append(limit)
if offset:
query += " OFFSET ?"
params.append(offset)
try:
cursor.execute(query, params)
rows = cursor.fetchall()
@@ -996,61 +994,48 @@ class DatabaseService:
)
transactions.append(transaction)
conn.close()
return transactions
except Exception as e:
conn.close()
raise e
def _get_balances(self, account_id=None):
"""Get latest balances from SQLite database"""
db_path = path_manager.get_database_path()
if not db_path.exists():
return []
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
# Get latest balance for each account_id and type combination
query = """
SELECT * FROM balances b1
WHERE b1.timestamp = (
SELECT MAX(b2.timestamp)
FROM balances b2
WHERE b2.account_id = b1.account_id AND b2.type = b1.type
)
"""
params = []
with self._get_db_connection(row_factory=True) as conn:
cursor = conn.cursor()
if account_id:
query += " AND b1.account_id = ?"
params.append(account_id)
# Get latest balance for each account_id and type combination
query = """
SELECT * FROM balances b1
WHERE b1.timestamp = (
SELECT MAX(b2.timestamp)
FROM balances b2
WHERE b2.account_id = b1.account_id AND b2.type = b1.type
)
"""
params = []
query += " ORDER BY b1.account_id, b1.type"
if account_id:
query += " AND b1.account_id = ?"
params.append(account_id)
query += " ORDER BY b1.account_id, b1.type"
try:
cursor.execute(query, params)
rows = cursor.fetchall()
balances = [dict(row) for row in rows]
conn.close()
return balances
except Exception as e:
conn.close()
raise e
return [dict(row) for row in rows]
def _get_account_summary(self, account_id):
"""Get basic account info from transactions table (avoids GoCardless API call)"""
db_path = path_manager.get_database_path()
if not db_path.exists():
return None
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
try:
with self._get_db_connection(row_factory=True) as conn:
cursor = conn.cursor()
# Get account info from most recent transaction
cursor.execute(
"""
@@ -1064,96 +1049,82 @@ class DatabaseService:
)
row = cursor.fetchone()
conn.close()
if row:
return dict(row)
return None
except Exception as e:
conn.close()
raise e
def _get_transaction_count(self, account_id=None, **filters):
"""Get total count of transactions matching filters"""
db_path = path_manager.get_database_path()
if not db_path.exists():
return 0
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
query = "SELECT COUNT(*) FROM transactions WHERE 1=1"
params = []
with self._get_db_connection() as conn:
cursor = conn.cursor()
if account_id:
query += " AND accountId = ?"
params.append(account_id)
query = "SELECT COUNT(*) FROM transactions WHERE 1=1"
params = []
# Add same filters as get_transactions
if filters.get("date_from"):
query += " AND transactionDate >= ?"
params.append(filters["date_from"])
if account_id:
query += " AND accountId = ?"
params.append(account_id)
if filters.get("date_to"):
query += " AND transactionDate <= ?"
params.append(filters["date_to"])
# Add same filters as get_transactions
if filters.get("date_from"):
query += " AND transactionDate >= ?"
params.append(filters["date_from"])
if filters.get("min_amount") is not None:
query += " AND transactionValue >= ?"
params.append(filters["min_amount"])
if filters.get("date_to"):
query += " AND transactionDate <= ?"
params.append(filters["date_to"])
if filters.get("max_amount") is not None:
query += " AND transactionValue <= ?"
params.append(filters["max_amount"])
if filters.get("min_amount") is not None:
query += " AND transactionValue >= ?"
params.append(filters["min_amount"])
if filters.get("search"):
query += " AND description LIKE ?"
params.append(f"%{filters['search']}%")
if filters.get("max_amount") is not None:
query += " AND transactionValue <= ?"
params.append(filters["max_amount"])
if filters.get("search"):
query += " AND description LIKE ?"
params.append(f"%{filters['search']}%")
try:
cursor.execute(query, params)
count = cursor.fetchone()[0]
conn.close()
return count
except Exception as e:
conn.close()
raise e
return cursor.fetchone()[0]
def _persist_account(self, account_data: dict):
"""Persist account details to SQLite database"""
db_path = path_manager.get_database_path()
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
with self._get_db_connection() as conn:
cursor = conn.cursor()
# Create the accounts table if it doesn't exist
cursor.execute(
"""CREATE TABLE IF NOT EXISTS accounts (
id TEXT PRIMARY KEY,
institution_id TEXT,
status TEXT,
iban TEXT,
name TEXT,
currency TEXT,
created DATETIME,
last_accessed DATETIME,
last_updated DATETIME,
display_name TEXT,
logo TEXT
)"""
)
# Create the accounts table if it doesn't exist
cursor.execute(
"""CREATE TABLE IF NOT EXISTS accounts (
id TEXT PRIMARY KEY,
institution_id TEXT,
status TEXT,
iban TEXT,
name TEXT,
currency TEXT,
created DATETIME,
last_accessed DATETIME,
last_updated DATETIME,
display_name TEXT,
logo TEXT
)"""
)
# Create indexes for accounts table
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_accounts_institution_id
ON accounts(institution_id)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_accounts_status
ON accounts(status)"""
)
# Create indexes for accounts table
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_accounts_institution_id
ON accounts(institution_id)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_accounts_status
ON accounts(status)"""
)
try:
# First, check if account exists and preserve display_name
cursor.execute(
"SELECT display_name FROM accounts WHERE id = ?", (account_data["id"],)
@@ -1194,67 +1165,50 @@ class DatabaseService:
),
)
conn.commit()
conn.close()
return account_data
except Exception as e:
conn.close()
raise e
def _get_accounts(self, account_ids=None):
"""Get account details from SQLite database"""
db_path = path_manager.get_database_path()
if not db_path.exists():
return []
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
query = "SELECT * FROM accounts"
params = []
with self._get_db_connection(row_factory=True) as conn:
cursor = conn.cursor()
if account_ids:
placeholders = ",".join("?" * len(account_ids))
query += f" WHERE id IN ({placeholders})"
params.extend(account_ids)
query = "SELECT * FROM accounts"
params = []
query += " ORDER BY created DESC"
if account_ids:
placeholders = ",".join("?" * len(account_ids))
query += f" WHERE id IN ({placeholders})"
params.extend(account_ids)
query += " ORDER BY created DESC"
try:
cursor.execute(query, params)
rows = cursor.fetchall()
accounts = [dict(row) for row in rows]
conn.close()
return accounts
except Exception as e:
conn.close()
raise e
return [dict(row) for row in rows]
def _get_account(self, account_id: str):
"""Get specific account details from SQLite database"""
db_path = path_manager.get_database_path()
if not db_path.exists():
return None
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
try:
with self._get_db_connection(row_factory=True) as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM accounts WHERE id = ?", (account_id,))
row = cursor.fetchone()
conn.close()
if row:
return dict(row)
return None
except Exception as e:
conn.close()
raise e
@require_sqlite
async def get_monthly_transaction_stats_from_db(
self,
account_id: Optional[str] = None,
@@ -1262,10 +1216,6 @@ class DatabaseService:
date_to: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""Get monthly transaction statistics aggregated by the database"""
if not self.sqlite_enabled:
logger.warning("SQLite database disabled, cannot read monthly stats")
return []
try:
db_path = path_manager.get_database_path()
monthly_stats = self.analytics_processor.calculate_monthly_stats(