From 7d9744a40e7898e5bbe52e2e9f54317aa5c1cdd6 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 11 Sep 2025 22:45:01 +0000 Subject: [PATCH] refactor(core): Integrate directory creation with database path retrieval and remove backup file. Co-authored-by: elisiariocouto <818914+elisiariocouto@users.noreply.github.com> --- leggen/database/sqlite.py | 3 - leggen/utils/paths.py | 28 +- leggend/services/database_service.py | 4 - tests/unit/test_sqlite_database.py.bak | 364 ------------------------- 4 files changed, 18 insertions(+), 381 deletions(-) delete mode 100644 tests/unit/test_sqlite_database.py.bak diff --git a/leggen/database/sqlite.py b/leggen/database/sqlite.py index 744302a..8fbfcde 100644 --- a/leggen/database/sqlite.py +++ b/leggen/database/sqlite.py @@ -11,7 +11,6 @@ from leggen.utils.paths import path_manager def persist_balances(ctx: click.Context, balance: dict): # Connect to SQLite database db_path = path_manager.get_database_path() - path_manager.ensure_database_dir_exists() conn = sqlite3.connect(str(db_path)) cursor = conn.cursor() @@ -108,7 +107,6 @@ def persist_balances(ctx: click.Context, balance: dict): def persist_transactions(ctx: click.Context, account: str, transactions: list) -> list: # Connect to SQLite database db_path = path_manager.get_database_path() - path_manager.ensure_database_dir_exists() conn = sqlite3.connect(str(db_path)) cursor = conn.cursor() @@ -404,7 +402,6 @@ def get_transaction_count(account_id=None, **filters): def persist_account(account_data: dict): """Persist account details to SQLite database""" db_path = path_manager.get_database_path() - path_manager.ensure_database_dir_exists() conn = sqlite3.connect(str(db_path)) cursor = conn.cursor() diff --git a/leggen/utils/paths.py b/leggen/utils/paths.py index 40241eb..714b81d 100644 --- a/leggen/utils/paths.py +++ b/leggen/utils/paths.py @@ -34,17 +34,21 @@ class PathManager: return self.get_config_dir() / "config.toml" def get_database_path(self) -> Path: - """Get the database file path.""" + """Get the database file path and ensure the directory exists.""" if self._database_path is not None: - return self._database_path + db_path = self._database_path + else: + # Check environment variable first + database_path = os.environ.get("LEGGEN_DATABASE_PATH") + if database_path: + db_path = Path(database_path) + else: + # Default to config_dir/leggen.db + db_path = self.get_config_dir() / "leggen.db" - # Check environment variable first - database_path = os.environ.get("LEGGEN_DATABASE_PATH") - if database_path: - return Path(database_path) - - # Default to config_dir/leggen.db - return self.get_config_dir() / "leggen.db" + # Ensure the directory exists + db_path.parent.mkdir(parents=True, exist_ok=True) + return db_path def set_database_path(self, path: Path) -> None: """Set the database file path.""" @@ -59,7 +63,11 @@ class PathManager: self.get_config_dir().mkdir(parents=True, exist_ok=True) def ensure_database_dir_exists(self) -> None: - """Ensure the database directory exists.""" + """Ensure the database directory exists. + + Note: get_database_path() now automatically ensures the directory exists, + so this method is mainly for explicit directory creation in tests. + """ self.get_database_path().parent.mkdir(parents=True, exist_ok=True) diff --git a/leggend/services/database_service.py b/leggend/services/database_service.py index 4edf192..b05cc8d 100644 --- a/leggend/services/database_service.py +++ b/leggend/services/database_service.py @@ -695,7 +695,6 @@ class DatabaseService: import sqlite3 db_path = path_manager.get_database_path() - path_manager.ensure_database_dir_exists() conn = sqlite3.connect(str(db_path)) cursor = conn.cursor() @@ -775,7 +774,6 @@ class DatabaseService: import json db_path = path_manager.get_database_path() - path_manager.ensure_database_dir_exists() conn = sqlite3.connect(str(db_path)) cursor = conn.cursor() @@ -874,8 +872,6 @@ class DatabaseService: ) -> None: """Persist account details to SQLite""" try: - path_manager.ensure_database_dir_exists() - # Use the sqlite_db module function sqlite_db.persist_account(account_data) diff --git a/tests/unit/test_sqlite_database.py.bak b/tests/unit/test_sqlite_database.py.bak deleted file mode 100644 index 0d0da67..0000000 --- a/tests/unit/test_sqlite_database.py.bak +++ /dev/null @@ -1,364 +0,0 @@ -"""Tests for SQLite database functions.""" - -import pytest -import tempfile -from pathlib import Path -from unittest.mock import patch -from datetime import datetime - -import leggen.database.sqlite as sqlite_db - - -@pytest.fixture -def temp_db_path(): - """Create a temporary database file for testing.""" - import uuid - - with tempfile.TemporaryDirectory() as tmpdir: - db_path = Path(tmpdir) / f"test_{uuid.uuid4().hex}.db" - yield db_path - - -@pytest.fixture -def mock_home_db_path(temp_db_path): - """Mock the database path to use temp file.""" - from leggen.utils.paths import path_manager - - # Set the path manager to use the temporary database - original_database_path = path_manager._database_path - path_manager.set_database_path(temp_db_path) - - try: - yield temp_db_path - finally: - # Restore original path - path_manager._database_path = original_database_path - - -@pytest.fixture -def sample_transactions(): - """Sample transaction data for testing.""" - return [ - { - "transactionId": "bank-txn-001", # NEW: stable bank-provided ID - "internalTransactionId": "txn-001", - "institutionId": "REVOLUT_REVOLT21", - "iban": "LT313250081177977789", - "transactionDate": datetime(2025, 9, 1, 9, 30), - "description": "Coffee Shop Payment", - "transactionValue": -10.50, - "transactionCurrency": "EUR", - "transactionStatus": "booked", - "accountId": "test-account-123", - "rawTransaction": {"transactionId": "bank-txn-001", "some": "data"}, - }, - { - "transactionId": "bank-txn-002", # NEW: stable bank-provided ID - "internalTransactionId": "txn-002", - "institutionId": "REVOLUT_REVOLT21", - "iban": "LT313250081177977789", - "transactionDate": datetime(2025, 9, 2, 14, 15), - "description": "Grocery Store", - "transactionValue": -45.30, - "transactionCurrency": "EUR", - "transactionStatus": "booked", - "accountId": "test-account-123", - "rawTransaction": {"transactionId": "bank-txn-002", "other": "data"}, - }, - ] - - -@pytest.fixture -def sample_balance(): - """Sample balance data for testing.""" - return { - "account_id": "test-account-123", - "bank": "REVOLUT_REVOLT21", - "status": "active", - "iban": "LT313250081177977789", - "amount": 1000.00, - "currency": "EUR", - "type": "interimAvailable", - "timestamp": datetime.now(), - } - - -class MockContext: - """Mock context for testing.""" - - -class TestSQLiteDatabase: - """Test SQLite database operations.""" - - def test_persist_transactions(self, mock_home_db_path, sample_transactions): - """Test persisting transactions to database.""" - ctx = MockContext() - - # Persist transactions - new_transactions = sqlite_db.persist_transactions( - ctx, "test-account-123", sample_transactions - ) - - # Should return all transactions as new - assert len(new_transactions) == 2 - assert new_transactions[0]["internalTransactionId"] == "txn-001" - - def test_persist_transactions_duplicates( - self, mock_home_db_path, sample_transactions - ): - """Test handling duplicate transactions.""" - ctx = MockContext() - - # Insert transactions twice - new_transactions_1 = sqlite_db.persist_transactions( - ctx, "test-account-123", sample_transactions - ) - new_transactions_2 = sqlite_db.persist_transactions( - ctx, "test-account-123", sample_transactions - ) - - # First time should return all as new - assert len(new_transactions_1) == 2 - # Second time should also return all (INSERT OR REPLACE behavior with composite key) - assert len(new_transactions_2) == 2 - - def test_get_transactions_all(self, mock_home_db_path, sample_transactions): - """Test retrieving all transactions.""" - ctx = MockContext() - - # Insert test data - sqlite_db.persist_transactions(ctx, "test-account-123", sample_transactions) - - # Get all transactions - transactions = sqlite_db.get_transactions() - - assert len(transactions) == 2 - assert ( - transactions[0]["internalTransactionId"] == "txn-002" - ) # Ordered by date DESC - assert transactions[1]["internalTransactionId"] == "txn-001" - - def test_get_transactions_filtered_by_account( - self, mock_home_db_path, sample_transactions - ): - """Test filtering transactions by account ID.""" - ctx = MockContext() - - # Add transaction for different account - other_account_transaction = sample_transactions[0].copy() - other_account_transaction["internalTransactionId"] = "txn-003" - other_account_transaction["accountId"] = "other-account" - - all_transactions = sample_transactions + [other_account_transaction] - - with patch("pathlib.Path.home") as mock_home: - mock_home.return_value = mock_home_db_path.parent / ".." - - sqlite_db.persist_transactions(ctx, "test-account-123", all_transactions) - - # Filter by account - transactions = sqlite_db.get_transactions(account_id="test-account-123") - - assert len(transactions) == 2 - for txn in transactions: - assert txn["accountId"] == "test-account-123" - - def test_get_transactions_with_pagination( - self, mock_home_db_path, sample_transactions - ): - """Test transaction pagination.""" - ctx = MockContext() - - with patch("pathlib.Path.home") as mock_home: - mock_home.return_value = mock_home_db_path.parent / ".." - - sqlite_db.persist_transactions(ctx, "test-account-123", sample_transactions) - - # Get first page - transactions_page1 = sqlite_db.get_transactions(limit=1, offset=0) - assert len(transactions_page1) == 1 - - # Get second page - transactions_page2 = sqlite_db.get_transactions(limit=1, offset=1) - assert len(transactions_page2) == 1 - - # Should be different transactions - assert ( - transactions_page1[0]["internalTransactionId"] - != transactions_page2[0]["internalTransactionId"] - ) - - def test_get_transactions_with_amount_filter( - self, mock_home_db_path, sample_transactions - ): - """Test filtering transactions by amount.""" - ctx = MockContext() - - with patch("pathlib.Path.home") as mock_home: - mock_home.return_value = mock_home_db_path.parent / ".." - - sqlite_db.persist_transactions(ctx, "test-account-123", sample_transactions) - - # Filter by minimum amount (should exclude coffee shop payment) - transactions = sqlite_db.get_transactions(min_amount=-20.0) - assert len(transactions) == 1 - assert transactions[0]["transactionValue"] == -10.50 - - def test_get_transactions_with_search(self, mock_home_db_path, sample_transactions): - """Test searching transactions by description.""" - ctx = MockContext() - - with patch("pathlib.Path.home") as mock_home: - mock_home.return_value = mock_home_db_path.parent / ".." - - sqlite_db.persist_transactions(ctx, "test-account-123", sample_transactions) - - # Search for "Coffee" - transactions = sqlite_db.get_transactions(search="Coffee") - assert len(transactions) == 1 - assert "Coffee" in transactions[0]["description"] - - def test_get_transactions_empty_database(self, mock_home_db_path): - """Test getting transactions from empty database.""" - with patch("pathlib.Path.home") as mock_home: - mock_home.return_value = mock_home_db_path.parent / ".." - - transactions = sqlite_db.get_transactions() - assert transactions == [] - - def test_get_transactions_nonexistent_database(self): - """Test getting transactions when database doesn't exist.""" - with patch("pathlib.Path.home") as mock_home: - mock_home.return_value = Path("/nonexistent") - - transactions = sqlite_db.get_transactions() - assert transactions == [] - - def test_persist_balances(self, mock_home_db_path, sample_balance): - """Test persisting balance data.""" - ctx = MockContext() - - with patch("pathlib.Path.home") as mock_home: - mock_home.return_value = mock_home_db_path.parent / ".." - - result = sqlite_db.persist_balances(ctx, sample_balance) - - # Should return the balance data - assert result["account_id"] == "test-account-123" - - def test_get_balances(self, mock_home_db_path, sample_balance): - """Test retrieving balances.""" - ctx = MockContext() - - with patch("pathlib.Path.home") as mock_home: - mock_home.return_value = mock_home_db_path.parent / ".." - - # Insert test balance - sqlite_db.persist_balances(ctx, sample_balance) - - # Get balances - balances = sqlite_db.get_balances() - - assert len(balances) == 1 - assert balances[0]["account_id"] == "test-account-123" - assert balances[0]["amount"] == 1000.00 - - def test_get_balances_filtered_by_account(self, mock_home_db_path, sample_balance): - """Test filtering balances by account ID.""" - ctx = MockContext() - - # Create balance for different account - other_balance = sample_balance.copy() - other_balance["account_id"] = "other-account" - - with patch("pathlib.Path.home") as mock_home: - mock_home.return_value = mock_home_db_path.parent / ".." - - sqlite_db.persist_balances(ctx, sample_balance) - sqlite_db.persist_balances(ctx, other_balance) - - # Filter by account - balances = sqlite_db.get_balances(account_id="test-account-123") - - assert len(balances) == 1 - assert balances[0]["account_id"] == "test-account-123" - - def test_get_account_summary(self, mock_home_db_path, sample_transactions): - """Test getting account summary from transactions.""" - ctx = MockContext() - - with patch("pathlib.Path.home") as mock_home: - mock_home.return_value = mock_home_db_path.parent / ".." - - sqlite_db.persist_transactions(ctx, "test-account-123", sample_transactions) - - summary = sqlite_db.get_account_summary("test-account-123") - - assert summary is not None - assert summary["accountId"] == "test-account-123" - assert summary["institutionId"] == "REVOLUT_REVOLT21" - assert summary["iban"] == "LT313250081177977789" - - def test_get_account_summary_nonexistent(self, mock_home_db_path): - """Test getting summary for nonexistent account.""" - with patch("pathlib.Path.home") as mock_home: - mock_home.return_value = mock_home_db_path.parent / ".." - - summary = sqlite_db.get_account_summary("nonexistent") - assert summary is None - - def test_get_transaction_count(self, mock_home_db_path, sample_transactions): - """Test getting transaction count.""" - ctx = MockContext() - - with patch("pathlib.Path.home") as mock_home: - mock_home.return_value = mock_home_db_path.parent / ".." - - sqlite_db.persist_transactions(ctx, "test-account-123", sample_transactions) - - # Get total count - count = sqlite_db.get_transaction_count() - assert count == 2 - - # Get count for specific account - count_filtered = sqlite_db.get_transaction_count( - account_id="test-account-123" - ) - assert count_filtered == 2 - - # Get count for nonexistent account - count_none = sqlite_db.get_transaction_count(account_id="nonexistent") - assert count_none == 0 - - def test_get_transaction_count_with_filters( - self, mock_home_db_path, sample_transactions - ): - """Test getting transaction count with filters.""" - ctx = MockContext() - - with patch("pathlib.Path.home") as mock_home: - mock_home.return_value = mock_home_db_path.parent / ".." - - sqlite_db.persist_transactions(ctx, "test-account-123", sample_transactions) - - # Filter by search - count = sqlite_db.get_transaction_count(search="Coffee") - assert count == 1 - - # Filter by amount - count = sqlite_db.get_transaction_count(min_amount=-20.0) - assert count == 1 - - def test_database_indexes_created(self, mock_home_db_path, sample_transactions): - """Test that database indexes are created properly.""" - ctx = MockContext() - - with patch("pathlib.Path.home") as mock_home: - mock_home.return_value = mock_home_db_path.parent / ".." - - # Persist transactions to create tables and indexes - sqlite_db.persist_transactions(ctx, "test-account-123", sample_transactions) - - # Get transactions to ensure we can query the table (indexes working) - transactions = sqlite_db.get_transactions(account_id="test-account-123") - assert len(transactions) == 2