refactor(api): Update all modified files with dependency injection changes.

This commit is contained in:
Elisiário Couto
2025-12-09 00:50:02 +00:00
committed by Elisiário Couto
parent 9dc6357905
commit 9e9b1cf15f
11 changed files with 505 additions and 330 deletions

View File

@@ -3,6 +3,12 @@ from typing import List, Optional, Union
from fastapi import APIRouter, HTTPException, Query
from loguru import logger
from leggen.api.dependencies import (
AccountRepo,
AnalyticsProc,
BalanceRepo,
TransactionRepo,
)
from leggen.api.models.accounts import (
AccountBalance,
AccountDetails,
@@ -10,28 +16,27 @@ from leggen.api.models.accounts import (
Transaction,
TransactionSummary,
)
from leggen.services.database_service import DatabaseService
router = APIRouter()
database_service = DatabaseService()
@router.get("/accounts")
async def get_all_accounts() -> List[AccountDetails]:
async def get_all_accounts(
account_repo: AccountRepo,
balance_repo: BalanceRepo,
) -> List[AccountDetails]:
"""Get all connected accounts from database"""
try:
accounts = []
# Get all account details from database
db_accounts = await database_service.get_accounts_from_db()
db_accounts = account_repo.get_accounts()
# Process accounts found in database
for db_account in db_accounts:
try:
# Get latest balances from database for this account
balances_data = await database_service.get_balances_from_db(
db_account["id"]
)
balances_data = balance_repo.get_balances(db_account["id"])
# Process balances
balances = []
@@ -77,11 +82,15 @@ async def get_all_accounts() -> List[AccountDetails]:
@router.get("/accounts/{account_id}")
async def get_account_details(account_id: str) -> AccountDetails:
async def get_account_details(
account_id: str,
account_repo: AccountRepo,
balance_repo: BalanceRepo,
) -> AccountDetails:
"""Get details for a specific account from database"""
try:
# Get account details from database
db_account = await database_service.get_account_details_from_db(account_id)
db_account = account_repo.get_account(account_id)
if not db_account:
raise HTTPException(
@@ -89,7 +98,7 @@ async def get_account_details(account_id: str) -> AccountDetails:
)
# Get latest balances from database for this account
balances_data = await database_service.get_balances_from_db(account_id)
balances_data = balance_repo.get_balances(account_id)
# Process balances
balances = []
@@ -129,11 +138,14 @@ async def get_account_details(account_id: str) -> AccountDetails:
@router.get("/accounts/{account_id}/balances")
async def get_account_balances(account_id: str) -> List[AccountBalance]:
async def get_account_balances(
account_id: str,
balance_repo: BalanceRepo,
) -> List[AccountBalance]:
"""Get balances for a specific account from database"""
try:
# Get balances from database instead of GoCardless API
db_balances = await database_service.get_balances_from_db(account_id=account_id)
db_balances = balance_repo.get_balances(account_id=account_id)
balances = []
for balance in db_balances:
@@ -158,19 +170,20 @@ async def get_account_balances(account_id: str) -> List[AccountBalance]:
@router.get("/balances")
async def get_all_balances() -> List[dict]:
async def get_all_balances(
account_repo: AccountRepo,
balance_repo: BalanceRepo,
) -> List[dict]:
"""Get all balances from all accounts in database"""
try:
# Get all accounts first to iterate through them
db_accounts = await database_service.get_accounts_from_db()
db_accounts = account_repo.get_accounts()
all_balances = []
for db_account in db_accounts:
try:
# Get balances for this account
db_balances = await database_service.get_balances_from_db(
account_id=db_account["id"]
)
db_balances = balance_repo.get_balances(account_id=db_account["id"])
# Process balances and add account info
for balance in db_balances:
@@ -205,6 +218,7 @@ async def get_all_balances() -> List[dict]:
@router.get("/balances/history")
async def get_historical_balances(
analytics_proc: AnalyticsProc,
days: Optional[int] = Query(
default=365, le=1095, ge=1, description="Number of days of history to retrieve"
),
@@ -214,9 +228,12 @@ async def get_historical_balances(
) -> List[dict]:
"""Get historical balance progression calculated from transaction history"""
try:
from leggen.utils.paths import path_manager
# Get historical balances from database
historical_balances = await database_service.get_historical_balances_from_db(
account_id=account_id, days=days or 365
db_path = path_manager.get_database_path()
historical_balances = analytics_proc.calculate_historical_balances(
db_path, account_id=account_id, days=days or 365
)
return historical_balances
@@ -231,6 +248,7 @@ async def get_historical_balances(
@router.get("/accounts/{account_id}/transactions")
async def get_account_transactions(
account_id: str,
transaction_repo: TransactionRepo,
limit: Optional[int] = Query(default=100, le=500),
offset: Optional[int] = Query(default=0, ge=0),
summary_only: bool = Query(
@@ -240,10 +258,10 @@ async def get_account_transactions(
"""Get transactions for a specific account from database"""
try:
# Get transactions from database instead of GoCardless API
db_transactions = await database_service.get_transactions_from_db(
db_transactions = transaction_repo.get_transactions(
account_id=account_id,
limit=limit,
offset=offset,
offset=offset or 0,
)
data: Union[List[TransactionSummary], List[Transaction]]
@@ -294,11 +312,15 @@ async def get_account_transactions(
@router.put("/accounts/{account_id}")
async def update_account_details(account_id: str, update_data: AccountUpdate) -> dict:
async def update_account_details(
account_id: str,
update_data: AccountUpdate,
account_repo: AccountRepo,
) -> dict:
"""Update account details (currently only display_name)"""
try:
# Get current account details
current_account = await database_service.get_account_details_from_db(account_id)
current_account = account_repo.get_account(account_id)
if not current_account:
raise HTTPException(
@@ -311,7 +333,7 @@ async def update_account_details(account_id: str, update_data: AccountUpdate) ->
updated_account_data["display_name"] = update_data.display_name
# Persist updated account details
await database_service.persist_account_details(updated_account_data)
account_repo.persist(updated_account_data)
return {"id": account_id, "display_name": update_data.display_name}

View File

@@ -198,9 +198,10 @@ async def stop_scheduler() -> dict:
async def get_sync_operations(limit: int = 50, offset: int = 0) -> dict:
"""Get sync operations history"""
try:
operations = await sync_service.database.get_sync_operations(
limit=limit, offset=offset
)
from leggen.repositories import SyncRepository
sync_repo = SyncRepository()
operations = sync_repo.get_operations(limit=limit, offset=offset)
return {"operations": operations, "count": len(operations)}

View File

@@ -4,16 +4,16 @@ from typing import List, Optional, Union
from fastapi import APIRouter, HTTPException, Query
from loguru import logger
from leggen.api.dependencies import AnalyticsProc, TransactionRepo
from leggen.api.models.accounts import Transaction, TransactionSummary
from leggen.api.models.common import PaginatedResponse
from leggen.services.database_service import DatabaseService
router = APIRouter()
database_service = DatabaseService()
@router.get("/transactions")
async def get_all_transactions(
transaction_repo: TransactionRepo,
page: int = Query(default=1, ge=1, description="Page number (1-based)"),
per_page: int = Query(default=50, le=500, description="Items per page"),
summary_only: bool = Query(
@@ -43,7 +43,7 @@ async def get_all_transactions(
limit = per_page
# Get transactions from database instead of GoCardless API
db_transactions = await database_service.get_transactions_from_db(
db_transactions = transaction_repo.get_transactions(
account_id=account_id,
limit=limit,
offset=offset,
@@ -55,7 +55,7 @@ async def get_all_transactions(
)
# Get total count for pagination info (respecting the same filters)
total_transactions = await database_service.get_transaction_count_from_db(
total_transactions = transaction_repo.get_count(
account_id=account_id,
date_from=date_from,
date_to=date_to,
@@ -119,6 +119,7 @@ async def get_all_transactions(
@router.get("/transactions/stats")
async def get_transaction_stats(
transaction_repo: TransactionRepo,
days: int = Query(default=30, description="Number of days to include in stats"),
account_id: Optional[str] = Query(default=None, description="Filter by account ID"),
) -> dict:
@@ -133,7 +134,7 @@ async def get_transaction_stats(
date_to = end_date.isoformat()
# Get transactions from database
recent_transactions = await database_service.get_transactions_from_db(
recent_transactions = transaction_repo.get_transactions(
account_id=account_id,
date_from=date_from,
date_to=date_to,
@@ -198,6 +199,7 @@ async def get_transaction_stats(
@router.get("/transactions/analytics")
async def get_transactions_for_analytics(
transaction_repo: TransactionRepo,
days: int = Query(default=365, description="Number of days to include"),
account_id: Optional[str] = Query(default=None, description="Filter by account ID"),
) -> List[dict]:
@@ -212,7 +214,7 @@ async def get_transactions_for_analytics(
date_to = end_date.isoformat()
# Get ALL transactions from database (no limit for analytics)
transactions = await database_service.get_transactions_from_db(
transactions = transaction_repo.get_transactions(
account_id=account_id,
date_from=date_from,
date_to=date_to,
@@ -244,11 +246,14 @@ async def get_transactions_for_analytics(
@router.get("/transactions/monthly-stats")
async def get_monthly_transaction_stats(
analytics_proc: AnalyticsProc,
days: int = Query(default=365, description="Number of days to include"),
account_id: Optional[str] = Query(default=None, description="Filter by account ID"),
) -> List[dict]:
"""Get monthly transaction statistics aggregated by the database"""
try:
from leggen.utils.paths import path_manager
# Date range for monthly stats
end_date = datetime.now()
start_date = end_date - timedelta(days=days)
@@ -258,10 +263,9 @@ async def get_monthly_transaction_stats(
date_to = end_date.isoformat()
# Get monthly aggregated stats from database
monthly_stats = await database_service.get_monthly_transaction_stats_from_db(
account_id=account_id,
date_from=date_from,
date_to=date_to,
db_path = path_manager.get_database_path()
monthly_stats = analytics_proc.calculate_monthly_stats(
db_path, account_id=account_id, date_from=date_from, date_to=date_to
)
return monthly_stats

View File

@@ -28,10 +28,10 @@ async def lifespan(app: FastAPI):
# Run database migrations
try:
from leggen.services.database_service import DatabaseService
from leggen.api.dependencies import get_migration_repository
db_service = DatabaseService()
await db_service.run_migrations_if_needed()
migrations = get_migration_repository()
await migrations.run_all_migrations()
logger.info("Database migrations completed")
except Exception as e:
logger.error(f"Database migration failed: {e}")

View File

@@ -1,3 +1,11 @@
"""
DEPRECATED: DatabaseService is deprecated in favor of direct repository usage via dependency injection.
This module is kept for backward compatibility with existing tests.
New code should use repositories directly via leggen.api.dependencies.
"""
import warnings
from functools import wraps
from typing import Any, Dict, List, Optional
@@ -38,9 +46,19 @@ def require_sqlite(func):
class DatabaseService:
"""Simplified database service using repository pattern"""
"""
DEPRECATED: Use repositories directly via dependency injection.
This class is maintained for backward compatibility with existing tests.
For new code, inject repositories using leggen.api.dependencies.
"""
def __init__(self):
warnings.warn(
"DatabaseService is deprecated. Use repositories via dependency injection.",
DeprecationWarning,
stacklevel=2,
)
self.db_config = config.database_config
self.sqlite_enabled = self.db_config.get("sqlite", True)

View File

@@ -4,12 +4,17 @@ from typing import List
from loguru import logger
from leggen.api.models.sync import SyncResult, SyncStatus
from leggen.repositories import (
AccountRepository,
BalanceRepository,
SyncRepository,
TransactionRepository,
)
from leggen.services.data_processors import (
AccountEnricher,
BalanceTransformer,
TransactionProcessor,
)
from leggen.services.database_service import DatabaseService
from leggen.services.gocardless_service import GoCardlessService
from leggen.services.notification_service import NotificationService
@@ -20,9 +25,14 @@ EXPIRED_DAYS_LEFT = 0
class SyncService:
def __init__(self):
self.gocardless = GoCardlessService()
self.database = DatabaseService()
self.notifications = NotificationService()
# Repositories
self.accounts = AccountRepository()
self.balances = BalanceRepository()
self.transactions = TransactionRepository()
self.sync = SyncRepository()
# Data processors
self.account_enricher = AccountEnricher()
self.balance_transformer = BalanceTransformer()
@@ -104,21 +114,22 @@ class SyncService:
)
# Persist enriched account details to database
await self.database.persist_account_details(
enriched_account_details
)
self.accounts.persist(enriched_account_details)
# Merge account metadata into balances for persistence
balances_with_account_info = self.balance_transformer.merge_account_metadata_into_balances(
balances, enriched_account_details
)
await self.database.persist_balance(
balance_rows = (
self.balance_transformer.transform_to_database_format(
account_id, balances_with_account_info
)
)
self.balances.persist(account_id, balance_rows)
balances_updated += len(balances.get("balances", []))
elif account_details:
# Fallback: persist account details without currency if balances failed
await self.database.persist_account_details(account_details)
self.accounts.persist(account_details)
# Get and save transactions
transactions = await self.gocardless.get_account_transactions(
@@ -130,7 +141,7 @@ class SyncService:
account_id, account_details, transactions
)
)
new_transactions = await self.database.persist_transactions(
new_transactions = self.transactions.persist(
account_id, processed_transactions
)
transactions_added += len(new_transactions)
@@ -184,9 +195,7 @@ class SyncService:
# Persist sync operation to database
try:
operation_id = await self.database.persist_sync_operation(
sync_operation
)
operation_id = self.sync.persist(sync_operation)
logger.debug(f"Saved sync operation with ID: {operation_id}")
except Exception as e:
logger.error(f"Failed to persist sync operation: {e}")
@@ -235,9 +244,7 @@ class SyncService:
)
try:
operation_id = await self.database.persist_sync_operation(
sync_operation
)
operation_id = self.sync.persist(sync_operation)
logger.debug(f"Saved failed sync operation with ID: {operation_id}")
except Exception as persist_error:
logger.error(

View File

@@ -126,6 +126,38 @@ def api_client(fastapi_app):
return TestClient(fastapi_app)
@pytest.fixture
def mock_account_repo():
"""Create mock AccountRepository for testing."""
from unittest.mock import MagicMock
return MagicMock()
@pytest.fixture
def mock_balance_repo():
"""Create mock BalanceRepository for testing."""
from unittest.mock import MagicMock
return MagicMock()
@pytest.fixture
def mock_transaction_repo():
"""Create mock TransactionRepository for testing."""
from unittest.mock import MagicMock
return MagicMock()
@pytest.fixture
def mock_analytics_proc():
"""Create mock AnalyticsProcessor for testing."""
from unittest.mock import MagicMock
return MagicMock()
@pytest.fixture
def mock_db_path(temp_db_path):
"""Mock the database path to use temporary database for testing."""

View File

@@ -1,13 +1,13 @@
"""Tests for analytics fixes to ensure all transactions are used in statistics."""
from datetime import datetime, timedelta
from unittest.mock import AsyncMock, Mock, patch
from unittest.mock import Mock
import pytest
from fastapi.testclient import TestClient
from leggen.api.dependencies import get_transaction_repository
from leggen.commands.server import create_app
from leggen.services.database_service import DatabaseService
class TestAnalyticsFix:
@@ -19,11 +19,11 @@ class TestAnalyticsFix:
return TestClient(app)
@pytest.fixture
def mock_database_service(self):
return Mock(spec=DatabaseService)
def mock_transaction_repo(self):
return Mock()
@pytest.mark.asyncio
async def test_transaction_stats_uses_all_transactions(self, mock_database_service):
async def test_transaction_stats_uses_all_transactions(self, mock_transaction_repo):
"""Test that transaction stats endpoint uses all transactions (not limited to 100)"""
# Mock data for 600 transactions (simulating the issue)
mock_transactions = []
@@ -42,15 +42,12 @@ class TestAnalyticsFix:
}
)
mock_database_service.get_transactions_from_db = AsyncMock(
return_value=mock_transactions
)
mock_transaction_repo.get_transactions.return_value = mock_transactions
# Test that the endpoint calls get_transactions_from_db with limit=None
with patch(
"leggen.api.routes.transactions.database_service", mock_database_service
):
app = create_app()
app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
client = TestClient(app)
response = client.get("/api/v1/transactions/stats?days=365")
@@ -59,8 +56,8 @@ class TestAnalyticsFix:
data = response.json()
# Verify that limit=None was passed to get all transactions
mock_database_service.get_transactions_from_db.assert_called_once()
call_args = mock_database_service.get_transactions_from_db.call_args
mock_transaction_repo.get_transactions.assert_called_once()
call_args = mock_transaction_repo.get_transactions.call_args
assert call_args.kwargs.get("limit") is None, (
"Stats endpoint should pass limit=None to get all transactions"
)
@@ -88,7 +85,7 @@ class TestAnalyticsFix:
@pytest.mark.asyncio
async def test_analytics_endpoint_returns_all_transactions(
self, mock_database_service
self, mock_transaction_repo
):
"""Test that the new analytics endpoint returns all transactions without pagination"""
# Mock data for 600 transactions
@@ -108,14 +105,12 @@ class TestAnalyticsFix:
}
)
mock_database_service.get_transactions_from_db = AsyncMock(
return_value=mock_transactions
)
mock_transaction_repo.get_transactions.return_value = mock_transactions
with patch(
"leggen.api.routes.transactions.database_service", mock_database_service
):
app = create_app()
app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
client = TestClient(app)
response = client.get("/api/v1/transactions/analytics?days=365")
@@ -124,8 +119,8 @@ class TestAnalyticsFix:
data = response.json()
# Verify that limit=None was passed to get all transactions
mock_database_service.get_transactions_from_db.assert_called_once()
call_args = mock_database_service.get_transactions_from_db.call_args
mock_transaction_repo.get_transactions.assert_called_once()
call_args = mock_transaction_repo.get_transactions.call_args
assert call_args.kwargs.get("limit") is None, (
"Analytics endpoint should pass limit=None"
)

View File

@@ -4,6 +4,12 @@ from unittest.mock import patch
import pytest
from leggen.api.dependencies import (
get_account_repository,
get_balance_repository,
get_transaction_repository,
)
@pytest.mark.api
class TestAccountsAPI:
@@ -11,11 +17,14 @@ class TestAccountsAPI:
def test_get_all_accounts_success(
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
sample_account_data,
mock_db_path,
mock_account_repo,
mock_balance_repo,
):
"""Test successful retrieval of all accounts from database."""
mock_accounts = [
@@ -45,19 +54,21 @@ class TestAccountsAPI:
}
]
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.accounts.database_service.get_accounts_from_db",
return_value=mock_accounts,
),
patch(
"leggen.api.routes.accounts.database_service.get_balances_from_db",
return_value=mock_balances,
),
):
mock_account_repo.get_accounts.return_value = mock_accounts
mock_balance_repo.get_balances.return_value = mock_balances
fastapi_app.dependency_overrides[get_account_repository] = (
lambda: mock_account_repo
)
fastapi_app.dependency_overrides[get_balance_repository] = (
lambda: mock_balance_repo
)
with patch("leggen.utils.config.config", mock_config):
response = api_client.get("/api/v1/accounts")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200
data = response.json()
assert len(data) == 1
@@ -69,11 +80,14 @@ class TestAccountsAPI:
def test_get_account_details_success(
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
sample_account_data,
mock_db_path,
mock_account_repo,
mock_balance_repo,
):
"""Test successful retrieval of specific account details from database."""
mock_account = {
@@ -101,19 +115,21 @@ class TestAccountsAPI:
}
]
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.accounts.database_service.get_account_details_from_db",
return_value=mock_account,
),
patch(
"leggen.api.routes.accounts.database_service.get_balances_from_db",
return_value=mock_balances,
),
):
mock_account_repo.get_account.return_value = mock_account
mock_balance_repo.get_balances.return_value = mock_balances
fastapi_app.dependency_overrides[get_account_repository] = (
lambda: mock_account_repo
)
fastapi_app.dependency_overrides[get_balance_repository] = (
lambda: mock_balance_repo
)
with patch("leggen.utils.config.config", mock_config):
response = api_client.get("/api/v1/accounts/test-account-123")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200
data = response.json()
assert data["id"] == "test-account-123"
@@ -121,7 +137,13 @@ class TestAccountsAPI:
assert len(data["balances"]) == 1
def test_get_account_balances_success(
self, api_client, mock_config, mock_auth_token, mock_db_path
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_db_path,
mock_balance_repo,
):
"""Test successful retrieval of account balances from database."""
mock_balances = [
@@ -149,15 +171,17 @@ class TestAccountsAPI:
},
]
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.accounts.database_service.get_balances_from_db",
return_value=mock_balances,
),
):
mock_balance_repo.get_balances.return_value = mock_balances
fastapi_app.dependency_overrides[get_balance_repository] = (
lambda: mock_balance_repo
)
with patch("leggen.utils.config.config", mock_config):
response = api_client.get("/api/v1/accounts/test-account-123/balances")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200
data = response.json()
assert len(data) == 2
@@ -167,12 +191,14 @@ class TestAccountsAPI:
def test_get_account_transactions_success(
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
sample_account_data,
sample_transaction_data,
mock_db_path,
mock_transaction_repo,
):
"""Test successful retrieval of account transactions from database."""
mock_transactions = [
@@ -191,21 +217,19 @@ class TestAccountsAPI:
}
]
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.accounts.database_service.get_transactions_from_db",
return_value=mock_transactions,
),
patch(
"leggen.api.routes.accounts.database_service.get_transaction_count_from_db",
return_value=1,
),
):
mock_transaction_repo.get_transactions.return_value = mock_transactions
fastapi_app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
with patch("leggen.utils.config.config", mock_config):
response = api_client.get(
"/api/v1/accounts/test-account-123/transactions?summary_only=true"
)
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200
data = response.json()
assert len(data) == 1
@@ -218,12 +242,14 @@ class TestAccountsAPI:
def test_get_account_transactions_full_details(
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
sample_account_data,
sample_transaction_data,
mock_db_path,
mock_transaction_repo,
):
"""Test retrieval of full transaction details from database."""
mock_transactions = [
@@ -242,21 +268,19 @@ class TestAccountsAPI:
}
]
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.accounts.database_service.get_transactions_from_db",
return_value=mock_transactions,
),
patch(
"leggen.api.routes.accounts.database_service.get_transaction_count_from_db",
return_value=1,
),
):
mock_transaction_repo.get_transactions.return_value = mock_transactions
fastapi_app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
with patch("leggen.utils.config.config", mock_config):
response = api_client.get(
"/api/v1/accounts/test-account-123/transactions?summary_only=false"
)
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200
data = response.json()
assert len(data) == 1
@@ -268,22 +292,36 @@ class TestAccountsAPI:
assert "raw_transaction" in transaction
def test_get_account_not_found(
self, api_client, mock_config, mock_auth_token, mock_db_path
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_db_path,
mock_account_repo,
):
"""Test handling of non-existent account."""
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.accounts.database_service.get_account_details_from_db",
return_value=None,
),
):
mock_account_repo.get_account.return_value = None
fastapi_app.dependency_overrides[get_account_repository] = (
lambda: mock_account_repo
)
with patch("leggen.utils.config.config", mock_config):
response = api_client.get("/api/v1/accounts/nonexistent")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 404
def test_update_account_display_name_success(
self, api_client, mock_config, mock_auth_token, mock_db_path
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_db_path,
mock_account_repo,
):
"""Test successful update of account display name."""
mock_account = {
@@ -297,41 +335,48 @@ class TestAccountsAPI:
"last_accessed": "2025-09-01T09:30:00Z",
}
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.accounts.database_service.get_account_details_from_db",
return_value=mock_account,
),
patch(
"leggen.api.routes.accounts.database_service.persist_account_details",
return_value=None,
),
):
mock_account_repo.get_account.return_value = mock_account
mock_account_repo.persist.return_value = mock_account
fastapi_app.dependency_overrides[get_account_repository] = (
lambda: mock_account_repo
)
with patch("leggen.utils.config.config", mock_config):
response = api_client.put(
"/api/v1/accounts/test-account-123",
json={"display_name": "My Custom Account Name"},
)
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200
data = response.json()
assert data["id"] == "test-account-123"
assert data["display_name"] == "My Custom Account Name"
def test_update_account_not_found(
self, api_client, mock_config, mock_auth_token, mock_db_path
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_db_path,
mock_account_repo,
):
"""Test updating non-existent account."""
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.accounts.database_service.get_account_details_from_db",
return_value=None,
),
):
mock_account_repo.get_account.return_value = None
fastapi_app.dependency_overrides[get_account_repository] = (
lambda: mock_account_repo
)
with patch("leggen.utils.config.config", mock_config):
response = api_client.put(
"/api/v1/accounts/nonexistent",
json={"display_name": "New Name"},
)
fastapi_app.dependency_overrides.clear()
assert response.status_code == 404

View File

@@ -5,13 +5,20 @@ from unittest.mock import patch
import pytest
from leggen.api.dependencies import get_transaction_repository
@pytest.mark.api
class TestTransactionsAPI:
"""Test transaction-related API endpoints."""
def test_get_all_transactions_success(
self, api_client, mock_config, mock_auth_token
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
):
"""Test successful retrieval of all transactions from database."""
mock_transactions = [
@@ -43,19 +50,17 @@ class TestTransactionsAPI:
},
]
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.transactions.database_service.get_transactions_from_db",
return_value=mock_transactions,
),
patch(
"leggen.api.routes.transactions.database_service.get_transaction_count_from_db",
return_value=2,
),
):
mock_transaction_repo.get_transactions.return_value = mock_transactions
mock_transaction_repo.get_count.return_value = len(mock_transactions)
fastapi_app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
with patch("leggen.utils.config.config", mock_config):
response = api_client.get("/api/v1/transactions?summary_only=true")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200
data = response.json()
assert len(data["data"]) == 2
@@ -70,7 +75,12 @@ class TestTransactionsAPI:
assert transaction["account_id"] == "test-account-123"
def test_get_all_transactions_full_details(
self, api_client, mock_config, mock_auth_token
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
):
"""Test retrieval of full transaction details from database."""
mock_transactions = [
@@ -89,19 +99,17 @@ class TestTransactionsAPI:
}
]
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.transactions.database_service.get_transactions_from_db",
return_value=mock_transactions,
),
patch(
"leggen.api.routes.transactions.database_service.get_transaction_count_from_db",
return_value=1,
),
):
mock_transaction_repo.get_transactions.return_value = mock_transactions
mock_transaction_repo.get_count.return_value = len(mock_transactions)
fastapi_app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
with patch("leggen.utils.config.config", mock_config):
response = api_client.get("/api/v1/transactions?summary_only=false")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200
data = response.json()
assert len(data["data"]) == 1
@@ -114,7 +122,12 @@ class TestTransactionsAPI:
assert "raw_transaction" in transaction
def test_get_transactions_with_filters(
self, api_client, mock_config, mock_auth_token
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
):
"""Test getting transactions with various filters."""
mock_transactions = [
@@ -133,17 +146,14 @@ class TestTransactionsAPI:
}
]
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.transactions.database_service.get_transactions_from_db",
return_value=mock_transactions,
) as mock_get_transactions,
patch(
"leggen.api.routes.transactions.database_service.get_transaction_count_from_db",
return_value=1,
),
):
mock_transaction_repo.get_transactions.return_value = mock_transactions
mock_transaction_repo.get_count.return_value = 1
fastapi_app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
with patch("leggen.utils.config.config", mock_config):
response = api_client.get(
"/api/v1/transactions?"
"account_id=test-account-123&"
@@ -156,10 +166,12 @@ class TestTransactionsAPI:
"per_page=10"
)
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200
# Verify the database service was called with correct filters
mock_get_transactions.assert_called_once_with(
# Verify the repository was called with correct filters
mock_transaction_repo.get_transactions.assert_called_once_with(
account_id="test-account-123",
limit=10,
offset=10, # (page-1) * per_page = (2-1) * 10 = 10
@@ -171,22 +183,26 @@ class TestTransactionsAPI:
)
def test_get_transactions_empty_result(
self, api_client, mock_config, mock_auth_token
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
):
"""Test getting transactions when database returns empty result."""
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.transactions.database_service.get_transactions_from_db",
return_value=[],
),
patch(
"leggen.api.routes.transactions.database_service.get_transaction_count_from_db",
return_value=0,
),
):
mock_transaction_repo.get_transactions.return_value = []
mock_transaction_repo.get_count.return_value = 0
fastapi_app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
with patch("leggen.utils.config.config", mock_config):
response = api_client.get("/api/v1/transactions")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200
data = response.json()
assert len(data["data"]) == 0
@@ -195,23 +211,37 @@ class TestTransactionsAPI:
assert data["total_pages"] == 0
def test_get_transactions_database_error(
self, api_client, mock_config, mock_auth_token
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
):
"""Test handling database error when getting transactions."""
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.transactions.database_service.get_transactions_from_db",
side_effect=Exception("Database connection failed"),
),
):
mock_transaction_repo.get_transactions.side_effect = Exception(
"Database connection failed"
)
fastapi_app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
with patch("leggen.utils.config.config", mock_config):
response = api_client.get("/api/v1/transactions")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 500
assert "Failed to get transactions" in response.json()["detail"]
def test_get_transaction_stats_success(
self, api_client, mock_config, mock_auth_token
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
):
"""Test successful retrieval of transaction statistics from database."""
mock_transactions = [
@@ -238,15 +268,16 @@ class TestTransactionsAPI:
},
]
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.transactions.database_service.get_transactions_from_db",
return_value=mock_transactions,
),
):
mock_transaction_repo.get_transactions.return_value = mock_transactions
fastapi_app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
with patch("leggen.utils.config.config", mock_config):
response = api_client.get("/api/v1/transactions/stats?days=30")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200
data = response.json()
@@ -264,7 +295,12 @@ class TestTransactionsAPI:
assert data["average_transaction"] == expected_avg
def test_get_transaction_stats_with_account_filter(
self, api_client, mock_config, mock_auth_token
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
):
"""Test getting transaction stats filtered by account."""
mock_transactions = [
@@ -277,37 +313,46 @@ class TestTransactionsAPI:
}
]
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.transactions.database_service.get_transactions_from_db",
return_value=mock_transactions,
) as mock_get_transactions,
):
mock_transaction_repo.get_transactions.return_value = mock_transactions
fastapi_app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
with patch("leggen.utils.config.config", mock_config):
response = api_client.get(
"/api/v1/transactions/stats?account_id=test-account-123"
)
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200
# Verify the database service was called with account filter
mock_get_transactions.assert_called_once()
call_kwargs = mock_get_transactions.call_args.kwargs
# Verify the repository was called with account filter
mock_transaction_repo.get_transactions.assert_called_once()
call_kwargs = mock_transaction_repo.get_transactions.call_args.kwargs
assert call_kwargs["account_id"] == "test-account-123"
def test_get_transaction_stats_empty_result(
self, api_client, mock_config, mock_auth_token
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
):
"""Test getting stats when no transactions match criteria."""
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.transactions.database_service.get_transactions_from_db",
return_value=[],
),
):
mock_transaction_repo.get_transactions.return_value = []
fastapi_app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
with patch("leggen.utils.config.config", mock_config):
response = api_client.get("/api/v1/transactions/stats")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200
data = response.json()
@@ -319,23 +364,37 @@ class TestTransactionsAPI:
assert data["accounts_included"] == 0
def test_get_transaction_stats_database_error(
self, api_client, mock_config, mock_auth_token
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
):
"""Test handling database error when getting stats."""
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.transactions.database_service.get_transactions_from_db",
side_effect=Exception("Database connection failed"),
),
):
mock_transaction_repo.get_transactions.side_effect = Exception(
"Database connection failed"
)
fastapi_app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
with patch("leggen.utils.config.config", mock_config):
response = api_client.get("/api/v1/transactions/stats")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 500
assert "Failed to get transaction stats" in response.json()["detail"]
def test_get_transaction_stats_custom_period(
self, api_client, mock_config, mock_auth_token
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
):
"""Test getting transaction stats for custom time period."""
mock_transactions = [
@@ -348,21 +407,23 @@ class TestTransactionsAPI:
}
]
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.transactions.database_service.get_transactions_from_db",
return_value=mock_transactions,
) as mock_get_transactions,
):
mock_transaction_repo.get_transactions.return_value = mock_transactions
fastapi_app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
with patch("leggen.utils.config.config", mock_config):
response = api_client.get("/api/v1/transactions/stats?days=7")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200
data = response.json()
assert data["period_days"] == 7
# Verify the date range was calculated correctly for 7 days
mock_get_transactions.assert_called_once()
call_kwargs = mock_get_transactions.call_args.kwargs
mock_transaction_repo.get_transactions.assert_called_once()
call_kwargs = mock_transaction_repo.get_transactions.call_args.kwargs
assert "date_from" in call_kwargs
assert "date_to" in call_kwargs

