diff --git a/leggen/database/sqlite.py b/leggen/database/sqlite.py deleted file mode 100644 index 9ada689..0000000 --- a/leggen/database/sqlite.py +++ /dev/null @@ -1,658 +0,0 @@ -import json -import sqlite3 -from sqlite3 import IntegrityError - -import click - -from leggen.utils.text import success, warning -from leggen.utils.paths import path_manager - - -def persist_balances(ctx: click.Context, balance: dict): - # Connect to SQLite database - db_path = path_manager.get_database_path() - conn = sqlite3.connect(str(db_path)) - cursor = conn.cursor() - - # Create the accounts table if it doesn't exist - cursor.execute( - """CREATE TABLE IF NOT EXISTS accounts ( - id TEXT PRIMARY KEY, - institution_id TEXT, - status TEXT, - iban TEXT, - name TEXT, - currency TEXT, - created DATETIME, - last_accessed DATETIME, - last_updated DATETIME - )""" - ) - - # Create indexes for accounts table - cursor.execute( - """CREATE INDEX IF NOT EXISTS idx_accounts_institution_id - ON accounts(institution_id)""" - ) - cursor.execute( - """CREATE INDEX IF NOT EXISTS idx_accounts_status - ON accounts(status)""" - ) - - # Create the balances table if it doesn't exist - cursor.execute( - """CREATE TABLE IF NOT EXISTS balances ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - account_id TEXT, - bank TEXT, - status TEXT, - iban TEXT, - amount REAL, - currency TEXT, - type TEXT, - timestamp DATETIME - )""" - ) - - # Create indexes for better performance - cursor.execute( - """CREATE INDEX IF NOT EXISTS idx_balances_account_id - ON balances(account_id)""" - ) - cursor.execute( - """CREATE INDEX IF NOT EXISTS idx_balances_timestamp - ON balances(timestamp)""" - ) - cursor.execute( - """CREATE INDEX IF NOT EXISTS idx_balances_account_type_timestamp - ON balances(account_id, type, timestamp)""" - ) - - # Insert balance into SQLite database - try: - cursor.execute( - """INSERT INTO balances ( - account_id, - bank, - status, - iban, - amount, - currency, - type, - timestamp - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", - ( - balance["account_id"], - balance["bank"], - balance["status"], - balance["iban"], - balance["amount"], - balance["currency"], - balance["type"], - balance["timestamp"], - ), - ) - except IntegrityError: - warning(f"[{balance['account_id']}] Skipped duplicate balance") - - # Commit changes and close the connection - conn.commit() - conn.close() - - success(f"[{balance['account_id']}] Inserted balance of type {balance['type']}") - - return balance - - -def persist_transactions(ctx: click.Context, account: str, transactions: list) -> list: - # Connect to SQLite database - db_path = path_manager.get_database_path() - conn = sqlite3.connect(str(db_path)) - cursor = conn.cursor() - - # Create the transactions table if it doesn't exist - cursor.execute( - """CREATE TABLE IF NOT EXISTS transactions ( - accountId TEXT NOT NULL, - transactionId TEXT NOT NULL, - internalTransactionId TEXT, - institutionId TEXT, - iban TEXT, - transactionDate DATETIME, - description TEXT, - transactionValue REAL, - transactionCurrency TEXT, - transactionStatus TEXT, - rawTransaction JSON, - PRIMARY KEY (accountId, transactionId) - )""" - ) - - # Create indexes for better performance - cursor.execute( - """CREATE INDEX IF NOT EXISTS idx_transactions_internal_id - ON transactions(internalTransactionId)""" - ) - cursor.execute( - """CREATE INDEX IF NOT EXISTS idx_transactions_date - ON transactions(transactionDate)""" - ) - cursor.execute( - """CREATE INDEX IF NOT EXISTS idx_transactions_account_date - ON transactions(accountId, transactionDate)""" - ) - cursor.execute( - """CREATE INDEX IF NOT EXISTS idx_transactions_amount - ON transactions(transactionValue)""" - ) - - # Insert transactions into SQLite database - duplicates_count = 0 - - # Prepare an SQL statement for inserting data - insert_sql = """INSERT OR REPLACE INTO transactions ( - accountId, - transactionId, - internalTransactionId, - institutionId, - iban, - transactionDate, - description, - transactionValue, - transactionCurrency, - transactionStatus, - rawTransaction - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""" - - new_transactions = [] - - for transaction in transactions: - try: - cursor.execute( - insert_sql, - ( - transaction["accountId"], - transaction["transactionId"], - transaction.get("internalTransactionId"), - transaction["institutionId"], - transaction["iban"], - transaction["transactionDate"], - transaction["description"], - transaction["transactionValue"], - transaction["transactionCurrency"], - transaction["transactionStatus"], - json.dumps(transaction["rawTransaction"]), - ), - ) - new_transactions.append(transaction) - except IntegrityError: - # A transaction with the same ID already exists, indicating a duplicate - duplicates_count += 1 - - # Commit changes and close the connection - conn.commit() - conn.close() - - success(f"[{account}] Inserted {len(new_transactions)} new transactions") - if duplicates_count: - warning(f"[{account}] Skipped {duplicates_count} duplicate transactions") - - return new_transactions - - -def get_transactions( - account_id=None, - limit=100, - offset=0, - date_from=None, - date_to=None, - min_amount=None, - max_amount=None, - search=None, -): - """Get transactions from SQLite database with optional filtering""" - db_path = path_manager.get_database_path() - if not db_path.exists(): - return [] - conn = sqlite3.connect(str(db_path)) - conn.row_factory = sqlite3.Row # Enable dict-like access - cursor = conn.cursor() - - # Build query with filters - query = "SELECT * FROM transactions WHERE 1=1" - params = [] - - if account_id: - query += " AND accountId = ?" - params.append(account_id) - - if date_from: - query += " AND transactionDate >= ?" - params.append(date_from) - - if date_to: - query += " AND transactionDate <= ?" - params.append(date_to) - - if min_amount is not None: - query += " AND transactionValue >= ?" - params.append(min_amount) - - if max_amount is not None: - query += " AND transactionValue <= ?" - params.append(max_amount) - - if search: - query += " AND description LIKE ?" - params.append(f"%{search}%") - - # Add ordering and pagination - query += " ORDER BY transactionDate DESC" - - if limit: - query += " LIMIT ?" - params.append(limit) - - if offset: - query += " OFFSET ?" - params.append(offset) - - try: - cursor.execute(query, params) - rows = cursor.fetchall() - - # Convert to list of dicts and parse JSON fields - transactions = [] - for row in rows: - transaction = dict(row) - if transaction["rawTransaction"]: - transaction["rawTransaction"] = json.loads( - transaction["rawTransaction"] - ) - transactions.append(transaction) - - conn.close() - return transactions - - except Exception as e: - conn.close() - raise e - - -def get_balances(account_id=None): - """Get latest balances from SQLite database""" - db_path = path_manager.get_database_path() - if not db_path.exists(): - return [] - conn = sqlite3.connect(str(db_path)) - conn.row_factory = sqlite3.Row - cursor = conn.cursor() - - # Get latest balance for each account_id and type combination - query = """ - SELECT * FROM balances b1 - WHERE b1.timestamp = ( - SELECT MAX(b2.timestamp) - FROM balances b2 - WHERE b2.account_id = b1.account_id AND b2.type = b1.type - ) - """ - params = [] - - if account_id: - query += " AND b1.account_id = ?" - params.append(account_id) - - query += " ORDER BY b1.account_id, b1.type" - - try: - cursor.execute(query, params) - rows = cursor.fetchall() - - balances = [dict(row) for row in rows] - conn.close() - return balances - - except Exception as e: - conn.close() - raise e - - -def get_account_summary(account_id): - """Get basic account info from transactions table (avoids GoCardless API call)""" - db_path = path_manager.get_database_path() - if not db_path.exists(): - return None - conn = sqlite3.connect(str(db_path)) - conn.row_factory = sqlite3.Row - cursor = conn.cursor() - - try: - # Get account info from most recent transaction - cursor.execute( - """ - SELECT DISTINCT accountId, institutionId, iban - FROM transactions - WHERE accountId = ? - ORDER BY transactionDate DESC - LIMIT 1 - """, - (account_id,), - ) - - row = cursor.fetchone() - conn.close() - - if row: - return dict(row) - return None - - except Exception as e: - conn.close() - raise e - - -def get_transaction_count(account_id=None, **filters): - """Get total count of transactions matching filters""" - db_path = path_manager.get_database_path() - if not db_path.exists(): - return 0 - conn = sqlite3.connect(str(db_path)) - cursor = conn.cursor() - - query = "SELECT COUNT(*) FROM transactions WHERE 1=1" - params = [] - - if account_id: - query += " AND accountId = ?" - params.append(account_id) - - # Add same filters as get_transactions - if filters.get("date_from"): - query += " AND transactionDate >= ?" - params.append(filters["date_from"]) - - if filters.get("date_to"): - query += " AND transactionDate <= ?" - params.append(filters["date_to"]) - - if filters.get("min_amount") is not None: - query += " AND transactionValue >= ?" - params.append(filters["min_amount"]) - - if filters.get("max_amount") is not None: - query += " AND transactionValue <= ?" - params.append(filters["max_amount"]) - - if filters.get("search"): - query += " AND description LIKE ?" - params.append(f"%{filters['search']}%") - - try: - cursor.execute(query, params) - count = cursor.fetchone()[0] - conn.close() - return count - - except Exception as e: - conn.close() - raise e - - -def persist_account(account_data: dict): - """Persist account details to SQLite database""" - db_path = path_manager.get_database_path() - conn = sqlite3.connect(str(db_path)) - cursor = conn.cursor() - - # Create the accounts table if it doesn't exist - cursor.execute( - """CREATE TABLE IF NOT EXISTS accounts ( - id TEXT PRIMARY KEY, - institution_id TEXT, - status TEXT, - iban TEXT, - name TEXT, - currency TEXT, - created DATETIME, - last_accessed DATETIME, - last_updated DATETIME - )""" - ) - - # Create indexes for accounts table - cursor.execute( - """CREATE INDEX IF NOT EXISTS idx_accounts_institution_id - ON accounts(institution_id)""" - ) - cursor.execute( - """CREATE INDEX IF NOT EXISTS idx_accounts_status - ON accounts(status)""" - ) - - try: - # Insert or replace account data - cursor.execute( - """INSERT OR REPLACE INTO accounts ( - id, - institution_id, - status, - iban, - name, - currency, - created, - last_accessed, - last_updated - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""", - ( - account_data["id"], - account_data["institution_id"], - account_data["status"], - account_data.get("iban"), - account_data.get("name"), - account_data.get("currency"), - account_data["created"], - account_data.get("last_accessed"), - account_data.get("last_updated", account_data["created"]), - ), - ) - conn.commit() - conn.close() - - success(f"[{account_data['id']}] Account details persisted to database") - return account_data - - except Exception as e: - conn.close() - raise e - - -def get_accounts(account_ids=None): - """Get account details from SQLite database""" - db_path = path_manager.get_database_path() - if not db_path.exists(): - return [] - conn = sqlite3.connect(str(db_path)) - conn.row_factory = sqlite3.Row - cursor = conn.cursor() - - query = "SELECT * FROM accounts" - params = [] - - if account_ids: - placeholders = ",".join("?" * len(account_ids)) - query += f" WHERE id IN ({placeholders})" - params.extend(account_ids) - - query += " ORDER BY created DESC" - - try: - cursor.execute(query, params) - rows = cursor.fetchall() - - accounts = [dict(row) for row in rows] - conn.close() - return accounts - - except Exception as e: - conn.close() - raise e - - -def get_account(account_id: str): - """Get specific account details from SQLite database""" - db_path = path_manager.get_database_path() - if not db_path.exists(): - return None - conn = sqlite3.connect(str(db_path)) - conn.row_factory = sqlite3.Row - cursor = conn.cursor() - - try: - cursor.execute("SELECT * FROM accounts WHERE id = ?", (account_id,)) - row = cursor.fetchone() - conn.close() - - if row: - return dict(row) - return None - - except Exception as e: - conn.close() - raise e - - -def get_historical_balances(account_id=None, days=365): - """Get historical balance progression based on transaction history""" - db_path = path_manager.get_database_path() - if not db_path.exists(): - return [] - - conn = sqlite3.connect(str(db_path)) - conn.row_factory = sqlite3.Row - cursor = conn.cursor() - - try: - # Get current balance for each account/type to use as the final balance - current_balances_query = """ - SELECT account_id, type, amount, currency - FROM balances b1 - WHERE b1.timestamp = ( - SELECT MAX(b2.timestamp) - FROM balances b2 - WHERE b2.account_id = b1.account_id AND b2.type = b1.type - ) - """ - params = [] - - if account_id: - current_balances_query += " AND b1.account_id = ?" - params.append(account_id) - - cursor.execute(current_balances_query, params) - current_balances = { - (row["account_id"], row["type"]): { - "amount": row["amount"], - "currency": row["currency"], - } - for row in cursor.fetchall() - } - - # Get transactions for the specified period, ordered by date descending - from datetime import datetime, timedelta - - cutoff_date = (datetime.now() - timedelta(days=days)).isoformat() - - transactions_query = """ - SELECT accountId, transactionDate, transactionValue - FROM transactions - WHERE transactionDate >= ? - """ - - if account_id: - transactions_query += " AND accountId = ?" - params = [cutoff_date, account_id] - else: - params = [cutoff_date] - - transactions_query += " ORDER BY transactionDate DESC" - - cursor.execute(transactions_query, params) - transactions = cursor.fetchall() - - # Calculate historical balances by working backwards from current balance - historical_balances = [] - account_running_balances: dict[str, dict[str, float]] = {} - - # Initialize running balances with current balances - for (acc_id, balance_type), balance_info in current_balances.items(): - if acc_id not in account_running_balances: - account_running_balances[acc_id] = {} - account_running_balances[acc_id][balance_type] = balance_info["amount"] - - # Group transactions by date - from collections import defaultdict - - transactions_by_date = defaultdict(list) - - for txn in transactions: - date_str = txn["transactionDate"][:10] # Extract just the date part - transactions_by_date[date_str].append(txn) - - # Generate historical balance points - # Start from today and work backwards - current_date = datetime.now().date() - - for day_offset in range(0, days, 7): # Sample every 7 days for performance - target_date = current_date - timedelta(days=day_offset) - target_date_str = target_date.isoformat() - - # For each account, create balance entries - for acc_id in account_running_balances: - for balance_type in [ - "closingBooked" - ]: # Focus on closingBooked for the chart - if balance_type in account_running_balances[acc_id]: - balance_amount = account_running_balances[acc_id][balance_type] - currency = current_balances.get((acc_id, balance_type), {}).get( - "currency", "EUR" - ) - - historical_balances.append( - { - "id": f"{acc_id}_{balance_type}_{target_date_str}", - "account_id": acc_id, - "balance_amount": balance_amount, - "balance_type": balance_type, - "currency": currency, - "reference_date": target_date_str, - "created_at": None, - "updated_at": None, - } - ) - - # Subtract transactions that occurred on this date and later dates - # to simulate going back in time - for date_str in list(transactions_by_date.keys()): - if date_str >= target_date_str: - for txn in transactions_by_date[date_str]: - acc_id = txn["accountId"] - amount = txn["transactionValue"] - - if acc_id in account_running_balances: - for balance_type in account_running_balances[acc_id]: - account_running_balances[acc_id][balance_type] -= amount - - # Remove processed transactions to avoid double-processing - del transactions_by_date[date_str] - - conn.close() - - # Sort by date for proper chronological order - historical_balances.sort(key=lambda x: x["reference_date"]) - - return historical_balances - - except Exception as e: - conn.close() - raise e diff --git a/leggen/services/database_service.py b/leggen/services/database_service.py index 035defc..cf990df 100644 --- a/leggen/services/database_service.py +++ b/leggen/services/database_service.py @@ -1,18 +1,21 @@ -from datetime import datetime +from datetime import datetime, timedelta from typing import List, Dict, Any, Optional import sqlite3 +import json +from collections import defaultdict from loguru import logger from leggen.utils.config import config -import leggen.database.sqlite as sqlite_db from leggen.utils.paths import path_manager +from leggen.services.transaction_processor import TransactionProcessor class DatabaseService: def __init__(self): self.db_config = config.database_config self.sqlite_enabled = self.db_config.get("sqlite", True) + self.transaction_processor = TransactionProcessor() async def persist_balance( self, account_id: str, balance_data: Dict[str, Any] @@ -41,79 +44,9 @@ class DatabaseService: transaction_data: Dict[str, Any], ) -> List[Dict[str, Any]]: """Process raw transaction data into standardized format""" - transactions = [] - - # Process booked transactions - for transaction in transaction_data.get("transactions", {}).get("booked", []): - processed = self._process_single_transaction( - account_id, account_info, transaction, "booked" - ) - transactions.append(processed) - - # Process pending transactions - for transaction in transaction_data.get("transactions", {}).get("pending", []): - processed = self._process_single_transaction( - account_id, account_info, transaction, "pending" - ) - transactions.append(processed) - - return transactions - - def _process_single_transaction( - self, - account_id: str, - account_info: Dict[str, Any], - transaction: Dict[str, Any], - status: str, - ) -> Dict[str, Any]: - """Process a single transaction into standardized format""" - # Extract dates - booked_date = transaction.get("bookingDateTime") or transaction.get( - "bookingDate" + return self.transaction_processor.process_transactions( + account_id, account_info, transaction_data ) - value_date = transaction.get("valueDateTime") or transaction.get("valueDate") - - if booked_date and value_date: - min_date = min( - datetime.fromisoformat(booked_date), datetime.fromisoformat(value_date) - ) - else: - date_str = booked_date or value_date - if not date_str: - raise ValueError("No valid date found in transaction") - min_date = datetime.fromisoformat(date_str) - - # Extract amount and currency - transaction_amount = transaction.get("transactionAmount", {}) - amount = float(transaction_amount.get("amount", 0)) - currency = transaction_amount.get("currency", "") - - # Extract description - description = transaction.get( - "remittanceInformationUnstructured", - ",".join(transaction.get("remittanceInformationUnstructuredArray", [])), - ) - - # Extract transaction IDs - transactionId is now primary, internalTransactionId is reference - transaction_id = transaction.get("transactionId") - internal_transaction_id = transaction.get("internalTransactionId") - - if not transaction_id: - raise ValueError("Transaction missing required transactionId field") - - return { - "accountId": account_id, - "transactionId": transaction_id, - "internalTransactionId": internal_transaction_id, - "institutionId": account_info["institution_id"], - "iban": account_info.get("iban", "N/A"), - "transactionDate": min_date, - "description": description, - "transactionValue": amount, - "transactionCurrency": currency, - "transactionStatus": status, - "rawTransaction": transaction, - } async def get_transactions_from_db( self, @@ -132,7 +65,7 @@ class DatabaseService: return [] try: - transactions = sqlite_db.get_transactions( + transactions = self._get_transactions( account_id=account_id, limit=limit, # Pass limit as-is, None means no limit offset=offset or 0, @@ -172,7 +105,7 @@ class DatabaseService: # Remove None values filters = {k: v for k, v in filters.items() if v is not None} - count = sqlite_db.get_transaction_count(account_id=account_id, **filters) + count = self._get_transaction_count(account_id=account_id, **filters) logger.debug(f"Total transaction count: {count}") return count except Exception as e: @@ -188,7 +121,7 @@ class DatabaseService: return [] try: - balances = sqlite_db.get_balances(account_id=account_id) + balances = self._get_balances(account_id=account_id) logger.debug(f"Retrieved {len(balances)} balances from database") return balances except Exception as e: @@ -204,9 +137,7 @@ class DatabaseService: return [] try: - balances = sqlite_db.get_historical_balances( - account_id=account_id, days=days - ) + balances = self._get_historical_balances(account_id=account_id, days=days) logger.debug( f"Retrieved {len(balances)} historical balance points from database" ) @@ -223,7 +154,7 @@ class DatabaseService: return None try: - summary = sqlite_db.get_account_summary(account_id) + summary = self._get_account_summary(account_id) if summary: logger.debug( f"Retrieved account summary from database for {account_id}" @@ -250,7 +181,7 @@ class DatabaseService: return [] try: - accounts = sqlite_db.get_accounts(account_ids=account_ids) + accounts = self._get_accounts(account_ids=account_ids) logger.debug(f"Retrieved {len(accounts)} accounts from database") return accounts except Exception as e: @@ -266,7 +197,7 @@ class DatabaseService: return None try: - account = sqlite_db.get_account(account_id) + account = self._get_account(account_id) if account: logger.debug( f"Retrieved account details from database for {account_id}" @@ -893,7 +824,7 @@ class DatabaseService: """Persist account details to SQLite""" try: # Use the sqlite_db module function - sqlite_db.persist_account(account_data) + self._persist_account(account_data) logger.info( f"Persisted account details to SQLite for account {account_data['id']}" @@ -901,3 +832,453 @@ class DatabaseService: except Exception as e: logger.error(f"Failed to persist account details to SQLite: {e}") raise + + def _get_transactions( + self, + account_id=None, + limit=100, + offset=0, + date_from=None, + date_to=None, + min_amount=None, + max_amount=None, + search=None, + ): + """Get transactions from SQLite database with optional filtering""" + db_path = path_manager.get_database_path() + if not db_path.exists(): + return [] + conn = sqlite3.connect(str(db_path)) + conn.row_factory = sqlite3.Row # Enable dict-like access + cursor = conn.cursor() + + # Build query with filters + query = "SELECT * FROM transactions WHERE 1=1" + params = [] + + if account_id: + query += " AND accountId = ?" + params.append(account_id) + + if date_from: + query += " AND transactionDate >= ?" + params.append(date_from) + + if date_to: + query += " AND transactionDate <= ?" + params.append(date_to) + + if min_amount is not None: + query += " AND transactionValue >= ?" + params.append(min_amount) + + if max_amount is not None: + query += " AND transactionValue <= ?" + params.append(max_amount) + + if search: + query += " AND description LIKE ?" + params.append(f"%{search}%") + + # Add ordering and pagination + query += " ORDER BY transactionDate DESC" + + if limit: + query += " LIMIT ?" + params.append(limit) + + if offset: + query += " OFFSET ?" + params.append(offset) + + try: + cursor.execute(query, params) + rows = cursor.fetchall() + + # Convert to list of dicts and parse JSON fields + transactions = [] + for row in rows: + transaction = dict(row) + if transaction["rawTransaction"]: + transaction["rawTransaction"] = json.loads( + transaction["rawTransaction"] + ) + transactions.append(transaction) + + conn.close() + return transactions + + except Exception as e: + conn.close() + raise e + + def _get_balances(self, account_id=None): + """Get latest balances from SQLite database""" + db_path = path_manager.get_database_path() + if not db_path.exists(): + return [] + conn = sqlite3.connect(str(db_path)) + conn.row_factory = sqlite3.Row + cursor = conn.cursor() + + # Get latest balance for each account_id and type combination + query = """ + SELECT * FROM balances b1 + WHERE b1.timestamp = ( + SELECT MAX(b2.timestamp) + FROM balances b2 + WHERE b2.account_id = b1.account_id AND b2.type = b1.type + ) + """ + params = [] + + if account_id: + query += " AND b1.account_id = ?" + params.append(account_id) + + query += " ORDER BY b1.account_id, b1.type" + + try: + cursor.execute(query, params) + rows = cursor.fetchall() + + balances = [dict(row) for row in rows] + conn.close() + return balances + + except Exception as e: + conn.close() + raise e + + def _get_account_summary(self, account_id): + """Get basic account info from transactions table (avoids GoCardless API call)""" + db_path = path_manager.get_database_path() + if not db_path.exists(): + return None + conn = sqlite3.connect(str(db_path)) + conn.row_factory = sqlite3.Row + cursor = conn.cursor() + + try: + # Get account info from most recent transaction + cursor.execute( + """ + SELECT DISTINCT accountId, institutionId, iban + FROM transactions + WHERE accountId = ? + ORDER BY transactionDate DESC + LIMIT 1 + """, + (account_id,), + ) + + row = cursor.fetchone() + conn.close() + + if row: + return dict(row) + return None + + except Exception as e: + conn.close() + raise e + + def _get_transaction_count(self, account_id=None, **filters): + """Get total count of transactions matching filters""" + db_path = path_manager.get_database_path() + if not db_path.exists(): + return 0 + conn = sqlite3.connect(str(db_path)) + cursor = conn.cursor() + + query = "SELECT COUNT(*) FROM transactions WHERE 1=1" + params = [] + + if account_id: + query += " AND accountId = ?" + params.append(account_id) + + # Add same filters as get_transactions + if filters.get("date_from"): + query += " AND transactionDate >= ?" + params.append(filters["date_from"]) + + if filters.get("date_to"): + query += " AND transactionDate <= ?" + params.append(filters["date_to"]) + + if filters.get("min_amount") is not None: + query += " AND transactionValue >= ?" + params.append(filters["min_amount"]) + + if filters.get("max_amount") is not None: + query += " AND transactionValue <= ?" + params.append(filters["max_amount"]) + + if filters.get("search"): + query += " AND description LIKE ?" + params.append(f"%{filters['search']}%") + + try: + cursor.execute(query, params) + count = cursor.fetchone()[0] + conn.close() + return count + + except Exception as e: + conn.close() + raise e + + def _persist_account(self, account_data: dict): + """Persist account details to SQLite database""" + db_path = path_manager.get_database_path() + conn = sqlite3.connect(str(db_path)) + cursor = conn.cursor() + + # Create the accounts table if it doesn't exist + cursor.execute( + """CREATE TABLE IF NOT EXISTS accounts ( + id TEXT PRIMARY KEY, + institution_id TEXT, + status TEXT, + iban TEXT, + name TEXT, + currency TEXT, + created DATETIME, + last_accessed DATETIME, + last_updated DATETIME + )""" + ) + + # Create indexes for accounts table + cursor.execute( + """CREATE INDEX IF NOT EXISTS idx_accounts_institution_id + ON accounts(institution_id)""" + ) + cursor.execute( + """CREATE INDEX IF NOT EXISTS idx_accounts_status + ON accounts(status)""" + ) + + try: + # Insert or replace account data + cursor.execute( + """INSERT OR REPLACE INTO accounts ( + id, + institution_id, + status, + iban, + name, + currency, + created, + last_accessed, + last_updated + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""", + ( + account_data["id"], + account_data["institution_id"], + account_data["status"], + account_data.get("iban"), + account_data.get("name"), + account_data.get("currency"), + account_data["created"], + account_data.get("last_accessed"), + account_data.get("last_updated", account_data["created"]), + ), + ) + conn.commit() + conn.close() + + return account_data + + except Exception as e: + conn.close() + raise e + + def _get_accounts(self, account_ids=None): + """Get account details from SQLite database""" + db_path = path_manager.get_database_path() + if not db_path.exists(): + return [] + conn = sqlite3.connect(str(db_path)) + conn.row_factory = sqlite3.Row + cursor = conn.cursor() + + query = "SELECT * FROM accounts" + params = [] + + if account_ids: + placeholders = ",".join("?" * len(account_ids)) + query += f" WHERE id IN ({placeholders})" + params.extend(account_ids) + + query += " ORDER BY created DESC" + + try: + cursor.execute(query, params) + rows = cursor.fetchall() + + accounts = [dict(row) for row in rows] + conn.close() + return accounts + + except Exception as e: + conn.close() + raise e + + def _get_account(self, account_id: str): + """Get specific account details from SQLite database""" + db_path = path_manager.get_database_path() + if not db_path.exists(): + return None + conn = sqlite3.connect(str(db_path)) + conn.row_factory = sqlite3.Row + cursor = conn.cursor() + + try: + cursor.execute("SELECT * FROM accounts WHERE id = ?", (account_id,)) + row = cursor.fetchone() + conn.close() + + if row: + return dict(row) + return None + + except Exception as e: + conn.close() + raise e + + def _get_historical_balances(self, account_id=None, days=365): + """Get historical balance progression based on transaction history""" + db_path = path_manager.get_database_path() + if not db_path.exists(): + return [] + + conn = sqlite3.connect(str(db_path)) + conn.row_factory = sqlite3.Row + cursor = conn.cursor() + + try: + # Get current balance for each account/type to use as the final balance + current_balances_query = """ + SELECT account_id, type, amount, currency + FROM balances b1 + WHERE b1.timestamp = ( + SELECT MAX(b2.timestamp) + FROM balances b2 + WHERE b2.account_id = b1.account_id AND b2.type = b1.type + ) + """ + params = [] + + if account_id: + current_balances_query += " AND b1.account_id = ?" + params.append(account_id) + + cursor.execute(current_balances_query, params) + current_balances = { + (row["account_id"], row["type"]): { + "amount": row["amount"], + "currency": row["currency"], + } + for row in cursor.fetchall() + } + + # Get transactions for the specified period, ordered by date descending + cutoff_date = (datetime.now() - timedelta(days=days)).isoformat() + + transactions_query = """ + SELECT accountId, transactionDate, transactionValue + FROM transactions + WHERE transactionDate >= ? + """ + + if account_id: + transactions_query += " AND accountId = ?" + params = [cutoff_date, account_id] + else: + params = [cutoff_date] + + transactions_query += " ORDER BY transactionDate DESC" + + cursor.execute(transactions_query, params) + transactions = cursor.fetchall() + + # Calculate historical balances by working backwards from current balance + historical_balances = [] + account_running_balances: dict[str, dict[str, float]] = {} + + # Initialize running balances with current balances + for (acc_id, balance_type), balance_info in current_balances.items(): + if acc_id not in account_running_balances: + account_running_balances[acc_id] = {} + account_running_balances[acc_id][balance_type] = balance_info["amount"] + + # Group transactions by date + transactions_by_date = defaultdict(list) + + for txn in transactions: + date_str = txn["transactionDate"][:10] # Extract just the date part + transactions_by_date[date_str].append(txn) + + # Generate historical balance points + # Start from today and work backwards + current_date = datetime.now().date() + + for day_offset in range(0, days, 7): # Sample every 7 days for performance + target_date = current_date - timedelta(days=day_offset) + target_date_str = target_date.isoformat() + + # For each account, create balance entries + for acc_id in account_running_balances: + for balance_type in [ + "closingBooked" + ]: # Focus on closingBooked for the chart + if balance_type in account_running_balances[acc_id]: + balance_amount = account_running_balances[acc_id][ + balance_type + ] + currency = current_balances.get( + (acc_id, balance_type), {} + ).get("currency", "EUR") + + historical_balances.append( + { + "id": f"{acc_id}_{balance_type}_{target_date_str}", + "account_id": acc_id, + "balance_amount": balance_amount, + "balance_type": balance_type, + "currency": currency, + "reference_date": target_date_str, + "created_at": None, + "updated_at": None, + } + ) + + # Subtract transactions that occurred on this date and later dates + # to simulate going back in time + for date_str in list(transactions_by_date.keys()): + if date_str >= target_date_str: + for txn in transactions_by_date[date_str]: + acc_id = txn["accountId"] + amount = txn["transactionValue"] + + if acc_id in account_running_balances: + for balance_type in account_running_balances[acc_id]: + account_running_balances[acc_id][balance_type] -= ( + amount + ) + + # Remove processed transactions to avoid double-processing + del transactions_by_date[date_str] + + conn.close() + + # Sort by date for proper chronological order + historical_balances.sort(key=lambda x: x["reference_date"]) + + return historical_balances + + except Exception as e: + conn.close() + raise e diff --git a/leggen/services/transaction_processor.py b/leggen/services/transaction_processor.py new file mode 100644 index 0000000..fadda42 --- /dev/null +++ b/leggen/services/transaction_processor.py @@ -0,0 +1,87 @@ +from datetime import datetime +from typing import List, Dict, Any + + +class TransactionProcessor: + """Handles processing and transformation of raw transaction data""" + + def process_transactions( + self, + account_id: str, + account_info: Dict[str, Any], + transaction_data: Dict[str, Any], + ) -> List[Dict[str, Any]]: + """Process raw transaction data into standardized format""" + transactions = [] + + # Process booked transactions + for transaction in transaction_data.get("transactions", {}).get("booked", []): + processed = self._process_single_transaction( + account_id, account_info, transaction, "booked" + ) + transactions.append(processed) + + # Process pending transactions + for transaction in transaction_data.get("transactions", {}).get("pending", []): + processed = self._process_single_transaction( + account_id, account_info, transaction, "pending" + ) + transactions.append(processed) + + return transactions + + def _process_single_transaction( + self, + account_id: str, + account_info: Dict[str, Any], + transaction: Dict[str, Any], + status: str, + ) -> Dict[str, Any]: + """Process a single transaction into standardized format""" + # Extract dates + booked_date = transaction.get("bookingDateTime") or transaction.get( + "bookingDate" + ) + value_date = transaction.get("valueDateTime") or transaction.get("valueDate") + + if booked_date and value_date: + min_date = min( + datetime.fromisoformat(booked_date), datetime.fromisoformat(value_date) + ) + else: + date_str = booked_date or value_date + if not date_str: + raise ValueError("No valid date found in transaction") + min_date = datetime.fromisoformat(date_str) + + # Extract amount and currency + transaction_amount = transaction.get("transactionAmount", {}) + amount = float(transaction_amount.get("amount", 0)) + currency = transaction_amount.get("currency", "") + + # Extract description + description = transaction.get( + "remittanceInformationUnstructured", + ",".join(transaction.get("remittanceInformationUnstructuredArray", [])), + ) + + # Extract transaction IDs - transactionId is now primary, internalTransactionId is reference + transaction_id = transaction.get("transactionId") + internal_transaction_id = transaction.get("internalTransactionId") + + if not transaction_id: + raise ValueError("Transaction missing required transactionId field") + + return { + "accountId": account_id, + "transactionId": transaction_id, + "internalTransactionId": internal_transaction_id, + "institutionId": account_info["institution_id"], + "iban": account_info.get("iban", "N/A"), + "transactionDate": min_date, + "description": description, + "transactionValue": amount, + "transactionCurrency": currency, + "transactionStatus": status, + "rawTransaction": transaction, + } diff --git a/leggen/utils/database.py b/leggen/utils/database.py deleted file mode 100644 index a0a68b7..0000000 --- a/leggen/utils/database.py +++ /dev/null @@ -1,132 +0,0 @@ -from datetime import datetime - -import click - -import leggen.database.sqlite as sqlite_engine -from leggen.utils.text import info, warning - - -def persist_balance(ctx: click.Context, account: str, balance: dict) -> None: - sqlite = ctx.obj.get("database", {}).get("sqlite", True) - - if not sqlite: - warning("SQLite database is disabled, skipping balance saving") - return - - info(f"[{account}] Fetched balances, saving to SQLite") - sqlite_engine.persist_balances(ctx, balance) - - -def persist_transactions(ctx: click.Context, account: str, transactions: list) -> list: - sqlite = ctx.obj.get("database", {}).get("sqlite", True) - - if not sqlite: - warning("SQLite database is disabled, skipping transaction saving") - # WARNING: This will return the transactions list as is, without saving it to any database - # Possible duplicate notifications will be sent if the filters are enabled - return transactions - - info(f"[{account}] Fetched {len(transactions)} transactions, saving to SQLite") - return sqlite_engine.persist_transactions(ctx, account, transactions) - - -def save_transactions(ctx: click.Context, account: str) -> list: - import requests - - api_url = ctx.obj.get("api_url", "http://localhost:8000") - - info(f"[{account}] Getting account details") - res = requests.get(f"{api_url}/accounts/{account}") - res.raise_for_status() - account_info = res.json() - - info(f"[{account}] Getting transactions") - transactions = [] - - res = requests.get(f"{api_url}/accounts/{account}/transactions/") - res.raise_for_status() - account_transactions = res.json().get("transactions", []) - - for transaction in account_transactions.get("booked", []): - booked_date = transaction.get("bookingDateTime") or transaction.get( - "bookingDate" - ) - value_date = transaction.get("valueDateTime") or transaction.get("valueDate") - if booked_date and value_date: - min_date = min( - datetime.fromisoformat(booked_date), datetime.fromisoformat(value_date) - ) - else: - min_date = datetime.fromisoformat(booked_date or value_date) - - transactionValue = float( - transaction.get("transactionAmount", {}).get("amount", 0) - ) - currency = transaction.get("transactionAmount", {}).get("currency", "") - - description = transaction.get( - "remittanceInformationUnstructured", - ",".join(transaction.get("remittanceInformationUnstructuredArray", [])), - ) - - # Extract transaction ID, using transactionId as fallback when internalTransactionId is missing - transaction_id = transaction.get("internalTransactionId") or transaction.get( - "transactionId" - ) - - t = { - "internalTransactionId": transaction_id, - "institutionId": account_info["institution_id"], - "iban": account_info.get("iban", "N/A"), - "transactionDate": min_date, - "description": description, - "transactionValue": transactionValue, - "transactionCurrency": currency, - "transactionStatus": "booked", - "accountId": account, - "rawTransaction": transaction, - } - transactions.append(t) - - for transaction in account_transactions.get("pending", []): - booked_date = transaction.get("bookingDateTime") or transaction.get( - "bookingDate" - ) - value_date = transaction.get("valueDateTime") or transaction.get("valueDate") - if booked_date and value_date: - min_date = min( - datetime.fromisoformat(booked_date), datetime.fromisoformat(value_date) - ) - else: - min_date = datetime.fromisoformat(booked_date or value_date) - - transactionValue = float( - transaction.get("transactionAmount", {}).get("amount", 0) - ) - currency = transaction.get("transactionAmount", {}).get("currency", "") - - description = transaction.get( - "remittanceInformationUnstructured", - ",".join(transaction.get("remittanceInformationUnstructuredArray", [])), - ) - - # Extract transaction ID, using transactionId as fallback when internalTransactionId is missing - transaction_id = transaction.get("internalTransactionId") or transaction.get( - "transactionId" - ) - - t = { - "internalTransactionId": transaction_id, - "institutionId": account_info["institution_id"], - "iban": account_info.get("iban", "N/A"), - "transactionDate": min_date, - "description": description, - "transactionValue": transactionValue, - "transactionCurrency": currency, - "transactionStatus": "pending", - "accountId": account, - "rawTransaction": transaction, - } - transactions.append(t) - - return persist_transactions(ctx, account, transactions) diff --git a/tests/unit/test_configurable_paths.py b/tests/unit/test_configurable_paths.py index 452aeb7..40915e2 100644 --- a/tests/unit/test_configurable_paths.py +++ b/tests/unit/test_configurable_paths.py @@ -7,11 +7,7 @@ from pathlib import Path from unittest.mock import patch from leggen.utils.paths import path_manager -from leggen.database.sqlite import persist_balances, get_balances - - -class MockContext: - """Mock context for testing.""" +from leggen.services.database_service import DatabaseService @pytest.mark.unit @@ -109,24 +105,31 @@ class TestConfigurablePaths: # Set custom database path path_manager.set_database_path(test_db_path) - # Test database operations - ctx = MockContext() - balance = { - "account_id": "test-account", - "bank": "TEST_BANK", - "status": "active", + # Test database operations using DatabaseService + database_service = DatabaseService() + balance_data = { + "balances": [ + { + "balanceAmount": {"amount": "1000.0", "currency": "EUR"}, + "balanceType": "available", + } + ], + "institution_id": "TEST_BANK", + "account_status": "active", "iban": "TEST_IBAN", - "amount": 1000.0, - "currency": "EUR", - "type": "available", - "timestamp": "2023-01-01T00:00:00", } - # Persist balance - persist_balances(ctx, balance) + # Use the internal balance persistence method since the test needs direct database access + import asyncio + + asyncio.run( + database_service._persist_balance_sqlite("test-account", balance_data) + ) # Retrieve balances - balances = get_balances() + balances = asyncio.run( + database_service.get_balances_from_db("test-account") + ) assert len(balances) == 1 assert balances[0]["account_id"] == "test-account" diff --git a/tests/unit/test_database_service.py b/tests/unit/test_database_service.py index 8061159..cc158e9 100644 --- a/tests/unit/test_database_service.py +++ b/tests/unit/test_database_service.py @@ -83,7 +83,9 @@ class TestDatabaseService: self, database_service, sample_transactions_db_format ): """Test successful retrieval of transactions from database.""" - with patch("leggen.database.sqlite.get_transactions") as mock_get_transactions: + with patch.object( + database_service, "_get_transactions" + ) as mock_get_transactions: mock_get_transactions.return_value = sample_transactions_db_format result = await database_service.get_transactions_from_db( @@ -107,7 +109,9 @@ class TestDatabaseService: self, database_service, sample_transactions_db_format ): """Test retrieving transactions with filters.""" - with patch("leggen.database.sqlite.get_transactions") as mock_get_transactions: + with patch.object( + database_service, "_get_transactions" + ) as mock_get_transactions: mock_get_transactions.return_value = sample_transactions_db_format result = await database_service.get_transactions_from_db( @@ -143,7 +147,9 @@ class TestDatabaseService: async def test_get_transactions_from_db_error(self, database_service): """Test handling error when getting transactions.""" - with patch("leggen.database.sqlite.get_transactions") as mock_get_transactions: + with patch.object( + database_service, "_get_transactions" + ) as mock_get_transactions: mock_get_transactions.side_effect = Exception("Database error") result = await database_service.get_transactions_from_db() @@ -152,7 +158,7 @@ class TestDatabaseService: async def test_get_transaction_count_from_db_success(self, database_service): """Test successful retrieval of transaction count.""" - with patch("leggen.database.sqlite.get_transaction_count") as mock_get_count: + with patch.object(database_service, "_get_transaction_count") as mock_get_count: mock_get_count.return_value = 42 result = await database_service.get_transaction_count_from_db( @@ -164,7 +170,7 @@ class TestDatabaseService: async def test_get_transaction_count_from_db_with_filters(self, database_service): """Test getting transaction count with filters.""" - with patch("leggen.database.sqlite.get_transaction_count") as mock_get_count: + with patch.object(database_service, "_get_transaction_count") as mock_get_count: mock_get_count.return_value = 15 result = await database_service.get_transaction_count_from_db( @@ -194,7 +200,7 @@ class TestDatabaseService: async def test_get_transaction_count_from_db_error(self, database_service): """Test handling error when getting count.""" - with patch("leggen.database.sqlite.get_transaction_count") as mock_get_count: + with patch.object(database_service, "_get_transaction_count") as mock_get_count: mock_get_count.side_effect = Exception("Database error") result = await database_service.get_transaction_count_from_db() @@ -205,7 +211,7 @@ class TestDatabaseService: self, database_service, sample_balances_db_format ): """Test successful retrieval of balances from database.""" - with patch("leggen.database.sqlite.get_balances") as mock_get_balances: + with patch.object(database_service, "_get_balances") as mock_get_balances: mock_get_balances.return_value = sample_balances_db_format result = await database_service.get_balances_from_db( @@ -227,7 +233,7 @@ class TestDatabaseService: async def test_get_balances_from_db_error(self, database_service): """Test handling error when getting balances.""" - with patch("leggen.database.sqlite.get_balances") as mock_get_balances: + with patch.object(database_service, "_get_balances") as mock_get_balances: mock_get_balances.side_effect = Exception("Database error") result = await database_service.get_balances_from_db() @@ -242,7 +248,7 @@ class TestDatabaseService: "iban": "LT313250081177977789", } - with patch("leggen.database.sqlite.get_account_summary") as mock_get_summary: + with patch.object(database_service, "_get_account_summary") as mock_get_summary: mock_get_summary.return_value = mock_summary result = await database_service.get_account_summary_from_db( @@ -262,7 +268,7 @@ class TestDatabaseService: async def test_get_account_summary_from_db_error(self, database_service): """Test handling error when getting summary.""" - with patch("leggen.database.sqlite.get_account_summary") as mock_get_summary: + with patch.object(database_service, "_get_account_summary") as mock_get_summary: mock_get_summary.side_effect = Exception("Database error") result = await database_service.get_account_summary_from_db( diff --git a/tests/unit/test_sqlite_database.py b/tests/unit/test_sqlite_database.py deleted file mode 100644 index dcf88d5..0000000 --- a/tests/unit/test_sqlite_database.py +++ /dev/null @@ -1,364 +0,0 @@ -"""Tests for SQLite database functions.""" - -import pytest -import tempfile -from pathlib import Path -from unittest.mock import patch -from datetime import datetime - -import leggen.database.sqlite as sqlite_db - - -@pytest.fixture -def temp_db_path(): - """Create a temporary database file for testing.""" - import uuid - - with tempfile.TemporaryDirectory() as tmpdir: - db_path = Path(tmpdir) / f"test_{uuid.uuid4().hex}.db" - yield db_path - - -@pytest.fixture -def mock_home_db_path(temp_db_path): - """Mock the database path to use temp file.""" - from leggen.utils.paths import path_manager - - # Set the path manager to use the temporary database - original_database_path = path_manager._database_path - path_manager.set_database_path(temp_db_path) - - try: - yield temp_db_path - finally: - # Restore original path - path_manager._database_path = original_database_path - - -@pytest.fixture -def sample_transactions(): - """Sample transaction data for testing.""" - return [ - { - "transactionId": "bank-txn-001", # NEW: stable bank-provided ID - "internalTransactionId": "txn-001", - "institutionId": "REVOLUT_REVOLT21", - "iban": "LT313250081177977789", - "transactionDate": datetime(2025, 9, 1, 9, 30), - "description": "Coffee Shop Payment", - "transactionValue": -10.50, - "transactionCurrency": "EUR", - "transactionStatus": "booked", - "accountId": "test-account-123", - "rawTransaction": {"transactionId": "bank-txn-001", "some": "data"}, - }, - { - "transactionId": "bank-txn-002", # NEW: stable bank-provided ID - "internalTransactionId": "txn-002", - "institutionId": "REVOLUT_REVOLT21", - "iban": "LT313250081177977789", - "transactionDate": datetime(2025, 9, 2, 14, 15), - "description": "Grocery Store", - "transactionValue": -45.30, - "transactionCurrency": "EUR", - "transactionStatus": "booked", - "accountId": "test-account-123", - "rawTransaction": {"transactionId": "bank-txn-002", "other": "data"}, - }, - ] - - -@pytest.fixture -def sample_balance(): - """Sample balance data for testing.""" - return { - "account_id": "test-account-123", - "bank": "REVOLUT_REVOLT21", - "status": "active", - "iban": "LT313250081177977789", - "amount": 1000.00, - "currency": "EUR", - "type": "interimAvailable", - "timestamp": datetime.now(), - } - - -class MockContext: - """Mock context for testing.""" - - -class TestSQLiteDatabase: - """Test SQLite database operations.""" - - def test_persist_transactions(self, mock_home_db_path, sample_transactions): - """Test persisting transactions to database.""" - ctx = MockContext() - - # Persist transactions - new_transactions = sqlite_db.persist_transactions( - ctx, "test-account-123", sample_transactions - ) - - # Should return all transactions as new - assert len(new_transactions) == 2 - assert new_transactions[0]["internalTransactionId"] == "txn-001" - - def test_persist_transactions_duplicates( - self, mock_home_db_path, sample_transactions - ): - """Test handling duplicate transactions.""" - ctx = MockContext() - - # Insert transactions twice - new_transactions_1 = sqlite_db.persist_transactions( - ctx, "test-account-123", sample_transactions - ) - new_transactions_2 = sqlite_db.persist_transactions( - ctx, "test-account-123", sample_transactions - ) - - # First time should return all as new - assert len(new_transactions_1) == 2 - # Second time should also return all (INSERT OR REPLACE behavior with composite key) - assert len(new_transactions_2) == 2 - - def test_get_transactions_all(self, mock_home_db_path, sample_transactions): - """Test retrieving all transactions.""" - ctx = MockContext() - - # Insert test data - sqlite_db.persist_transactions(ctx, "test-account-123", sample_transactions) - - # Get all transactions - transactions = sqlite_db.get_transactions() - - assert len(transactions) == 2 - assert ( - transactions[0]["internalTransactionId"] == "txn-002" - ) # Ordered by date DESC - assert transactions[1]["internalTransactionId"] == "txn-001" - - def test_get_transactions_filtered_by_account( - self, mock_home_db_path, sample_transactions - ): - """Test filtering transactions by account ID.""" - ctx = MockContext() - - # Add transaction for different account - other_account_transaction = sample_transactions[0].copy() - other_account_transaction["internalTransactionId"] = "txn-003" - other_account_transaction["accountId"] = "other-account" - - all_transactions = sample_transactions + [other_account_transaction] - - with patch("pathlib.Path.home") as mock_home: - mock_home.return_value = mock_home_db_path.parent / ".." - - sqlite_db.persist_transactions(ctx, "test-account-123", all_transactions) - - # Filter by account - transactions = sqlite_db.get_transactions(account_id="test-account-123") - - assert len(transactions) == 2 - for txn in transactions: - assert txn["accountId"] == "test-account-123" - - def test_get_transactions_with_pagination( - self, mock_home_db_path, sample_transactions - ): - """Test transaction pagination.""" - ctx = MockContext() - - with patch("pathlib.Path.home") as mock_home: - mock_home.return_value = mock_home_db_path.parent / ".." - - sqlite_db.persist_transactions(ctx, "test-account-123", sample_transactions) - - # Get first page - transactions_page1 = sqlite_db.get_transactions(limit=1, offset=0) - assert len(transactions_page1) == 1 - - # Get second page - transactions_page2 = sqlite_db.get_transactions(limit=1, offset=1) - assert len(transactions_page2) == 1 - - # Should be different transactions - assert ( - transactions_page1[0]["internalTransactionId"] - != transactions_page2[0]["internalTransactionId"] - ) - - def test_get_transactions_with_amount_filter( - self, mock_home_db_path, sample_transactions - ): - """Test filtering transactions by amount.""" - ctx = MockContext() - - with patch("pathlib.Path.home") as mock_home: - mock_home.return_value = mock_home_db_path.parent / ".." - - sqlite_db.persist_transactions(ctx, "test-account-123", sample_transactions) - - # Filter by minimum amount (should exclude coffee shop payment) - transactions = sqlite_db.get_transactions(min_amount=-20.0) - assert len(transactions) == 1 - assert transactions[0]["transactionValue"] == -10.50 - - def test_get_transactions_with_search(self, mock_home_db_path, sample_transactions): - """Test searching transactions by description.""" - ctx = MockContext() - - with patch("pathlib.Path.home") as mock_home: - mock_home.return_value = mock_home_db_path.parent / ".." - - sqlite_db.persist_transactions(ctx, "test-account-123", sample_transactions) - - # Search for "Coffee" - transactions = sqlite_db.get_transactions(search="Coffee") - assert len(transactions) == 1 - assert "Coffee" in transactions[0]["description"] - - def test_get_transactions_empty_database(self, mock_home_db_path): - """Test getting transactions from empty database.""" - with patch("pathlib.Path.home") as mock_home: - mock_home.return_value = mock_home_db_path.parent / ".." - - transactions = sqlite_db.get_transactions() - assert transactions == [] - - def test_get_transactions_nonexistent_database(self): - """Test getting transactions when database doesn't exist.""" - with patch("pathlib.Path.home") as mock_home: - mock_home.return_value = Path("/nonexistent") - - transactions = sqlite_db.get_transactions() - assert transactions == [] - - def test_persist_balances(self, mock_home_db_path, sample_balance): - """Test persisting balance data.""" - ctx = MockContext() - - with patch("pathlib.Path.home") as mock_home: - mock_home.return_value = mock_home_db_path.parent / ".." - - result = sqlite_db.persist_balances(ctx, sample_balance) - - # Should return the balance data - assert result["account_id"] == "test-account-123" - - def test_get_balances(self, mock_home_db_path, sample_balance): - """Test retrieving balances.""" - ctx = MockContext() - - with patch("pathlib.Path.home") as mock_home: - mock_home.return_value = mock_home_db_path.parent / ".." - - # Insert test balance - sqlite_db.persist_balances(ctx, sample_balance) - - # Get balances - balances = sqlite_db.get_balances() - - assert len(balances) == 1 - assert balances[0]["account_id"] == "test-account-123" - assert balances[0]["amount"] == 1000.00 - - def test_get_balances_filtered_by_account(self, mock_home_db_path, sample_balance): - """Test filtering balances by account ID.""" - ctx = MockContext() - - # Create balance for different account - other_balance = sample_balance.copy() - other_balance["account_id"] = "other-account" - - with patch("pathlib.Path.home") as mock_home: - mock_home.return_value = mock_home_db_path.parent / ".." - - sqlite_db.persist_balances(ctx, sample_balance) - sqlite_db.persist_balances(ctx, other_balance) - - # Filter by account - balances = sqlite_db.get_balances(account_id="test-account-123") - - assert len(balances) == 1 - assert balances[0]["account_id"] == "test-account-123" - - def test_get_account_summary(self, mock_home_db_path, sample_transactions): - """Test getting account summary from transactions.""" - ctx = MockContext() - - with patch("pathlib.Path.home") as mock_home: - mock_home.return_value = mock_home_db_path.parent / ".." - - sqlite_db.persist_transactions(ctx, "test-account-123", sample_transactions) - - summary = sqlite_db.get_account_summary("test-account-123") - - assert summary is not None - assert summary["accountId"] == "test-account-123" - assert summary["institutionId"] == "REVOLUT_REVOLT21" - assert summary["iban"] == "LT313250081177977789" - - def test_get_account_summary_nonexistent(self, mock_home_db_path): - """Test getting summary for nonexistent account.""" - with patch("pathlib.Path.home") as mock_home: - mock_home.return_value = mock_home_db_path.parent / ".." - - summary = sqlite_db.get_account_summary("nonexistent") - assert summary is None - - def test_get_transaction_count(self, mock_home_db_path, sample_transactions): - """Test getting transaction count.""" - ctx = MockContext() - - with patch("pathlib.Path.home") as mock_home: - mock_home.return_value = mock_home_db_path.parent / ".." - - sqlite_db.persist_transactions(ctx, "test-account-123", sample_transactions) - - # Get total count - count = sqlite_db.get_transaction_count() - assert count == 2 - - # Get count for specific account - count_filtered = sqlite_db.get_transaction_count( - account_id="test-account-123" - ) - assert count_filtered == 2 - - # Get count for nonexistent account - count_none = sqlite_db.get_transaction_count(account_id="nonexistent") - assert count_none == 0 - - def test_get_transaction_count_with_filters( - self, mock_home_db_path, sample_transactions - ): - """Test getting transaction count with filters.""" - ctx = MockContext() - - with patch("pathlib.Path.home") as mock_home: - mock_home.return_value = mock_home_db_path.parent / ".." - - sqlite_db.persist_transactions(ctx, "test-account-123", sample_transactions) - - # Filter by search - count = sqlite_db.get_transaction_count(search="Coffee") - assert count == 1 - - # Filter by amount - count = sqlite_db.get_transaction_count(min_amount=-20.0) - assert count == 1 - - def test_database_indexes_created(self, mock_home_db_path, sample_transactions): - """Test that database indexes are created properly.""" - ctx = MockContext() - - with patch("pathlib.Path.home") as mock_home: - mock_home.return_value = mock_home_db_path.parent / ".." - - # Persist transactions to create tables and indexes - sqlite_db.persist_transactions(ctx, "test-account-123", sample_transactions) - - # Get transactions to ensure we can query the table (indexes working) - transactions = sqlite_db.get_transactions(account_id="test-account-123") - assert len(transactions) == 2