diff --git a/leggen/services/database_service.py b/leggen/services/database_service.py index 2d35a1a..3774d2c 100644 --- a/leggen/services/database_service.py +++ b/leggen/services/database_service.py @@ -3,9 +3,11 @@ from datetime import datetime, timedelta from typing import Any, Dict, List, Optional from loguru import logger +from sqlalchemy import and_, desc, func +from sqlmodel import col, select -from leggen.services.database import init_database -from leggen.services.database_helpers import get_db_connection +from leggen.models.database import Account, Balance, SyncOperation, Transaction +from leggen.services.database import get_session, init_database from leggen.services.transaction_processor import TransactionProcessor from leggen.utils.config import config from leggen.utils.paths import path_manager @@ -279,208 +281,102 @@ class DatabaseService: async def _persist_balance_sqlite( self, account_id: str, balance_data: Dict[str, Any] ) -> None: - """Persist balance to SQLite""" + """Persist balance to database using SQLModel""" try: - import sqlite3 + with get_session() as session: + # Convert GoCardless balance format to our format and persist + for balance in balance_data.get("balances", []): + balance_amount = balance["balanceAmount"] - db_path = path_manager.get_database_path() - conn = sqlite3.connect(str(db_path)) - cursor = conn.cursor() - - # Create the balances table if it doesn't exist - cursor.execute( - """CREATE TABLE IF NOT EXISTS balances ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - account_id TEXT, - bank TEXT, - status TEXT, - iban TEXT, - amount REAL, - currency TEXT, - type TEXT, - timestamp DATETIME - )""" - ) - - # Create indexes for better performance - cursor.execute( - """CREATE INDEX IF NOT EXISTS idx_balances_account_id - ON balances(account_id)""" - ) - cursor.execute( - """CREATE INDEX IF NOT EXISTS idx_balances_timestamp - ON balances(timestamp)""" - ) - cursor.execute( - """CREATE INDEX IF NOT EXISTS idx_balances_account_type_timestamp - ON balances(account_id, type, timestamp)""" - ) - - # Convert GoCardless balance format to our format and persist - for balance in balance_data.get("balances", []): - balance_amount = balance["balanceAmount"] - - try: - cursor.execute( - """INSERT INTO balances ( - account_id, - bank, - status, - iban, - amount, - currency, - type, - timestamp - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", - ( - account_id, - balance_data.get("institution_id", "unknown"), - balance_data.get("account_status"), - balance_data.get("iban", "N/A"), - float(balance_amount["amount"]), - balance_amount["currency"], - balance["balanceType"], - datetime.now().isoformat(), - ), + db_balance = Balance( + account_id=account_id, + bank=balance_data.get("institution_id", "unknown"), + status=balance_data.get("account_status", ""), + iban=balance_data.get("iban", "N/A"), + amount=float(balance_amount["amount"]), + currency=balance_amount["currency"], + type=balance["balanceType"], + timestamp=datetime.now(), ) - except sqlite3.IntegrityError: - logger.warning(f"Skipped duplicate balance for {account_id}") + session.add(db_balance) - conn.commit() - conn.close() + session.commit() - logger.info(f"Persisted balances to SQLite for account {account_id}") + logger.info(f"Persisted balances for account {account_id}") except Exception as e: - logger.error(f"Failed to persist balances to SQLite: {e}") + logger.error(f"Failed to persist balances: {e}") raise async def _persist_transactions_sqlite( self, account_id: str, transactions: List[Dict[str, Any]] ) -> List[Dict[str, Any]]: - """Persist transactions to SQLite""" + """Persist transactions to database using SQLModel""" try: - import json - import sqlite3 - - db_path = path_manager.get_database_path() - conn = sqlite3.connect(str(db_path)) - cursor = conn.cursor() - - # The table should already exist with the new schema from migration - # If it doesn't exist, create it (for new installations) - cursor.execute( - """CREATE TABLE IF NOT EXISTS transactions ( - accountId TEXT NOT NULL, - transactionId TEXT NOT NULL, - internalTransactionId TEXT, - institutionId TEXT, - iban TEXT, - transactionDate DATETIME, - description TEXT, - transactionValue REAL, - transactionCurrency TEXT, - transactionStatus TEXT, - rawTransaction JSON, - PRIMARY KEY (accountId, transactionId) - )""" - ) - - # Create indexes for better performance (if they don't exist) - cursor.execute( - """CREATE INDEX IF NOT EXISTS idx_transactions_internal_id - ON transactions(internalTransactionId)""" - ) - cursor.execute( - """CREATE INDEX IF NOT EXISTS idx_transactions_date - ON transactions(transactionDate)""" - ) - cursor.execute( - """CREATE INDEX IF NOT EXISTS idx_transactions_account_date - ON transactions(accountId, transactionDate)""" - ) - cursor.execute( - """CREATE INDEX IF NOT EXISTS idx_transactions_amount - ON transactions(transactionValue)""" - ) - - # Prepare an SQL statement for inserting/replacing data - insert_sql = """INSERT OR REPLACE INTO transactions ( - accountId, - transactionId, - internalTransactionId, - institutionId, - iban, - transactionDate, - description, - transactionValue, - transactionCurrency, - transactionStatus, - rawTransaction - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""" - new_transactions = [] - for transaction in transactions: - try: - # Check if transaction already exists before insertion - cursor.execute( - """SELECT COUNT(*) FROM transactions - WHERE accountId = ? AND transactionId = ?""", - (transaction["accountId"], transaction["transactionId"]), + with get_session() as session: + for transaction in transactions: + # Check if transaction already exists + statement = select(Transaction).where( + Transaction.accountId == transaction["accountId"], + Transaction.transactionId == transaction["transactionId"], ) - exists = cursor.fetchone()[0] > 0 + existing = session.exec(statement).first() - cursor.execute( - insert_sql, - ( - transaction["accountId"], - transaction["transactionId"], - transaction.get("internalTransactionId"), - transaction["institutionId"], - transaction["iban"], - transaction["transactionDate"], - transaction["description"], - transaction["transactionValue"], - transaction["transactionCurrency"], - transaction["transactionStatus"], - json.dumps(transaction["rawTransaction"]), - ), - ) - - # Only add to new_transactions if it didn't exist before - if not exists: + if existing: + # Update existing transaction + existing.internalTransactionId = transaction.get( + "internalTransactionId" + ) + existing.institutionId = transaction["institutionId"] + existing.iban = transaction["iban"] + existing.transactionDate = transaction["transactionDate"] + existing.description = transaction["description"] + existing.transactionValue = transaction["transactionValue"] + existing.transactionCurrency = transaction[ + "transactionCurrency" + ] + existing.transactionStatus = transaction["transactionStatus"] + existing.rawTransaction = transaction["rawTransaction"] + else: + # Add new transaction + db_transaction = Transaction( + accountId=transaction["accountId"], + transactionId=transaction["transactionId"], + internalTransactionId=transaction.get( + "internalTransactionId" + ), + institutionId=transaction["institutionId"], + iban=transaction["iban"], + transactionDate=transaction["transactionDate"], + description=transaction["description"], + transactionValue=transaction["transactionValue"], + transactionCurrency=transaction["transactionCurrency"], + transactionStatus=transaction["transactionStatus"], + rawTransaction=transaction["rawTransaction"], + ) + session.add(db_transaction) new_transactions.append(transaction) - except sqlite3.IntegrityError as e: - logger.warning( - f"Failed to insert transaction {transaction.get('transactionId')}: {e}" - ) - continue - - conn.commit() - conn.close() + session.commit() logger.info( - f"Persisted {len(new_transactions)} new transactions to SQLite for account {account_id}" + f"Persisted {len(new_transactions)} new transactions for account {account_id}" ) return new_transactions except Exception as e: - logger.error(f"Failed to persist transactions to SQLite: {e}") + logger.error(f"Failed to persist transactions: {e}") raise async def _persist_account_details_sqlite( self, account_data: Dict[str, Any] ) -> None: - """Persist account details to SQLite""" + """Persist account details using SQLModel""" try: - # Use the sqlite_db module function self._persist_account(account_data) - - logger.info( - f"Persisted account details to SQLite for account {account_data['id']}" - ) + logger.info(f"Persisted account details for account {account_data['id']}") except Exception as e: - logger.error(f"Failed to persist account details to SQLite: {e}") + logger.error(f"Failed to persist account details: {e}") raise def _get_transactions( @@ -494,270 +390,251 @@ class DatabaseService: max_amount=None, search=None, ): - """Get transactions from SQLite database with optional filtering""" - db_path = path_manager.get_database_path() - if not db_path.exists(): - return [] + """Get transactions from database with optional filtering using SQLModel""" + try: + with get_session() as session: + statement = select(Transaction) - # Build query with filters - query = "SELECT * FROM transactions WHERE 1=1" - params = [] + # Apply filters + if account_id: + statement = statement.where(Transaction.accountId == account_id) - if account_id: - query += " AND accountId = ?" - params.append(account_id) - - if date_from: - query += " AND transactionDate >= ?" - params.append(date_from) - - if date_to: - query += " AND transactionDate <= ?" - params.append(date_to) - - if min_amount is not None: - query += " AND transactionValue >= ?" - params.append(min_amount) - - if max_amount is not None: - query += " AND transactionValue <= ?" - params.append(max_amount) - - if search: - query += " AND description LIKE ?" - params.append(f"%{search}%") - - # Add ordering and pagination - query += " ORDER BY transactionDate DESC" - - if limit: - query += " LIMIT ?" - params.append(limit) - - if offset: - query += " OFFSET ?" - params.append(offset) - - with get_db_connection(db_path) as conn: - cursor = conn.cursor() - cursor.execute(query, tuple(params)) - rows = cursor.fetchall() - - # Convert to list of dicts and parse JSON fields - transactions = [] - for row in rows: - transaction = dict(row) - if transaction["rawTransaction"]: - transaction["rawTransaction"] = json.loads( - transaction["rawTransaction"] + if date_from: + statement = statement.where( + Transaction.transactionDate >= date_from ) - transactions.append(transaction) - return transactions + if date_to: + statement = statement.where(Transaction.transactionDate <= date_to) + + if min_amount is not None: + statement = statement.where( + Transaction.transactionValue >= min_amount + ) + + if max_amount is not None: + statement = statement.where( + Transaction.transactionValue <= max_amount + ) + + if search: + statement = statement.where( + col(Transaction.description).contains(search) + ) + + # Add ordering + statement = statement.order_by(desc(col(Transaction.transactionDate))) + + # Add pagination + if limit: + statement = statement.limit(limit) + + if offset: + statement = statement.offset(offset) + + results = session.exec(statement).all() + + # Convert to list of dicts + transactions = [] + for row in results: + transaction = row.model_dump() + transactions.append(transaction) + + return transactions + except Exception as e: + logger.error(f"Failed to get transactions: {e}") + return [] def _get_balances(self, account_id=None): - """Get latest balances from SQLite database""" - db_path = path_manager.get_database_path() - if not db_path.exists(): + """Get latest balances from database using SQLModel""" + try: + with get_session() as session: + # Subquery to get max timestamp for each account_id and type + subquery = ( + select( + Balance.account_id, + Balance.type, + func.max(Balance.timestamp).label("max_timestamp"), + ) + .group_by(Balance.account_id, Balance.type) + .subquery() + ) + + # Main query to get latest balances + statement = select(Balance).join( + subquery, + and_( + col(Balance.account_id) == subquery.c.account_id, + col(Balance.type) == subquery.c.type, + col(Balance.timestamp) == subquery.c.max_timestamp, + ), + ) + + if account_id: + statement = statement.where(Balance.account_id == account_id) + + statement = statement.order_by(Balance.account_id, Balance.type) + + results = session.exec(statement).all() + return [balance.model_dump() for balance in results] + except Exception as e: + logger.error(f"Failed to get balances: {e}") return [] - # Get latest balance for each account_id and type combination - query = """ - SELECT * FROM balances b1 - WHERE b1.timestamp = ( - SELECT MAX(b2.timestamp) - FROM balances b2 - WHERE b2.account_id = b1.account_id AND b2.type = b1.type - ) - """ - params = [] - - if account_id: - query += " AND b1.account_id = ?" - params.append(account_id) - - query += " ORDER BY b1.account_id, b1.type" - - with get_db_connection(db_path) as conn: - cursor = conn.cursor() - cursor.execute(query, tuple(params)) - rows = cursor.fetchall() - return [dict(row) for row in rows] - def _get_account_summary(self, account_id): - """Get basic account info from transactions table (avoids GoCardless API call)""" - db_path = path_manager.get_database_path() - if not db_path.exists(): - return None + """Get basic account info from transactions table""" + try: + with get_session() as session: + statement = ( + select( + Transaction.accountId, + Transaction.institutionId, + Transaction.iban, + ) + .where(Transaction.accountId == account_id) + .order_by(desc(col(Transaction.transactionDate))) + .limit(1) + ) - with get_db_connection(db_path) as conn: - cursor = conn.cursor() - cursor.execute( - """ - SELECT DISTINCT accountId, institutionId, iban - FROM transactions - WHERE accountId = ? - ORDER BY transactionDate DESC - LIMIT 1 - """, - (account_id,), - ) - row = cursor.fetchone() - return dict(row) if row else None + result = session.exec(statement).first() + if result: + return { + "accountId": result[0], + "institutionId": result[1], + "iban": result[2], + } + return None + except Exception as e: + logger.error(f"Failed to get account summary: {e}") + return None def _get_transaction_count(self, account_id=None, **filters): - """Get total count of transactions matching filters""" - db_path = path_manager.get_database_path() - if not db_path.exists(): + """Get total count of transactions matching filters using SQLModel""" + try: + with get_session() as session: + statement = select(func.count()).select_from(Transaction) + + # Apply filters + if account_id: + statement = statement.where(Transaction.accountId == account_id) + + if filters.get("date_from"): + statement = statement.where( + Transaction.transactionDate >= filters["date_from"] + ) + + if filters.get("date_to"): + statement = statement.where( + Transaction.transactionDate <= filters["date_to"] + ) + + if filters.get("min_amount") is not None: + statement = statement.where( + Transaction.transactionValue >= filters["min_amount"] + ) + + if filters.get("max_amount") is not None: + statement = statement.where( + Transaction.transactionValue <= filters["max_amount"] + ) + + if filters.get("search"): + statement = statement.where( + col(Transaction.description).contains(filters["search"]) + ) + + count = session.exec(statement).one() + return count + except Exception as e: + logger.error(f"Failed to get transaction count: {e}") return 0 - query = "SELECT COUNT(*) FROM transactions WHERE 1=1" - params = [] - - if account_id: - query += " AND accountId = ?" - params.append(account_id) - - # Add same filters as get_transactions - if filters.get("date_from"): - query += " AND transactionDate >= ?" - params.append(filters["date_from"]) - - if filters.get("date_to"): - query += " AND transactionDate <= ?" - params.append(filters["date_to"]) - - if filters.get("min_amount") is not None: - query += " AND transactionValue >= ?" - params.append(filters["min_amount"]) - - if filters.get("max_amount") is not None: - query += " AND transactionValue <= ?" - params.append(filters["max_amount"]) - - if filters.get("search"): - query += " AND description LIKE ?" - params.append(f"%{filters['search']}%") - - with get_db_connection(db_path) as conn: - cursor = conn.cursor() - cursor.execute(query, tuple(params)) - count = cursor.fetchone()[0] - return count - def _persist_account(self, account_data: dict): - """Persist account details to SQLite database""" - db_path = path_manager.get_database_path() + """Persist account details using SQLModel""" + try: + with get_session() as session: + # Check if account exists + statement = select(Account).where(Account.id == account_data["id"]) + existing = session.exec(statement).first() - with get_db_connection(db_path) as conn: - cursor = conn.cursor() + if existing: + # Preserve display_name if not provided in new data + display_name = account_data.get( + "display_name", existing.display_name + ) - # Create the accounts table if it doesn't exist - cursor.execute( - """CREATE TABLE IF NOT EXISTS accounts ( - id TEXT PRIMARY KEY, - institution_id TEXT, - status TEXT, - iban TEXT, - name TEXT, - currency TEXT, - created DATETIME, - last_accessed DATETIME, - last_updated DATETIME, - display_name TEXT, - logo TEXT - )""" - ) + # Update existing account + existing.institution_id = account_data["institution_id"] + existing.status = account_data["status"] + existing.iban = account_data.get("iban") + existing.name = account_data.get("name") + existing.currency = account_data.get("currency") + existing.created = account_data["created"] + existing.last_accessed = account_data.get("last_accessed") + existing.last_updated = account_data.get( + "last_updated", account_data["created"] + ) + existing.display_name = display_name + existing.logo = account_data.get("logo") + else: + # Create new account + db_account = Account( + id=account_data["id"], + institution_id=account_data["institution_id"], + status=account_data["status"], + iban=account_data.get("iban"), + name=account_data.get("name"), + currency=account_data.get("currency"), + created=account_data["created"], + last_accessed=account_data.get("last_accessed"), + last_updated=account_data.get( + "last_updated", account_data["created"] + ), + display_name=account_data.get("display_name"), + logo=account_data.get("logo"), + ) + session.add(db_account) - # Create indexes for accounts table - cursor.execute( - """CREATE INDEX IF NOT EXISTS idx_accounts_institution_id - ON accounts(institution_id)""" - ) - cursor.execute( - """CREATE INDEX IF NOT EXISTS idx_accounts_status - ON accounts(status)""" - ) - - # First, check if account exists and preserve display_name - cursor.execute( - "SELECT display_name FROM accounts WHERE id = ?", (account_data["id"],) - ) - existing_row = cursor.fetchone() - existing_display_name = existing_row[0] if existing_row else None - - # Use existing display_name if not provided in account_data - display_name = account_data.get("display_name", existing_display_name) - - # Insert or replace account data - cursor.execute( - """INSERT OR REPLACE INTO accounts ( - id, - institution_id, - status, - iban, - name, - currency, - created, - last_accessed, - last_updated, - display_name, - logo - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", - ( - account_data["id"], - account_data["institution_id"], - account_data["status"], - account_data.get("iban"), - account_data.get("name"), - account_data.get("currency"), - account_data["created"], - account_data.get("last_accessed"), - account_data.get("last_updated", account_data["created"]), - display_name, - account_data.get("logo"), - ), - ) - conn.commit() - - return account_data + session.commit() + return account_data + except Exception as e: + logger.error(f"Failed to persist account: {e}") + raise def _get_accounts(self, account_ids=None): - """Get account details from SQLite database""" - db_path = path_manager.get_database_path() - if not db_path.exists(): + """Get account details using SQLModel""" + try: + with get_session() as session: + statement = select(Account) + + if account_ids: + statement = statement.where(col(Account.id).in_(account_ids)) + + statement = statement.order_by(desc(col(Account.created))) + + results = session.exec(statement).all() + return [account.model_dump() for account in results] + except Exception as e: + logger.error(f"Failed to get accounts: {e}") return [] - query = "SELECT * FROM accounts" - params = [] - - if account_ids: - placeholders = ",".join("?" * len(account_ids)) - query += f" WHERE id IN ({placeholders})" - params.extend(account_ids) - - query += " ORDER BY created DESC" - - with get_db_connection(db_path) as conn: - cursor = conn.cursor() - cursor.execute(query, tuple(params)) - rows = cursor.fetchall() - return [dict(row) for row in rows] - def _get_account(self, account_id: str): - """Get specific account details from SQLite database""" - db_path = path_manager.get_database_path() - if not db_path.exists(): + """Get specific account details using SQLModel""" + try: + with get_session() as session: + statement = select(Account).where(Account.id == account_id) + result = session.exec(statement).first() + return result.model_dump() if result else None + except Exception as e: + logger.error(f"Failed to get account: {e}") return None - with get_db_connection(db_path) as conn: - cursor = conn.cursor() - cursor.execute("SELECT * FROM accounts WHERE id = ?", (account_id,)) - row = cursor.fetchone() - return dict(row) if row else None - def _get_historical_balances(self, account_id=None, days=365): """Get historical balance progression based on transaction history""" + # This method uses complex CTEs and window functions that are better kept as raw SQL + # for performance and readability + from leggen.services.database_helpers import get_db_connection + db_path = path_manager.get_database_path() if not db_path.exists(): return [] @@ -782,7 +659,7 @@ class DatabaseService: WHERE b1.timestamp = ( SELECT MAX(b2.timestamp) FROM balances b2 - WHERE b2.account_id = b1.account_id AND b2.type = b1.type + WHERE b2.account_id = b1.account_id AND b1.type = b2.type ) {account_filter} AND b1.type = 'closingBooked' -- Focus on closingBooked for charts @@ -825,11 +702,15 @@ class DatabaseService: # Format the query with conditional filter formatted_query = query.format(account_filter=account_filter) - with get_db_connection(db_path) as conn: - cursor = conn.cursor() - cursor.execute(formatted_query, tuple(params)) - rows = cursor.fetchall() - return [dict(row) for row in rows] + try: + with get_db_connection(db_path) as conn: + cursor = conn.cursor() + cursor.execute(formatted_query, tuple(params)) + rows = cursor.fetchall() + return [dict(row) for row in rows] + except Exception as e: + logger.error(f"Failed to get historical balances: {e}") + return [] async def get_monthly_transaction_stats_from_db( self, @@ -862,7 +743,9 @@ class DatabaseService: date_from: Optional[str] = None, date_to: Optional[str] = None, ) -> List[Dict[str, Any]]: - """Get monthly transaction statistics from SQLite database""" + """Get monthly transaction statistics - using raw SQL for date aggregation""" + from leggen.services.database_helpers import get_db_connection + db_path = path_manager.get_database_path() if not db_path.exists(): return [] @@ -897,75 +780,65 @@ class DatabaseService: ORDER BY month ASC """ - with get_db_connection(db_path) as conn: - cursor = conn.cursor() - cursor.execute(query, tuple(params)) - rows = cursor.fetchall() + try: + with get_db_connection(db_path) as conn: + cursor = conn.cursor() + cursor.execute(query, tuple(params)) + rows = cursor.fetchall() - # Convert to desired format with proper month display - monthly_stats = [] - for row in rows: - # Convert YYYY-MM to display format like "Mar 2024" - year, month_num = row["month"].split("-") - month_date = datetime.strptime(f"{year}-{month_num}-01", "%Y-%m-%d") - display_month = month_date.strftime("%b %Y") + # Convert to desired format with proper month display + monthly_stats = [] + for row in rows: + # Convert YYYY-MM to display format like "Mar 2024" + year, month_num = row["month"].split("-") + month_date = datetime.strptime(f"{year}-{month_num}-01", "%Y-%m-%d") + display_month = month_date.strftime("%b %Y") - monthly_stats.append( - { - "month": display_month, - "income": round(row["income"], 2), - "expenses": round(row["expenses"], 2), - "net": round(row["net"], 2), - } - ) + monthly_stats.append( + { + "month": display_month, + "income": round(row["income"], 2), + "expenses": round(row["expenses"], 2), + "net": round(row["net"], 2), + } + ) - return monthly_stats + return monthly_stats + except Exception as e: + logger.error(f"Failed to get monthly transaction stats: {e}") + return [] async def persist_sync_operation(self, sync_operation: Dict[str, Any]) -> int: - """Persist sync operation to database and return the ID""" + """Persist sync operation to database using SQLModel""" if not self.sqlite_enabled: logger.warning("SQLite database disabled, cannot persist sync operation") return 0 try: - import json - import sqlite3 + with get_session() as session: + db_sync = SyncOperation( + started_at=sync_operation.get("started_at"), + completed_at=sync_operation.get("completed_at"), + success=sync_operation.get("success"), + accounts_processed=sync_operation.get("accounts_processed", 0), + transactions_added=sync_operation.get("transactions_added", 0), + transactions_updated=sync_operation.get("transactions_updated", 0), + balances_updated=sync_operation.get("balances_updated", 0), + duration_seconds=sync_operation.get("duration_seconds"), + errors=json.dumps(sync_operation.get("errors", [])), + logs=json.dumps(sync_operation.get("logs", [])), + trigger_type=sync_operation.get("trigger_type", "manual"), + ) + session.add(db_sync) + session.commit() + session.refresh(db_sync) - db_path = path_manager.get_database_path() - conn = sqlite3.connect(str(db_path)) - cursor = conn.cursor() + operation_id = db_sync.id + if operation_id is None: + raise ValueError("Failed to get operation ID after insert") - # Insert sync operation - cursor.execute( - """INSERT INTO sync_operations ( - started_at, completed_at, success, accounts_processed, - transactions_added, transactions_updated, balances_updated, - duration_seconds, errors, logs, trigger_type - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", - ( - sync_operation.get("started_at"), - sync_operation.get("completed_at"), - sync_operation.get("success"), - sync_operation.get("accounts_processed", 0), - sync_operation.get("transactions_added", 0), - sync_operation.get("transactions_updated", 0), - sync_operation.get("balances_updated", 0), - sync_operation.get("duration_seconds"), - json.dumps(sync_operation.get("errors", [])), - json.dumps(sync_operation.get("logs", [])), - sync_operation.get("trigger_type", "manual"), - ), - ) - - operation_id = cursor.lastrowid - if operation_id is None: - raise ValueError("Failed to get operation ID after insert") - - conn.commit() - conn.close() - - logger.debug(f"Persisted sync operation with ID: {operation_id}") - return operation_id + logger.debug(f"Persisted sync operation with ID: {operation_id}") + return operation_id except Exception as e: logger.error(f"Failed to persist sync operation: {e}") @@ -974,50 +847,39 @@ class DatabaseService: async def get_sync_operations( self, limit: int = 50, offset: int = 0 ) -> List[Dict[str, Any]]: - """Get sync operations from database""" + """Get sync operations from database using SQLModel""" if not self.sqlite_enabled: logger.warning("SQLite database disabled, cannot get sync operations") return [] try: - import json - import sqlite3 + with get_session() as session: + statement = ( + select(SyncOperation) + .order_by(desc(col(SyncOperation.started_at))) + .limit(limit) + .offset(offset) + ) - db_path = path_manager.get_database_path() - conn = sqlite3.connect(str(db_path)) - cursor = conn.cursor() + results = session.exec(statement).all() - # Get sync operations ordered by started_at descending - cursor.execute( - """SELECT id, started_at, completed_at, success, accounts_processed, - transactions_added, transactions_updated, balances_updated, - duration_seconds, errors, logs, trigger_type - FROM sync_operations - ORDER BY started_at DESC - LIMIT ? OFFSET ?""", - (limit, offset), - ) + operations = [] + for sync_op in results: + operation = sync_op.model_dump() + # Parse JSON fields + if operation["errors"]: + operation["errors"] = json.loads(operation["errors"]) + else: + operation["errors"] = [] - operations = [] - for row in cursor.fetchall(): - operation = { - "id": row[0], - "started_at": row[1], - "completed_at": row[2], - "success": bool(row[3]) if row[3] is not None else None, - "accounts_processed": row[4], - "transactions_added": row[5], - "transactions_updated": row[6], - "balances_updated": row[7], - "duration_seconds": row[8], - "errors": json.loads(row[9]) if row[9] else [], - "logs": json.loads(row[10]) if row[10] else [], - "trigger_type": row[11], - } - operations.append(operation) + if operation["logs"]: + operation["logs"] = json.loads(operation["logs"]) + else: + operation["logs"] = [] - conn.close() - return operations + operations.append(operation) + + return operations except Exception as e: logger.error(f"Failed to get sync operations: {e}") diff --git a/tests/unit/test_configurable_paths.py b/tests/unit/test_configurable_paths.py index 556cc72..aef5eb7 100644 --- a/tests/unit/test_configurable_paths.py +++ b/tests/unit/test_configurable_paths.py @@ -106,6 +106,11 @@ class TestConfigurablePaths: # Set custom database path path_manager.set_database_path(test_db_path) + # Initialize database tables for the custom path + from leggen.services.database import init_database + + init_database() + # Test database operations using DatabaseService database_service = DatabaseService() balance_data = { diff --git a/tests/unit/test_database_service.py b/tests/unit/test_database_service.py index 221d36f..eceed9c 100644 --- a/tests/unit/test_database_service.py +++ b/tests/unit/test_database_service.py @@ -1,16 +1,54 @@ """Tests for database service.""" +import tempfile from datetime import datetime +from pathlib import Path from unittest.mock import patch import pytest +from leggen.services.database import init_database from leggen.services.database_service import DatabaseService +from leggen.utils.paths import path_manager @pytest.fixture -def database_service(): - """Create a database service instance for testing.""" +def test_db_path(): + """Create a temporary test database.""" + import os + + # Create a writable temporary file + fd, temp_path = tempfile.mkstemp(suffix=".db") + os.close(fd) # Close the file descriptor + db_path = Path(temp_path) + + # Set the test database path + original_path = path_manager._database_path + path_manager._database_path = db_path + + # Reset the engine to use the new database path + import leggen.services.database as db_module + + original_engine = db_module._engine + db_module._engine = None + + # Initialize database tables + init_database() + + yield db_path + + # Cleanup - close any sessions first + if db_module._engine: + db_module._engine.dispose() + db_module._engine = original_engine + path_manager._database_path = original_path + if db_path.exists(): + db_path.unlink() + + +@pytest.fixture +def database_service(test_db_path): + """Create a database service instance for testing with real database.""" return DatabaseService() @@ -282,6 +320,7 @@ class TestDatabaseService: """Test successful balance persistence.""" balance_data = { "institution_id": "REVOLUT_REVOLT21", + "account_status": "active", "iban": "LT313250081177977789", "balances": [ { @@ -291,26 +330,23 @@ class TestDatabaseService: ], } - with patch("sqlite3.connect") as mock_connect: - mock_conn = mock_connect.return_value - mock_cursor = mock_conn.cursor.return_value + # Test actual persistence + await database_service._persist_balance_sqlite("test-account-123", balance_data) - await database_service._persist_balance_sqlite( - "test-account-123", balance_data - ) - - # Verify database operations - mock_connect.assert_called() - mock_cursor.execute.assert_called() # Table creation and insert - mock_conn.commit.assert_called_once() - mock_conn.close.assert_called_once() + # Verify balance was persisted + balances = await database_service.get_balances_from_db("test-account-123") + assert len(balances) == 1 + assert balances[0]["account_id"] == "test-account-123" + assert balances[0]["amount"] == 1000.0 + assert balances[0]["currency"] == "EUR" async def test_persist_balance_sqlite_error(self, database_service): """Test handling error during balance persistence.""" balance_data = {"balances": []} - with patch("sqlite3.connect") as mock_connect: - mock_connect.side_effect = Exception("Database error") + # Mock get_session to raise an error + with patch("leggen.services.database_service.get_session") as mock_session: + mock_session.side_effect = Exception("Database error") with pytest.raises(Exception, match="Database error"): await database_service._persist_balance_sqlite( @@ -321,52 +357,48 @@ class TestDatabaseService: self, database_service, sample_transactions_db_format ): """Test successful transaction persistence.""" - with patch("sqlite3.connect") as mock_connect: - mock_conn = mock_connect.return_value - mock_cursor = mock_conn.cursor.return_value - # Mock fetchone to return (0,) indicating transaction doesn't exist yet - mock_cursor.fetchone.return_value = (0,) + result = await database_service._persist_transactions_sqlite( + "test-account-123", sample_transactions_db_format + ) - result = await database_service._persist_transactions_sqlite( - "test-account-123", sample_transactions_db_format - ) + # Should return all transactions as new + assert len(result) == 2 - # Should return the transactions (assuming no duplicates) - assert len(result) >= 0 # Could be empty if all are duplicates - - # Verify database operations - mock_connect.assert_called() - mock_cursor.execute.assert_called() - mock_conn.commit.assert_called_once() - mock_conn.close.assert_called_once() + # Verify transactions were persisted + transactions = await database_service.get_transactions_from_db( + account_id="test-account-123" + ) + assert len(transactions) == 2 + assert transactions[0]["accountId"] == "test-account-123" async def test_persist_transactions_sqlite_duplicate_detection( self, database_service, sample_transactions_db_format ): """Test that existing transactions are not returned as new.""" - with patch("sqlite3.connect") as mock_connect: - mock_conn = mock_connect.return_value - mock_cursor = mock_conn.cursor.return_value - # Mock fetchone to return (1,) indicating transaction already exists - mock_cursor.fetchone.return_value = (1,) + # First insert + result1 = await database_service._persist_transactions_sqlite( + "test-account-123", sample_transactions_db_format + ) + assert len(result1) == 2 - result = await database_service._persist_transactions_sqlite( - "test-account-123", sample_transactions_db_format - ) + # Second insert (duplicates) + result2 = await database_service._persist_transactions_sqlite( + "test-account-123", sample_transactions_db_format + ) - # Should return empty list since all transactions already exist - assert len(result) == 0 + # Should return empty list since all transactions already exist + assert len(result2) == 0 - # Verify database operations still happened (INSERT OR REPLACE executed) - mock_connect.assert_called() - mock_cursor.execute.assert_called() - mock_conn.commit.assert_called_once() - mock_conn.close.assert_called_once() + # Verify still only 2 transactions in database + transactions = await database_service.get_transactions_from_db( + account_id="test-account-123" + ) + assert len(transactions) == 2 async def test_persist_transactions_sqlite_error(self, database_service): """Test handling error during transaction persistence.""" - with patch("sqlite3.connect") as mock_connect: - mock_connect.side_effect = Exception("Database error") + with patch("leggen.services.database_service.get_session") as mock_session: + mock_session.side_effect = Exception("Database error") with pytest.raises(Exception, match="Database error"): await database_service._persist_transactions_sqlite(