refactor(api): Update all modified files with dependency injection changes.

This commit is contained in:
Elisiário Couto
2025-12-09 00:50:02 +00:00
committed by Elisiário Couto
parent 9dc6357905
commit 9e9b1cf15f
11 changed files with 505 additions and 330 deletions

View File

@@ -3,6 +3,12 @@ from typing import List, Optional, Union
from fastapi import APIRouter, HTTPException, Query from fastapi import APIRouter, HTTPException, Query
from loguru import logger from loguru import logger
from leggen.api.dependencies import (
AccountRepo,
AnalyticsProc,
BalanceRepo,
TransactionRepo,
)
from leggen.api.models.accounts import ( from leggen.api.models.accounts import (
AccountBalance, AccountBalance,
AccountDetails, AccountDetails,
@@ -10,28 +16,27 @@ from leggen.api.models.accounts import (
Transaction, Transaction,
TransactionSummary, TransactionSummary,
) )
from leggen.services.database_service import DatabaseService
router = APIRouter() router = APIRouter()
database_service = DatabaseService()
@router.get("/accounts") @router.get("/accounts")
async def get_all_accounts() -> List[AccountDetails]: async def get_all_accounts(
account_repo: AccountRepo,
balance_repo: BalanceRepo,
) -> List[AccountDetails]:
"""Get all connected accounts from database""" """Get all connected accounts from database"""
try: try:
accounts = [] accounts = []
# Get all account details from database # Get all account details from database
db_accounts = await database_service.get_accounts_from_db() db_accounts = account_repo.get_accounts()
# Process accounts found in database # Process accounts found in database
for db_account in db_accounts: for db_account in db_accounts:
try: try:
# Get latest balances from database for this account # Get latest balances from database for this account
balances_data = await database_service.get_balances_from_db( balances_data = balance_repo.get_balances(db_account["id"])
db_account["id"]
)
# Process balances # Process balances
balances = [] balances = []
@@ -77,11 +82,15 @@ async def get_all_accounts() -> List[AccountDetails]:
@router.get("/accounts/{account_id}") @router.get("/accounts/{account_id}")
async def get_account_details(account_id: str) -> AccountDetails: async def get_account_details(
account_id: str,
account_repo: AccountRepo,
balance_repo: BalanceRepo,
) -> AccountDetails:
"""Get details for a specific account from database""" """Get details for a specific account from database"""
try: try:
# Get account details from database # Get account details from database
db_account = await database_service.get_account_details_from_db(account_id) db_account = account_repo.get_account(account_id)
if not db_account: if not db_account:
raise HTTPException( raise HTTPException(
@@ -89,7 +98,7 @@ async def get_account_details(account_id: str) -> AccountDetails:
) )
# Get latest balances from database for this account # Get latest balances from database for this account
balances_data = await database_service.get_balances_from_db(account_id) balances_data = balance_repo.get_balances(account_id)
# Process balances # Process balances
balances = [] balances = []
@@ -129,11 +138,14 @@ async def get_account_details(account_id: str) -> AccountDetails:
@router.get("/accounts/{account_id}/balances") @router.get("/accounts/{account_id}/balances")
async def get_account_balances(account_id: str) -> List[AccountBalance]: async def get_account_balances(
account_id: str,
balance_repo: BalanceRepo,
) -> List[AccountBalance]:
"""Get balances for a specific account from database""" """Get balances for a specific account from database"""
try: try:
# Get balances from database instead of GoCardless API # Get balances from database instead of GoCardless API
db_balances = await database_service.get_balances_from_db(account_id=account_id) db_balances = balance_repo.get_balances(account_id=account_id)
balances = [] balances = []
for balance in db_balances: for balance in db_balances:
@@ -158,19 +170,20 @@ async def get_account_balances(account_id: str) -> List[AccountBalance]:
@router.get("/balances") @router.get("/balances")
async def get_all_balances() -> List[dict]: async def get_all_balances(
account_repo: AccountRepo,
balance_repo: BalanceRepo,
) -> List[dict]:
"""Get all balances from all accounts in database""" """Get all balances from all accounts in database"""
try: try:
# Get all accounts first to iterate through them # Get all accounts first to iterate through them
db_accounts = await database_service.get_accounts_from_db() db_accounts = account_repo.get_accounts()
all_balances = [] all_balances = []
for db_account in db_accounts: for db_account in db_accounts:
try: try:
# Get balances for this account # Get balances for this account
db_balances = await database_service.get_balances_from_db( db_balances = balance_repo.get_balances(account_id=db_account["id"])
account_id=db_account["id"]
)
# Process balances and add account info # Process balances and add account info
for balance in db_balances: for balance in db_balances:
@@ -205,6 +218,7 @@ async def get_all_balances() -> List[dict]:
@router.get("/balances/history") @router.get("/balances/history")
async def get_historical_balances( async def get_historical_balances(
analytics_proc: AnalyticsProc,
days: Optional[int] = Query( days: Optional[int] = Query(
default=365, le=1095, ge=1, description="Number of days of history to retrieve" default=365, le=1095, ge=1, description="Number of days of history to retrieve"
), ),
@@ -214,9 +228,12 @@ async def get_historical_balances(
) -> List[dict]: ) -> List[dict]:
"""Get historical balance progression calculated from transaction history""" """Get historical balance progression calculated from transaction history"""
try: try:
from leggen.utils.paths import path_manager
# Get historical balances from database # Get historical balances from database
historical_balances = await database_service.get_historical_balances_from_db( db_path = path_manager.get_database_path()
account_id=account_id, days=days or 365 historical_balances = analytics_proc.calculate_historical_balances(
db_path, account_id=account_id, days=days or 365
) )
return historical_balances return historical_balances
@@ -231,6 +248,7 @@ async def get_historical_balances(
@router.get("/accounts/{account_id}/transactions") @router.get("/accounts/{account_id}/transactions")
async def get_account_transactions( async def get_account_transactions(
account_id: str, account_id: str,
transaction_repo: TransactionRepo,
limit: Optional[int] = Query(default=100, le=500), limit: Optional[int] = Query(default=100, le=500),
offset: Optional[int] = Query(default=0, ge=0), offset: Optional[int] = Query(default=0, ge=0),
summary_only: bool = Query( summary_only: bool = Query(
@@ -240,10 +258,10 @@ async def get_account_transactions(
"""Get transactions for a specific account from database""" """Get transactions for a specific account from database"""
try: try:
# Get transactions from database instead of GoCardless API # Get transactions from database instead of GoCardless API
db_transactions = await database_service.get_transactions_from_db( db_transactions = transaction_repo.get_transactions(
account_id=account_id, account_id=account_id,
limit=limit, limit=limit,
offset=offset, offset=offset or 0,
) )
data: Union[List[TransactionSummary], List[Transaction]] data: Union[List[TransactionSummary], List[Transaction]]
@@ -294,11 +312,15 @@ async def get_account_transactions(
@router.put("/accounts/{account_id}") @router.put("/accounts/{account_id}")
async def update_account_details(account_id: str, update_data: AccountUpdate) -> dict: async def update_account_details(
account_id: str,
update_data: AccountUpdate,
account_repo: AccountRepo,
) -> dict:
"""Update account details (currently only display_name)""" """Update account details (currently only display_name)"""
try: try:
# Get current account details # Get current account details
current_account = await database_service.get_account_details_from_db(account_id) current_account = account_repo.get_account(account_id)
if not current_account: if not current_account:
raise HTTPException( raise HTTPException(
@@ -311,7 +333,7 @@ async def update_account_details(account_id: str, update_data: AccountUpdate) ->
updated_account_data["display_name"] = update_data.display_name updated_account_data["display_name"] = update_data.display_name
# Persist updated account details # Persist updated account details
await database_service.persist_account_details(updated_account_data) account_repo.persist(updated_account_data)
return {"id": account_id, "display_name": update_data.display_name} return {"id": account_id, "display_name": update_data.display_name}

View File

@@ -198,9 +198,10 @@ async def stop_scheduler() -> dict:
async def get_sync_operations(limit: int = 50, offset: int = 0) -> dict: async def get_sync_operations(limit: int = 50, offset: int = 0) -> dict:
"""Get sync operations history""" """Get sync operations history"""
try: try:
operations = await sync_service.database.get_sync_operations( from leggen.repositories import SyncRepository
limit=limit, offset=offset
) sync_repo = SyncRepository()
operations = sync_repo.get_operations(limit=limit, offset=offset)
return {"operations": operations, "count": len(operations)} return {"operations": operations, "count": len(operations)}

View File

@@ -4,16 +4,16 @@ from typing import List, Optional, Union
from fastapi import APIRouter, HTTPException, Query from fastapi import APIRouter, HTTPException, Query
from loguru import logger from loguru import logger
from leggen.api.dependencies import AnalyticsProc, TransactionRepo
from leggen.api.models.accounts import Transaction, TransactionSummary from leggen.api.models.accounts import Transaction, TransactionSummary
from leggen.api.models.common import PaginatedResponse from leggen.api.models.common import PaginatedResponse
from leggen.services.database_service import DatabaseService
router = APIRouter() router = APIRouter()
database_service = DatabaseService()
@router.get("/transactions") @router.get("/transactions")
async def get_all_transactions( async def get_all_transactions(
transaction_repo: TransactionRepo,
page: int = Query(default=1, ge=1, description="Page number (1-based)"), page: int = Query(default=1, ge=1, description="Page number (1-based)"),
per_page: int = Query(default=50, le=500, description="Items per page"), per_page: int = Query(default=50, le=500, description="Items per page"),
summary_only: bool = Query( summary_only: bool = Query(
@@ -43,7 +43,7 @@ async def get_all_transactions(
limit = per_page limit = per_page
# Get transactions from database instead of GoCardless API # Get transactions from database instead of GoCardless API
db_transactions = await database_service.get_transactions_from_db( db_transactions = transaction_repo.get_transactions(
account_id=account_id, account_id=account_id,
limit=limit, limit=limit,
offset=offset, offset=offset,
@@ -55,7 +55,7 @@ async def get_all_transactions(
) )
# Get total count for pagination info (respecting the same filters) # Get total count for pagination info (respecting the same filters)
total_transactions = await database_service.get_transaction_count_from_db( total_transactions = transaction_repo.get_count(
account_id=account_id, account_id=account_id,
date_from=date_from, date_from=date_from,
date_to=date_to, date_to=date_to,
@@ -119,6 +119,7 @@ async def get_all_transactions(
@router.get("/transactions/stats") @router.get("/transactions/stats")
async def get_transaction_stats( async def get_transaction_stats(
transaction_repo: TransactionRepo,
days: int = Query(default=30, description="Number of days to include in stats"), days: int = Query(default=30, description="Number of days to include in stats"),
account_id: Optional[str] = Query(default=None, description="Filter by account ID"), account_id: Optional[str] = Query(default=None, description="Filter by account ID"),
) -> dict: ) -> dict:
@@ -133,7 +134,7 @@ async def get_transaction_stats(
date_to = end_date.isoformat() date_to = end_date.isoformat()
# Get transactions from database # Get transactions from database
recent_transactions = await database_service.get_transactions_from_db( recent_transactions = transaction_repo.get_transactions(
account_id=account_id, account_id=account_id,
date_from=date_from, date_from=date_from,
date_to=date_to, date_to=date_to,
@@ -198,6 +199,7 @@ async def get_transaction_stats(
@router.get("/transactions/analytics") @router.get("/transactions/analytics")
async def get_transactions_for_analytics( async def get_transactions_for_analytics(
transaction_repo: TransactionRepo,
days: int = Query(default=365, description="Number of days to include"), days: int = Query(default=365, description="Number of days to include"),
account_id: Optional[str] = Query(default=None, description="Filter by account ID"), account_id: Optional[str] = Query(default=None, description="Filter by account ID"),
) -> List[dict]: ) -> List[dict]:
@@ -212,7 +214,7 @@ async def get_transactions_for_analytics(
date_to = end_date.isoformat() date_to = end_date.isoformat()
# Get ALL transactions from database (no limit for analytics) # Get ALL transactions from database (no limit for analytics)
transactions = await database_service.get_transactions_from_db( transactions = transaction_repo.get_transactions(
account_id=account_id, account_id=account_id,
date_from=date_from, date_from=date_from,
date_to=date_to, date_to=date_to,
@@ -244,11 +246,14 @@ async def get_transactions_for_analytics(
@router.get("/transactions/monthly-stats") @router.get("/transactions/monthly-stats")
async def get_monthly_transaction_stats( async def get_monthly_transaction_stats(
analytics_proc: AnalyticsProc,
days: int = Query(default=365, description="Number of days to include"), days: int = Query(default=365, description="Number of days to include"),
account_id: Optional[str] = Query(default=None, description="Filter by account ID"), account_id: Optional[str] = Query(default=None, description="Filter by account ID"),
) -> List[dict]: ) -> List[dict]:
"""Get monthly transaction statistics aggregated by the database""" """Get monthly transaction statistics aggregated by the database"""
try: try:
from leggen.utils.paths import path_manager
# Date range for monthly stats # Date range for monthly stats
end_date = datetime.now() end_date = datetime.now()
start_date = end_date - timedelta(days=days) start_date = end_date - timedelta(days=days)
@@ -258,10 +263,9 @@ async def get_monthly_transaction_stats(
date_to = end_date.isoformat() date_to = end_date.isoformat()
# Get monthly aggregated stats from database # Get monthly aggregated stats from database
monthly_stats = await database_service.get_monthly_transaction_stats_from_db( db_path = path_manager.get_database_path()
account_id=account_id, monthly_stats = analytics_proc.calculate_monthly_stats(
date_from=date_from, db_path, account_id=account_id, date_from=date_from, date_to=date_to
date_to=date_to,
) )
return monthly_stats return monthly_stats

View File

@@ -28,10 +28,10 @@ async def lifespan(app: FastAPI):
# Run database migrations # Run database migrations
try: try:
from leggen.services.database_service import DatabaseService from leggen.api.dependencies import get_migration_repository
db_service = DatabaseService() migrations = get_migration_repository()
await db_service.run_migrations_if_needed() await migrations.run_all_migrations()
logger.info("Database migrations completed") logger.info("Database migrations completed")
except Exception as e: except Exception as e:
logger.error(f"Database migration failed: {e}") logger.error(f"Database migration failed: {e}")

View File

@@ -1,3 +1,11 @@
"""
DEPRECATED: DatabaseService is deprecated in favor of direct repository usage via dependency injection.
This module is kept for backward compatibility with existing tests.
New code should use repositories directly via leggen.api.dependencies.
"""
import warnings
from functools import wraps from functools import wraps
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
@@ -38,9 +46,19 @@ def require_sqlite(func):
class DatabaseService: class DatabaseService:
"""Simplified database service using repository pattern""" """
DEPRECATED: Use repositories directly via dependency injection.
This class is maintained for backward compatibility with existing tests.
For new code, inject repositories using leggen.api.dependencies.
"""
def __init__(self): def __init__(self):
warnings.warn(
"DatabaseService is deprecated. Use repositories via dependency injection.",
DeprecationWarning,
stacklevel=2,
)
self.db_config = config.database_config self.db_config = config.database_config
self.sqlite_enabled = self.db_config.get("sqlite", True) self.sqlite_enabled = self.db_config.get("sqlite", True)

View File

@@ -4,12 +4,17 @@ from typing import List
from loguru import logger from loguru import logger
from leggen.api.models.sync import SyncResult, SyncStatus from leggen.api.models.sync import SyncResult, SyncStatus
from leggen.repositories import (
AccountRepository,
BalanceRepository,
SyncRepository,
TransactionRepository,
)
from leggen.services.data_processors import ( from leggen.services.data_processors import (
AccountEnricher, AccountEnricher,
BalanceTransformer, BalanceTransformer,
TransactionProcessor, TransactionProcessor,
) )
from leggen.services.database_service import DatabaseService
from leggen.services.gocardless_service import GoCardlessService from leggen.services.gocardless_service import GoCardlessService
from leggen.services.notification_service import NotificationService from leggen.services.notification_service import NotificationService
@@ -20,9 +25,14 @@ EXPIRED_DAYS_LEFT = 0
class SyncService: class SyncService:
def __init__(self): def __init__(self):
self.gocardless = GoCardlessService() self.gocardless = GoCardlessService()
self.database = DatabaseService()
self.notifications = NotificationService() self.notifications = NotificationService()
# Repositories
self.accounts = AccountRepository()
self.balances = BalanceRepository()
self.transactions = TransactionRepository()
self.sync = SyncRepository()
# Data processors # Data processors
self.account_enricher = AccountEnricher() self.account_enricher = AccountEnricher()
self.balance_transformer = BalanceTransformer() self.balance_transformer = BalanceTransformer()
@@ -104,21 +114,22 @@ class SyncService:
) )
# Persist enriched account details to database # Persist enriched account details to database
await self.database.persist_account_details( self.accounts.persist(enriched_account_details)
enriched_account_details
)
# Merge account metadata into balances for persistence # Merge account metadata into balances for persistence
balances_with_account_info = self.balance_transformer.merge_account_metadata_into_balances( balances_with_account_info = self.balance_transformer.merge_account_metadata_into_balances(
balances, enriched_account_details balances, enriched_account_details
) )
await self.database.persist_balance( balance_rows = (
account_id, balances_with_account_info self.balance_transformer.transform_to_database_format(
account_id, balances_with_account_info
)
) )
self.balances.persist(account_id, balance_rows)
balances_updated += len(balances.get("balances", [])) balances_updated += len(balances.get("balances", []))
elif account_details: elif account_details:
# Fallback: persist account details without currency if balances failed # Fallback: persist account details without currency if balances failed
await self.database.persist_account_details(account_details) self.accounts.persist(account_details)
# Get and save transactions # Get and save transactions
transactions = await self.gocardless.get_account_transactions( transactions = await self.gocardless.get_account_transactions(
@@ -130,7 +141,7 @@ class SyncService:
account_id, account_details, transactions account_id, account_details, transactions
) )
) )
new_transactions = await self.database.persist_transactions( new_transactions = self.transactions.persist(
account_id, processed_transactions account_id, processed_transactions
) )
transactions_added += len(new_transactions) transactions_added += len(new_transactions)
@@ -184,9 +195,7 @@ class SyncService:
# Persist sync operation to database # Persist sync operation to database
try: try:
operation_id = await self.database.persist_sync_operation( operation_id = self.sync.persist(sync_operation)
sync_operation
)
logger.debug(f"Saved sync operation with ID: {operation_id}") logger.debug(f"Saved sync operation with ID: {operation_id}")
except Exception as e: except Exception as e:
logger.error(f"Failed to persist sync operation: {e}") logger.error(f"Failed to persist sync operation: {e}")
@@ -235,9 +244,7 @@ class SyncService:
) )
try: try:
operation_id = await self.database.persist_sync_operation( operation_id = self.sync.persist(sync_operation)
sync_operation
)
logger.debug(f"Saved failed sync operation with ID: {operation_id}") logger.debug(f"Saved failed sync operation with ID: {operation_id}")
except Exception as persist_error: except Exception as persist_error:
logger.error( logger.error(

View File

@@ -126,6 +126,38 @@ def api_client(fastapi_app):
return TestClient(fastapi_app) return TestClient(fastapi_app)
@pytest.fixture
def mock_account_repo():
"""Create mock AccountRepository for testing."""
from unittest.mock import MagicMock
return MagicMock()
@pytest.fixture
def mock_balance_repo():
"""Create mock BalanceRepository for testing."""
from unittest.mock import MagicMock
return MagicMock()
@pytest.fixture
def mock_transaction_repo():
"""Create mock TransactionRepository for testing."""
from unittest.mock import MagicMock
return MagicMock()
@pytest.fixture
def mock_analytics_proc():
"""Create mock AnalyticsProcessor for testing."""
from unittest.mock import MagicMock
return MagicMock()
@pytest.fixture @pytest.fixture
def mock_db_path(temp_db_path): def mock_db_path(temp_db_path):
"""Mock the database path to use temporary database for testing.""" """Mock the database path to use temporary database for testing."""

View File

@@ -1,13 +1,13 @@
"""Tests for analytics fixes to ensure all transactions are used in statistics.""" """Tests for analytics fixes to ensure all transactions are used in statistics."""
from datetime import datetime, timedelta from datetime import datetime, timedelta
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import Mock
import pytest import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from leggen.api.dependencies import get_transaction_repository
from leggen.commands.server import create_app from leggen.commands.server import create_app
from leggen.services.database_service import DatabaseService
class TestAnalyticsFix: class TestAnalyticsFix:
@@ -19,11 +19,11 @@ class TestAnalyticsFix:
return TestClient(app) return TestClient(app)
@pytest.fixture @pytest.fixture
def mock_database_service(self): def mock_transaction_repo(self):
return Mock(spec=DatabaseService) return Mock()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_transaction_stats_uses_all_transactions(self, mock_database_service): async def test_transaction_stats_uses_all_transactions(self, mock_transaction_repo):
"""Test that transaction stats endpoint uses all transactions (not limited to 100)""" """Test that transaction stats endpoint uses all transactions (not limited to 100)"""
# Mock data for 600 transactions (simulating the issue) # Mock data for 600 transactions (simulating the issue)
mock_transactions = [] mock_transactions = []
@@ -42,53 +42,50 @@ class TestAnalyticsFix:
} }
) )
mock_database_service.get_transactions_from_db = AsyncMock( mock_transaction_repo.get_transactions.return_value = mock_transactions
return_value=mock_transactions
app = create_app()
app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
client = TestClient(app)
response = client.get("/api/v1/transactions/stats?days=365")
assert response.status_code == 200
data = response.json()
# Verify that limit=None was passed to get all transactions
mock_transaction_repo.get_transactions.assert_called_once()
call_args = mock_transaction_repo.get_transactions.call_args
assert call_args.kwargs.get("limit") is None, (
"Stats endpoint should pass limit=None to get all transactions"
) )
# Test that the endpoint calls get_transactions_from_db with limit=None # Verify that the response contains stats for all 600 transactions
with patch( stats = data
"leggen.api.routes.transactions.database_service", mock_database_service assert stats["total_transactions"] == 600, (
): "Should process all 600 transactions, not just 100"
app = create_app() )
client = TestClient(app)
response = client.get("/api/v1/transactions/stats?days=365") # Verify calculations are correct for all transactions
expected_income = sum(
txn["transactionValue"]
for txn in mock_transactions
if txn["transactionValue"] > 0
)
expected_expenses = sum(
abs(txn["transactionValue"])
for txn in mock_transactions
if txn["transactionValue"] < 0
)
assert response.status_code == 200 assert stats["total_income"] == expected_income
data = response.json() assert stats["total_expenses"] == expected_expenses
# Verify that limit=None was passed to get all transactions
mock_database_service.get_transactions_from_db.assert_called_once()
call_args = mock_database_service.get_transactions_from_db.call_args
assert call_args.kwargs.get("limit") is None, (
"Stats endpoint should pass limit=None to get all transactions"
)
# Verify that the response contains stats for all 600 transactions
stats = data
assert stats["total_transactions"] == 600, (
"Should process all 600 transactions, not just 100"
)
# Verify calculations are correct for all transactions
expected_income = sum(
txn["transactionValue"]
for txn in mock_transactions
if txn["transactionValue"] > 0
)
expected_expenses = sum(
abs(txn["transactionValue"])
for txn in mock_transactions
if txn["transactionValue"] < 0
)
assert stats["total_income"] == expected_income
assert stats["total_expenses"] == expected_expenses
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_analytics_endpoint_returns_all_transactions( async def test_analytics_endpoint_returns_all_transactions(
self, mock_database_service self, mock_transaction_repo
): ):
"""Test that the new analytics endpoint returns all transactions without pagination""" """Test that the new analytics endpoint returns all transactions without pagination"""
# Mock data for 600 transactions # Mock data for 600 transactions
@@ -108,30 +105,28 @@ class TestAnalyticsFix:
} }
) )
mock_database_service.get_transactions_from_db = AsyncMock( mock_transaction_repo.get_transactions.return_value = mock_transactions
return_value=mock_transactions
app = create_app()
app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
client = TestClient(app)
response = client.get("/api/v1/transactions/analytics?days=365")
assert response.status_code == 200
data = response.json()
# Verify that limit=None was passed to get all transactions
mock_transaction_repo.get_transactions.assert_called_once()
call_args = mock_transaction_repo.get_transactions.call_args
assert call_args.kwargs.get("limit") is None, (
"Analytics endpoint should pass limit=None"
) )
with patch( # Verify that all 600 transactions are returned
"leggen.api.routes.transactions.database_service", mock_database_service transactions_data = data
): assert len(transactions_data) == 600, (
app = create_app() "Analytics endpoint should return all 600 transactions"
client = TestClient(app) )
response = client.get("/api/v1/transactions/analytics?days=365")
assert response.status_code == 200
data = response.json()
# Verify that limit=None was passed to get all transactions
mock_database_service.get_transactions_from_db.assert_called_once()
call_args = mock_database_service.get_transactions_from_db.call_args
assert call_args.kwargs.get("limit") is None, (
"Analytics endpoint should pass limit=None"
)
# Verify that all 600 transactions are returned
transactions_data = data
assert len(transactions_data) == 600, (
"Analytics endpoint should return all 600 transactions"
)

View File

@@ -4,6 +4,12 @@ from unittest.mock import patch
import pytest import pytest
from leggen.api.dependencies import (
get_account_repository,
get_balance_repository,
get_transaction_repository,
)
@pytest.mark.api @pytest.mark.api
class TestAccountsAPI: class TestAccountsAPI:
@@ -11,11 +17,14 @@ class TestAccountsAPI:
def test_get_all_accounts_success( def test_get_all_accounts_success(
self, self,
fastapi_app,
api_client, api_client,
mock_config, mock_config,
mock_auth_token, mock_auth_token,
sample_account_data, sample_account_data,
mock_db_path, mock_db_path,
mock_account_repo,
mock_balance_repo,
): ):
"""Test successful retrieval of all accounts from database.""" """Test successful retrieval of all accounts from database."""
mock_accounts = [ mock_accounts = [
@@ -45,19 +54,21 @@ class TestAccountsAPI:
} }
] ]
with ( mock_account_repo.get_accounts.return_value = mock_accounts
patch("leggen.utils.config.config", mock_config), mock_balance_repo.get_balances.return_value = mock_balances
patch(
"leggen.api.routes.accounts.database_service.get_accounts_from_db", fastapi_app.dependency_overrides[get_account_repository] = (
return_value=mock_accounts, lambda: mock_account_repo
), )
patch( fastapi_app.dependency_overrides[get_balance_repository] = (
"leggen.api.routes.accounts.database_service.get_balances_from_db", lambda: mock_balance_repo
return_value=mock_balances, )
),
): with patch("leggen.utils.config.config", mock_config):
response = api_client.get("/api/v1/accounts") response = api_client.get("/api/v1/accounts")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert len(data) == 1 assert len(data) == 1
@@ -69,11 +80,14 @@ class TestAccountsAPI:
def test_get_account_details_success( def test_get_account_details_success(
self, self,
fastapi_app,
api_client, api_client,
mock_config, mock_config,
mock_auth_token, mock_auth_token,
sample_account_data, sample_account_data,
mock_db_path, mock_db_path,
mock_account_repo,
mock_balance_repo,
): ):
"""Test successful retrieval of specific account details from database.""" """Test successful retrieval of specific account details from database."""
mock_account = { mock_account = {
@@ -101,19 +115,21 @@ class TestAccountsAPI:
} }
] ]
with ( mock_account_repo.get_account.return_value = mock_account
patch("leggen.utils.config.config", mock_config), mock_balance_repo.get_balances.return_value = mock_balances
patch(
"leggen.api.routes.accounts.database_service.get_account_details_from_db", fastapi_app.dependency_overrides[get_account_repository] = (
return_value=mock_account, lambda: mock_account_repo
), )
patch( fastapi_app.dependency_overrides[get_balance_repository] = (
"leggen.api.routes.accounts.database_service.get_balances_from_db", lambda: mock_balance_repo
return_value=mock_balances, )
),
): with patch("leggen.utils.config.config", mock_config):
response = api_client.get("/api/v1/accounts/test-account-123") response = api_client.get("/api/v1/accounts/test-account-123")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["id"] == "test-account-123" assert data["id"] == "test-account-123"
@@ -121,7 +137,13 @@ class TestAccountsAPI:
assert len(data["balances"]) == 1 assert len(data["balances"]) == 1
def test_get_account_balances_success( def test_get_account_balances_success(
self, api_client, mock_config, mock_auth_token, mock_db_path self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_db_path,
mock_balance_repo,
): ):
"""Test successful retrieval of account balances from database.""" """Test successful retrieval of account balances from database."""
mock_balances = [ mock_balances = [
@@ -149,15 +171,17 @@ class TestAccountsAPI:
}, },
] ]
with ( mock_balance_repo.get_balances.return_value = mock_balances
patch("leggen.utils.config.config", mock_config),
patch( fastapi_app.dependency_overrides[get_balance_repository] = (
"leggen.api.routes.accounts.database_service.get_balances_from_db", lambda: mock_balance_repo
return_value=mock_balances, )
),
): with patch("leggen.utils.config.config", mock_config):
response = api_client.get("/api/v1/accounts/test-account-123/balances") response = api_client.get("/api/v1/accounts/test-account-123/balances")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert len(data) == 2 assert len(data) == 2
@@ -167,12 +191,14 @@ class TestAccountsAPI:
def test_get_account_transactions_success( def test_get_account_transactions_success(
self, self,
fastapi_app,
api_client, api_client,
mock_config, mock_config,
mock_auth_token, mock_auth_token,
sample_account_data, sample_account_data,
sample_transaction_data, sample_transaction_data,
mock_db_path, mock_db_path,
mock_transaction_repo,
): ):
"""Test successful retrieval of account transactions from database.""" """Test successful retrieval of account transactions from database."""
mock_transactions = [ mock_transactions = [
@@ -191,21 +217,19 @@ class TestAccountsAPI:
} }
] ]
with ( mock_transaction_repo.get_transactions.return_value = mock_transactions
patch("leggen.utils.config.config", mock_config),
patch( fastapi_app.dependency_overrides[get_transaction_repository] = (
"leggen.api.routes.accounts.database_service.get_transactions_from_db", lambda: mock_transaction_repo
return_value=mock_transactions, )
),
patch( with patch("leggen.utils.config.config", mock_config):
"leggen.api.routes.accounts.database_service.get_transaction_count_from_db",
return_value=1,
),
):
response = api_client.get( response = api_client.get(
"/api/v1/accounts/test-account-123/transactions?summary_only=true" "/api/v1/accounts/test-account-123/transactions?summary_only=true"
) )
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert len(data) == 1 assert len(data) == 1
@@ -218,12 +242,14 @@ class TestAccountsAPI:
def test_get_account_transactions_full_details( def test_get_account_transactions_full_details(
self, self,
fastapi_app,
api_client, api_client,
mock_config, mock_config,
mock_auth_token, mock_auth_token,
sample_account_data, sample_account_data,
sample_transaction_data, sample_transaction_data,
mock_db_path, mock_db_path,
mock_transaction_repo,
): ):
"""Test retrieval of full transaction details from database.""" """Test retrieval of full transaction details from database."""
mock_transactions = [ mock_transactions = [
@@ -242,21 +268,19 @@ class TestAccountsAPI:
} }
] ]
with ( mock_transaction_repo.get_transactions.return_value = mock_transactions
patch("leggen.utils.config.config", mock_config),
patch( fastapi_app.dependency_overrides[get_transaction_repository] = (
"leggen.api.routes.accounts.database_service.get_transactions_from_db", lambda: mock_transaction_repo
return_value=mock_transactions, )
),
patch( with patch("leggen.utils.config.config", mock_config):
"leggen.api.routes.accounts.database_service.get_transaction_count_from_db",
return_value=1,
),
):
response = api_client.get( response = api_client.get(
"/api/v1/accounts/test-account-123/transactions?summary_only=false" "/api/v1/accounts/test-account-123/transactions?summary_only=false"
) )
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert len(data) == 1 assert len(data) == 1
@@ -268,22 +292,36 @@ class TestAccountsAPI:
assert "raw_transaction" in transaction assert "raw_transaction" in transaction
def test_get_account_not_found( def test_get_account_not_found(
self, api_client, mock_config, mock_auth_token, mock_db_path self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_db_path,
mock_account_repo,
): ):
"""Test handling of non-existent account.""" """Test handling of non-existent account."""
with ( mock_account_repo.get_account.return_value = None
patch("leggen.utils.config.config", mock_config),
patch( fastapi_app.dependency_overrides[get_account_repository] = (
"leggen.api.routes.accounts.database_service.get_account_details_from_db", lambda: mock_account_repo
return_value=None, )
),
): with patch("leggen.utils.config.config", mock_config):
response = api_client.get("/api/v1/accounts/nonexistent") response = api_client.get("/api/v1/accounts/nonexistent")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 404 assert response.status_code == 404
def test_update_account_display_name_success( def test_update_account_display_name_success(
self, api_client, mock_config, mock_auth_token, mock_db_path self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_db_path,
mock_account_repo,
): ):
"""Test successful update of account display name.""" """Test successful update of account display name."""
mock_account = { mock_account = {
@@ -297,41 +335,48 @@ class TestAccountsAPI:
"last_accessed": "2025-09-01T09:30:00Z", "last_accessed": "2025-09-01T09:30:00Z",
} }
with ( mock_account_repo.get_account.return_value = mock_account
patch("leggen.utils.config.config", mock_config), mock_account_repo.persist.return_value = mock_account
patch(
"leggen.api.routes.accounts.database_service.get_account_details_from_db", fastapi_app.dependency_overrides[get_account_repository] = (
return_value=mock_account, lambda: mock_account_repo
), )
patch(
"leggen.api.routes.accounts.database_service.persist_account_details", with patch("leggen.utils.config.config", mock_config):
return_value=None,
),
):
response = api_client.put( response = api_client.put(
"/api/v1/accounts/test-account-123", "/api/v1/accounts/test-account-123",
json={"display_name": "My Custom Account Name"}, json={"display_name": "My Custom Account Name"},
) )
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["id"] == "test-account-123" assert data["id"] == "test-account-123"
assert data["display_name"] == "My Custom Account Name" assert data["display_name"] == "My Custom Account Name"
def test_update_account_not_found( def test_update_account_not_found(
self, api_client, mock_config, mock_auth_token, mock_db_path self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_db_path,
mock_account_repo,
): ):
"""Test updating non-existent account.""" """Test updating non-existent account."""
with ( mock_account_repo.get_account.return_value = None
patch("leggen.utils.config.config", mock_config),
patch( fastapi_app.dependency_overrides[get_account_repository] = (
"leggen.api.routes.accounts.database_service.get_account_details_from_db", lambda: mock_account_repo
return_value=None, )
),
): with patch("leggen.utils.config.config", mock_config):
response = api_client.put( response = api_client.put(
"/api/v1/accounts/nonexistent", "/api/v1/accounts/nonexistent",
json={"display_name": "New Name"}, json={"display_name": "New Name"},
) )
fastapi_app.dependency_overrides.clear()
assert response.status_code == 404 assert response.status_code == 404

View File

@@ -5,13 +5,20 @@ from unittest.mock import patch
import pytest import pytest
from leggen.api.dependencies import get_transaction_repository
@pytest.mark.api @pytest.mark.api
class TestTransactionsAPI: class TestTransactionsAPI:
"""Test transaction-related API endpoints.""" """Test transaction-related API endpoints."""
def test_get_all_transactions_success( def test_get_all_transactions_success(
self, api_client, mock_config, mock_auth_token self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
): ):
"""Test successful retrieval of all transactions from database.""" """Test successful retrieval of all transactions from database."""
mock_transactions = [ mock_transactions = [
@@ -43,19 +50,17 @@ class TestTransactionsAPI:
}, },
] ]
with ( mock_transaction_repo.get_transactions.return_value = mock_transactions
patch("leggen.utils.config.config", mock_config), mock_transaction_repo.get_count.return_value = len(mock_transactions)
patch( fastapi_app.dependency_overrides[get_transaction_repository] = (
"leggen.api.routes.transactions.database_service.get_transactions_from_db", lambda: mock_transaction_repo
return_value=mock_transactions, )
),
patch( with patch("leggen.utils.config.config", mock_config):
"leggen.api.routes.transactions.database_service.get_transaction_count_from_db",
return_value=2,
),
):
response = api_client.get("/api/v1/transactions?summary_only=true") response = api_client.get("/api/v1/transactions?summary_only=true")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert len(data["data"]) == 2 assert len(data["data"]) == 2
@@ -70,7 +75,12 @@ class TestTransactionsAPI:
assert transaction["account_id"] == "test-account-123" assert transaction["account_id"] == "test-account-123"
def test_get_all_transactions_full_details( def test_get_all_transactions_full_details(
self, api_client, mock_config, mock_auth_token self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
): ):
"""Test retrieval of full transaction details from database.""" """Test retrieval of full transaction details from database."""
mock_transactions = [ mock_transactions = [
@@ -89,19 +99,17 @@ class TestTransactionsAPI:
} }
] ]
with ( mock_transaction_repo.get_transactions.return_value = mock_transactions
patch("leggen.utils.config.config", mock_config), mock_transaction_repo.get_count.return_value = len(mock_transactions)
patch( fastapi_app.dependency_overrides[get_transaction_repository] = (
"leggen.api.routes.transactions.database_service.get_transactions_from_db", lambda: mock_transaction_repo
return_value=mock_transactions, )
),
patch( with patch("leggen.utils.config.config", mock_config):
"leggen.api.routes.transactions.database_service.get_transaction_count_from_db",
return_value=1,
),
):
response = api_client.get("/api/v1/transactions?summary_only=false") response = api_client.get("/api/v1/transactions?summary_only=false")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert len(data["data"]) == 1 assert len(data["data"]) == 1
@@ -114,7 +122,12 @@ class TestTransactionsAPI:
assert "raw_transaction" in transaction assert "raw_transaction" in transaction
def test_get_transactions_with_filters( def test_get_transactions_with_filters(
self, api_client, mock_config, mock_auth_token self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
): ):
"""Test getting transactions with various filters.""" """Test getting transactions with various filters."""
mock_transactions = [ mock_transactions = [
@@ -133,17 +146,14 @@ class TestTransactionsAPI:
} }
] ]
with ( mock_transaction_repo.get_transactions.return_value = mock_transactions
patch("leggen.utils.config.config", mock_config), mock_transaction_repo.get_count.return_value = 1
patch(
"leggen.api.routes.transactions.database_service.get_transactions_from_db", fastapi_app.dependency_overrides[get_transaction_repository] = (
return_value=mock_transactions, lambda: mock_transaction_repo
) as mock_get_transactions, )
patch(
"leggen.api.routes.transactions.database_service.get_transaction_count_from_db", with patch("leggen.utils.config.config", mock_config):
return_value=1,
),
):
response = api_client.get( response = api_client.get(
"/api/v1/transactions?" "/api/v1/transactions?"
"account_id=test-account-123&" "account_id=test-account-123&"
@@ -156,10 +166,12 @@ class TestTransactionsAPI:
"per_page=10" "per_page=10"
) )
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200 assert response.status_code == 200
# Verify the database service was called with correct filters # Verify the repository was called with correct filters
mock_get_transactions.assert_called_once_with( mock_transaction_repo.get_transactions.assert_called_once_with(
account_id="test-account-123", account_id="test-account-123",
limit=10, limit=10,
offset=10, # (page-1) * per_page = (2-1) * 10 = 10 offset=10, # (page-1) * per_page = (2-1) * 10 = 10
@@ -171,22 +183,26 @@ class TestTransactionsAPI:
) )
def test_get_transactions_empty_result( def test_get_transactions_empty_result(
self, api_client, mock_config, mock_auth_token self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
): ):
"""Test getting transactions when database returns empty result.""" """Test getting transactions when database returns empty result."""
with ( mock_transaction_repo.get_transactions.return_value = []
patch("leggen.utils.config.config", mock_config), mock_transaction_repo.get_count.return_value = 0
patch(
"leggen.api.routes.transactions.database_service.get_transactions_from_db", fastapi_app.dependency_overrides[get_transaction_repository] = (
return_value=[], lambda: mock_transaction_repo
), )
patch(
"leggen.api.routes.transactions.database_service.get_transaction_count_from_db", with patch("leggen.utils.config.config", mock_config):
return_value=0,
),
):
response = api_client.get("/api/v1/transactions") response = api_client.get("/api/v1/transactions")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert len(data["data"]) == 0 assert len(data["data"]) == 0
@@ -195,23 +211,37 @@ class TestTransactionsAPI:
assert data["total_pages"] == 0 assert data["total_pages"] == 0
def test_get_transactions_database_error( def test_get_transactions_database_error(
self, api_client, mock_config, mock_auth_token self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
): ):
"""Test handling database error when getting transactions.""" """Test handling database error when getting transactions."""
with ( mock_transaction_repo.get_transactions.side_effect = Exception(
patch("leggen.utils.config.config", mock_config), "Database connection failed"
patch( )
"leggen.api.routes.transactions.database_service.get_transactions_from_db",
side_effect=Exception("Database connection failed"), fastapi_app.dependency_overrides[get_transaction_repository] = (
), lambda: mock_transaction_repo
): )
with patch("leggen.utils.config.config", mock_config):
response = api_client.get("/api/v1/transactions") response = api_client.get("/api/v1/transactions")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 500 assert response.status_code == 500
assert "Failed to get transactions" in response.json()["detail"] assert "Failed to get transactions" in response.json()["detail"]
def test_get_transaction_stats_success( def test_get_transaction_stats_success(
self, api_client, mock_config, mock_auth_token self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
): ):
"""Test successful retrieval of transaction statistics from database.""" """Test successful retrieval of transaction statistics from database."""
mock_transactions = [ mock_transactions = [
@@ -238,15 +268,16 @@ class TestTransactionsAPI:
}, },
] ]
with ( mock_transaction_repo.get_transactions.return_value = mock_transactions
patch("leggen.utils.config.config", mock_config), fastapi_app.dependency_overrides[get_transaction_repository] = (
patch( lambda: mock_transaction_repo
"leggen.api.routes.transactions.database_service.get_transactions_from_db", )
return_value=mock_transactions,
), with patch("leggen.utils.config.config", mock_config):
):
response = api_client.get("/api/v1/transactions/stats?days=30") response = api_client.get("/api/v1/transactions/stats?days=30")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
@@ -264,7 +295,12 @@ class TestTransactionsAPI:
assert data["average_transaction"] == expected_avg assert data["average_transaction"] == expected_avg
def test_get_transaction_stats_with_account_filter( def test_get_transaction_stats_with_account_filter(
self, api_client, mock_config, mock_auth_token self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
): ):
"""Test getting transaction stats filtered by account.""" """Test getting transaction stats filtered by account."""
mock_transactions = [ mock_transactions = [
@@ -277,37 +313,46 @@ class TestTransactionsAPI:
} }
] ]
with ( mock_transaction_repo.get_transactions.return_value = mock_transactions
patch("leggen.utils.config.config", mock_config),
patch( fastapi_app.dependency_overrides[get_transaction_repository] = (
"leggen.api.routes.transactions.database_service.get_transactions_from_db", lambda: mock_transaction_repo
return_value=mock_transactions, )
) as mock_get_transactions,
): with patch("leggen.utils.config.config", mock_config):
response = api_client.get( response = api_client.get(
"/api/v1/transactions/stats?account_id=test-account-123" "/api/v1/transactions/stats?account_id=test-account-123"
) )
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200 assert response.status_code == 200
# Verify the database service was called with account filter # Verify the repository was called with account filter
mock_get_transactions.assert_called_once() mock_transaction_repo.get_transactions.assert_called_once()
call_kwargs = mock_get_transactions.call_args.kwargs call_kwargs = mock_transaction_repo.get_transactions.call_args.kwargs
assert call_kwargs["account_id"] == "test-account-123" assert call_kwargs["account_id"] == "test-account-123"
def test_get_transaction_stats_empty_result( def test_get_transaction_stats_empty_result(
self, api_client, mock_config, mock_auth_token self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
): ):
"""Test getting stats when no transactions match criteria.""" """Test getting stats when no transactions match criteria."""
with ( mock_transaction_repo.get_transactions.return_value = []
patch("leggen.utils.config.config", mock_config),
patch( fastapi_app.dependency_overrides[get_transaction_repository] = (
"leggen.api.routes.transactions.database_service.get_transactions_from_db", lambda: mock_transaction_repo
return_value=[], )
),
): with patch("leggen.utils.config.config", mock_config):
response = api_client.get("/api/v1/transactions/stats") response = api_client.get("/api/v1/transactions/stats")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
@@ -319,23 +364,37 @@ class TestTransactionsAPI:
assert data["accounts_included"] == 0 assert data["accounts_included"] == 0
def test_get_transaction_stats_database_error( def test_get_transaction_stats_database_error(
self, api_client, mock_config, mock_auth_token self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
): ):
"""Test handling database error when getting stats.""" """Test handling database error when getting stats."""
with ( mock_transaction_repo.get_transactions.side_effect = Exception(
patch("leggen.utils.config.config", mock_config), "Database connection failed"
patch( )
"leggen.api.routes.transactions.database_service.get_transactions_from_db",
side_effect=Exception("Database connection failed"), fastapi_app.dependency_overrides[get_transaction_repository] = (
), lambda: mock_transaction_repo
): )
with patch("leggen.utils.config.config", mock_config):
response = api_client.get("/api/v1/transactions/stats") response = api_client.get("/api/v1/transactions/stats")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 500 assert response.status_code == 500
assert "Failed to get transaction stats" in response.json()["detail"] assert "Failed to get transaction stats" in response.json()["detail"]
def test_get_transaction_stats_custom_period( def test_get_transaction_stats_custom_period(
self, api_client, mock_config, mock_auth_token self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
): ):
"""Test getting transaction stats for custom time period.""" """Test getting transaction stats for custom time period."""
mock_transactions = [ mock_transactions = [
@@ -348,21 +407,23 @@ class TestTransactionsAPI:
} }
] ]
with ( mock_transaction_repo.get_transactions.return_value = mock_transactions
patch("leggen.utils.config.config", mock_config),
patch( fastapi_app.dependency_overrides[get_transaction_repository] = (
"leggen.api.routes.transactions.database_service.get_transactions_from_db", lambda: mock_transaction_repo
return_value=mock_transactions, )
) as mock_get_transactions,
): with patch("leggen.utils.config.config", mock_config):
response = api_client.get("/api/v1/transactions/stats?days=7") response = api_client.get("/api/v1/transactions/stats?days=7")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["period_days"] == 7 assert data["period_days"] == 7
# Verify the date range was calculated correctly for 7 days # Verify the date range was calculated correctly for 7 days
mock_get_transactions.assert_called_once() mock_transaction_repo.get_transactions.assert_called_once()
call_kwargs = mock_get_transactions.call_args.kwargs call_kwargs = mock_transaction_repo.get_transactions.call_args.kwargs
assert "date_from" in call_kwargs assert "date_from" in call_kwargs
assert "date_to" in call_kwargs assert "date_to" in call_kwargs