View File

@@ -27,9 +27,7 @@ class TestSyncNotifications:
patch.object(
sync_service.notifications, "send_sync_failure_notification"
) as mock_send_notification,
patch.object(
sync_service.database, "persist_sync_operation", return_value=1
),
patch.object(sync_service.sync, "persist", return_value=1),
):
# Setup: One requisition with one account that will fail
mock_get_requisitions.return_value = {
@@ -69,9 +67,7 @@ class TestSyncNotifications:
patch.object(
sync_service.notifications, "send_expiry_notification"
) as mock_send_expiry,
patch.object(
sync_service.database, "persist_sync_operation", return_value=1
),
patch.object(sync_service.sync, "persist", return_value=1),
):
# Setup: One expired requisition
mock_get_requisitions.return_value = {
@@ -112,9 +108,7 @@ class TestSyncNotifications:
patch.object(
sync_service.notifications, "send_sync_failure_notification"
) as mock_send_notification,
patch.object(
sync_service.database, "persist_sync_operation", return_value=1
),
patch.object(sync_service.sync, "persist", return_value=1),
):
# Setup: One requisition with two accounts that will fail
mock_get_requisitions.return_value = {
@@ -160,17 +154,15 @@ class TestSyncNotifications:
sync_service.notifications, "send_sync_failure_notification"
) as mock_send_notification,
patch.object(sync_service.notifications, "send_transaction_notifications"),
patch.object(sync_service.database, "persist_account_details"),
patch.object(sync_service.database, "persist_balance"),
patch.object(sync_service.accounts, "persist"),
patch.object(sync_service.balances, "persist"),
patch.object(
sync_service.database, "process_transactions", return_value=[]
),
patch.object(
sync_service.database, "persist_transactions", return_value=[]
),
patch.object(
sync_service.database, "persist_sync_operation", return_value=1
sync_service.transaction_processor,
"process_transactions",
return_value=[],
),
patch.object(sync_service.transactions, "persist", return_value=[]),
patch.object(sync_service.sync, "persist", return_value=1),
):
# Setup: One requisition with one account that succeeds
mock_get_requisitions.return_value = {
@@ -222,9 +214,7 @@ class TestSyncNotifications:
patch.object(
sync_service.notifications, "_send_telegram_sync_failure"
) as mock_telegram_notification,
patch.object(
sync_service.database, "persist_sync_operation", return_value=1
),
patch.object(sync_service.sync, "persist", return_value=1),
):
# Setup: One requisition with one account that will fail
mock_get_requisitions.return_value = {