Compare commits

..

3 Commits

Author SHA1 Message Date
copilot-swe-agent[bot]
6368b5c62c refactor(frontend): Address code review feedback on focus and currency handling.
Co-authored-by: elisiariocouto <818914+elisiariocouto@users.noreply.github.com>
2025-12-07 19:45:50 +00:00
copilot-swe-agent[bot]
300b4e7db7 feat(frontend): Fix search focus issue and add transaction statistics.
Co-authored-by: elisiariocouto <818914+elisiariocouto@users.noreply.github.com>
2025-12-07 19:45:49 +00:00
copilot-swe-agent[bot]
19814121de Initial plan 2025-12-07 19:43:56 +00:00
31 changed files with 2020 additions and 2672 deletions

View File

@@ -6,9 +6,6 @@ on:
pull_request:
branches: ["main", "dev"]
permissions:
contents: read
jobs:
test-python:
name: Test Python

View File

@@ -5,11 +5,6 @@ on:
tags:
- "**"
permissions:
contents: write
packages: write
id-token: write
jobs:
build:
runs-on: ubuntu-latest
@@ -49,9 +44,6 @@ jobs:
push-docker-backend:
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -98,9 +90,6 @@ jobs:
push-docker-frontend:
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -148,8 +137,6 @@ jobs:
create-github-release:
name: Create GitHub Release
runs-on: ubuntu-latest
permissions:
contents: write
needs: [build, publish-to-pypi, push-docker-backend, push-docker-frontend]
steps:
- name: Checkout

View File

@@ -1,127 +0,0 @@
# Backend Refactoring Summary
## What Was Accomplished ✅
### 1. Removed DatabaseService Layer from Production Code
- **Removed**: The `DatabaseService` class is no longer used in production API routes
- **Replaced with**: Direct repository usage via FastAPI dependency injection
- **Files changed**:
- `leggen/api/routes/accounts.py` - Now uses `AccountRepo`, `BalanceRepo`, `TransactionRepo`, `AnalyticsProc`
- `leggen/api/routes/transactions.py` - Now uses `TransactionRepo`, `AnalyticsProc`
- `leggen/services/sync_service.py` - Now uses repositories directly
- `leggen/commands/server.py` - Now uses `MigrationRepository` directly
### 2. Created Dependency Injection System
- **New file**: `leggen/api/dependencies.py`
- **Provides**: Centralized dependency injection setup for FastAPI
- **Includes**: Factory functions for all repositories and data processors
- **Type annotations**: `AccountRepo`, `BalanceRepo`, `TransactionRepo`, etc.
### 3. Simplified Code Architecture
- **Before**: Routes → DatabaseService → Repositories
- **After**: Routes → Repositories (via DI)
- **Benefits**:
- One less layer of indirection
- Clearer dependencies
- Easier to test with FastAPI's `app.dependency_overrides`
- Better separation of concerns
### 4. Maintained Backward Compatibility
- **DatabaseService** is kept but deprecated for test compatibility
- Added deprecation warning when instantiated
- Tests continue to work without immediate changes required
## Code Statistics
- **Lines removed from API layer**: ~16 imports of DatabaseService
- **New dependency injection file**: 80 lines
- **Files refactored**: 4 main files
## Benefits Achieved
1. **Cleaner Architecture**: Removed unnecessary abstraction layer
2. **Better Testability**: FastAPI dependency overrides are cleaner than mocking
3. **More Explicit Dependencies**: Function signatures show exactly what's needed
4. **Easier to Maintain**: Less indirection makes code easier to follow
5. **Performance**: Slightly fewer object instantiations per request
## What Still Needs Work
### Tests Need Updating
The test files still patch `database_service` which no longer exists in routes:
```python
# Old test pattern (needs updating):
patch("leggen.api.routes.accounts.database_service.get_accounts_from_db")
# New pattern (should use):
app.dependency_overrides[get_account_repository] = lambda: mock_repo
```
**Files needing test updates**:
- `tests/unit/test_api_accounts.py` (7 tests failing)
- `tests/unit/test_api_transactions.py` (10 tests failing)
- `tests/unit/test_analytics_fix.py` (2 tests failing)
### Test Update Strategy
**Option 1 - Quick Fix (Recommended for now)**:
Keep `DatabaseService` and have routes import it again temporarily, update tests at leisure.
**Option 2 - Proper Fix**:
Update all tests to use FastAPI dependency overrides pattern:
```python
def test_get_accounts(fastapi_app, api_client, mock_account_repo):
mock_account_repo.get_accounts.return_value = [...]
fastapi_app.dependency_overrides[get_account_repository] = lambda: mock_account_repo
response = api_client.get("/api/v1/accounts")
fastapi_app.dependency_overrides.clear()
```
## Migration Path Forward
1.**Phase 1**: Refactor production code (DONE)
2. 🔄 **Phase 2**: Update tests to use dependency overrides (TODO)
3. 🔄 **Phase 3**: Remove deprecated `DatabaseService` completely (TODO)
4. 🔄 **Phase 4**: Consider extracting analytics logic to separate service (TODO)
## How to Use the New System
### In API Routes
```python
from leggen.api.dependencies import AccountRepo, BalanceRepo
@router.get("/accounts")
async def get_accounts(
account_repo: AccountRepo, # Injected automatically
balance_repo: BalanceRepo, # Injected automatically
) -> List[AccountDetails]:
accounts = account_repo.get_accounts()
# ...
```
### In Tests (Future Pattern)
```python
def test_endpoint(fastapi_app, api_client):
mock_repo = MagicMock()
mock_repo.get_accounts.return_value = [...]
fastapi_app.dependency_overrides[get_account_repository] = lambda: mock_repo
response = api_client.get("/api/v1/accounts")
# assertions...
```
## Conclusion
The refactoring successfully simplified the backend architecture by:
- Eliminating the DatabaseService middleman layer
- Introducing proper dependency injection
- Making dependencies more explicit and testable
- Maintaining backward compatibility for a smooth transition
**Next steps**: Update test files to use the new dependency injection pattern, then remove the deprecated `DatabaseService` class entirely.

View File