View File

@@ -27,9 +27,7 @@ class TestSyncNotifications:
patch.object( patch.object(
sync_service.notifications, "send_sync_failure_notification" sync_service.notifications, "send_sync_failure_notification"
) as mock_send_notification, ) as mock_send_notification,
patch.object( patch.object(sync_service.sync, "persist", return_value=1),
sync_service.database, "persist_sync_operation", return_value=1
),
): ):
# Setup: One requisition with one account that will fail # Setup: One requisition with one account that will fail
mock_get_requisitions.return_value = { mock_get_requisitions.return_value = {
@@ -69,9 +67,7 @@ class TestSyncNotifications:
patch.object( patch.object(
sync_service.notifications, "send_expiry_notification" sync_service.notifications, "send_expiry_notification"
) as mock_send_expiry, ) as mock_send_expiry,
patch.object( patch.object(sync_service.sync, "persist", return_value=1),
sync_service.database, "persist_sync_operation", return_value=1
),
): ):
# Setup: One expired requisition # Setup: One expired requisition
mock_get_requisitions.return_value = { mock_get_requisitions.return_value = {
@@ -112,9 +108,7 @@ class TestSyncNotifications:
patch.object( patch.object(
sync_service.notifications, "send_sync_failure_notification" sync_service.notifications, "send_sync_failure_notification"
) as mock_send_notification, ) as mock_send_notification,
patch.object( patch.object(sync_service.sync, "persist", return_value=1),
sync_service.database, "persist_sync_operation", return_value=1
),
): ):
# Setup: One requisition with two accounts that will fail # Setup: One requisition with two accounts that will fail
mock_get_requisitions.return_value = { mock_get_requisitions.return_value = {
@@ -160,17 +154,15 @@ class TestSyncNotifications:
sync_service.notifications, "send_sync_failure_notification" sync_service.notifications, "send_sync_failure_notification"
) as mock_send_notification, ) as mock_send_notification,
patch.object(sync_service.notifications, "send_transaction_notifications"), patch.object(sync_service.notifications, "send_transaction_notifications"),
patch.object(sync_service.database, "persist_account_details"), patch.object(sync_service.accounts, "persist"),
patch.object(sync_service.database, "persist_balance"), patch.object(sync_service.balances, "persist"),
patch.object( patch.object(
sync_service.database, "process_transactions", return_value=[] sync_service.transaction_processor,
), "process_transactions",
patch.object( return_value=[],
sync_service.database, "persist_transactions", return_value=[]
),
patch.object(
sync_service.database, "persist_sync_operation", return_value=1
), ),
patch.object(sync_service.transactions, "persist", return_value=[]),
patch.object(sync_service.sync, "persist", return_value=1),
): ):
# Setup: One requisition with one account that succeeds # Setup: One requisition with one account that succeeds
mock_get_requisitions.return_value = { mock_get_requisitions.return_value = {
@@ -222,9 +214,7 @@ class TestSyncNotifications:
patch.object( patch.object(
sync_service.notifications, "_send_telegram_sync_failure" sync_service.notifications, "_send_telegram_sync_failure"
) as mock_telegram_notification, ) as mock_telegram_notification,
patch.object( patch.object(sync_service.sync, "persist", return_value=1),
sync_service.database, "persist_sync_operation", return_value=1
),
): ):
# Setup: One requisition with one account that will fail # Setup: One requisition with one account that will fail
mock_get_requisitions.return_value = { mock_get_requisitions.return_value = {