From 9e9b1cf15f6762f6f363ebbc60279634147086ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elisi=C3=A1rio=20Couto?= Date: Tue, 9 Dec 2025 00:50:02 +0000 Subject: [PATCH] refactor(api): Update all modified files with dependency injection changes. --- leggen/api/routes/accounts.py | 70 ++++--- leggen/api/routes/sync.py | 7 +- leggen/api/routes/transactions.py | 24 ++- leggen/commands/server.py | 6 +- leggen/services/database_service.py | 20 +- leggen/services/sync_service.py | 37 ++-- tests/conftest.py | 32 +++ tests/unit/test_analytics_fix.py | 135 +++++++------ tests/unit/test_api_accounts.py | 205 ++++++++++++-------- tests/unit/test_api_transactions.py | 267 ++++++++++++++++---------- tests/unit/test_sync_notifications.py | 32 ++- 11 files changed, 505 insertions(+), 330 deletions(-) diff --git a/leggen/api/routes/accounts.py b/leggen/api/routes/accounts.py index 84dccd9..0c6b9c9 100644 --- a/leggen/api/routes/accounts.py +++ b/leggen/api/routes/accounts.py @@ -3,6 +3,12 @@ from typing import List, Optional, Union from fastapi import APIRouter, HTTPException, Query from loguru import logger +from leggen.api.dependencies import ( + AccountRepo, + AnalyticsProc, + BalanceRepo, + TransactionRepo, +) from leggen.api.models.accounts import ( AccountBalance, AccountDetails, @@ -10,28 +16,27 @@ from leggen.api.models.accounts import ( Transaction, TransactionSummary, ) -from leggen.services.database_service import DatabaseService router = APIRouter() -database_service = DatabaseService() @router.get("/accounts") -async def get_all_accounts() -> List[AccountDetails]: +async def get_all_accounts( + account_repo: AccountRepo, + balance_repo: BalanceRepo, +) -> List[AccountDetails]: """Get all connected accounts from database""" try: accounts = [] # Get all account details from database - db_accounts = await database_service.get_accounts_from_db() + db_accounts = account_repo.get_accounts() # Process accounts found in database for db_account in db_accounts: try: # Get latest balances from database for this account - balances_data = await database_service.get_balances_from_db( - db_account["id"] - ) + balances_data = balance_repo.get_balances(db_account["id"]) # Process balances balances = [] @@ -77,11 +82,15 @@ async def get_all_accounts() -> List[AccountDetails]: @router.get("/accounts/{account_id}") -async def get_account_details(account_id: str) -> AccountDetails: +async def get_account_details( + account_id: str, + account_repo: AccountRepo, + balance_repo: BalanceRepo, +) -> AccountDetails: """Get details for a specific account from database""" try: # Get account details from database - db_account = await database_service.get_account_details_from_db(account_id) + db_account = account_repo.get_account(account_id) if not db_account: raise HTTPException( @@ -89,7 +98,7 @@ async def get_account_details(account_id: str) -> AccountDetails: ) # Get latest balances from database for this account - balances_data = await database_service.get_balances_from_db(account_id) + balances_data = balance_repo.get_balances(account_id) # Process balances balances = [] @@ -129,11 +138,14 @@ async def get_account_details(account_id: str) -> AccountDetails: @router.get("/accounts/{account_id}/balances") -async def get_account_balances(account_id: str) -> List[AccountBalance]: +async def get_account_balances( + account_id: str, + balance_repo: BalanceRepo, +) -> List[AccountBalance]: """Get balances for a specific account from database""" try: # Get balances from database instead of GoCardless API - db_balances = await database_service.get_balances_from_db(account_id=account_id) + db_balances = balance_repo.get_balances(account_id=account_id) balances = [] for balance in db_balances: @@ -158,19 +170,20 @@ async def get_account_balances(account_id: str) -> List[AccountBalance]: @router.get("/balances") -async def get_all_balances() -> List[dict]: +async def get_all_balances( + account_repo: AccountRepo, + balance_repo: BalanceRepo, +) -> List[dict]: """Get all balances from all accounts in database""" try: # Get all accounts first to iterate through them - db_accounts = await database_service.get_accounts_from_db() + db_accounts = account_repo.get_accounts() all_balances = [] for db_account in db_accounts: try: # Get balances for this account - db_balances = await database_service.get_balances_from_db( - account_id=db_account["id"] - ) + db_balances = balance_repo.get_balances(account_id=db_account["id"]) # Process balances and add account info for balance in db_balances: @@ -205,6 +218,7 @@ async def get_all_balances() -> List[dict]: @router.get("/balances/history") async def get_historical_balances( + analytics_proc: AnalyticsProc, days: Optional[int] = Query( default=365, le=1095, ge=1, description="Number of days of history to retrieve" ), @@ -214,9 +228,12 @@ async def get_historical_balances( ) -> List[dict]: """Get historical balance progression calculated from transaction history""" try: + from leggen.utils.paths import path_manager + # Get historical balances from database - historical_balances = await database_service.get_historical_balances_from_db( - account_id=account_id, days=days or 365 + db_path = path_manager.get_database_path() + historical_balances = analytics_proc.calculate_historical_balances( + db_path, account_id=account_id, days=days or 365 ) return historical_balances @@ -231,6 +248,7 @@ async def get_historical_balances( @router.get("/accounts/{account_id}/transactions") async def get_account_transactions( account_id: str, + transaction_repo: TransactionRepo, limit: Optional[int] = Query(default=100, le=500), offset: Optional[int] = Query(default=0, ge=0), summary_only: bool = Query( @@ -240,10 +258,10 @@ async def get_account_transactions( """Get transactions for a specific account from database""" try: # Get transactions from database instead of GoCardless API - db_transactions = await database_service.get_transactions_from_db( + db_transactions = transaction_repo.get_transactions( account_id=account_id, limit=limit, - offset=offset, + offset=offset or 0, ) data: Union[List[TransactionSummary], List[Transaction]] @@ -294,11 +312,15 @@ async def get_account_transactions( @router.put("/accounts/{account_id}") -async def update_account_details(account_id: str, update_data: AccountUpdate) -> dict: +async def update_account_details( + account_id: str, + update_data: AccountUpdate, + account_repo: AccountRepo, +) -> dict: """Update account details (currently only display_name)""" try: # Get current account details - current_account = await database_service.get_account_details_from_db(account_id) + current_account = account_repo.get_account(account_id) if not current_account: raise HTTPException( @@ -311,7 +333,7 @@ async def update_account_details(account_id: str, update_data: AccountUpdate) -> updated_account_data["display_name"] = update_data.display_name # Persist updated account details - await database_service.persist_account_details(updated_account_data) + account_repo.persist(updated_account_data) return {"id": account_id, "display_name": update_data.display_name} diff --git a/leggen/api/routes/sync.py b/leggen/api/routes/sync.py index 4a505e4..201dd07 100644 --- a/leggen/api/routes/sync.py +++ b/leggen/api/routes/sync.py @@ -198,9 +198,10 @@ async def stop_scheduler() -> dict: async def get_sync_operations(limit: int = 50, offset: int = 0) -> dict: """Get sync operations history""" try: - operations = await sync_service.database.get_sync_operations( - limit=limit, offset=offset - ) + from leggen.repositories import SyncRepository + + sync_repo = SyncRepository() + operations = sync_repo.get_operations(limit=limit, offset=offset) return {"operations": operations, "count": len(operations)} diff --git a/leggen/api/routes/transactions.py b/leggen/api/routes/transactions.py index 0347576..d67a515 100644 --- a/leggen/api/routes/transactions.py +++ b/leggen/api/routes/transactions.py @@ -4,16 +4,16 @@ from typing import List, Optional, Union from fastapi import APIRouter, HTTPException, Query from loguru import logger +from leggen.api.dependencies import AnalyticsProc, TransactionRepo from leggen.api.models.accounts import Transaction, TransactionSummary from leggen.api.models.common import PaginatedResponse -from leggen.services.database_service import DatabaseService router = APIRouter() -database_service = DatabaseService() @router.get("/transactions") async def get_all_transactions( + transaction_repo: TransactionRepo, page: int = Query(default=1, ge=1, description="Page number (1-based)"), per_page: int = Query(default=50, le=500, description="Items per page"), summary_only: bool = Query( @@ -43,7 +43,7 @@ async def get_all_transactions( limit = per_page # Get transactions from database instead of GoCardless API - db_transactions = await database_service.get_transactions_from_db( + db_transactions = transaction_repo.get_transactions( account_id=account_id, limit=limit, offset=offset, @@ -55,7 +55,7 @@ async def get_all_transactions( ) # Get total count for pagination info (respecting the same filters) - total_transactions = await database_service.get_transaction_count_from_db( + total_transactions = transaction_repo.get_count( account_id=account_id, date_from=date_from, date_to=date_to, @@ -119,6 +119,7 @@ async def get_all_transactions( @router.get("/transactions/stats") async def get_transaction_stats( + transaction_repo: TransactionRepo, days: int = Query(default=30, description="Number of days to include in stats"), account_id: Optional[str] = Query(default=None, description="Filter by account ID"), ) -> dict: @@ -133,7 +134,7 @@ async def get_transaction_stats( date_to = end_date.isoformat() # Get transactions from database - recent_transactions = await database_service.get_transactions_from_db( + recent_transactions = transaction_repo.get_transactions( account_id=account_id, date_from=date_from, date_to=date_to, @@ -198,6 +199,7 @@ async def get_transaction_stats( @router.get("/transactions/analytics") async def get_transactions_for_analytics( + transaction_repo: TransactionRepo, days: int = Query(default=365, description="Number of days to include"), account_id: Optional[str] = Query(default=None, description="Filter by account ID"), ) -> List[dict]: @@ -212,7 +214,7 @@ async def get_transactions_for_analytics( date_to = end_date.isoformat() # Get ALL transactions from database (no limit for analytics) - transactions = await database_service.get_transactions_from_db( + transactions = transaction_repo.get_transactions( account_id=account_id, date_from=date_from, date_to=date_to, @@ -244,11 +246,14 @@ async def get_transactions_for_analytics( @router.get("/transactions/monthly-stats") async def get_monthly_transaction_stats( + analytics_proc: AnalyticsProc, days: int = Query(default=365, description="Number of days to include"), account_id: Optional[str] = Query(default=None, description="Filter by account ID"), ) -> List[dict]: """Get monthly transaction statistics aggregated by the database""" try: + from leggen.utils.paths import path_manager + # Date range for monthly stats end_date = datetime.now() start_date = end_date - timedelta(days=days) @@ -258,10 +263,9 @@ async def get_monthly_transaction_stats( date_to = end_date.isoformat() # Get monthly aggregated stats from database - monthly_stats = await database_service.get_monthly_transaction_stats_from_db( - account_id=account_id, - date_from=date_from, - date_to=date_to, + db_path = path_manager.get_database_path() + monthly_stats = analytics_proc.calculate_monthly_stats( + db_path, account_id=account_id, date_from=date_from, date_to=date_to ) return monthly_stats diff --git a/leggen/commands/server.py b/leggen/commands/server.py index 689e8a4..b27d5d4 100644 --- a/leggen/commands/server.py +++ b/leggen/commands/server.py @@ -28,10 +28,10 @@ async def lifespan(app: FastAPI): # Run database migrations try: - from leggen.services.database_service import DatabaseService + from leggen.api.dependencies import get_migration_repository - db_service = DatabaseService() - await db_service.run_migrations_if_needed() + migrations = get_migration_repository() + await migrations.run_all_migrations() logger.info("Database migrations completed") except Exception as e: logger.error(f"Database migration failed: {e}") diff --git a/leggen/services/database_service.py b/leggen/services/database_service.py index 00a3235..751f76a 100644 --- a/leggen/services/database_service.py +++ b/leggen/services/database_service.py @@ -1,3 +1,11 @@ +""" +DEPRECATED: DatabaseService is deprecated in favor of direct repository usage via dependency injection. + +This module is kept for backward compatibility with existing tests. +New code should use repositories directly via leggen.api.dependencies. +""" + +import warnings from functools import wraps from typing import Any, Dict, List, Optional @@ -38,9 +46,19 @@ def require_sqlite(func): class DatabaseService: - """Simplified database service using repository pattern""" + """ + DEPRECATED: Use repositories directly via dependency injection. + + This class is maintained for backward compatibility with existing tests. + For new code, inject repositories using leggen.api.dependencies. + """ def __init__(self): + warnings.warn( + "DatabaseService is deprecated. Use repositories via dependency injection.", + DeprecationWarning, + stacklevel=2, + ) self.db_config = config.database_config self.sqlite_enabled = self.db_config.get("sqlite", True) diff --git a/leggen/services/sync_service.py b/leggen/services/sync_service.py index f10e450..b0aaafe 100644 --- a/leggen/services/sync_service.py +++ b/leggen/services/sync_service.py @@ -4,12 +4,17 @@ from typing import List from loguru import logger from leggen.api.models.sync import SyncResult, SyncStatus +from leggen.repositories import ( + AccountRepository, + BalanceRepository, + SyncRepository, + TransactionRepository, +) from leggen.services.data_processors import ( AccountEnricher, BalanceTransformer, TransactionProcessor, ) -from leggen.services.database_service import DatabaseService from leggen.services.gocardless_service import GoCardlessService from leggen.services.notification_service import NotificationService @@ -20,9 +25,14 @@ EXPIRED_DAYS_LEFT = 0 class SyncService: def __init__(self): self.gocardless = GoCardlessService() - self.database = DatabaseService() self.notifications = NotificationService() + # Repositories + self.accounts = AccountRepository() + self.balances = BalanceRepository() + self.transactions = TransactionRepository() + self.sync = SyncRepository() + # Data processors self.account_enricher = AccountEnricher() self.balance_transformer = BalanceTransformer() @@ -104,21 +114,22 @@ class SyncService: ) # Persist enriched account details to database - await self.database.persist_account_details( - enriched_account_details - ) + self.accounts.persist(enriched_account_details) # Merge account metadata into balances for persistence balances_with_account_info = self.balance_transformer.merge_account_metadata_into_balances( balances, enriched_account_details ) - await self.database.persist_balance( - account_id, balances_with_account_info + balance_rows = ( + self.balance_transformer.transform_to_database_format( + account_id, balances_with_account_info + ) ) + self.balances.persist(account_id, balance_rows) balances_updated += len(balances.get("balances", [])) elif account_details: # Fallback: persist account details without currency if balances failed - await self.database.persist_account_details(account_details) + self.accounts.persist(account_details) # Get and save transactions transactions = await self.gocardless.get_account_transactions( @@ -130,7 +141,7 @@ class SyncService: account_id, account_details, transactions ) ) - new_transactions = await self.database.persist_transactions( + new_transactions = self.transactions.persist( account_id, processed_transactions ) transactions_added += len(new_transactions) @@ -184,9 +195,7 @@ class SyncService: # Persist sync operation to database try: - operation_id = await self.database.persist_sync_operation( - sync_operation - ) + operation_id = self.sync.persist(sync_operation) logger.debug(f"Saved sync operation with ID: {operation_id}") except Exception as e: logger.error(f"Failed to persist sync operation: {e}") @@ -235,9 +244,7 @@ class SyncService: ) try: - operation_id = await self.database.persist_sync_operation( - sync_operation - ) + operation_id = self.sync.persist(sync_operation) logger.debug(f"Saved failed sync operation with ID: {operation_id}") except Exception as persist_error: logger.error( diff --git a/tests/conftest.py b/tests/conftest.py index a4ab62f..a429712 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -126,6 +126,38 @@ def api_client(fastapi_app): return TestClient(fastapi_app) +@pytest.fixture +def mock_account_repo(): + """Create mock AccountRepository for testing.""" + from unittest.mock import MagicMock + + return MagicMock() + + +@pytest.fixture +def mock_balance_repo(): + """Create mock BalanceRepository for testing.""" + from unittest.mock import MagicMock + + return MagicMock() + + +@pytest.fixture +def mock_transaction_repo(): + """Create mock TransactionRepository for testing.""" + from unittest.mock import MagicMock + + return MagicMock() + + +@pytest.fixture +def mock_analytics_proc(): + """Create mock AnalyticsProcessor for testing.""" + from unittest.mock import MagicMock + + return MagicMock() + + @pytest.fixture def mock_db_path(temp_db_path): """Mock the database path to use temporary database for testing.""" diff --git a/tests/unit/test_analytics_fix.py b/tests/unit/test_analytics_fix.py index ad12d34..c499abc 100644 --- a/tests/unit/test_analytics_fix.py +++ b/tests/unit/test_analytics_fix.py @@ -1,13 +1,13 @@ """Tests for analytics fixes to ensure all transactions are used in statistics.""" from datetime import datetime, timedelta -from unittest.mock import AsyncMock, Mock, patch +from unittest.mock import Mock import pytest from fastapi.testclient import TestClient +from leggen.api.dependencies import get_transaction_repository from leggen.commands.server import create_app -from leggen.services.database_service import DatabaseService class TestAnalyticsFix: @@ -19,11 +19,11 @@ class TestAnalyticsFix: return TestClient(app) @pytest.fixture - def mock_database_service(self): - return Mock(spec=DatabaseService) + def mock_transaction_repo(self): + return Mock() @pytest.mark.asyncio - async def test_transaction_stats_uses_all_transactions(self, mock_database_service): + async def test_transaction_stats_uses_all_transactions(self, mock_transaction_repo): """Test that transaction stats endpoint uses all transactions (not limited to 100)""" # Mock data for 600 transactions (simulating the issue) mock_transactions = [] @@ -42,53 +42,50 @@ class TestAnalyticsFix: } ) - mock_database_service.get_transactions_from_db = AsyncMock( - return_value=mock_transactions + mock_transaction_repo.get_transactions.return_value = mock_transactions + + app = create_app() + app.dependency_overrides[get_transaction_repository] = ( + lambda: mock_transaction_repo + ) + client = TestClient(app) + + response = client.get("/api/v1/transactions/stats?days=365") + + assert response.status_code == 200 + data = response.json() + + # Verify that limit=None was passed to get all transactions + mock_transaction_repo.get_transactions.assert_called_once() + call_args = mock_transaction_repo.get_transactions.call_args + assert call_args.kwargs.get("limit") is None, ( + "Stats endpoint should pass limit=None to get all transactions" ) - # Test that the endpoint calls get_transactions_from_db with limit=None - with patch( - "leggen.api.routes.transactions.database_service", mock_database_service - ): - app = create_app() - client = TestClient(app) + # Verify that the response contains stats for all 600 transactions + stats = data + assert stats["total_transactions"] == 600, ( + "Should process all 600 transactions, not just 100" + ) - response = client.get("/api/v1/transactions/stats?days=365") + # Verify calculations are correct for all transactions + expected_income = sum( + txn["transactionValue"] + for txn in mock_transactions + if txn["transactionValue"] > 0 + ) + expected_expenses = sum( + abs(txn["transactionValue"]) + for txn in mock_transactions + if txn["transactionValue"] < 0 + ) - assert response.status_code == 200 - data = response.json() - - # Verify that limit=None was passed to get all transactions - mock_database_service.get_transactions_from_db.assert_called_once() - call_args = mock_database_service.get_transactions_from_db.call_args - assert call_args.kwargs.get("limit") is None, ( - "Stats endpoint should pass limit=None to get all transactions" - ) - - # Verify that the response contains stats for all 600 transactions - stats = data - assert stats["total_transactions"] == 600, ( - "Should process all 600 transactions, not just 100" - ) - - # Verify calculations are correct for all transactions - expected_income = sum( - txn["transactionValue"] - for txn in mock_transactions - if txn["transactionValue"] > 0 - ) - expected_expenses = sum( - abs(txn["transactionValue"]) - for txn in mock_transactions - if txn["transactionValue"] < 0 - ) - - assert stats["total_income"] == expected_income - assert stats["total_expenses"] == expected_expenses + assert stats["total_income"] == expected_income + assert stats["total_expenses"] == expected_expenses @pytest.mark.asyncio async def test_analytics_endpoint_returns_all_transactions( - self, mock_database_service + self, mock_transaction_repo ): """Test that the new analytics endpoint returns all transactions without pagination""" # Mock data for 600 transactions @@ -108,30 +105,28 @@ class TestAnalyticsFix: } ) - mock_database_service.get_transactions_from_db = AsyncMock( - return_value=mock_transactions + mock_transaction_repo.get_transactions.return_value = mock_transactions + + app = create_app() + app.dependency_overrides[get_transaction_repository] = ( + lambda: mock_transaction_repo + ) + client = TestClient(app) + + response = client.get("/api/v1/transactions/analytics?days=365") + + assert response.status_code == 200 + data = response.json() + + # Verify that limit=None was passed to get all transactions + mock_transaction_repo.get_transactions.assert_called_once() + call_args = mock_transaction_repo.get_transactions.call_args + assert call_args.kwargs.get("limit") is None, ( + "Analytics endpoint should pass limit=None" ) - with patch( - "leggen.api.routes.transactions.database_service", mock_database_service - ): - app = create_app() - client = TestClient(app) - - response = client.get("/api/v1/transactions/analytics?days=365") - - assert response.status_code == 200 - data = response.json() - - # Verify that limit=None was passed to get all transactions - mock_database_service.get_transactions_from_db.assert_called_once() - call_args = mock_database_service.get_transactions_from_db.call_args - assert call_args.kwargs.get("limit") is None, ( - "Analytics endpoint should pass limit=None" - ) - - # Verify that all 600 transactions are returned - transactions_data = data - assert len(transactions_data) == 600, ( - "Analytics endpoint should return all 600 transactions" - ) + # Verify that all 600 transactions are returned + transactions_data = data + assert len(transactions_data) == 600, ( + "Analytics endpoint should return all 600 transactions" + ) diff --git a/tests/unit/test_api_accounts.py b/tests/unit/test_api_accounts.py index 5afc96e..620ec3c 100644 --- a/tests/unit/test_api_accounts.py +++ b/tests/unit/test_api_accounts.py @@ -4,6 +4,12 @@ from unittest.mock import patch import pytest +from leggen.api.dependencies import ( + get_account_repository, + get_balance_repository, + get_transaction_repository, +) + @pytest.mark.api class TestAccountsAPI: @@ -11,11 +17,14 @@ class TestAccountsAPI: def test_get_all_accounts_success( self, + fastapi_app, api_client, mock_config, mock_auth_token, sample_account_data, mock_db_path, + mock_account_repo, + mock_balance_repo, ): """Test successful retrieval of all accounts from database.""" mock_accounts = [ @@ -45,19 +54,21 @@ class TestAccountsAPI: } ] - with ( - patch("leggen.utils.config.config", mock_config), - patch( - "leggen.api.routes.accounts.database_service.get_accounts_from_db", - return_value=mock_accounts, - ), - patch( - "leggen.api.routes.accounts.database_service.get_balances_from_db", - return_value=mock_balances, - ), - ): + mock_account_repo.get_accounts.return_value = mock_accounts + mock_balance_repo.get_balances.return_value = mock_balances + + fastapi_app.dependency_overrides[get_account_repository] = ( + lambda: mock_account_repo + ) + fastapi_app.dependency_overrides[get_balance_repository] = ( + lambda: mock_balance_repo + ) + + with patch("leggen.utils.config.config", mock_config): response = api_client.get("/api/v1/accounts") + fastapi_app.dependency_overrides.clear() + assert response.status_code == 200 data = response.json() assert len(data) == 1 @@ -69,11 +80,14 @@ class TestAccountsAPI: def test_get_account_details_success( self, + fastapi_app, api_client, mock_config, mock_auth_token, sample_account_data, mock_db_path, + mock_account_repo, + mock_balance_repo, ): """Test successful retrieval of specific account details from database.""" mock_account = { @@ -101,19 +115,21 @@ class TestAccountsAPI: } ] - with ( - patch("leggen.utils.config.config", mock_config), - patch( - "leggen.api.routes.accounts.database_service.get_account_details_from_db", - return_value=mock_account, - ), - patch( - "leggen.api.routes.accounts.database_service.get_balances_from_db", - return_value=mock_balances, - ), - ): + mock_account_repo.get_account.return_value = mock_account + mock_balance_repo.get_balances.return_value = mock_balances + + fastapi_app.dependency_overrides[get_account_repository] = ( + lambda: mock_account_repo + ) + fastapi_app.dependency_overrides[get_balance_repository] = ( + lambda: mock_balance_repo + ) + + with patch("leggen.utils.config.config", mock_config): response = api_client.get("/api/v1/accounts/test-account-123") + fastapi_app.dependency_overrides.clear() + assert response.status_code == 200 data = response.json() assert data["id"] == "test-account-123" @@ -121,7 +137,13 @@ class TestAccountsAPI: assert len(data["balances"]) == 1 def test_get_account_balances_success( - self, api_client, mock_config, mock_auth_token, mock_db_path + self, + fastapi_app, + api_client, + mock_config, + mock_auth_token, + mock_db_path, + mock_balance_repo, ): """Test successful retrieval of account balances from database.""" mock_balances = [ @@ -149,15 +171,17 @@ class TestAccountsAPI: }, ] - with ( - patch("leggen.utils.config.config", mock_config), - patch( - "leggen.api.routes.accounts.database_service.get_balances_from_db", - return_value=mock_balances, - ), - ): + mock_balance_repo.get_balances.return_value = mock_balances + + fastapi_app.dependency_overrides[get_balance_repository] = ( + lambda: mock_balance_repo + ) + + with patch("leggen.utils.config.config", mock_config): response = api_client.get("/api/v1/accounts/test-account-123/balances") + fastapi_app.dependency_overrides.clear() + assert response.status_code == 200 data = response.json() assert len(data) == 2 @@ -167,12 +191,14 @@ class TestAccountsAPI: def test_get_account_transactions_success( self, + fastapi_app, api_client, mock_config, mock_auth_token, sample_account_data, sample_transaction_data, mock_db_path, + mock_transaction_repo, ): """Test successful retrieval of account transactions from database.""" mock_transactions = [ @@ -191,21 +217,19 @@ class TestAccountsAPI: } ] - with ( - patch("leggen.utils.config.config", mock_config), - patch( - "leggen.api.routes.accounts.database_service.get_transactions_from_db", - return_value=mock_transactions, - ), - patch( - "leggen.api.routes.accounts.database_service.get_transaction_count_from_db", - return_value=1, - ), - ): + mock_transaction_repo.get_transactions.return_value = mock_transactions + + fastapi_app.dependency_overrides[get_transaction_repository] = ( + lambda: mock_transaction_repo + ) + + with patch("leggen.utils.config.config", mock_config): response = api_client.get( "/api/v1/accounts/test-account-123/transactions?summary_only=true" ) + fastapi_app.dependency_overrides.clear() + assert response.status_code == 200 data = response.json() assert len(data) == 1 @@ -218,12 +242,14 @@ class TestAccountsAPI: def test_get_account_transactions_full_details( self, + fastapi_app, api_client, mock_config, mock_auth_token, sample_account_data, sample_transaction_data, mock_db_path, + mock_transaction_repo, ): """Test retrieval of full transaction details from database.""" mock_transactions = [ @@ -242,21 +268,19 @@ class TestAccountsAPI: } ] - with ( - patch("leggen.utils.config.config", mock_config), - patch( - "leggen.api.routes.accounts.database_service.get_transactions_from_db", - return_value=mock_transactions, - ), - patch( - "leggen.api.routes.accounts.database_service.get_transaction_count_from_db", - return_value=1, - ), - ): + mock_transaction_repo.get_transactions.return_value = mock_transactions + + fastapi_app.dependency_overrides[get_transaction_repository] = ( + lambda: mock_transaction_repo + ) + + with patch("leggen.utils.config.config", mock_config): response = api_client.get( "/api/v1/accounts/test-account-123/transactions?summary_only=false" ) + fastapi_app.dependency_overrides.clear() + assert response.status_code == 200 data = response.json() assert len(data) == 1 @@ -268,22 +292,36 @@ class TestAccountsAPI: assert "raw_transaction" in transaction def test_get_account_not_found( - self, api_client, mock_config, mock_auth_token, mock_db_path + self, + fastapi_app, + api_client, + mock_config, + mock_auth_token, + mock_db_path, + mock_account_repo, ): """Test handling of non-existent account.""" - with ( - patch("leggen.utils.config.config", mock_config), - patch( - "leggen.api.routes.accounts.database_service.get_account_details_from_db", - return_value=None, - ), - ): + mock_account_repo.get_account.return_value = None + + fastapi_app.dependency_overrides[get_account_repository] = ( + lambda: mock_account_repo + ) + + with patch("leggen.utils.config.config", mock_config): response = api_client.get("/api/v1/accounts/nonexistent") + fastapi_app.dependency_overrides.clear() + assert response.status_code == 404 def test_update_account_display_name_success( - self, api_client, mock_config, mock_auth_token, mock_db_path + self, + fastapi_app, + api_client, + mock_config, + mock_auth_token, + mock_db_path, + mock_account_repo, ): """Test successful update of account display name.""" mock_account = { @@ -297,41 +335,48 @@ class TestAccountsAPI: "last_accessed": "2025-09-01T09:30:00Z", } - with ( - patch("leggen.utils.config.config", mock_config), - patch( - "leggen.api.routes.accounts.database_service.get_account_details_from_db", - return_value=mock_account, - ), - patch( - "leggen.api.routes.accounts.database_service.persist_account_details", - return_value=None, - ), - ): + mock_account_repo.get_account.return_value = mock_account + mock_account_repo.persist.return_value = mock_account + + fastapi_app.dependency_overrides[get_account_repository] = ( + lambda: mock_account_repo + ) + + with patch("leggen.utils.config.config", mock_config): response = api_client.put( "/api/v1/accounts/test-account-123", json={"display_name": "My Custom Account Name"}, ) + fastapi_app.dependency_overrides.clear() + assert response.status_code == 200 data = response.json() assert data["id"] == "test-account-123" assert data["display_name"] == "My Custom Account Name" def test_update_account_not_found( - self, api_client, mock_config, mock_auth_token, mock_db_path + self, + fastapi_app, + api_client, + mock_config, + mock_auth_token, + mock_db_path, + mock_account_repo, ): """Test updating non-existent account.""" - with ( - patch("leggen.utils.config.config", mock_config), - patch( - "leggen.api.routes.accounts.database_service.get_account_details_from_db", - return_value=None, - ), - ): + mock_account_repo.get_account.return_value = None + + fastapi_app.dependency_overrides[get_account_repository] = ( + lambda: mock_account_repo + ) + + with patch("leggen.utils.config.config", mock_config): response = api_client.put( "/api/v1/accounts/nonexistent", json={"display_name": "New Name"}, ) + fastapi_app.dependency_overrides.clear() + assert response.status_code == 404 diff --git a/tests/unit/test_api_transactions.py b/tests/unit/test_api_transactions.py index f6e0def..b5fc004 100644 --- a/tests/unit/test_api_transactions.py +++ b/tests/unit/test_api_transactions.py @@ -5,13 +5,20 @@ from unittest.mock import patch import pytest +from leggen.api.dependencies import get_transaction_repository + @pytest.mark.api class TestTransactionsAPI: """Test transaction-related API endpoints.""" def test_get_all_transactions_success( - self, api_client, mock_config, mock_auth_token + self, + fastapi_app, + api_client, + mock_config, + mock_auth_token, + mock_transaction_repo, ): """Test successful retrieval of all transactions from database.""" mock_transactions = [ @@ -43,19 +50,17 @@ class TestTransactionsAPI: }, ] - with ( - patch("leggen.utils.config.config", mock_config), - patch( - "leggen.api.routes.transactions.database_service.get_transactions_from_db", - return_value=mock_transactions, - ), - patch( - "leggen.api.routes.transactions.database_service.get_transaction_count_from_db", - return_value=2, - ), - ): + mock_transaction_repo.get_transactions.return_value = mock_transactions + mock_transaction_repo.get_count.return_value = len(mock_transactions) + fastapi_app.dependency_overrides[get_transaction_repository] = ( + lambda: mock_transaction_repo + ) + + with patch("leggen.utils.config.config", mock_config): response = api_client.get("/api/v1/transactions?summary_only=true") + fastapi_app.dependency_overrides.clear() + assert response.status_code == 200 data = response.json() assert len(data["data"]) == 2 @@ -70,7 +75,12 @@ class TestTransactionsAPI: assert transaction["account_id"] == "test-account-123" def test_get_all_transactions_full_details( - self, api_client, mock_config, mock_auth_token + self, + fastapi_app, + api_client, + mock_config, + mock_auth_token, + mock_transaction_repo, ): """Test retrieval of full transaction details from database.""" mock_transactions = [ @@ -89,19 +99,17 @@ class TestTransactionsAPI: } ] - with ( - patch("leggen.utils.config.config", mock_config), - patch( - "leggen.api.routes.transactions.database_service.get_transactions_from_db", - return_value=mock_transactions, - ), - patch( - "leggen.api.routes.transactions.database_service.get_transaction_count_from_db", - return_value=1, - ), - ): + mock_transaction_repo.get_transactions.return_value = mock_transactions + mock_transaction_repo.get_count.return_value = len(mock_transactions) + fastapi_app.dependency_overrides[get_transaction_repository] = ( + lambda: mock_transaction_repo + ) + + with patch("leggen.utils.config.config", mock_config): response = api_client.get("/api/v1/transactions?summary_only=false") + fastapi_app.dependency_overrides.clear() + assert response.status_code == 200 data = response.json() assert len(data["data"]) == 1 @@ -114,7 +122,12 @@ class TestTransactionsAPI: assert "raw_transaction" in transaction def test_get_transactions_with_filters( - self, api_client, mock_config, mock_auth_token + self, + fastapi_app, + api_client, + mock_config, + mock_auth_token, + mock_transaction_repo, ): """Test getting transactions with various filters.""" mock_transactions = [ @@ -133,17 +146,14 @@ class TestTransactionsAPI: } ] - with ( - patch("leggen.utils.config.config", mock_config), - patch( - "leggen.api.routes.transactions.database_service.get_transactions_from_db", - return_value=mock_transactions, - ) as mock_get_transactions, - patch( - "leggen.api.routes.transactions.database_service.get_transaction_count_from_db", - return_value=1, - ), - ): + mock_transaction_repo.get_transactions.return_value = mock_transactions + mock_transaction_repo.get_count.return_value = 1 + + fastapi_app.dependency_overrides[get_transaction_repository] = ( + lambda: mock_transaction_repo + ) + + with patch("leggen.utils.config.config", mock_config): response = api_client.get( "/api/v1/transactions?" "account_id=test-account-123&" @@ -156,10 +166,12 @@ class TestTransactionsAPI: "per_page=10" ) + fastapi_app.dependency_overrides.clear() + assert response.status_code == 200 - # Verify the database service was called with correct filters - mock_get_transactions.assert_called_once_with( + # Verify the repository was called with correct filters + mock_transaction_repo.get_transactions.assert_called_once_with( account_id="test-account-123", limit=10, offset=10, # (page-1) * per_page = (2-1) * 10 = 10 @@ -171,22 +183,26 @@ class TestTransactionsAPI: ) def test_get_transactions_empty_result( - self, api_client, mock_config, mock_auth_token + self, + fastapi_app, + api_client, + mock_config, + mock_auth_token, + mock_transaction_repo, ): """Test getting transactions when database returns empty result.""" - with ( - patch("leggen.utils.config.config", mock_config), - patch( - "leggen.api.routes.transactions.database_service.get_transactions_from_db", - return_value=[], - ), - patch( - "leggen.api.routes.transactions.database_service.get_transaction_count_from_db", - return_value=0, - ), - ): + mock_transaction_repo.get_transactions.return_value = [] + mock_transaction_repo.get_count.return_value = 0 + + fastapi_app.dependency_overrides[get_transaction_repository] = ( + lambda: mock_transaction_repo + ) + + with patch("leggen.utils.config.config", mock_config): response = api_client.get("/api/v1/transactions") + fastapi_app.dependency_overrides.clear() + assert response.status_code == 200 data = response.json() assert len(data["data"]) == 0 @@ -195,23 +211,37 @@ class TestTransactionsAPI: assert data["total_pages"] == 0 def test_get_transactions_database_error( - self, api_client, mock_config, mock_auth_token + self, + fastapi_app, + api_client, + mock_config, + mock_auth_token, + mock_transaction_repo, ): """Test handling database error when getting transactions.""" - with ( - patch("leggen.utils.config.config", mock_config), - patch( - "leggen.api.routes.transactions.database_service.get_transactions_from_db", - side_effect=Exception("Database connection failed"), - ), - ): + mock_transaction_repo.get_transactions.side_effect = Exception( + "Database connection failed" + ) + + fastapi_app.dependency_overrides[get_transaction_repository] = ( + lambda: mock_transaction_repo + ) + + with patch("leggen.utils.config.config", mock_config): response = api_client.get("/api/v1/transactions") + fastapi_app.dependency_overrides.clear() + assert response.status_code == 500 assert "Failed to get transactions" in response.json()["detail"] def test_get_transaction_stats_success( - self, api_client, mock_config, mock_auth_token + self, + fastapi_app, + api_client, + mock_config, + mock_auth_token, + mock_transaction_repo, ): """Test successful retrieval of transaction statistics from database.""" mock_transactions = [ @@ -238,15 +268,16 @@ class TestTransactionsAPI: }, ] - with ( - patch("leggen.utils.config.config", mock_config), - patch( - "leggen.api.routes.transactions.database_service.get_transactions_from_db", - return_value=mock_transactions, - ), - ): + mock_transaction_repo.get_transactions.return_value = mock_transactions + fastapi_app.dependency_overrides[get_transaction_repository] = ( + lambda: mock_transaction_repo + ) + + with patch("leggen.utils.config.config", mock_config): response = api_client.get("/api/v1/transactions/stats?days=30") + fastapi_app.dependency_overrides.clear() + assert response.status_code == 200 data = response.json() @@ -264,7 +295,12 @@ class TestTransactionsAPI: assert data["average_transaction"] == expected_avg def test_get_transaction_stats_with_account_filter( - self, api_client, mock_config, mock_auth_token + self, + fastapi_app, + api_client, + mock_config, + mock_auth_token, + mock_transaction_repo, ): """Test getting transaction stats filtered by account.""" mock_transactions = [ @@ -277,37 +313,46 @@ class TestTransactionsAPI: } ] - with ( - patch("leggen.utils.config.config", mock_config), - patch( - "leggen.api.routes.transactions.database_service.get_transactions_from_db", - return_value=mock_transactions, - ) as mock_get_transactions, - ): + mock_transaction_repo.get_transactions.return_value = mock_transactions + + fastapi_app.dependency_overrides[get_transaction_repository] = ( + lambda: mock_transaction_repo + ) + + with patch("leggen.utils.config.config", mock_config): response = api_client.get( "/api/v1/transactions/stats?account_id=test-account-123" ) + fastapi_app.dependency_overrides.clear() + assert response.status_code == 200 - # Verify the database service was called with account filter - mock_get_transactions.assert_called_once() - call_kwargs = mock_get_transactions.call_args.kwargs + # Verify the repository was called with account filter + mock_transaction_repo.get_transactions.assert_called_once() + call_kwargs = mock_transaction_repo.get_transactions.call_args.kwargs assert call_kwargs["account_id"] == "test-account-123" def test_get_transaction_stats_empty_result( - self, api_client, mock_config, mock_auth_token + self, + fastapi_app, + api_client, + mock_config, + mock_auth_token, + mock_transaction_repo, ): """Test getting stats when no transactions match criteria.""" - with ( - patch("leggen.utils.config.config", mock_config), - patch( - "leggen.api.routes.transactions.database_service.get_transactions_from_db", - return_value=[], - ), - ): + mock_transaction_repo.get_transactions.return_value = [] + + fastapi_app.dependency_overrides[get_transaction_repository] = ( + lambda: mock_transaction_repo + ) + + with patch("leggen.utils.config.config", mock_config): response = api_client.get("/api/v1/transactions/stats") + fastapi_app.dependency_overrides.clear() + assert response.status_code == 200 data = response.json() @@ -319,23 +364,37 @@ class TestTransactionsAPI: assert data["accounts_included"] == 0 def test_get_transaction_stats_database_error( - self, api_client, mock_config, mock_auth_token + self, + fastapi_app, + api_client, + mock_config, + mock_auth_token, + mock_transaction_repo, ): """Test handling database error when getting stats.""" - with ( - patch("leggen.utils.config.config", mock_config), - patch( - "leggen.api.routes.transactions.database_service.get_transactions_from_db", - side_effect=Exception("Database connection failed"), - ), - ): + mock_transaction_repo.get_transactions.side_effect = Exception( + "Database connection failed" + ) + + fastapi_app.dependency_overrides[get_transaction_repository] = ( + lambda: mock_transaction_repo + ) + + with patch("leggen.utils.config.config", mock_config): response = api_client.get("/api/v1/transactions/stats") + fastapi_app.dependency_overrides.clear() + assert response.status_code == 500 assert "Failed to get transaction stats" in response.json()["detail"] def test_get_transaction_stats_custom_period( - self, api_client, mock_config, mock_auth_token + self, + fastapi_app, + api_client, + mock_config, + mock_auth_token, + mock_transaction_repo, ): """Test getting transaction stats for custom time period.""" mock_transactions = [ @@ -348,21 +407,23 @@ class TestTransactionsAPI: } ] - with ( - patch("leggen.utils.config.config", mock_config), - patch( - "leggen.api.routes.transactions.database_service.get_transactions_from_db", - return_value=mock_transactions, - ) as mock_get_transactions, - ): + mock_transaction_repo.get_transactions.return_value = mock_transactions + + fastapi_app.dependency_overrides[get_transaction_repository] = ( + lambda: mock_transaction_repo + ) + + with patch("leggen.utils.config.config", mock_config): response = api_client.get("/api/v1/transactions/stats?days=7") + fastapi_app.dependency_overrides.clear() + assert response.status_code == 200 data = response.json() assert data["period_days"] == 7 # Verify the date range was calculated correctly for 7 days - mock_get_transactions.assert_called_once() - call_kwargs = mock_get_transactions.call_args.kwargs + mock_transaction_repo.get_transactions.assert_called_once() + call_kwargs = mock_transaction_repo.get_transactions.call_args.kwargs assert "date_from" in call_kwargs assert "date_to" in call_kwargs diff --git a/tests/unit/test_sync_notifications.py b/tests/unit/test_sync_notifications.py index eedf730..c5f728e 100644 --- a/tests/unit/test_sync_notifications.py +++ b/tests/unit/test_sync_notifications.py @@ -27,9 +27,7 @@ class TestSyncNotifications: patch.object( sync_service.notifications, "send_sync_failure_notification" ) as mock_send_notification, - patch.object( - sync_service.database, "persist_sync_operation", return_value=1 - ), + patch.object(sync_service.sync, "persist", return_value=1), ): # Setup: One requisition with one account that will fail mock_get_requisitions.return_value = { @@ -69,9 +67,7 @@ class TestSyncNotifications: patch.object( sync_service.notifications, "send_expiry_notification" ) as mock_send_expiry, - patch.object( - sync_service.database, "persist_sync_operation", return_value=1 - ), + patch.object(sync_service.sync, "persist", return_value=1), ): # Setup: One expired requisition mock_get_requisitions.return_value = { @@ -112,9 +108,7 @@ class TestSyncNotifications: patch.object( sync_service.notifications, "send_sync_failure_notification" ) as mock_send_notification, - patch.object( - sync_service.database, "persist_sync_operation", return_value=1 - ), + patch.object(sync_service.sync, "persist", return_value=1), ): # Setup: One requisition with two accounts that will fail mock_get_requisitions.return_value = { @@ -160,17 +154,15 @@ class TestSyncNotifications: sync_service.notifications, "send_sync_failure_notification" ) as mock_send_notification, patch.object(sync_service.notifications, "send_transaction_notifications"), - patch.object(sync_service.database, "persist_account_details"), - patch.object(sync_service.database, "persist_balance"), + patch.object(sync_service.accounts, "persist"), + patch.object(sync_service.balances, "persist"), patch.object( - sync_service.database, "process_transactions", return_value=[] - ), - patch.object( - sync_service.database, "persist_transactions", return_value=[] - ), - patch.object( - sync_service.database, "persist_sync_operation", return_value=1 + sync_service.transaction_processor, + "process_transactions", + return_value=[], ), + patch.object(sync_service.transactions, "persist", return_value=[]), + patch.object(sync_service.sync, "persist", return_value=1), ): # Setup: One requisition with one account that succeeds mock_get_requisitions.return_value = { @@ -222,9 +214,7 @@ class TestSyncNotifications: patch.object( sync_service.notifications, "_send_telegram_sync_failure" ) as mock_telegram_notification, - patch.object( - sync_service.database, "persist_sync_operation", return_value=1 - ), + patch.object(sync_service.sync, "persist", return_value=1), ): # Setup: One requisition with one account that will fail mock_get_requisitions.return_value = {