@@ -123,13 +123,9 @@ export default function TransactionsTable() {
search: debouncedSearchTerm || undefined,
summaryOnly: false,
}),
placeholderData: (previousData) => previousData,
});
const transactions = useMemo(
() => transactionsResponse?.data || [],
[transactionsResponse],
);
const transactions = transactionsResponse?.data || [];
const pagination = useMemo(
() =>
transactionsResponse
@@ -145,31 +141,6 @@ export default function TransactionsTable() {
[transactionsResponse],
);
// Calculate stats from current page transactions, memoized for performance
const stats = useMemo(() => {
const totalIncome = transactions
.filter((t: Transaction) => t.transaction_value > 0)
.reduce((sum: number, t: Transaction) => sum + t.transaction_value, 0);
const totalExpenses = Math.abs(
transactions
.filter((t: Transaction) => t.transaction_value < 0)
.reduce((sum: number, t: Transaction) => sum + t.transaction_value, 0)
);
// Get currency from first transaction, fallback to EUR
const displayCurrency = transactions.length > 0 ? transactions[0].transaction_currency : "EUR";
return {
totalCount: pagination?.total || 0,
pageCount: transactions.length,
totalIncome,
totalExpenses,
netChange: totalIncome - totalExpenses,
displayCurrency,
};
}, [transactions, pagination]);
// Check if search is currently debouncing
const isSearchLoading = filterState.searchTerm !== debouncedSearchTerm;
@@ -384,6 +355,28 @@ export default function TransactionsTable() {
);
}
// Calculate stats from current page transactions
const totalIncome = transactions
.filter((t: Transaction) => t.transaction_value > 0)
.reduce((sum: number, t: Transaction) => sum + t.transaction_value, 0);
const totalExpenses = Math.abs(
transactions
.filter((t: Transaction) => t.transaction_value < 0)
.reduce((sum: number, t: Transaction) => sum + t.transaction_value, 0)
);
// Get currency from first transaction, fallback to EUR
const displayCurrency = transactions.length > 0 ? transactions[0].transaction_currency : "EUR";
const stats = {
totalCount: pagination?.total || 0,
pageCount: transactions.length,
totalIncome,
totalExpenses,
netChange: totalIncome - totalExpenses,
};
return (
<div className="space-y-6 max-w-full">
{/* New FilterBar */}
@@ -420,9 +413,9 @@ export default function TransactionsTable() {
<p className="text-xs text-muted-foreground uppercase tracking-wider">
Income
</p>
<BlurredValue className="text-2xl font-bold text-green-600 mt-1 block">
+{formatCurrency(stats.totalIncome, stats.displayCurrency)}
</BlurredValue>
<p className="text-2xl font-bold text-green-600 mt-1">
+{formatCurrency(stats.totalIncome, displayCurrency)}
</p>
</div>
<TrendingUp className="h-8 w-8 text-green-600 opacity-50" />
</div>
@@ -434,9 +427,9 @@ export default function TransactionsTable() {
<p className="text-xs text-muted-foreground uppercase tracking-wider">
Expenses
</p>
<BlurredValue className="text-2xl font-bold text-red-600 mt-1 block">
-{formatCurrency(stats.totalExpenses, stats.displayCurrency)}
</BlurredValue>
<p className="text-2xl font-bold text-red-600 mt-1">
-{formatCurrency(stats.totalExpenses, displayCurrency)}
</p>
</div>
<TrendingDown className="h-8 w-8 text-red-600 opacity-50" />
</div>
@@ -448,14 +441,14 @@ export default function TransactionsTable() {
<p className="text-xs text-muted-foreground uppercase tracking-wider">
Net Change
</p>
<BlurredValue
className={`text-2xl font-bold mt-1 block ${
<p
className={`text-2xl font-bold mt-1 ${
stats.netChange >= 0 ? "text-green-600" : "text-red-600"
}`}
>
{stats.netChange >= 0 ? "+" : ""}
{formatCurrency(stats.netChange, stats.displayCurrency)}
</BlurredValue>
{formatCurrency(stats.netChange, displayCurrency)}
</p>
</div>
{stats.netChange >= 0 ? (
<TrendingUp className="h-8 w-8 text-green-600 opacity-50" />

View File

@@ -32,17 +32,20 @@ export function FilterBar({
className,
}: FilterBarProps) {
const searchInputRef = useRef<HTMLInputElement>(null);
const cursorPositionRef = useRef<number | null>(null);
// Maintain focus and cursor position on search input during re-renders
// Maintain focus on search input during re-renders
useEffect(() => {
const currentInput = searchInputRef.current;
if (!currentInput) return;
// Restore focus and cursor position after data fetches complete
if (cursorPositionRef.current !== null && document.activeElement !== currentInput) {
currentInput.focus();
currentInput.setSelectionRange(cursorPositionRef.current, cursorPositionRef.current);
// Only restore focus if the search input had focus before
const wasFocused = document.activeElement === currentInput;
// Use requestAnimationFrame to restore focus after React finishes rendering
if (wasFocused && document.activeElement !== currentInput) {
requestAnimationFrame(() => {
currentInput.focus();
});
}
}, [isSearchLoading]);
@@ -80,16 +83,7 @@ export function FilterBar({
ref={searchInputRef}
placeholder="Search transactions..."
value={filterState.searchTerm}
onChange={(e) => {
cursorPositionRef.current = e.target.selectionStart;
onFilterChange("searchTerm", e.target.value);
}}
onFocus={() => {
cursorPositionRef.current = searchInputRef.current?.selectionStart ?? null;
}}
onBlur={() => {
cursorPositionRef.current = null;
}}
onChange={(e) => onFilterChange("searchTerm", e.target.value)}
className="pl-9 pr-8 bg-background"
/>
{isSearchLoading && (
@@ -128,16 +122,7 @@ export function FilterBar({
ref={searchInputRef}
placeholder="Search..."
value={filterState.searchTerm}
onChange={(e) => {
cursorPositionRef.current = e.target.selectionStart;
onFilterChange("searchTerm", e.target.value);
}}
onFocus={() => {
cursorPositionRef.current = searchInputRef.current?.selectionStart ?? null;
}}
onBlur={() => {
cursorPositionRef.current = null;
}}
onChange={(e) => onFilterChange("searchTerm", e.target.value)}
className="pl-9 pr-8 bg-background w-full"
/>
{isSearchLoading && (

View File

@@ -1,75 +0,0 @@
"""FastAPI dependency injection setup for repositories and services."""
from typing import Annotated
from fastapi import Depends
from leggen.repositories import (
AccountRepository,
BalanceRepository,
MigrationRepository,
SyncRepository,
TransactionRepository,
)
from leggen.services.data_processors import (
AnalyticsProcessor,
BalanceTransformer,
TransactionProcessor,
)
from leggen.utils.config import config
def get_account_repository() -> AccountRepository:
"""Get account repository instance."""
return AccountRepository()
def get_balance_repository() -> BalanceRepository:
"""Get balance repository instance."""
return BalanceRepository()
def get_transaction_repository() -> TransactionRepository:
"""Get transaction repository instance."""
return TransactionRepository()
def get_sync_repository() -> SyncRepository:
"""Get sync repository instance."""
return SyncRepository()
def get_migration_repository() -> MigrationRepository:
"""Get migration repository instance."""
return MigrationRepository()
def get_transaction_processor() -> TransactionProcessor:
"""Get transaction processor instance."""
return TransactionProcessor()
def get_balance_transformer() -> BalanceTransformer:
"""Get balance transformer instance."""
return BalanceTransformer()
def get_analytics_processor() -> AnalyticsProcessor:
"""Get analytics processor instance."""
return AnalyticsProcessor()
def is_sqlite_enabled() -> bool:
"""Check if SQLite is enabled in configuration."""
return config.database_config.get("sqlite", True)
# Type annotations for dependency injection
AccountRepo = Annotated[AccountRepository, Depends(get_account_repository)]
BalanceRepo = Annotated[BalanceRepository, Depends(get_balance_repository)]
TransactionRepo = Annotated[TransactionRepository, Depends(get_transaction_repository)]
SyncRepo = Annotated[SyncRepository, Depends(get_sync_repository)]
MigrationRepo = Annotated[MigrationRepository, Depends(get_migration_repository)]
TransactionProc = Annotated[TransactionProcessor, Depends(get_transaction_processor)]
BalanceTransform = Annotated[BalanceTransformer, Depends(get_balance_transformer)]
AnalyticsProc = Annotated[AnalyticsProcessor, Depends(get_analytics_processor)]

View File

@@ -3,12 +3,6 @@ from typing import List, Optional, Union
from fastapi import APIRouter, HTTPException, Query
from loguru import logger
from leggen.api.dependencies import (
AccountRepo,
AnalyticsProc,
BalanceRepo,
TransactionRepo,
)
from leggen.api.models.accounts import (
AccountBalance,
AccountDetails,
@@ -16,27 +10,28 @@ from leggen.api.models.accounts import (
Transaction,
TransactionSummary,
)
from leggen.services.database_service import DatabaseService
router = APIRouter()
database_service = DatabaseService()
@router.get("/accounts")
async def get_all_accounts(
account_repo: AccountRepo,
balance_repo: BalanceRepo,
) -> List[AccountDetails]:
async def get_all_accounts() -> List[AccountDetails]:
"""Get all connected accounts from database"""
try:
accounts = []
# Get all account details from database
db_accounts = account_repo.get_accounts()
db_accounts = await database_service.get_accounts_from_db()
# Process accounts found in database
for db_account in db_accounts:
try:
# Get latest balances from database for this account
balances_data = balance_repo.get_balances(db_account["id"])
balances_data = await database_service.get_balances_from_db(
db_account["id"]
)
# Process balances
balances = []
@@ -82,15 +77,11 @@ async def get_all_accounts(
@router.get("/accounts/{account_id}")
async def get_account_details(
account_id: str,
account_repo: AccountRepo,
balance_repo: BalanceRepo,
) -> AccountDetails:
async def get_account_details(account_id: str) -> AccountDetails:
"""Get details for a specific account from database"""
try:
# Get account details from database
db_account = account_repo.get_account(account_id)
db_account = await database_service.get_account_details_from_db(account_id)
if not db_account:
raise HTTPException(
@@ -98,7 +89,7 @@ async def get_account_details(
)
# Get latest balances from database for this account
balances_data = balance_repo.get_balances(account_id)
balances_data = await database_service.get_balances_from_db(account_id)
# Process balances
balances = []
@@ -138,14 +129,11 @@ async def get_account_details(
@router.get("/accounts/{account_id}/balances")
async def get_account_balances(
account_id: str,
balance_repo: BalanceRepo,
) -> List[AccountBalance]:
async def get_account_balances(account_id: str) -> List[AccountBalance]:
"""Get balances for a specific account from database"""
try:
# Get balances from database instead of GoCardless API
db_balances = balance_repo.get_balances(account_id=account_id)
db_balances = await database_service.get_balances_from_db(account_id=account_id)
balances = []
for balance in db_balances:
@@ -170,20 +158,19 @@ async def get_account_balances(
@router.get("/balances")
async def get_all_balances(
account_repo: AccountRepo,
balance_repo: BalanceRepo,
) -> List[dict]:
async def get_all_balances() -> List[dict]:
"""Get all balances from all accounts in database"""
try:
# Get all accounts first to iterate through them
db_accounts = account_repo.get_accounts()
db_accounts = await database_service.get_accounts_from_db()
all_balances = []
for db_account in db_accounts:
try:
# Get balances for this account
db_balances = balance_repo.get_balances(account_id=db_account["id"])
db_balances = await database_service.get_balances_from_db(
account_id=db_account["id"]
)
# Process balances and add account info
for balance in db_balances:
@@ -218,7 +205,6 @@ async def get_all_balances(
@router.get("/balances/history")
async def get_historical_balances(
analytics_proc: AnalyticsProc,
days: Optional[int] = Query(
default=365, le=1095, ge=1, description="Number of days of history to retrieve"
),
@@ -228,12 +214,9 @@ async def get_historical_balances(
) -> List[dict]:
"""Get historical balance progression calculated from transaction history"""
try:
from leggen.utils.paths import path_manager
# Get historical balances from database
db_path = path_manager.get_database_path()
historical_balances = analytics_proc.calculate_historical_balances(
db_path, account_id=account_id, days=days or 365
historical_balances = await database_service.get_historical_balances_from_db(
account_id=account_id, days=days or 365
)
return historical_balances
@@ -248,7 +231,6 @@ async def get_historical_balances(
@router.get("/accounts/{account_id}/transactions")
async def get_account_transactions(
account_id: str,
transaction_repo: TransactionRepo,
limit: Optional[int] = Query(default=100, le=500),
offset: Optional[int] = Query(default=0, ge=0),
summary_only: bool = Query(
@@ -258,10 +240,10 @@ async def get_account_transactions(
"""Get transactions for a specific account from database"""
try:
# Get transactions from database instead of GoCardless API
db_transactions = transaction_repo.get_transactions(
db_transactions = await database_service.get_transactions_from_db(
account_id=account_id,
limit=limit,
offset=offset or 0,
offset=offset,
)
data: Union[List[TransactionSummary], List[Transaction]]
@@ -312,15 +294,11 @@ async def get_account_transactions(
@router.put("/accounts/{account_id}")
async def update_account_details(
account_id: str,
update_data: AccountUpdate,
account_repo: AccountRepo,
) -> dict:
async def update_account_details(account_id: str, update_data: AccountUpdate) -> dict:
"""Update account details (currently only display_name)"""
try:
# Get current account details
current_account = account_repo.get_account(account_id)
current_account = await database_service.get_account_details_from_db(account_id)
if not current_account:
raise HTTPException(
@@ -333,7 +311,7 @@ async def update_account_details(
updated_account_data["display_name"] = update_data.display_name
# Persist updated account details
account_repo.persist(updated_account_data)
await database_service.persist_account_details(updated_account_data)
return {"id": account_id, "display_name": update_data.display_name}

View File

@@ -198,10 +198,9 @@ async def stop_scheduler() -> dict:
async def get_sync_operations(limit: int = 50, offset: int = 0) -> dict:
"""Get sync operations history"""
try:
from leggen.repositories import SyncRepository
sync_repo = SyncRepository()
operations = sync_repo.get_operations(limit=limit, offset=offset)
operations = await sync_service.database.get_sync_operations(
limit=limit, offset=offset
)
return {"operations": operations, "count": len(operations)}

View File

@@ -4,16 +4,16 @@ from typing import List, Optional, Union
from fastapi import APIRouter, HTTPException, Query
from loguru import logger
from leggen.api.dependencies import AnalyticsProc, TransactionRepo
from leggen.api.models.accounts import Transaction, TransactionSummary
from leggen.api.models.common import PaginatedResponse
from leggen.services.database_service import DatabaseService
router = APIRouter()
database_service = DatabaseService()
@router.get("/transactions")
async def get_all_transactions(
transaction_repo: TransactionRepo,
page: int = Query(default=1, ge=1, description="Page number (1-based)"),
per_page: int = Query(default=50, le=500, description="Items per page"),
summary_only: bool = Query(
@@ -43,7 +43,7 @@ async def get_all_transactions(
limit = per_page
# Get transactions from database instead of GoCardless API
db_transactions = transaction_repo.get_transactions(
db_transactions = await database_service.get_transactions_from_db(
account_id=account_id,
limit=limit,
offset=offset,
@@ -55,7 +55,7 @@ async def get_all_transactions(
)
# Get total count for pagination info (respecting the same filters)
total_transactions = transaction_repo.get_count(
total_transactions = await database_service.get_transaction_count_from_db(
account_id=account_id,
date_from=date_from,
date_to=date_to,
@@ -119,7 +119,6 @@ async def get_all_transactions(
@router.get("/transactions/stats")
async def get_transaction_stats(
transaction_repo: TransactionRepo,
days: int = Query(default=30, description="Number of days to include in stats"),
account_id: Optional[str] = Query(default=None, description="Filter by account ID"),
) -> dict:
@@ -134,7 +133,7 @@ async def get_transaction_stats(
date_to = end_date.isoformat()
# Get transactions from database
recent_transactions = transaction_repo.get_transactions(
recent_transactions = await database_service.get_transactions_from_db(
account_id=account_id,
date_from=date_from,
date_to=date_to,
@@ -199,7 +198,6 @@ async def get_transaction_stats(
@router.get("/transactions/analytics")
async def get_transactions_for_analytics(
transaction_repo: TransactionRepo,
days: int = Query(default=365, description="Number of days to include"),
account_id: Optional[str] = Query(default=None, description="Filter by account ID"),
) -> List[dict]:
@@ -214,7 +212,7 @@ async def get_transactions_for_analytics(
date_to = end_date.isoformat()
# Get ALL transactions from database (no limit for analytics)
transactions = transaction_repo.get_transactions(
transactions = await database_service.get_transactions_from_db(
account_id=account_id,
date_from=date_from,
date_to=date_to,
@@ -246,14 +244,11 @@ async def get_transactions_for_analytics(
@router.get("/transactions/monthly-stats")
async def get_monthly_transaction_stats(
analytics_proc: AnalyticsProc,
days: int = Query(default=365, description="Number of days to include"),
account_id: Optional[str] = Query(default=None, description="Filter by account ID"),
) -> List[dict]:
"""Get monthly transaction statistics aggregated by the database"""
try:
from leggen.utils.paths import path_manager
# Date range for monthly stats
end_date = datetime.now()
start_date = end_date - timedelta(days=days)
@@ -263,9 +258,10 @@ async def get_monthly_transaction_stats(
date_to = end_date.isoformat()
# Get monthly aggregated stats from database
db_path = path_manager.get_database_path()
monthly_stats = analytics_proc.calculate_monthly_stats(
db_path, account_id=account_id, date_from=date_from, date_to=date_to
monthly_stats = await database_service.get_monthly_transaction_stats_from_db(
account_id=account_id,
date_from=date_from,
date_to=date_to,
)
return monthly_stats

View File

@@ -28,10 +28,10 @@ async def lifespan(app: FastAPI):
# Run database migrations
try:
from leggen.api.dependencies import get_migration_repository
from leggen.services.database_service import DatabaseService
migrations = get_migration_repository()
await migrations.run_all_migrations()
db_service = DatabaseService()
await db_service.run_migrations_if_needed()
logger.info("Database migrations completed")
except Exception as e:
logger.error(f"Database migration failed: {e}")

View File

@@ -1,13 +0,0 @@
from leggen.repositories.account_repository import AccountRepository
from leggen.repositories.balance_repository import BalanceRepository
from leggen.repositories.migration_repository import MigrationRepository
from leggen.repositories.sync_repository import SyncRepository
from leggen.repositories.transaction_repository import TransactionRepository
__all__ = [
"AccountRepository",
"BalanceRepository",
"MigrationRepository",
"SyncRepository",
"TransactionRepository",
]

View File

@@ -1,128 +0,0 @@
from typing import Any, Dict, List, Optional
from leggen.repositories.base_repository import BaseRepository
class AccountRepository(BaseRepository):
"""Repository for account data operations"""
def create_table(self):
"""Create accounts table with indexes"""
with self._get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"""CREATE TABLE IF NOT EXISTS accounts (
id TEXT PRIMARY KEY,
institution_id TEXT,
status TEXT,
iban TEXT,
name TEXT,
currency TEXT,
created DATETIME,
last_accessed DATETIME,
last_updated DATETIME,
display_name TEXT,
logo TEXT
)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_accounts_institution_id
ON accounts(institution_id)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_accounts_status
ON accounts(status)"""
)
conn.commit()
def persist(self, account_data: Dict[str, Any]) -> Dict[str, Any]:
"""Persist account details to database"""
self.create_table()
with self._get_db_connection() as conn:
cursor = conn.cursor()
# Check if account exists and preserve display_name
cursor.execute(
"SELECT display_name FROM accounts WHERE id = ?", (account_data["id"],)
)
existing_row = cursor.fetchone()
existing_display_name = existing_row[0] if existing_row else None
# Use existing display_name if not provided in account_data
display_name = account_data.get("display_name", existing_display_name)
cursor.execute(
"""INSERT OR REPLACE INTO accounts (
id,
institution_id,
status,
iban,
name,
currency,
created,
last_accessed,
last_updated,
display_name,
logo
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
(
account_data["id"],
account_data["institution_id"],
account_data["status"],
account_data.get("iban"),
account_data.get("name"),
account_data.get("currency"),
account_data["created"],
account_data.get("last_accessed"),
account_data.get("last_updated", account_data["created"]),
display_name,
account_data.get("logo"),
),
)
conn.commit()
return account_data
def get_accounts(
self, account_ids: Optional[List[str]] = None
) -> List[Dict[str, Any]]:
"""Get account details from database"""
if not self._db_exists():
return []
with self._get_db_connection(row_factory=True) as conn:
cursor = conn.cursor()
query = "SELECT * FROM accounts"
params = []
if account_ids:
placeholders = ",".join("?" * len(account_ids))
query += f" WHERE id IN ({placeholders})"
params.extend(account_ids)
query += " ORDER BY created DESC"
cursor.execute(query, params)
rows = cursor.fetchall()
return [dict(row) for row in rows]
def get_account(self, account_id: str) -> Optional[Dict[str, Any]]:
"""Get specific account details from database"""
if not self._db_exists():
return None
with self._get_db_connection(row_factory=True) as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM accounts WHERE id = ?", (account_id,))
row = cursor.fetchone()
if row:
return dict(row)
return None

View File

@@ -1,107 +0,0 @@
import sqlite3
from typing import Any, Dict, List, Optional
from loguru import logger
from leggen.repositories.base_repository import BaseRepository
class BalanceRepository(BaseRepository):
"""Repository for balance data operations"""
def create_table(self):
"""Create balances table with indexes"""
with self._get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"""CREATE TABLE IF NOT EXISTS balances (
id INTEGER PRIMARY KEY AUTOINCREMENT,
account_id TEXT,
bank TEXT,
status TEXT,
iban TEXT,
amount REAL,
currency TEXT,
type TEXT,
timestamp DATETIME
)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_balances_account_id
ON balances(account_id)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_balances_timestamp
ON balances(timestamp)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_balances_account_type_timestamp
ON balances(account_id, type, timestamp)"""
)
conn.commit()
def persist(self, account_id: str, balance_rows: List[tuple]) -> None:
"""Persist balance rows to database"""
try:
self.create_table()
with self._get_db_connection() as conn:
cursor = conn.cursor()
for row in balance_rows:
try:
cursor.execute(
"""INSERT INTO balances (
account_id,
bank,
status,
iban,
amount,
currency,
type,
timestamp
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
row,
)
except sqlite3.IntegrityError:
logger.warning(f"Skipped duplicate balance for {account_id}")
conn.commit()
logger.info(f"Persisted balances for account {account_id}")
except Exception as e:
logger.error(f"Failed to persist balances: {e}")
raise
def get_balances(self, account_id: Optional[str] = None) -> List[Dict[str, Any]]:
"""Get latest balances from database"""
if not self._db_exists():
return []
with self._get_db_connection(row_factory=True) as conn:
cursor = conn.cursor()
# Get latest balance for each account_id and type combination
query = """
SELECT * FROM balances b1
WHERE b1.timestamp = (
SELECT MAX(b2.timestamp)
FROM balances b2
WHERE b2.account_id = b1.account_id AND b2.type = b1.type
)
"""
params = []
if account_id:
query += " AND b1.account_id = ?"
params.append(account_id)
query += " ORDER BY b1.account_id, b1.type"
cursor.execute(query, params)
rows = cursor.fetchall()
return [dict(row) for row in rows]

View File

@@ -1,28 +0,0 @@
import sqlite3
from contextlib import contextmanager
from leggen.utils.paths import path_manager
class BaseRepository:
"""Base repository with shared database connection logic"""
@contextmanager
def _get_db_connection(self, row_factory: bool = False):
"""Context manager for database connections with proper cleanup"""
db_path = path_manager.get_database_path()
conn = sqlite3.connect(str(db_path))
if row_factory:
conn.row_factory = sqlite3.Row
try:
yield conn
except Exception as e:
conn.rollback()
raise e
finally:
conn.close()
def _db_exists(self) -> bool:
"""Check if database file exists"""
db_path = path_manager.get_database_path()
return db_path.exists()

View File

@@ -1,626 +0,0 @@
import sqlite3
import uuid
from datetime import datetime
from loguru import logger
from leggen.repositories.base_repository import BaseRepository
from leggen.utils.paths import path_manager
class MigrationRepository(BaseRepository):
"""Repository for database migrations"""
async def run_all_migrations(self):
"""Run all necessary database migrations"""
await self.migrate_balance_timestamps_if_needed()
await self.migrate_null_transaction_ids_if_needed()
await self.migrate_to_composite_key_if_needed()
await self.migrate_add_display_name_if_needed()
await self.migrate_add_sync_operations_if_needed()
await self.migrate_add_logo_if_needed()
# Balance timestamp migration methods
async def migrate_balance_timestamps_if_needed(self):
"""Check and migrate balance timestamps if needed"""
try:
if await self._check_balance_timestamp_migration_needed():
logger.info("Balance timestamp migration needed, starting...")
await self._migrate_balance_timestamps()
logger.info("Balance timestamp migration completed")
else:
logger.info("Balance timestamps are already consistent")
except Exception as e:
logger.error(f"Balance timestamp migration failed: {e}")
raise
async def _check_balance_timestamp_migration_needed(self) -> bool:
"""Check if balance timestamps need migration"""
db_path = path_manager.get_database_path()
if not db_path.exists():
return False
try:
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
cursor.execute("""
SELECT typeof(timestamp) as type, COUNT(*) as count
FROM balances
GROUP BY typeof(timestamp)
""")
types = cursor.fetchall()
conn.close()
type_names = [row[0] for row in types]
return "real" in type_names and "text" in type_names
except Exception as e:
logger.error(f"Failed to check migration status: {e}")
return False
async def _migrate_balance_timestamps(self):
"""Convert all Unix timestamps to datetime strings"""
db_path = path_manager.get_database_path()
if not db_path.exists():
logger.warning("Database file not found, skipping migration")
return
try:
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
cursor.execute("""
SELECT id, timestamp
FROM balances
WHERE typeof(timestamp) = 'real'
ORDER BY id
""")
unix_records = cursor.fetchall()
total_records = len(unix_records)
if total_records == 0:
logger.info("No Unix timestamps found to migrate")
conn.close()
return
logger.info(
f"Migrating {total_records} balance records from Unix to datetime format"
)
batch_size = 100
migrated_count = 0
for i in range(0, total_records, batch_size):
batch = unix_records[i : i + batch_size]
for record_id, unix_timestamp in batch:
try:
dt_string = self._unix_to_datetime_string(float(unix_timestamp))
cursor.execute(
"""
UPDATE balances
SET timestamp = ?
WHERE id = ?
""",
(dt_string, record_id),
)
migrated_count += 1
if migrated_count % 100 == 0:
logger.info(
f"Migrated {migrated_count}/{total_records} balance records"
)
except Exception as e:
logger.error(f"Failed to migrate record {record_id}: {e}")
continue
conn.commit()
conn.close()
logger.info(f"Successfully migrated {migrated_count} balance records")
except Exception as e:
logger.error(f"Balance timestamp migration failed: {e}")
raise
def _unix_to_datetime_string(self, unix_timestamp: float) -> str:
"""Convert Unix timestamp to datetime string"""
dt = datetime.fromtimestamp(unix_timestamp)
return dt.isoformat()
# Null transaction IDs migration methods
async def migrate_null_transaction_ids_if_needed(self):
"""Check and migrate null transaction IDs if needed"""
try:
if await self._check_null_transaction_ids_migration_needed():
logger.info("Null transaction IDs migration needed, starting...")
await self._migrate_null_transaction_ids()
logger.info("Null transaction IDs migration completed")
else:
logger.info("No null transaction IDs found to migrate")
except Exception as e:
logger.error(f"Null transaction IDs migration failed: {e}")
raise
async def _check_null_transaction_ids_migration_needed(self) -> bool:
"""Check if null transaction IDs need migration"""
db_path = path_manager.get_database_path()
if not db_path.exists():
return False
try:
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
cursor.execute("""
SELECT COUNT(*)
FROM transactions
WHERE (internalTransactionId IS NULL OR internalTransactionId = '')
AND json_extract(rawTransaction, '$.transactionId') IS NOT NULL
""")
count = cursor.fetchone()[0]
conn.close()
return count > 0
except Exception as e:
logger.error(f"Failed to check null transaction IDs migration status: {e}")
return False
async def _migrate_null_transaction_ids(self):
"""Populate null internalTransactionId fields using transactionId from raw data"""
db_path = path_manager.get_database_path()
if not db_path.exists():
logger.warning("Database file not found, skipping migration")
return
try:
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
cursor.execute("""
SELECT rowid, json_extract(rawTransaction, '$.transactionId') as transactionId
FROM transactions
WHERE (internalTransactionId IS NULL OR internalTransactionId = '')
AND json_extract(rawTransaction, '$.transactionId') IS NOT NULL
ORDER BY rowid
""")
null_records = cursor.fetchall()
total_records = len(null_records)
if total_records == 0:
logger.info("No null transaction IDs found to migrate")
conn.close()
return
logger.info(
f"Migrating {total_records} transaction records with null internalTransactionId"
)
batch_size = 100
migrated_count = 0
for i in range(0, total_records, batch_size):
batch = null_records[i : i + batch_size]
for rowid, transaction_id in batch:
try:
cursor.execute(
"SELECT COUNT(*) FROM transactions WHERE internalTransactionId = ?",
(str(transaction_id),),
)
existing_count = cursor.fetchone()[0]
if existing_count > 0:
unique_id = f"{str(transaction_id)}_{uuid.uuid4().hex[:8]}"
logger.debug(
f"Generated unique ID for duplicate transactionId: {unique_id}"
)
else:
unique_id = str(transaction_id)
cursor.execute(
"""
UPDATE transactions
SET internalTransactionId = ?
WHERE rowid = ?
""",
(unique_id, rowid),
)
migrated_count += 1
if migrated_count % 100 == 0:
logger.info(
f"Migrated {migrated_count}/{total_records} transaction records"
)
except Exception as e:
logger.error(f"Failed to migrate record {rowid}: {e}")
continue
conn.commit()
conn.close()
logger.info(f"Successfully migrated {migrated_count} transaction records")
except Exception as e:
logger.error(f"Null transaction IDs migration failed: {e}")
raise
# Composite key migration methods
async def migrate_to_composite_key_if_needed(self):
"""Check and migrate to composite primary key if needed"""
try:
if await self._check_composite_key_migration_needed():
logger.info("Composite key migration needed, starting...")
await self._migrate_to_composite_key()
logger.info("Composite key migration completed")
else:
logger.info("Composite key migration not needed")
except Exception as e:
logger.error(f"Composite key migration failed: {e}")
raise
async def _check_composite_key_migration_needed(self) -> bool:
"""Check if composite key migration is needed"""
db_path = path_manager.get_database_path()
if not db_path.exists():
return False
try:
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='transactions'"
)
if not cursor.fetchone():
conn.close()
return False
cursor.execute("PRAGMA table_info(transactions)")
columns = cursor.fetchall()
internal_transaction_id_is_pk = any(
col[1] == "internalTransactionId" and col[5] == 1 for col in columns
)
has_composite_key = any(
col[1] in ["accountId", "transactionId"] and col[5] == 1
for col in columns
)
conn.close()
return internal_transaction_id_is_pk or not has_composite_key
except Exception as e:
logger.error(f"Failed to check composite key migration status: {e}")
return False
async def _migrate_to_composite_key(self):
"""Migrate transactions table to use composite primary key (accountId, transactionId)"""
db_path = path_manager.get_database_path()
if not db_path.exists():
logger.warning("Database file not found, skipping migration")
return
try:
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
logger.info("Starting composite key migration...")
logger.info("Creating temporary table with composite primary key...")
cursor.execute("DROP TABLE IF EXISTS transactions_temp")
cursor.execute("""
CREATE TABLE transactions_temp (
accountId TEXT NOT NULL,
transactionId TEXT NOT NULL,
internalTransactionId TEXT,
institutionId TEXT,
iban TEXT,
transactionDate DATETIME,
description TEXT,
transactionValue REAL,
transactionCurrency TEXT,
transactionStatus TEXT,
rawTransaction JSON,
PRIMARY KEY (accountId, transactionId)
)
""")
logger.info("Inserting deduplicated data...")
cursor.execute("""
INSERT INTO transactions_temp
SELECT
accountId,
json_extract(rawTransaction, '$.transactionId') as transactionId,
internalTransactionId,
institutionId,
iban,
transactionDate,
description,
transactionValue,
transactionCurrency,
transactionStatus,
rawTransaction
FROM (
SELECT *,
ROW_NUMBER() OVER (
PARTITION BY accountId, json_extract(rawTransaction, '$.transactionId')
ORDER BY transactionDate DESC
) as rn
FROM transactions
WHERE json_extract(rawTransaction, '$.transactionId') IS NOT NULL
)
WHERE rn = 1
""")
rows_migrated = cursor.rowcount
logger.info(f"Migrated {rows_migrated} unique transactions")
logger.info("Replacing old table...")
cursor.execute("DROP TABLE transactions")
cursor.execute("ALTER TABLE transactions_temp RENAME TO transactions")
logger.info("Recreating indexes...")
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_transactions_internal_id
ON transactions(internalTransactionId)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_transactions_date
ON transactions(transactionDate)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_transactions_account_date
ON transactions(accountId, transactionDate)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_transactions_amount
ON transactions(transactionValue)"""
)
conn.commit()
conn.close()
logger.info("Composite key migration completed successfully")
except Exception as e:
logger.error(f"Composite key migration failed: {e}")
raise
# Display name migration methods
async def migrate_add_display_name_if_needed(self):
"""Check and add display_name column if needed"""
try:
if await self._check_display_name_migration_needed():
logger.info("Display name column migration needed, starting...")
await self._migrate_add_display_name()
logger.info("Display name column migration completed")
else:
logger.info("Display name column already exists")
except Exception as e:
logger.error(f"Display name column migration failed: {e}")
raise
async def _check_display_name_migration_needed(self) -> bool:
"""Check if display_name column needs to be added"""
db_path = path_manager.get_database_path()
if not db_path.exists():
return False
try:
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='accounts'"
)
if not cursor.fetchone():
conn.close()
return False
cursor.execute("PRAGMA table_info(accounts)")
columns = cursor.fetchall()
has_display_name = any(col[1] == "display_name" for col in columns)
conn.close()
return not has_display_name
except Exception as e:
logger.error(f"Failed to check display_name migration status: {e}")
return False
async def _migrate_add_display_name(self):
"""Add display_name column to accounts table"""
db_path = path_manager.get_database_path()
if not db_path.exists():
logger.warning("Database file not found, skipping migration")
return
try:
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
logger.info("Adding display_name column to accounts table...")
cursor.execute("""
ALTER TABLE accounts
ADD COLUMN display_name TEXT
""")
conn.commit()
conn.close()
logger.info("Display name column migration completed successfully")
except Exception as e:
logger.error(f"Display name column migration failed: {e}")
raise
# Sync operations migration methods
async def migrate_add_sync_operations_if_needed(self):
"""Check and add sync_operations table if needed"""
try:
if await self._check_sync_operations_migration_needed():
logger.info("Sync operations table migration needed, starting...")
await self._migrate_add_sync_operations()
logger.info("Sync operations table migration completed")
else:
logger.info("Sync operations table already exists")
except Exception as e:
logger.error(f"Sync operations table migration failed: {e}")
raise
async def _check_sync_operations_migration_needed(self) -> bool:
"""Check if sync_operations table needs to be created"""
db_path = path_manager.get_database_path()
if not db_path.exists():
return False
try:
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='sync_operations'"
)
table_exists = cursor.fetchone() is not None
conn.close()
return not table_exists
except Exception as e:
logger.error(f"Failed to check sync_operations migration status: {e}")
return False
async def _migrate_add_sync_operations(self):
"""Add sync_operations table"""
db_path = path_manager.get_database_path()
if not db_path.exists():
logger.warning("Database file not found, skipping migration")
return
try:
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
logger.info("Creating sync_operations table...")
cursor.execute("""
CREATE TABLE sync_operations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
started_at DATETIME NOT NULL,
completed_at DATETIME,
success BOOLEAN,
accounts_processed INTEGER DEFAULT 0,
transactions_added INTEGER DEFAULT 0,
transactions_updated INTEGER DEFAULT 0,
balances_updated INTEGER DEFAULT 0,
duration_seconds REAL,
errors TEXT,
logs TEXT,
trigger_type TEXT DEFAULT 'manual'
)
""")
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_sync_operations_started_at ON sync_operations(started_at)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_sync_operations_success ON sync_operations(success)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_sync_operations_trigger_type ON sync_operations(trigger_type)"
)
conn.commit()
conn.close()
logger.info("Sync operations table migration completed successfully")
except Exception as e:
logger.error(f"Sync operations table migration failed: {e}")
raise
# Logo migration methods
async def migrate_add_logo_if_needed(self):
"""Check and add logo column to accounts table if needed"""
try:
if await self._check_logo_migration_needed():
logger.info("Logo column migration needed, starting...")
await self._migrate_add_logo()
logger.info("Logo column migration completed")
else:
logger.info("Logo column already exists")
except Exception as e:
logger.error(f"Logo column migration failed: {e}")
raise
async def _check_logo_migration_needed(self) -> bool:
"""Check if logo column needs to be added to accounts table"""
db_path = path_manager.get_database_path()
if not db_path.exists():
return False
try:
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='accounts'"
)
if not cursor.fetchone():
conn.close()
return False
cursor.execute("PRAGMA table_info(accounts)")
columns = cursor.fetchall()
has_logo = any(col[1] == "logo" for col in columns)
conn.close()
return not has_logo
except Exception as e:
logger.error(f"Failed to check logo migration status: {e}")
return False
async def _migrate_add_logo(self):
"""Add logo column to accounts table"""
db_path = path_manager.get_database_path()
if not db_path.exists():
logger.warning("Database file not found, skipping migration")
return
try:
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
logger.info("Adding logo column to accounts table...")
cursor.execute("""
ALTER TABLE accounts
ADD COLUMN logo TEXT
""")
conn.commit()
conn.close()
logger.info("Logo column migration completed successfully")
except Exception as e:
logger.error(f"Logo column migration failed: {e}")
raise

View File

@@ -1,132 +0,0 @@
import json
import sqlite3
from typing import Any, Dict, List
from loguru import logger
from leggen.repositories.base_repository import BaseRepository
from leggen.utils.paths import path_manager
class SyncRepository(BaseRepository):
"""Repository for sync operation data"""
def create_table(self):
"""Create sync_operations table with indexes"""
with self._get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS sync_operations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
started_at DATETIME NOT NULL,
completed_at DATETIME,
success BOOLEAN,
accounts_processed INTEGER DEFAULT 0,
transactions_added INTEGER DEFAULT 0,
transactions_updated INTEGER DEFAULT 0,
balances_updated INTEGER DEFAULT 0,
duration_seconds REAL,
errors TEXT,
logs TEXT,
trigger_type TEXT DEFAULT 'manual'
)
""")
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_sync_operations_started_at ON sync_operations(started_at)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_sync_operations_success ON sync_operations(success)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_sync_operations_trigger_type ON sync_operations(trigger_type)"
)
conn.commit()
def persist(self, sync_operation: Dict[str, Any]) -> int:
"""Persist sync operation to database and return the ID"""
try:
self.create_table()
db_path = path_manager.get_database_path()
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
cursor.execute(
"""INSERT INTO sync_operations (
started_at, completed_at, success, accounts_processed,
transactions_added, transactions_updated, balances_updated,
duration_seconds, errors, logs, trigger_type
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
(
sync_operation.get("started_at"),
sync_operation.get("completed_at"),
sync_operation.get("success"),
sync_operation.get("accounts_processed", 0),
sync_operation.get("transactions_added", 0),
sync_operation.get("transactions_updated", 0),
sync_operation.get("balances_updated", 0),
sync_operation.get("duration_seconds"),
json.dumps(sync_operation.get("errors", [])),
json.dumps(sync_operation.get("logs", [])),
sync_operation.get("trigger_type", "manual"),
),
)
operation_id = cursor.lastrowid
if operation_id is None:
raise ValueError("Failed to get operation ID after insert")
conn.commit()
conn.close()
logger.debug(f"Persisted sync operation with ID: {operation_id}")
return operation_id
except Exception as e:
logger.error(f"Failed to persist sync operation: {e}")
raise
def get_operations(self, limit: int = 50, offset: int = 0) -> List[Dict[str, Any]]:
"""Get sync operations from database"""
try:
db_path = path_manager.get_database_path()
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
cursor.execute(
"""SELECT id, started_at, completed_at, success, accounts_processed,
transactions_added, transactions_updated, balances_updated,
duration_seconds, errors, logs, trigger_type
FROM sync_operations
ORDER BY started_at DESC
LIMIT ? OFFSET ?""",
(limit, offset),
)
operations = []
for row in cursor.fetchall():
operation = {
"id": row[0],
"started_at": row[1],
"completed_at": row[2],
"success": bool(row[3]) if row[3] is not None else None,
"accounts_processed": row[4],
"transactions_added": row[5],
"transactions_updated": row[6],
"balances_updated": row[7],
"duration_seconds": row[8],
"errors": json.loads(row[9]) if row[9] else [],
"logs": json.loads(row[10]) if row[10] else [],
"trigger_type": row[11],
}
operations.append(operation)
conn.close()
return operations
except Exception as e:
logger.error(f"Failed to get sync operations: {e}")
return []

View File

@@ -1,264 +0,0 @@
import json
import sqlite3
from typing import Any, Dict, List, Optional, Union
from loguru import logger
from leggen.repositories.base_repository import BaseRepository
class TransactionRepository(BaseRepository):
"""Repository for transaction data operations"""
def create_table(self):
"""Create transactions table with indexes"""
with self._get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"""CREATE TABLE IF NOT EXISTS transactions (
accountId TEXT NOT NULL,
transactionId TEXT NOT NULL,
internalTransactionId TEXT,
institutionId TEXT,
iban TEXT,
transactionDate DATETIME,
description TEXT,
transactionValue REAL,
transactionCurrency TEXT,
transactionStatus TEXT,
rawTransaction JSON,
PRIMARY KEY (accountId, transactionId)
)"""
)
# Create indexes for better performance
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_transactions_internal_id
ON transactions(internalTransactionId)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_transactions_date
ON transactions(transactionDate)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_transactions_account_date
ON transactions(accountId, transactionDate)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_transactions_amount
ON transactions(transactionValue)"""
)
conn.commit()
def persist(
self, account_id: str, transactions: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Persist transactions to database, return new ones"""
try:
self.create_table()
with self._get_db_connection() as conn:
cursor = conn.cursor()
insert_sql = """INSERT OR REPLACE INTO transactions (
accountId,
transactionId,
internalTransactionId,
institutionId,
iban,
transactionDate,
description,
transactionValue,
transactionCurrency,
transactionStatus,
rawTransaction
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"""
new_transactions = []
for transaction in transactions:
try:
# Check if transaction already exists
cursor.execute(
"""SELECT COUNT(*) FROM transactions
WHERE accountId = ? AND transactionId = ?""",
(transaction["accountId"], transaction["transactionId"]),
)
exists = cursor.fetchone()[0] > 0
cursor.execute(
insert_sql,
(
transaction["accountId"],
transaction["transactionId"],
transaction.get("internalTransactionId"),
transaction["institutionId"],
transaction["iban"],
transaction["transactionDate"],
transaction["description"],
transaction["transactionValue"],
transaction["transactionCurrency"],
transaction["transactionStatus"],
json.dumps(transaction["rawTransaction"]),
),
)
if not exists:
new_transactions.append(transaction)
except sqlite3.IntegrityError as e:
logger.warning(
f"Failed to insert transaction {transaction.get('transactionId')}: {e}"
)
continue
conn.commit()
logger.info(
f"Persisted {len(new_transactions)} new transactions for account {account_id}"
)
return new_transactions
except Exception as e:
logger.error(f"Failed to persist transactions: {e}")
raise
def get_transactions(
self,
account_id: Optional[str] = None,
limit: Optional[int] = 100,
offset: int = 0,
date_from: Optional[str] = None,
date_to: Optional[str] = None,
min_amount: Optional[float] = None,
max_amount: Optional[float] = None,
search: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""Get transactions with optional filtering"""
if not self._db_exists():
return []
with self._get_db_connection(row_factory=True) as conn:
cursor = conn.cursor()
query = "SELECT * FROM transactions WHERE 1=1"
params: List[Union[str, int, float]] = []
if account_id:
query += " AND accountId = ?"
params.append(account_id)
if date_from:
query += " AND transactionDate >= ?"
params.append(date_from)
if date_to:
query += " AND transactionDate <= ?"
params.append(date_to)
if min_amount is not None:
query += " AND transactionValue >= ?"
params.append(min_amount)
if max_amount is not None:
query += " AND transactionValue <= ?"
params.append(max_amount)
if search:
query += " AND description LIKE ?"
params.append(f"%{search}%")
query += " ORDER BY transactionDate DESC"
if limit:
query += " LIMIT ?"
params.append(limit)
if offset:
query += " OFFSET ?"
params.append(offset)
cursor.execute(query, params)
rows = cursor.fetchall()
transactions = []
for row in rows:
transaction = dict(row)
if transaction["rawTransaction"]:
transaction["rawTransaction"] = json.loads(
transaction["rawTransaction"]
)
transactions.append(transaction)
return transactions
def get_count(
self,
account_id: Optional[str] = None,
date_from: Optional[str] = None,
date_to: Optional[str] = None,
min_amount: Optional[float] = None,
max_amount: Optional[float] = None,
search: Optional[str] = None,
) -> int:
"""Get total count of transactions matching filters"""
if not self._db_exists():
return 0
with self._get_db_connection() as conn:
cursor = conn.cursor()
query = "SELECT COUNT(*) FROM transactions WHERE 1=1"
params: List[Union[str, float]] = []
if account_id:
query += " AND accountId = ?"
params.append(account_id)
if date_from:
query += " AND transactionDate >= ?"
params.append(date_from)
if date_to:
query += " AND transactionDate <= ?"
params.append(date_to)
if min_amount is not None:
query += " AND transactionValue >= ?"
params.append(min_amount)
if max_amount is not None:
query += " AND transactionValue <= ?"
params.append(max_amount)
if search:
query += " AND description LIKE ?"
params.append(f"%{search}%")
cursor.execute(query, params)
return cursor.fetchone()[0]
def get_account_summary(self, account_id: str) -> Optional[Dict[str, Any]]:
"""Get basic account info from transactions table"""
if not self._db_exists():
return None
with self._get_db_connection(row_factory=True) as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT DISTINCT accountId, institutionId, iban
FROM transactions
WHERE accountId = ?
ORDER BY transactionDate DESC
LIMIT 1
""",
(account_id,),
)
row = cursor.fetchone()
if row:
return dict(row)
return None

View File

@@ -1,13 +0,0 @@
"""Data processing layer for all transformation logic."""
from leggen.services.data_processors.account_enricher import AccountEnricher
from leggen.services.data_processors.analytics_processor import AnalyticsProcessor
from leggen.services.data_processors.balance_transformer import BalanceTransformer
from leggen.services.data_processors.transaction_processor import TransactionProcessor
__all__ = [
"AccountEnricher",
"AnalyticsProcessor",
"BalanceTransformer",
"TransactionProcessor",
]

View File

@@ -1,71 +0,0 @@
"""Account enrichment processor for adding currency, logos, and metadata."""
from typing import Any, Dict
from loguru import logger
from leggen.services.gocardless_service import GoCardlessService
class AccountEnricher:
"""Enriches account details with currency and institution information."""
def __init__(self):
self.gocardless = GoCardlessService()
async def enrich_account_details(
self,
account_details: Dict[str, Any],
balances: Dict[str, Any],
) -> Dict[str, Any]:
"""
Enrich account details with currency from balances and institution logo.
Args:
account_details: Raw account details from GoCardless
balances: Balance data containing currency information
Returns:
Enriched account details with currency and logo added
"""
enriched_account = account_details.copy()
# Extract currency from first balance
currency = self._extract_currency_from_balances(balances)
if currency:
enriched_account["currency"] = currency
# Fetch and add institution logo
institution_id = enriched_account.get("institution_id")
if institution_id:
logo = await self._fetch_institution_logo(institution_id)
if logo:
enriched_account["logo"] = logo
return enriched_account
def _extract_currency_from_balances(self, balances: Dict[str, Any]) -> str | None:
"""Extract currency from the first balance in the balances data."""
balances_list = balances.get("balances", [])
if not balances_list:
return None
first_balance = balances_list[0]
balance_amount = first_balance.get("balanceAmount", {})
return balance_amount.get("currency")
async def _fetch_institution_logo(self, institution_id: str) -> str | None:
"""Fetch institution logo from GoCardless API."""
try:
institution_details = await self.gocardless.get_institution_details(
institution_id
)
logo = institution_details.get("logo", "")
if logo:
logger.info(f"Fetched logo for institution {institution_id}: {logo}")
return logo
except Exception as e:
logger.warning(
f"Failed to fetch institution details for {institution_id}: {e}"
)
return None

View File

@@ -1,201 +0,0 @@
"""Analytics processor for calculating historical balances and statistics."""
import sqlite3
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Dict, List, Optional
from loguru import logger
class AnalyticsProcessor:
"""Calculates historical balances and transaction statistics from database data."""
def calculate_historical_balances(
self,
db_path: Path,
account_id: Optional[str] = None,
days: int = 365,
) -> List[Dict[str, Any]]:
"""
Generate historical balance progression based on transaction history.
Uses current balances and subtracts future transactions to calculate
balance at each historical point in time.
Args:
db_path: Path to SQLite database
account_id: Optional account ID to filter by
days: Number of days to look back (default 365)
Returns:
List of historical balance data points
"""
if not db_path.exists():
return []
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
try:
cutoff_date = (datetime.now() - timedelta(days=days)).date().isoformat()
today_date = datetime.now().date().isoformat()
# Single SQL query to generate historical balances using window functions
query = """
WITH RECURSIVE date_series AS (
-- Generate weekly dates from cutoff_date to today
SELECT date(?) as ref_date
UNION ALL
SELECT date(ref_date, '+7 days')
FROM date_series
WHERE ref_date < date(?)
),
current_balances AS (
-- Get current balance for each account/type
SELECT account_id, type, amount, currency
FROM balances b1
WHERE b1.timestamp = (
SELECT MAX(b2.timestamp)
FROM balances b2
WHERE b2.account_id = b1.account_id AND b2.type = b1.type
)
{account_filter}
AND b1.type = 'closingBooked' -- Focus on closingBooked for charts
),
historical_points AS (
-- Calculate balance at each weekly point by subtracting future transactions
SELECT
cb.account_id,
cb.type as balance_type,
cb.currency,
ds.ref_date,
cb.amount - COALESCE(
(SELECT SUM(t.transactionValue)
FROM transactions t
WHERE t.accountId = cb.account_id
AND date(t.transactionDate) > ds.ref_date), 0
) as balance_amount
FROM current_balances cb
CROSS JOIN date_series ds
)
SELECT
account_id || '_' || balance_type || '_' || ref_date as id,
account_id,
balance_amount,
balance_type,
currency,
ref_date as reference_date
FROM historical_points
ORDER BY account_id, ref_date
"""
# Build parameters and account filter
params = [cutoff_date, today_date]
if account_id:
account_filter = "AND b1.account_id = ?"
params.append(account_id)
else:
account_filter = ""
# Format the query with conditional filter
formatted_query = query.format(account_filter=account_filter)
cursor.execute(formatted_query, params)
rows = cursor.fetchall()
conn.close()
return [dict(row) for row in rows]
except Exception as e:
conn.close()
logger.error(f"Failed to calculate historical balances: {e}")
raise
def calculate_monthly_stats(
self,
db_path: Path,
account_id: Optional[str] = None,
date_from: Optional[str] = None,
date_to: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""
Calculate monthly transaction statistics aggregated from database.
Sums transactions by month and calculates income, expenses, and net values.
Args:
db_path: Path to SQLite database
account_id: Optional account ID to filter by
date_from: Optional start date (ISO format)
date_to: Optional end date (ISO format)
Returns:
List of monthly statistics with income, expenses, and net totals
"""
if not db_path.exists():
return []
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
try:
# SQL query to aggregate transactions by month
query = """
SELECT
strftime('%Y-%m', transactionDate) as month,
COALESCE(SUM(CASE WHEN transactionValue > 0 THEN transactionValue ELSE 0 END), 0) as income,
COALESCE(SUM(CASE WHEN transactionValue < 0 THEN ABS(transactionValue) ELSE 0 END), 0) as expenses,
COALESCE(SUM(transactionValue), 0) as net
FROM transactions
WHERE 1=1
"""
params = []
if account_id:
query += " AND accountId = ?"
params.append(account_id)
if date_from:
query += " AND transactionDate >= ?"
params.append(date_from)
if date_to:
query += " AND transactionDate <= ?"
params.append(date_to)
query += """
GROUP BY strftime('%Y-%m', transactionDate)
ORDER BY month ASC
"""
cursor.execute(query, params)
rows = cursor.fetchall()
# Convert to desired format with proper month display
monthly_stats = []
for row in rows:
# Convert YYYY-MM to display format like "Mar 2024"
year, month_num = row["month"].split("-")
month_date = datetime.strptime(f"{year}-{month_num}-01", "%Y-%m-%d")
display_month = month_date.strftime("%b %Y")
monthly_stats.append(
{
"month": display_month,
"income": round(row["income"], 2),
"expenses": round(row["expenses"], 2),
"net": round(row["net"], 2),
}
)
conn.close()
return monthly_stats
except Exception as e:
conn.close()
logger.error(f"Failed to calculate monthly stats: {e}")
raise

View File

@@ -1,69 +0,0 @@
"""Balance data transformation processor for format conversions."""
from datetime import datetime
from typing import Any, Dict, List, Tuple
class BalanceTransformer:
"""Transforms balance data between GoCardless and internal database formats."""
def merge_account_metadata_into_balances(
self,
balances: Dict[str, Any],
account_details: Dict[str, Any],
) -> Dict[str, Any]:
"""
Merge account metadata into balance data for proper persistence.
This adds institution_id, iban, and account_status to the balances
so they can be persisted alongside the balance data.
Args:
balances: Raw balance data from GoCardless
account_details: Enriched account details containing metadata
Returns:
Balance data with account metadata merged in
"""
balances_with_metadata = balances.copy()
balances_with_metadata["institution_id"] = account_details.get("institution_id")
balances_with_metadata["iban"] = account_details.get("iban")
balances_with_metadata["account_status"] = account_details.get("status")
return balances_with_metadata
def transform_to_database_format(
self,
account_id: str,
balance_data: Dict[str, Any],
) -> List[Tuple[Any, ...]]:
"""
Transform GoCardless balance format to database row format.
Converts nested GoCardless balance structure into flat tuples
ready for SQLite insertion.
Args:
account_id: The account ID
balance_data: Balance data with merged account metadata
Returns:
List of tuples in database row format (account_id, bank, status, ...)
"""
rows = []
for balance in balance_data.get("balances", []):
balance_amount = balance.get("balanceAmount", {})
row = (
account_id,
balance_data.get("institution_id", "unknown"),
balance_data.get("account_status"),
balance_data.get("iban", "N/A"),
float(balance_amount.get("amount", 0)),
balance_amount.get("currency"),
balance.get("balanceType"),
datetime.now().isoformat(),
)
rows.append(row)
return rows

File diff suppressed because it is too large Load Diff

View File

@@ -4,17 +4,7 @@ from typing import List
from loguru import logger
from leggen.api.models.sync import SyncResult, SyncStatus
from leggen.repositories import (
AccountRepository,
BalanceRepository,
SyncRepository,
TransactionRepository,
)
from leggen.services.data_processors import (
AccountEnricher,
BalanceTransformer,
TransactionProcessor,
)
from leggen.services.database_service import DatabaseService
from leggen.services.gocardless_service import GoCardlessService
from leggen.services.notification_service import NotificationService
@@ -25,20 +15,10 @@ EXPIRED_DAYS_LEFT = 0
class SyncService:
def __init__(self):
self.gocardless = GoCardlessService()
self.database = DatabaseService()
self.notifications = NotificationService()
# Repositories
self.accounts = AccountRepository()
self.balances = BalanceRepository()
self.transactions = TransactionRepository()
self.sync = SyncRepository()
# Data processors
self.account_enricher = AccountEnricher()
self.balance_transformer = BalanceTransformer()
self.transaction_processor = TransactionProcessor()
self._sync_status = SyncStatus(is_running=False)
self._institution_logos = {} # Cache for institution logos
async def get_sync_status(self) -> SyncStatus:
"""Get current sync status"""
@@ -104,44 +84,72 @@ class SyncService:
# Get balances to extract currency information
balances = await self.gocardless.get_account_balances(account_id)
# Enrich and persist account details
# Enrich account details with currency and institution logo
if account_details and balances:
# Enrich account with currency and institution logo
enriched_account_details = (
await self.account_enricher.enrich_account_details(
account_details, balances
)
)
enriched_account_details = account_details.copy()
# Extract currency from first balance
balances_list = balances.get("balances", [])
if balances_list:
first_balance = balances_list[0]
balance_amount = first_balance.get("balanceAmount", {})
currency = balance_amount.get("currency")
if currency:
enriched_account_details["currency"] = currency
# Get institution details to fetch logo
institution_id = enriched_account_details.get("institution_id")
if institution_id:
try:
institution_details = (
await self.gocardless.get_institution_details(
institution_id
)
)
enriched_account_details["logo"] = (
institution_details.get("logo", "")
)
logger.info(
f"Fetched logo for institution {institution_id}: {enriched_account_details.get('logo', 'No logo')}"
)
except Exception as e:
logger.warning(
f"Failed to fetch institution details for {institution_id}: {e}"
)
# Persist enriched account details to database
self.accounts.persist(enriched_account_details)
await self.database.persist_account_details(
enriched_account_details
)
# Merge account metadata into balances for persistence
balances_with_account_info = self.balance_transformer.merge_account_metadata_into_balances(
balances, enriched_account_details
# Merge account details into balances data for proper persistence
balances_with_account_info = balances.copy()
balances_with_account_info["institution_id"] = (
enriched_account_details.get("institution_id")
)
balance_rows = (
self.balance_transformer.transform_to_database_format(
account_id, balances_with_account_info
)
balances_with_account_info["iban"] = (
enriched_account_details.get("iban")
)
balances_with_account_info["account_status"] = (
enriched_account_details.get("status")
)
await self.database.persist_balance(
account_id, balances_with_account_info
)
self.balances.persist(account_id, balance_rows)
balances_updated += len(balances.get("balances", []))
elif account_details:
# Fallback: persist account details without currency if balances failed
self.accounts.persist(account_details)
await self.database.persist_account_details(account_details)
# Get and save transactions
transactions = await self.gocardless.get_account_transactions(
account_id
)
if transactions:
processed_transactions = (
self.transaction_processor.process_transactions(
account_id, account_details, transactions
)
processed_transactions = self.database.process_transactions(
account_id, account_details, transactions
)
new_transactions = self.transactions.persist(
new_transactions = await self.database.persist_transactions(
account_id, processed_transactions
)
transactions_added += len(new_transactions)
@@ -195,7 +203,9 @@ class SyncService:
# Persist sync operation to database
try:
operation_id = self.sync.persist(sync_operation)
operation_id = await self.database.persist_sync_operation(
sync_operation
)
logger.debug(f"Saved sync operation with ID: {operation_id}")
except Exception as e:
logger.error(f"Failed to persist sync operation: {e}")
@@ -244,7 +254,9 @@ class SyncService:
)
try:
operation_id = self.sync.persist(sync_operation)
operation_id = await self.database.persist_sync_operation(
sync_operation
)
logger.debug(f"Saved failed sync operation with ID: {operation_id}")
except Exception as persist_error:
logger.error(

View File

@@ -126,38 +126,6 @@ def api_client(fastapi_app):
return TestClient(fastapi_app)
@pytest.fixture
def mock_account_repo():
"""Create mock AccountRepository for testing."""
from unittest.mock import MagicMock
return MagicMock()
@pytest.fixture
def mock_balance_repo():
"""Create mock BalanceRepository for testing."""
from unittest.mock import MagicMock
return MagicMock()
@pytest.fixture
def mock_transaction_repo():
"""Create mock TransactionRepository for testing."""
from unittest.mock import MagicMock
return MagicMock()
@pytest.fixture
def mock_analytics_proc():
"""Create mock AnalyticsProcessor for testing."""
from unittest.mock import MagicMock
return MagicMock()
@pytest.fixture
def mock_db_path(temp_db_path):
"""Mock the database path to use temporary database for testing."""

View File

@@ -1,13 +1,13 @@
"""Tests for analytics fixes to ensure all transactions are used in statistics."""
from datetime import datetime, timedelta
from unittest.mock import Mock
from unittest.mock import AsyncMock, Mock, patch
import pytest
from fastapi.testclient import TestClient
from leggen.api.dependencies import get_transaction_repository
from leggen.commands.server import create_app
from leggen.services.database_service import DatabaseService
class TestAnalyticsFix:
@@ -19,11 +19,11 @@ class TestAnalyticsFix:
return TestClient(app)
@pytest.fixture
def mock_transaction_repo(self):
return Mock()
def mock_database_service(self):
return Mock(spec=DatabaseService)
@pytest.mark.asyncio
async def test_transaction_stats_uses_all_transactions(self, mock_transaction_repo):
async def test_transaction_stats_uses_all_transactions(self, mock_database_service):
"""Test that transaction stats endpoint uses all transactions (not limited to 100)"""
# Mock data for 600 transactions (simulating the issue)
mock_transactions = []
@@ -42,50 +42,53 @@ class TestAnalyticsFix:
}
)
mock_transaction_repo.get_transactions.return_value = mock_transactions
app = create_app()
app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
client = TestClient(app)
response = client.get("/api/v1/transactions/stats?days=365")
assert response.status_code == 200
data = response.json()
# Verify that limit=None was passed to get all transactions
mock_transaction_repo.get_transactions.assert_called_once()
call_args = mock_transaction_repo.get_transactions.call_args
assert call_args.kwargs.get("limit") is None, (
"Stats endpoint should pass limit=None to get all transactions"
mock_database_service.get_transactions_from_db = AsyncMock(
return_value=mock_transactions
)
# Verify that the response contains stats for all 600 transactions
stats = data
assert stats["total_transactions"] == 600, (
"Should process all 600 transactions, not just 100"
)
# Test that the endpoint calls get_transactions_from_db with limit=None
with patch(
"leggen.api.routes.transactions.database_service", mock_database_service
):
app = create_app()
client = TestClient(app)
# Verify calculations are correct for all transactions
expected_income = sum(
txn["transactionValue"]
for txn in mock_transactions
if txn["transactionValue"] > 0
)
expected_expenses = sum(
abs(txn["transactionValue"])
for txn in mock_transactions
if txn["transactionValue"] < 0
)
response = client.get("/api/v1/transactions/stats?days=365")
assert stats["total_income"] == expected_income
assert stats["total_expenses"] == expected_expenses
assert response.status_code == 200
data = response.json()
# Verify that limit=None was passed to get all transactions
mock_database_service.get_transactions_from_db.assert_called_once()
call_args = mock_database_service.get_transactions_from_db.call_args
assert call_args.kwargs.get("limit") is None, (
"Stats endpoint should pass limit=None to get all transactions"
)
# Verify that the response contains stats for all 600 transactions
stats = data
assert stats["total_transactions"] == 600, (
"Should process all 600 transactions, not just 100"
)
# Verify calculations are correct for all transactions
expected_income = sum(
txn["transactionValue"]
for txn in mock_transactions
if txn["transactionValue"] > 0
)
expected_expenses = sum(
abs(txn["transactionValue"])
for txn in mock_transactions
if txn["transactionValue"] < 0
)
assert stats["total_income"] == expected_income
assert stats["total_expenses"] == expected_expenses
@pytest.mark.asyncio
async def test_analytics_endpoint_returns_all_transactions(
self, mock_transaction_repo
self, mock_database_service
):
"""Test that the new analytics endpoint returns all transactions without pagination"""
# Mock data for 600 transactions
@@ -105,28 +108,30 @@ class TestAnalyticsFix:
}
)
mock_transaction_repo.get_transactions.return_value = mock_transactions
app = create_app()
app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
client = TestClient(app)
response = client.get("/api/v1/transactions/analytics?days=365")
assert response.status_code == 200
data = response.json()
# Verify that limit=None was passed to get all transactions
mock_transaction_repo.get_transactions.assert_called_once()
call_args = mock_transaction_repo.get_transactions.call_args
assert call_args.kwargs.get("limit") is None, (
"Analytics endpoint should pass limit=None"
mock_database_service.get_transactions_from_db = AsyncMock(
return_value=mock_transactions
)
# Verify that all 600 transactions are returned
transactions_data = data
assert len(transactions_data) == 600, (
"Analytics endpoint should return all 600 transactions"
)
with patch(
"leggen.api.routes.transactions.database_service", mock_database_service
):
app = create_app()
client = TestClient(app)
response = client.get("/api/v1/transactions/analytics?days=365")
assert response.status_code == 200
data = response.json()
# Verify that limit=None was passed to get all transactions
mock_database_service.get_transactions_from_db.assert_called_once()
call_args = mock_database_service.get_transactions_from_db.call_args
assert call_args.kwargs.get("limit") is None, (
"Analytics endpoint should pass limit=None"
)
# Verify that all 600 transactions are returned
transactions_data = data
assert len(transactions_data) == 600, (
"Analytics endpoint should return all 600 transactions"
)

View File

@@ -4,12 +4,6 @@ from unittest.mock import patch
import pytest
from leggen.api.dependencies import (
get_account_repository,
get_balance_repository,
get_transaction_repository,
)
@pytest.mark.api
class TestAccountsAPI:
@@ -17,14 +11,11 @@ class TestAccountsAPI:
def test_get_all_accounts_success(
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
sample_account_data,
mock_db_path,
mock_account_repo,
mock_balance_repo,
):
"""Test successful retrieval of all accounts from database."""
mock_accounts = [
@@ -54,21 +45,19 @@ class TestAccountsAPI:
}
]
mock_account_repo.get_accounts.return_value = mock_accounts
mock_balance_repo.get_balances.return_value = mock_balances
fastapi_app.dependency_overrides[get_account_repository] = (
lambda: mock_account_repo
)
fastapi_app.dependency_overrides[get_balance_repository] = (
lambda: mock_balance_repo
)
with patch("leggen.utils.config.config", mock_config):
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.accounts.database_service.get_accounts_from_db",
return_value=mock_accounts,
),
patch(
"leggen.api.routes.accounts.database_service.get_balances_from_db",
return_value=mock_balances,
),
):
response = api_client.get("/api/v1/accounts")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200
data = response.json()
assert len(data) == 1
@@ -80,14 +69,11 @@ class TestAccountsAPI:
def test_get_account_details_success(
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
sample_account_data,
mock_db_path,
mock_account_repo,
mock_balance_repo,
):
"""Test successful retrieval of specific account details from database."""
mock_account = {
@@ -115,21 +101,19 @@ class TestAccountsAPI:
}
]
mock_account_repo.get_account.return_value = mock_account
mock_balance_repo.get_balances.return_value = mock_balances
fastapi_app.dependency_overrides[get_account_repository] = (
lambda: mock_account_repo
)
fastapi_app.dependency_overrides[get_balance_repository] = (
lambda: mock_balance_repo
)
with patch("leggen.utils.config.config", mock_config):
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.accounts.database_service.get_account_details_from_db",
return_value=mock_account,
),
patch(
"leggen.api.routes.accounts.database_service.get_balances_from_db",
return_value=mock_balances,
),
):
response = api_client.get("/api/v1/accounts/test-account-123")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200
data = response.json()
assert data["id"] == "test-account-123"
@@ -137,13 +121,7 @@ class TestAccountsAPI:
assert len(data["balances"]) == 1
def test_get_account_balances_success(
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_db_path,
mock_balance_repo,
self, api_client, mock_config, mock_auth_token, mock_db_path
):
"""Test successful retrieval of account balances from database."""
mock_balances = [
@@ -171,17 +149,15 @@ class TestAccountsAPI:
},
]
mock_balance_repo.get_balances.return_value = mock_balances
fastapi_app.dependency_overrides[get_balance_repository] = (
lambda: mock_balance_repo
)
with patch("leggen.utils.config.config", mock_config):
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.accounts.database_service.get_balances_from_db",
return_value=mock_balances,
),
):
response = api_client.get("/api/v1/accounts/test-account-123/balances")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200
data = response.json()
assert len(data) == 2
@@ -191,14 +167,12 @@ class TestAccountsAPI:
def test_get_account_transactions_success(
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
sample_account_data,
sample_transaction_data,
mock_db_path,
mock_transaction_repo,
):
"""Test successful retrieval of account transactions from database."""
mock_transactions = [
@@ -217,19 +191,21 @@ class TestAccountsAPI:
}
]
mock_transaction_repo.get_transactions.return_value = mock_transactions
fastapi_app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
with patch("leggen.utils.config.config", mock_config):
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.accounts.database_service.get_transactions_from_db",
return_value=mock_transactions,
),
patch(
"leggen.api.routes.accounts.database_service.get_transaction_count_from_db",
return_value=1,
),
):
response = api_client.get(
"/api/v1/accounts/test-account-123/transactions?summary_only=true"
)
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200
data = response.json()
assert len(data) == 1
@@ -242,14 +218,12 @@ class TestAccountsAPI:
def test_get_account_transactions_full_details(
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
sample_account_data,
sample_transaction_data,
mock_db_path,
mock_transaction_repo,
):
"""Test retrieval of full transaction details from database."""
mock_transactions = [
@@ -268,19 +242,21 @@ class TestAccountsAPI:
}
]
mock_transaction_repo.get_transactions.return_value = mock_transactions
fastapi_app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
with patch("leggen.utils.config.config", mock_config):
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.accounts.database_service.get_transactions_from_db",
return_value=mock_transactions,
),
patch(
"leggen.api.routes.accounts.database_service.get_transaction_count_from_db",
return_value=1,
),
):
response = api_client.get(
"/api/v1/accounts/test-account-123/transactions?summary_only=false"
)
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200
data = response.json()
assert len(data) == 1
@@ -292,36 +268,22 @@ class TestAccountsAPI:
assert "raw_transaction" in transaction
def test_get_account_not_found(
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_db_path,
mock_account_repo,
self, api_client, mock_config, mock_auth_token, mock_db_path
):
"""Test handling of non-existent account."""
mock_account_repo.get_account.return_value = None
fastapi_app.dependency_overrides[get_account_repository] = (
lambda: mock_account_repo
)
with patch("leggen.utils.config.config", mock_config):
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.accounts.database_service.get_account_details_from_db",
return_value=None,
),
):
response = api_client.get("/api/v1/accounts/nonexistent")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 404
def test_update_account_display_name_success(
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_db_path,
mock_account_repo,
self, api_client, mock_config, mock_auth_token, mock_db_path
):
"""Test successful update of account display name."""
mock_account = {
@@ -335,48 +297,41 @@ class TestAccountsAPI:
"last_accessed": "2025-09-01T09:30:00Z",
}
mock_account_repo.get_account.return_value = mock_account
mock_account_repo.persist.return_value = mock_account
fastapi_app.dependency_overrides[get_account_repository] = (
lambda: mock_account_repo
)
with patch("leggen.utils.config.config", mock_config):
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.accounts.database_service.get_account_details_from_db",
return_value=mock_account,
),
patch(
"leggen.api.routes.accounts.database_service.persist_account_details",
return_value=None,
),
):
response = api_client.put(
"/api/v1/accounts/test-account-123",
json={"display_name": "My Custom Account Name"},
)
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200
data = response.json()
assert data["id"] == "test-account-123"
assert data["display_name"] == "My Custom Account Name"
def test_update_account_not_found(
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_db_path,
mock_account_repo,
self, api_client, mock_config, mock_auth_token, mock_db_path
):
"""Test updating non-existent account."""
mock_account_repo.get_account.return_value = None
fastapi_app.dependency_overrides[get_account_repository] = (
lambda: mock_account_repo
)
with patch("leggen.utils.config.config", mock_config):
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.accounts.database_service.get_account_details_from_db",
return_value=None,
),
):
response = api_client.put(
"/api/v1/accounts/nonexistent",
json={"display_name": "New Name"},
)
fastapi_app.dependency_overrides.clear()
assert response.status_code == 404

View File

@@ -5,20 +5,13 @@ from unittest.mock import patch
import pytest
from leggen.api.dependencies import get_transaction_repository
@pytest.mark.api
class TestTransactionsAPI:
"""Test transaction-related API endpoints."""
def test_get_all_transactions_success(
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
self, api_client, mock_config, mock_auth_token
):
"""Test successful retrieval of all transactions from database."""
mock_transactions = [
@@ -50,17 +43,19 @@ class TestTransactionsAPI:
},
]
mock_transaction_repo.get_transactions.return_value = mock_transactions
mock_transaction_repo.get_count.return_value = len(mock_transactions)
fastapi_app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
with patch("leggen.utils.config.config", mock_config):
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.transactions.database_service.get_transactions_from_db",
return_value=mock_transactions,
),
patch(
"leggen.api.routes.transactions.database_service.get_transaction_count_from_db",
return_value=2,
),
):
response = api_client.get("/api/v1/transactions?summary_only=true")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200
data = response.json()
assert len(data["data"]) == 2
@@ -75,12 +70,7 @@ class TestTransactionsAPI:
assert transaction["account_id"] == "test-account-123"
def test_get_all_transactions_full_details(
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
self, api_client, mock_config, mock_auth_token
):
"""Test retrieval of full transaction details from database."""
mock_transactions = [
@@ -99,17 +89,19 @@ class TestTransactionsAPI:
}
]
mock_transaction_repo.get_transactions.return_value = mock_transactions
mock_transaction_repo.get_count.return_value = len(mock_transactions)
fastapi_app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
with patch("leggen.utils.config.config", mock_config):
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.transactions.database_service.get_transactions_from_db",
return_value=mock_transactions,
),
patch(
"leggen.api.routes.transactions.database_service.get_transaction_count_from_db",
return_value=1,
),
):
response = api_client.get("/api/v1/transactions?summary_only=false")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200
data = response.json()
assert len(data["data"]) == 1
@@ -122,12 +114,7 @@ class TestTransactionsAPI:
assert "raw_transaction" in transaction
def test_get_transactions_with_filters(
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
self, api_client, mock_config, mock_auth_token
):
"""Test getting transactions with various filters."""
mock_transactions = [
@@ -146,14 +133,17 @@ class TestTransactionsAPI:
}
]
mock_transaction_repo.get_transactions.return_value = mock_transactions
mock_transaction_repo.get_count.return_value = 1
fastapi_app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
with patch("leggen.utils.config.config", mock_config):
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.transactions.database_service.get_transactions_from_db",
return_value=mock_transactions,
) as mock_get_transactions,
patch(
"leggen.api.routes.transactions.database_service.get_transaction_count_from_db",
return_value=1,
),
):
response = api_client.get(
"/api/v1/transactions?"
"account_id=test-account-123&"
@@ -166,12 +156,10 @@ class TestTransactionsAPI:
"per_page=10"
)
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200
# Verify the repository was called with correct filters
mock_transaction_repo.get_transactions.assert_called_once_with(
# Verify the database service was called with correct filters
mock_get_transactions.assert_called_once_with(
account_id="test-account-123",
limit=10,
offset=10, # (page-1) * per_page = (2-1) * 10 = 10
@@ -183,26 +171,22 @@ class TestTransactionsAPI:
)
def test_get_transactions_empty_result(
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
self, api_client, mock_config, mock_auth_token
):
"""Test getting transactions when database returns empty result."""
mock_transaction_repo.get_transactions.return_value = []
mock_transaction_repo.get_count.return_value = 0
fastapi_app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
with patch("leggen.utils.config.config", mock_config):
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.transactions.database_service.get_transactions_from_db",
return_value=[],
),
patch(
"leggen.api.routes.transactions.database_service.get_transaction_count_from_db",
return_value=0,
),
):
response = api_client.get("/api/v1/transactions")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200
data = response.json()
assert len(data["data"]) == 0
@@ -211,37 +195,23 @@ class TestTransactionsAPI:
assert data["total_pages"] == 0
def test_get_transactions_database_error(
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
self, api_client, mock_config, mock_auth_token
):
"""Test handling database error when getting transactions."""
mock_transaction_repo.get_transactions.side_effect = Exception(
"Database connection failed"
)
fastapi_app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
with patch("leggen.utils.config.config", mock_config):
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.transactions.database_service.get_transactions_from_db",
side_effect=Exception("Database connection failed"),
),
):
response = api_client.get("/api/v1/transactions")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 500
assert "Failed to get transactions" in response.json()["detail"]
def test_get_transaction_stats_success(
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
self, api_client, mock_config, mock_auth_token
):
"""Test successful retrieval of transaction statistics from database."""
mock_transactions = [
@@ -268,16 +238,15 @@ class TestTransactionsAPI:
},
]
mock_transaction_repo.get_transactions.return_value = mock_transactions
fastapi_app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
with patch("leggen.utils.config.config", mock_config):
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.transactions.database_service.get_transactions_from_db",
return_value=mock_transactions,
),
):
response = api_client.get("/api/v1/transactions/stats?days=30")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200
data = response.json()
@@ -295,12 +264,7 @@ class TestTransactionsAPI:
assert data["average_transaction"] == expected_avg
def test_get_transaction_stats_with_account_filter(
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
self, api_client, mock_config, mock_auth_token
):
"""Test getting transaction stats filtered by account."""
mock_transactions = [
@@ -313,46 +277,37 @@ class TestTransactionsAPI:
}
]
mock_transaction_repo.get_transactions.return_value = mock_transactions
fastapi_app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
with patch("leggen.utils.config.config", mock_config):
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.transactions.database_service.get_transactions_from_db",
return_value=mock_transactions,
) as mock_get_transactions,
):
response = api_client.get(
"/api/v1/transactions/stats?account_id=test-account-123"
)
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200
# Verify the repository was called with account filter
mock_transaction_repo.get_transactions.assert_called_once()
call_kwargs = mock_transaction_repo.get_transactions.call_args.kwargs
# Verify the database service was called with account filter
mock_get_transactions.assert_called_once()
call_kwargs = mock_get_transactions.call_args.kwargs
assert call_kwargs["account_id"] == "test-account-123"
def test_get_transaction_stats_empty_result(
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
self, api_client, mock_config, mock_auth_token
):
"""Test getting stats when no transactions match criteria."""
mock_transaction_repo.get_transactions.return_value = []
fastapi_app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
with patch("leggen.utils.config.config", mock_config):
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.transactions.database_service.get_transactions_from_db",
return_value=[],
),
):
response = api_client.get("/api/v1/transactions/stats")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200
data = response.json()
@@ -364,37 +319,23 @@ class TestTransactionsAPI:
assert data["accounts_included"] == 0
def test_get_transaction_stats_database_error(
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
self, api_client, mock_config, mock_auth_token
):
"""Test handling database error when getting stats."""
mock_transaction_repo.get_transactions.side_effect = Exception(
"Database connection failed"
)
fastapi_app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
with patch("leggen.utils.config.config", mock_config):
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.transactions.database_service.get_transactions_from_db",
side_effect=Exception("Database connection failed"),
),
):
response = api_client.get("/api/v1/transactions/stats")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 500
assert "Failed to get transaction stats" in response.json()["detail"]
def test_get_transaction_stats_custom_period(
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
self, api_client, mock_config, mock_auth_token
):
"""Test getting transaction stats for custom time period."""
mock_transactions = [
@@ -407,23 +348,21 @@ class TestTransactionsAPI:
}
]
mock_transaction_repo.get_transactions.return_value = mock_transactions
fastapi_app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
with patch("leggen.utils.config.config", mock_config):
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.transactions.database_service.get_transactions_from_db",
return_value=mock_transactions,
) as mock_get_transactions,
):
response = api_client.get("/api/v1/transactions/stats?days=7")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200
data = response.json()
assert data["period_days"] == 7
# Verify the date range was calculated correctly for 7 days
mock_transaction_repo.get_transactions.assert_called_once()
call_kwargs = mock_transaction_repo.get_transactions.call_args.kwargs
mock_get_transactions.assert_called_once()
call_kwargs = mock_get_transactions.call_args.kwargs
assert "date_from" in call_kwargs
assert "date_to" in call_kwargs

View File

@@ -120,10 +120,12 @@ class TestConfigurablePaths:
"iban": "TEST_IBAN",
}
# Use the public balance persistence method
# Use the internal balance persistence method since the test needs direct database access
import asyncio
asyncio.run(database_service.persist_balance("test-account", balance_data))
asyncio.run(
database_service._persist_balance_sqlite("test-account", balance_data)
)
# Retrieve balances
balances = asyncio.run(

View File

@@ -85,7 +85,7 @@ class TestDatabaseService:
):
"""Test successful retrieval of transactions from database."""
with patch.object(
database_service.transactions, "get_transactions"
database_service, "_get_transactions"
) as mock_get_transactions:
mock_get_transactions.return_value = sample_transactions_db_format
@@ -111,7 +111,7 @@ class TestDatabaseService:
):
"""Test retrieving transactions with filters."""
with patch.object(
database_service.transactions, "get_transactions"
database_service, "_get_transactions"
) as mock_get_transactions:
mock_get_transactions.return_value = sample_transactions_db_format
@@ -149,7 +149,7 @@ class TestDatabaseService:
async def test_get_transactions_from_db_error(self, database_service):
"""Test handling error when getting transactions."""
with patch.object(
database_service.transactions, "get_transactions"
database_service, "_get_transactions"
) as mock_get_transactions:
mock_get_transactions.side_effect = Exception("Database error")
@@ -159,7 +159,7 @@ class TestDatabaseService:
async def test_get_transaction_count_from_db_success(self, database_service):
"""Test successful retrieval of transaction count."""
with patch.object(database_service.transactions, "get_count") as mock_get_count:
with patch.object(database_service, "_get_transaction_count") as mock_get_count:
mock_get_count.return_value = 42
result = await database_service.get_transaction_count_from_db(
@@ -167,18 +167,11 @@ class TestDatabaseService:
)
assert result == 42
mock_get_count.assert_called_once_with(
account_id="test-account-123",
date_from=None,
date_to=None,
min_amount=None,
max_amount=None,
search=None,
)
mock_get_count.assert_called_once_with(account_id="test-account-123")
async def test_get_transaction_count_from_db_with_filters(self, database_service):
"""Test getting transaction count with filters."""
with patch.object(database_service.transactions, "get_count") as mock_get_count:
with patch.object(database_service, "_get_transaction_count") as mock_get_count:
mock_get_count.return_value = 15
result = await database_service.get_transaction_count_from_db(
@@ -192,9 +185,7 @@ class TestDatabaseService:
mock_get_count.assert_called_once_with(
account_id="test-account-123",
date_from="2025-09-01",
date_to=None,
min_amount=-100.0,
max_amount=None,
search="Coffee",
)
@@ -210,7 +201,7 @@ class TestDatabaseService:
async def test_get_transaction_count_from_db_error(self, database_service):
"""Test handling error when getting count."""
with patch.object(database_service.transactions, "get_count") as mock_get_count:
with patch.object(database_service, "_get_transaction_count") as mock_get_count:
mock_get_count.side_effect = Exception("Database error")
result = await database_service.get_transaction_count_from_db()
@@ -221,9 +212,7 @@ class TestDatabaseService:
self, database_service, sample_balances_db_format
):
"""Test successful retrieval of balances from database."""
with patch.object(
database_service.balances, "get_balances"
) as mock_get_balances:
with patch.object(database_service, "_get_balances") as mock_get_balances:
mock_get_balances.return_value = sample_balances_db_format
result = await database_service.get_balances_from_db(
@@ -245,9 +234,7 @@ class TestDatabaseService:
async def test_get_balances_from_db_error(self, database_service):
"""Test handling error when getting balances."""
with patch.object(
database_service.balances, "get_balances"
) as mock_get_balances:
with patch.object(database_service, "_get_balances") as mock_get_balances:
mock_get_balances.side_effect = Exception("Database error")
result = await database_service.get_balances_from_db()
@@ -262,9 +249,7 @@ class TestDatabaseService:
"iban": "LT313250081177977789",
}
with patch.object(
database_service.transactions, "get_account_summary"
) as mock_get_summary:
with patch.object(database_service, "_get_account_summary") as mock_get_summary:
mock_get_summary.return_value = mock_summary
result = await database_service.get_account_summary_from_db(
@@ -284,9 +269,7 @@ class TestDatabaseService:
async def test_get_account_summary_from_db_error(self, database_service):
"""Test handling error when getting summary."""
with patch.object(
database_service.transactions, "get_account_summary"
) as mock_get_summary:
with patch.object(database_service, "_get_account_summary") as mock_get_summary:
mock_get_summary.side_effect = Exception("Database error")
result = await database_service.get_account_summary_from_db(
@@ -308,87 +291,87 @@ class TestDatabaseService:
],
}
with (
patch.object(database_service.balances, "persist") as mock_persist,
patch.object(
database_service.balance_transformer, "transform_to_database_format"
) as mock_transform,
):
mock_transform.return_value = [
(
"test-account-123",
"REVOLUT_REVOLT21",
"active",
"LT313250081177977789",
1000.0,
"EUR",
"interimAvailable",
"2025-09-01T10:00:00",
)
]
with patch("sqlite3.connect") as mock_connect:
mock_conn = mock_connect.return_value
mock_cursor = mock_conn.cursor.return_value
await database_service.persist_balance("test-account-123", balance_data)
await database_service._persist_balance_sqlite(
"test-account-123", balance_data
)
# Verify transformation and persistence were called
mock_transform.assert_called_once_with("test-account-123", balance_data)
mock_persist.assert_called_once()
# Verify database operations
mock_connect.assert_called()
mock_cursor.execute.assert_called() # Table creation and insert
mock_conn.commit.assert_called_once()
mock_conn.close.assert_called_once()
async def test_persist_balance_sqlite_error(self, database_service):
"""Test handling error during balance persistence."""
balance_data = {"balances": []}
with (
patch.object(database_service.balances, "persist") as mock_persist,
patch.object(
database_service.balance_transformer, "transform_to_database_format"
) as mock_transform,
):
mock_persist.side_effect = Exception("Database error")
mock_transform.return_value = []
with patch("sqlite3.connect") as mock_connect:
mock_connect.side_effect = Exception("Database error")
with pytest.raises(Exception, match="Database error"):
await database_service.persist_balance("test-account-123", balance_data)
await database_service._persist_balance_sqlite(
"test-account-123", balance_data
)
async def test_persist_transactions_sqlite_success(
self, database_service, sample_transactions_db_format
):
"""Test successful transaction persistence."""
with patch.object(database_service.transactions, "persist") as mock_persist:
mock_persist.return_value = sample_transactions_db_format
with patch("sqlite3.connect") as mock_connect:
mock_conn = mock_connect.return_value
mock_cursor = mock_conn.cursor.return_value
# Mock fetchone to return (0,) indicating transaction doesn't exist yet
mock_cursor.fetchone.return_value = (0,)
result = await database_service.persist_transactions(
result = await database_service._persist_transactions_sqlite(
"test-account-123", sample_transactions_db_format
)
# Should return the new transactions
assert len(result) == 2
mock_persist.assert_called_once_with(
"test-account-123", sample_transactions_db_format
)
# Should return the transactions (assuming no duplicates)
assert len(result) >= 0 # Could be empty if all are duplicates
# Verify database operations
mock_connect.assert_called()
mock_cursor.execute.assert_called()
mock_conn.commit.assert_called_once()
mock_conn.close.assert_called_once()
async def test_persist_transactions_sqlite_duplicate_detection(
self, database_service, sample_transactions_db_format
):
"""Test that existing transactions are not returned as new."""
with patch.object(database_service.transactions, "persist") as mock_persist:
# Return empty list indicating all were duplicates
mock_persist.return_value = []
with patch("sqlite3.connect") as mock_connect:
mock_conn = mock_connect.return_value
mock_cursor = mock_conn.cursor.return_value
# Mock fetchone to return (1,) indicating transaction already exists
mock_cursor.fetchone.return_value = (1,)
result = await database_service.persist_transactions(
result = await database_service._persist_transactions_sqlite(
"test-account-123", sample_transactions_db_format
)
# Should return empty list since all transactions already exist
assert len(result) == 0
mock_persist.assert_called_once()
# Verify database operations still happened (INSERT OR REPLACE executed)
mock_connect.assert_called()
mock_cursor.execute.assert_called()
mock_conn.commit.assert_called_once()
mock_conn.close.assert_called_once()
async def test_persist_transactions_sqlite_error(self, database_service):
"""Test handling error during transaction persistence."""
with patch.object(database_service.transactions, "persist") as mock_persist:
mock_persist.side_effect = Exception("Database error")
with patch("sqlite3.connect") as mock_connect:
mock_connect.side_effect = Exception("Database error")
with pytest.raises(Exception, match="Database error"):
await database_service.persist_transactions("test-account-123", [])
await database_service._persist_transactions_sqlite(
"test-account-123", []
)
async def test_process_transactions_booked_and_pending(self, database_service):
"""Test processing transactions with both booked and pending."""

View File

@@ -27,7 +27,9 @@ class TestSyncNotifications:
patch.object(
sync_service.notifications, "send_sync_failure_notification"
) as mock_send_notification,
patch.object(sync_service.sync, "persist", return_value=1),
patch.object(
sync_service.database, "persist_sync_operation", return_value=1
),
):
# Setup: One requisition with one account that will fail
mock_get_requisitions.return_value = {
@@ -67,7 +69,9 @@ class TestSyncNotifications:
patch.object(
sync_service.notifications, "send_expiry_notification"
) as mock_send_expiry,
patch.object(sync_service.sync, "persist", return_value=1),
patch.object(
sync_service.database, "persist_sync_operation", return_value=1
),
):
# Setup: One expired requisition
mock_get_requisitions.return_value = {
@@ -108,7 +112,9 @@ class TestSyncNotifications:
patch.object(
sync_service.notifications, "send_sync_failure_notification"
) as mock_send_notification,
patch.object(sync_service.sync, "persist", return_value=1),
patch.object(
sync_service.database, "persist_sync_operation", return_value=1
),
):
# Setup: One requisition with two accounts that will fail
mock_get_requisitions.return_value = {
@@ -154,15 +160,17 @@ class TestSyncNotifications:
sync_service.notifications, "send_sync_failure_notification"
) as mock_send_notification,
patch.object(sync_service.notifications, "send_transaction_notifications"),
patch.object(sync_service.accounts, "persist"),
patch.object(sync_service.balances, "persist"),
patch.object(sync_service.database, "persist_account_details"),
patch.object(sync_service.database, "persist_balance"),
patch.object(
sync_service.transaction_processor,
"process_transactions",
return_value=[],
sync_service.database, "process_transactions", return_value=[]
),
patch.object(
sync_service.database, "persist_transactions", return_value=[]
),
patch.object(
sync_service.database, "persist_sync_operation", return_value=1
),
patch.object(sync_service.transactions, "persist", return_value=[]),
patch.object(sync_service.sync, "persist", return_value=1),
):
# Setup: One requisition with one account that succeeds
mock_get_requisitions.return_value = {
@@ -214,7 +222,9 @@ class TestSyncNotifications:
patch.object(
sync_service.notifications, "_send_telegram_sync_failure"
) as mock_telegram_notification,
patch.object(sync_service.sync, "persist", return_value=1),
patch.object(
sync_service.database, "persist_sync_operation", return_value=1
),
):
# Setup: One requisition with one account that will fail
mock_get_requisitions.return_value = {