Compare commits

...

17 Commits

Author SHA1 Message Date
Elisiário Couto
9e9b1cf15f refactor(api): Update all modified files with dependency injection changes. 2025-12-10 00:18:53 +00:00
Elisiário Couto
9dc6357905 refactor(api): Remove DatabaseService layer and implement dependency injection.
- Remove DatabaseService abstraction layer from API routes
- Implement FastAPI dependency injection for repositories
- Create leggen/api/dependencies.py with factory functions
- Update routes to use AccountRepo, BalanceRepo, TransactionRepo directly
- Refactor SyncService to use repositories instead of DatabaseService
- Deprecate DatabaseService with warnings for backward compatibility
- Update all tests to use FastAPI dependency overrides pattern
- Fix mypy type errors in routes

Benefits:
- Simpler architecture with one less abstraction layer
- More explicit dependencies via function signatures
- Better testability with FastAPI's app.dependency_overrides
- Clearer separation of concerns

All 114 tests passing, mypy clean.
2025-12-10 00:18:53 +00:00
Elisiário Couto
5f87991076 refactor(api): Split DatabaseService into repository pattern.
Split the monolithic DatabaseService (1,492 lines) into focused repository
modules using the repository pattern for better maintainability and
separation of concerns.

Changes:
- Create new repositories/ directory with 5 focused repositories:
  - TransactionRepository: transaction data operations (264 lines)
  - AccountRepository: account data operations (128 lines)
  - BalanceRepository: balance data operations (107 lines)
  - MigrationRepository: all database migrations (629 lines)
  - SyncRepository: sync operation tracking (132 lines)
  - BaseRepository: shared database connection logic (28 lines)

- Refactor DatabaseService into a clean facade (287 lines):
  - Delegates data access to repositories
  - Maintains public API (no breaking changes)
  - Keeps data processors in service layer
  - Preserves require_sqlite decorator

- Update tests to mock repository methods instead of private methods
- Fix test references to internal methods (_persist_*, _get_*)

Benefits:
- Clear separation of concerns (one repository per domain)
- Easier maintenance (changes isolated to specific repositories)
- Better testability (repositories can be mocked individually)
- Improved code organization (from 1 file to 7 focused files)

All 114 tests passing.
2025-12-08 23:21:55 +00:00
Elisiário Couto
267db8ac63 refactor(api): Improve database connection management and reduce boilerplate.
- Add context manager for database connections with proper cleanup
- Add @require_sqlite decorator to eliminate duplicate checks
- Refactor 9 core CRUD methods to use managed connections
- Reduce code by 50 lines while improving resource management
- All 114 tests passing
2025-12-08 22:54:57 +00:00
Elisiário Couto
7007043521 Reformat. 2025-12-08 22:05:04 +00:00
Elisiário Couto
fbb3eb9e64 refactor: Consolidate service layer with dedicated data processors.
Introduces a DataProcessor layer to separate transformation logic from orchestration and persistence concerns:

- Created data_processors/ directory with AccountEnricher, BalanceTransformer, AnalyticsProcessor, and moved TransactionProcessor
- Refactored SyncService to pure orchestrator, removing account/balance enrichment logic
- Refactored DatabaseService to pure CRUD, removing analytics and transformation logic
- Extracted 90+ lines of analytics SQL from DatabaseService to AnalyticsProcessor
- Extracted 80+ lines of balance transformation logic to BalanceTransformer
- Maintained backward compatibility - all 109 tests pass
- No API contract changes

This improves code clarity, testability, and maintainability while maintaining the existing API surface.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Haiku 4.5 <noreply@anthropic.com>
2025-12-08 22:04:00 +00:00
Elisiário Couto
3d5994bf30 Fix lint issues. 2025-12-08 21:59:54 +00:00
Elisiário Couto
edbc1cb39e Update frontend/src/components/TransactionsTable.tsx
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-12-08 21:59:54 +00:00
Elisiário Couto
504f78aa85 Update frontend/src/components/filters/FilterBar.tsx
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-12-08 21:59:54 +00:00
Elisiário Couto
cbbc316537 chore(ci): Fix workflow permissions. 2025-12-08 21:59:54 +00:00
Elisiário Couto
18ee52bdff fix(frontend): Prevent full transactions page reload on search. 2025-12-08 21:59:54 +00:00
Elisiário Couto
07edfeaf25 fix(frontend): Blur balances in transactions page cards. 2025-12-08 21:59:54 +00:00
copilot-swe-agent[bot]
c8b161e7f2 refactor(frontend): Address code review feedback on focus and currency handling.
Co-authored-by: elisiariocouto <818914+elisiariocouto@users.noreply.github.com>
2025-12-08 21:59:54 +00:00
copilot-swe-agent[bot]
2c85722fd0 feat(frontend): Fix search focus issue and add transaction statistics.
Co-authored-by: elisiariocouto <818914+elisiariocouto@users.noreply.github.com>
2025-12-08 21:59:54 +00:00
copilot-swe-agent[bot]
88037f328d fix: Address code review feedback on notification error handling.
Co-authored-by: elisiariocouto <818914+elisiariocouto@users.noreply.github.com>
2025-12-07 19:05:28 +00:00
copilot-swe-agent[bot]
d58894d07c refactor: Replace magic numbers with named constants.
Co-authored-by: elisiariocouto <818914+elisiariocouto@users.noreply.github.com>
2025-12-07 19:05:28 +00:00
copilot-swe-agent[bot]
1a2ec45f89 feat: Add sync error and account expiry notifications.
Co-authored-by: elisiariocouto <818914+elisiariocouto@users.noreply.github.com>
2025-12-07 19:05:28 +00:00
34 changed files with 3048 additions and 1988 deletions

View File

@@ -6,6 +6,9 @@ on:
pull_request:
branches: ["main", "dev"]
permissions:
contents: read
jobs:
test-python:
name: Test Python

View File

@@ -5,6 +5,11 @@ on:
tags:
- "**"
permissions:
contents: write
packages: write
id-token: write
jobs:
build:
runs-on: ubuntu-latest
@@ -44,6 +49,9 @@ jobs:
push-docker-backend:
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -90,6 +98,9 @@ jobs:
push-docker-frontend:
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -137,6 +148,8 @@ jobs:
create-github-release:
name: Create GitHub Release
runs-on: ubuntu-latest
permissions:
contents: write
needs: [build, publish-to-pypi, push-docker-backend, push-docker-frontend]
steps:
- name: Checkout

127
REFACTORING_SUMMARY.md Normal file
View File

@@ -0,0 +1,127 @@
# Backend Refactoring Summary
## What Was Accomplished ✅
### 1. Removed DatabaseService Layer from Production Code
- **Removed**: The `DatabaseService` class is no longer used in production API routes
- **Replaced with**: Direct repository usage via FastAPI dependency injection
- **Files changed**:
- `leggen/api/routes/accounts.py` - Now uses `AccountRepo`, `BalanceRepo`, `TransactionRepo`, `AnalyticsProc`
- `leggen/api/routes/transactions.py` - Now uses `TransactionRepo`, `AnalyticsProc`
- `leggen/services/sync_service.py` - Now uses repositories directly
- `leggen/commands/server.py` - Now uses `MigrationRepository` directly
### 2. Created Dependency Injection System
- **New file**: `leggen/api/dependencies.py`
- **Provides**: Centralized dependency injection setup for FastAPI
- **Includes**: Factory functions for all repositories and data processors
- **Type annotations**: `AccountRepo`, `BalanceRepo`, `TransactionRepo`, etc.
### 3. Simplified Code Architecture
- **Before**: Routes → DatabaseService → Repositories
- **After**: Routes → Repositories (via DI)
- **Benefits**:
- One less layer of indirection
- Clearer dependencies
- Easier to test with FastAPI's `app.dependency_overrides`
- Better separation of concerns
### 4. Maintained Backward Compatibility
- **DatabaseService** is kept but deprecated for test compatibility
- Added deprecation warning when instantiated
- Tests continue to work without immediate changes required
## Code Statistics
- **Lines removed from API layer**: ~16 imports of DatabaseService
- **New dependency injection file**: 80 lines
- **Files refactored**: 4 main files
## Benefits Achieved
1. **Cleaner Architecture**: Removed unnecessary abstraction layer
2. **Better Testability**: FastAPI dependency overrides are cleaner than mocking
3. **More Explicit Dependencies**: Function signatures show exactly what's needed
4. **Easier to Maintain**: Less indirection makes code easier to follow
5. **Performance**: Slightly fewer object instantiations per request
## What Still Needs Work
### Tests Need Updating
The test files still patch `database_service` which no longer exists in routes:
```python
# Old test pattern (needs updating):
patch("leggen.api.routes.accounts.database_service.get_accounts_from_db")
# New pattern (should use):
app.dependency_overrides[get_account_repository] = lambda: mock_repo
```
**Files needing test updates**:
- `tests/unit/test_api_accounts.py` (7 tests failing)
- `tests/unit/test_api_transactions.py` (10 tests failing)
- `tests/unit/test_analytics_fix.py` (2 tests failing)
### Test Update Strategy
**Option 1 - Quick Fix (Recommended for now)**:
Keep `DatabaseService` and have routes import it again temporarily, update tests at leisure.
**Option 2 - Proper Fix**:
Update all tests to use FastAPI dependency overrides pattern:
```python
def test_get_accounts(fastapi_app, api_client, mock_account_repo):
mock_account_repo.get_accounts.return_value = [...]
fastapi_app.dependency_overrides[get_account_repository] = lambda: mock_account_repo
response = api_client.get("/api/v1/accounts")
fastapi_app.dependency_overrides.clear()
```
## Migration Path Forward
1.**Phase 1**: Refactor production code (DONE)
2. 🔄 **Phase 2**: Update tests to use dependency overrides (TODO)
3. 🔄 **Phase 3**: Remove deprecated `DatabaseService` completely (TODO)
4. 🔄 **Phase 4**: Consider extracting analytics logic to separate service (TODO)
## How to Use the New System
### In API Routes
```python
from leggen.api.dependencies import AccountRepo, BalanceRepo
@router.get("/accounts")
async def get_accounts(
account_repo: AccountRepo, # Injected automatically
balance_repo: BalanceRepo, # Injected automatically
) -> List[AccountDetails]:
accounts = account_repo.get_accounts()
# ...
```
### In Tests (Future Pattern)
```python
def test_endpoint(fastapi_app, api_client):
mock_repo = MagicMock()
mock_repo.get_accounts.return_value = [...]
fastapi_app.dependency_overrides[get_account_repository] = lambda: mock_repo
response = api_client.get("/api/v1/accounts")
# assertions...
```
## Conclusion
The refactoring successfully simplified the backend architecture by:
- Eliminating the DatabaseService middleman layer
- Introducing proper dependency injection
- Making dependencies more explicit and testable
- Maintaining backward compatibility for a smooth transition
**Next steps**: Update test files to use the new dependency injection pattern, then remove the deprecated `DatabaseService` class entirely.

View File

@@ -123,9 +123,13 @@ export default function TransactionsTable() {
search: debouncedSearchTerm || undefined,
summaryOnly: false,
}),
placeholderData: (previousData) => previousData,
});
const transactions = transactionsResponse?.data || [];
const transactions = useMemo(
() => transactionsResponse?.data || [],
[transactionsResponse],
);
const pagination = useMemo(
() =>
transactionsResponse
@@ -141,6 +145,31 @@ export default function TransactionsTable() {
[transactionsResponse],
);
// Calculate stats from current page transactions, memoized for performance
const stats = useMemo(() => {
const totalIncome = transactions
.filter((t: Transaction) => t.transaction_value > 0)
.reduce((sum: number, t: Transaction) => sum + t.transaction_value, 0);
const totalExpenses = Math.abs(
transactions
.filter((t: Transaction) => t.transaction_value < 0)
.reduce((sum: number, t: Transaction) => sum + t.transaction_value, 0)
);
// Get currency from first transaction, fallback to EUR
const displayCurrency = transactions.length > 0 ? transactions[0].transaction_currency : "EUR";
return {
totalCount: pagination?.total || 0,
pageCount: transactions.length,
totalIncome,
totalExpenses,
netChange: totalIncome - totalExpenses,
displayCurrency,
};
}, [transactions, pagination]);
// Check if search is currently debouncing
const isSearchLoading = filterState.searchTerm !== debouncedSearchTerm;
@@ -366,6 +395,78 @@ export default function TransactionsTable() {
isSearchLoading={isSearchLoading}
/>
{/* Transaction Statistics */}
{transactions.length > 0 && (
<div className="grid grid-cols-1 md:grid-cols-4 gap-4">
<Card className="p-4">
<div className="flex items-center justify-between">
<div>
<p className="text-xs text-muted-foreground uppercase tracking-wider">
Showing
</p>
<p className="text-2xl font-bold text-foreground mt-1">
{stats.pageCount}
</p>
<p className="text-xs text-muted-foreground mt-1">
of {stats.totalCount} total
</p>
</div>
</div>
</Card>
<Card className="p-4">
<div className="flex items-center justify-between">
<div>
<p className="text-xs text-muted-foreground uppercase tracking-wider">
Income
</p>
<BlurredValue className="text-2xl font-bold text-green-600 mt-1 block">
+{formatCurrency(stats.totalIncome, stats.displayCurrency)}
</BlurredValue>
</div>
<TrendingUp className="h-8 w-8 text-green-600 opacity-50" />
</div>
</Card>
<Card className="p-4">
<div className="flex items-center justify-between">
<div>
<p className="text-xs text-muted-foreground uppercase tracking-wider">
Expenses
</p>
<BlurredValue className="text-2xl font-bold text-red-600 mt-1 block">
-{formatCurrency(stats.totalExpenses, stats.displayCurrency)}
</BlurredValue>
</div>
<TrendingDown className="h-8 w-8 text-red-600 opacity-50" />
</div>
</Card>
<Card className="p-4">
<div className="flex items-center justify-between">
<div>
<p className="text-xs text-muted-foreground uppercase tracking-wider">
Net Change
</p>
<BlurredValue
className={`text-2xl font-bold mt-1 block ${
stats.netChange >= 0 ? "text-green-600" : "text-red-600"
}`}
>
{stats.netChange >= 0 ? "+" : ""}
{formatCurrency(stats.netChange, stats.displayCurrency)}
</BlurredValue>
</div>
{stats.netChange >= 0 ? (
<TrendingUp className="h-8 w-8 text-green-600 opacity-50" />
) : (
<TrendingDown className="h-8 w-8 text-red-600 opacity-50" />
)}
</div>
</Card>
</div>
)}
{/* Responsive Table/Cards */}
<Card>
{/* Desktop Table View (hidden on mobile) */}

View File

@@ -1,3 +1,4 @@
import { useRef, useEffect } from "react";
import { Search } from "lucide-react";
import { Input } from "@/components/ui/input";
import { cn } from "@/lib/utils";
@@ -30,6 +31,21 @@ export function FilterBar({
isSearchLoading = false,
className,
}: FilterBarProps) {
const searchInputRef = useRef<HTMLInputElement>(null);
const cursorPositionRef = useRef<number | null>(null);
// Maintain focus and cursor position on search input during re-renders
useEffect(() => {
const currentInput = searchInputRef.current;
if (!currentInput) return;
// Restore focus and cursor position after data fetches complete
if (cursorPositionRef.current !== null && document.activeElement !== currentInput) {
currentInput.focus();
currentInput.setSelectionRange(cursorPositionRef.current, cursorPositionRef.current);
}
}, [isSearchLoading]);
const hasActiveFilters =
filterState.searchTerm ||
filterState.selectedAccount ||
@@ -61,9 +77,19 @@ export function FilterBar({
<div className="relative w-[200px]">
<Search className="absolute left-3 top-1/2 transform -translate-y-1/2 h-4 w-4 text-muted-foreground" />
<Input
ref={searchInputRef}
placeholder="Search transactions..."
value={filterState.searchTerm}
onChange={(e) => onFilterChange("searchTerm", e.target.value)}
onChange={(e) => {
cursorPositionRef.current = e.target.selectionStart;
onFilterChange("searchTerm", e.target.value);
}}
onFocus={() => {
cursorPositionRef.current = searchInputRef.current?.selectionStart ?? null;
}}
onBlur={() => {
cursorPositionRef.current = null;
}}
className="pl-9 pr-8 bg-background"
/>
{isSearchLoading && (
@@ -99,9 +125,19 @@ export function FilterBar({
<div className="relative">
<Search className="absolute left-3 top-1/2 transform -translate-y-1/2 h-4 w-4 text-muted-foreground" />
<Input
ref={searchInputRef}
placeholder="Search..."
value={filterState.searchTerm}
onChange={(e) => onFilterChange("searchTerm", e.target.value)}
onChange={(e) => {
cursorPositionRef.current = e.target.selectionStart;
onFilterChange("searchTerm", e.target.value);
}}
onFocus={() => {
cursorPositionRef.current = searchInputRef.current?.selectionStart ?? null;
}}
onBlur={() => {
cursorPositionRef.current = null;
}}
className="pl-9 pr-8 bg-background w-full"
/>
{isSearchLoading && (

View File

@@ -0,0 +1,75 @@
"""FastAPI dependency injection setup for repositories and services."""
from typing import Annotated
from fastapi import Depends
from leggen.repositories import (
AccountRepository,
BalanceRepository,
MigrationRepository,
SyncRepository,
TransactionRepository,
)
from leggen.services.data_processors import (
AnalyticsProcessor,
BalanceTransformer,
TransactionProcessor,
)
from leggen.utils.config import config
def get_account_repository() -> AccountRepository:
"""Get account repository instance."""
return AccountRepository()
def get_balance_repository() -> BalanceRepository:
"""Get balance repository instance."""
return BalanceRepository()
def get_transaction_repository() -> TransactionRepository:
"""Get transaction repository instance."""
return TransactionRepository()
def get_sync_repository() -> SyncRepository:
"""Get sync repository instance."""
return SyncRepository()
def get_migration_repository() -> MigrationRepository:
"""Get migration repository instance."""
return MigrationRepository()
def get_transaction_processor() -> TransactionProcessor:
"""Get transaction processor instance."""
return TransactionProcessor()
def get_balance_transformer() -> BalanceTransformer:
"""Get balance transformer instance."""
return BalanceTransformer()
def get_analytics_processor() -> AnalyticsProcessor:
"""Get analytics processor instance."""
return AnalyticsProcessor()
def is_sqlite_enabled() -> bool:
"""Check if SQLite is enabled in configuration."""
return config.database_config.get("sqlite", True)
# Type annotations for dependency injection
AccountRepo = Annotated[AccountRepository, Depends(get_account_repository)]
BalanceRepo = Annotated[BalanceRepository, Depends(get_balance_repository)]
TransactionRepo = Annotated[TransactionRepository, Depends(get_transaction_repository)]
SyncRepo = Annotated[SyncRepository, Depends(get_sync_repository)]
MigrationRepo = Annotated[MigrationRepository, Depends(get_migration_repository)]
TransactionProc = Annotated[TransactionProcessor, Depends(get_transaction_processor)]
BalanceTransform = Annotated[BalanceTransformer, Depends(get_balance_transformer)]
AnalyticsProc = Annotated[AnalyticsProcessor, Depends(get_analytics_processor)]

View File

@@ -3,6 +3,12 @@ from typing import List, Optional, Union
from fastapi import APIRouter, HTTPException, Query
from loguru import logger
from leggen.api.dependencies import (
AccountRepo,
AnalyticsProc,
BalanceRepo,
TransactionRepo,
)
from leggen.api.models.accounts import (
AccountBalance,
AccountDetails,
@@ -10,28 +16,27 @@ from leggen.api.models.accounts import (
Transaction,
TransactionSummary,
)
from leggen.services.database_service import DatabaseService
router = APIRouter()
database_service = DatabaseService()
@router.get("/accounts")
async def get_all_accounts() -> List[AccountDetails]:
async def get_all_accounts(
account_repo: AccountRepo,
balance_repo: BalanceRepo,
) -> List[AccountDetails]:
"""Get all connected accounts from database"""
try:
accounts = []
# Get all account details from database
db_accounts = await database_service.get_accounts_from_db()
db_accounts = account_repo.get_accounts()
# Process accounts found in database
for db_account in db_accounts:
try:
# Get latest balances from database for this account
balances_data = await database_service.get_balances_from_db(
db_account["id"]
)
balances_data = balance_repo.get_balances(db_account["id"])
# Process balances
balances = []
@@ -77,11 +82,15 @@ async def get_all_accounts() -> List[AccountDetails]:
@router.get("/accounts/{account_id}")
async def get_account_details(account_id: str) -> AccountDetails:
async def get_account_details(
account_id: str,
account_repo: AccountRepo,
balance_repo: BalanceRepo,
) -> AccountDetails:
"""Get details for a specific account from database"""
try:
# Get account details from database
db_account = await database_service.get_account_details_from_db(account_id)
db_account = account_repo.get_account(account_id)
if not db_account:
raise HTTPException(
@@ -89,7 +98,7 @@ async def get_account_details(account_id: str) -> AccountDetails:
)
# Get latest balances from database for this account
balances_data = await database_service.get_balances_from_db(account_id)
balances_data = balance_repo.get_balances(account_id)
# Process balances
balances = []
@@ -129,11 +138,14 @@ async def get_account_details(account_id: str) -> AccountDetails:
@router.get("/accounts/{account_id}/balances")
async def get_account_balances(account_id: str) -> List[AccountBalance]:
async def get_account_balances(
account_id: str,
balance_repo: BalanceRepo,
) -> List[AccountBalance]:
"""Get balances for a specific account from database"""
try:
# Get balances from database instead of GoCardless API
db_balances = await database_service.get_balances_from_db(account_id=account_id)
db_balances = balance_repo.get_balances(account_id=account_id)
balances = []
for balance in db_balances:
@@ -158,19 +170,20 @@ async def get_account_balances(account_id: str) -> List[AccountBalance]:
@router.get("/balances")
async def get_all_balances() -> List[dict]:
async def get_all_balances(
account_repo: AccountRepo,
balance_repo: BalanceRepo,
) -> List[dict]:
"""Get all balances from all accounts in database"""
try:
# Get all accounts first to iterate through them
db_accounts = await database_service.get_accounts_from_db()
db_accounts = account_repo.get_accounts()
all_balances = []
for db_account in db_accounts:
try:
# Get balances for this account
db_balances = await database_service.get_balances_from_db(
account_id=db_account["id"]
)
db_balances = balance_repo.get_balances(account_id=db_account["id"])
# Process balances and add account info
for balance in db_balances:
@@ -205,6 +218,7 @@ async def get_all_balances() -> List[dict]:
@router.get("/balances/history")
async def get_historical_balances(
analytics_proc: AnalyticsProc,
days: Optional[int] = Query(
default=365, le=1095, ge=1, description="Number of days of history to retrieve"
),
@@ -214,9 +228,12 @@ async def get_historical_balances(
) -> List[dict]:
"""Get historical balance progression calculated from transaction history"""
try:
from leggen.utils.paths import path_manager
# Get historical balances from database
historical_balances = await database_service.get_historical_balances_from_db(
account_id=account_id, days=days or 365
db_path = path_manager.get_database_path()
historical_balances = analytics_proc.calculate_historical_balances(
db_path, account_id=account_id, days=days or 365
)
return historical_balances
@@ -231,6 +248,7 @@ async def get_historical_balances(
@router.get("/accounts/{account_id}/transactions")
async def get_account_transactions(
account_id: str,
transaction_repo: TransactionRepo,
limit: Optional[int] = Query(default=100, le=500),
offset: Optional[int] = Query(default=0, ge=0),
summary_only: bool = Query(
@@ -240,10 +258,10 @@ async def get_account_transactions(
"""Get transactions for a specific account from database"""
try:
# Get transactions from database instead of GoCardless API
db_transactions = await database_service.get_transactions_from_db(
db_transactions = transaction_repo.get_transactions(
account_id=account_id,
limit=limit,
offset=offset,
offset=offset or 0,
)
data: Union[List[TransactionSummary], List[Transaction]]
@@ -294,11 +312,15 @@ async def get_account_transactions(
@router.put("/accounts/{account_id}")
async def update_account_details(account_id: str, update_data: AccountUpdate) -> dict:
async def update_account_details(
account_id: str,
update_data: AccountUpdate,
account_repo: AccountRepo,
) -> dict:
"""Update account details (currently only display_name)"""
try:
# Get current account details
current_account = await database_service.get_account_details_from_db(account_id)
current_account = account_repo.get_account(account_id)
if not current_account:
raise HTTPException(
@@ -311,7 +333,7 @@ async def update_account_details(account_id: str, update_data: AccountUpdate) ->
updated_account_data["display_name"] = update_data.display_name
# Persist updated account details
await database_service.persist_account_details(updated_account_data)
account_repo.persist(updated_account_data)
return {"id": account_id, "display_name": update_data.display_name}

View File

@@ -198,9 +198,10 @@ async def stop_scheduler() -> dict:
async def get_sync_operations(limit: int = 50, offset: int = 0) -> dict:
"""Get sync operations history"""
try:
operations = await sync_service.database.get_sync_operations(
limit=limit, offset=offset
)
from leggen.repositories import SyncRepository
sync_repo = SyncRepository()
operations = sync_repo.get_operations(limit=limit, offset=offset)
return {"operations": operations, "count": len(operations)}

View File

@@ -4,16 +4,16 @@ from typing import List, Optional, Union
from fastapi import APIRouter, HTTPException, Query
from loguru import logger
from leggen.api.dependencies import AnalyticsProc, TransactionRepo
from leggen.api.models.accounts import Transaction, TransactionSummary
from leggen.api.models.common import PaginatedResponse
from leggen.services.database_service import DatabaseService
router = APIRouter()
database_service = DatabaseService()
@router.get("/transactions")
async def get_all_transactions(
transaction_repo: TransactionRepo,
page: int = Query(default=1, ge=1, description="Page number (1-based)"),
per_page: int = Query(default=50, le=500, description="Items per page"),
summary_only: bool = Query(
@@ -43,7 +43,7 @@ async def get_all_transactions(
limit = per_page
# Get transactions from database instead of GoCardless API
db_transactions = await database_service.get_transactions_from_db(
db_transactions = transaction_repo.get_transactions(
account_id=account_id,
limit=limit,
offset=offset,
@@ -55,7 +55,7 @@ async def get_all_transactions(
)
# Get total count for pagination info (respecting the same filters)
total_transactions = await database_service.get_transaction_count_from_db(
total_transactions = transaction_repo.get_count(
account_id=account_id,
date_from=date_from,
date_to=date_to,
@@ -119,6 +119,7 @@ async def get_all_transactions(
@router.get("/transactions/stats")
async def get_transaction_stats(
transaction_repo: TransactionRepo,
days: int = Query(default=30, description="Number of days to include in stats"),
account_id: Optional[str] = Query(default=None, description="Filter by account ID"),
) -> dict:
@@ -133,7 +134,7 @@ async def get_transaction_stats(
date_to = end_date.isoformat()
# Get transactions from database
recent_transactions = await database_service.get_transactions_from_db(
recent_transactions = transaction_repo.get_transactions(
account_id=account_id,
date_from=date_from,
date_to=date_to,
@@ -198,6 +199,7 @@ async def get_transaction_stats(
@router.get("/transactions/analytics")
async def get_transactions_for_analytics(
transaction_repo: TransactionRepo,
days: int = Query(default=365, description="Number of days to include"),
account_id: Optional[str] = Query(default=None, description="Filter by account ID"),
) -> List[dict]:
@@ -212,7 +214,7 @@ async def get_transactions_for_analytics(
date_to = end_date.isoformat()
# Get ALL transactions from database (no limit for analytics)
transactions = await database_service.get_transactions_from_db(
transactions = transaction_repo.get_transactions(
account_id=account_id,
date_from=date_from,
date_to=date_to,
@@ -244,11 +246,14 @@ async def get_transactions_for_analytics(
@router.get("/transactions/monthly-stats")
async def get_monthly_transaction_stats(
analytics_proc: AnalyticsProc,
days: int = Query(default=365, description="Number of days to include"),
account_id: Optional[str] = Query(default=None, description="Filter by account ID"),
) -> List[dict]:
"""Get monthly transaction statistics aggregated by the database"""
try:
from leggen.utils.paths import path_manager
# Date range for monthly stats
end_date = datetime.now()
start_date = end_date - timedelta(days=days)
@@ -258,10 +263,9 @@ async def get_monthly_transaction_stats(
date_to = end_date.isoformat()
# Get monthly aggregated stats from database
monthly_stats = await database_service.get_monthly_transaction_stats_from_db(
account_id=account_id,
date_from=date_from,
date_to=date_to,
db_path = path_manager.get_database_path()
monthly_stats = analytics_proc.calculate_monthly_stats(
db_path, account_id=account_id, date_from=date_from, date_to=date_to
)
return monthly_stats

View File

@@ -28,10 +28,10 @@ async def lifespan(app: FastAPI):
# Run database migrations
try:
from leggen.services.database_service import DatabaseService
from leggen.api.dependencies import get_migration_repository
db_service = DatabaseService()
await db_service.run_migrations_if_needed()
migrations = get_migration_repository()
await migrations.run_all_migrations()
logger.info("Database migrations completed")
except Exception as e:
logger.error(f"Database migration failed: {e}")

View File

@@ -61,17 +61,13 @@ def send_sync_failure_notification(ctx: click.Context, notification: dict):
info("Sending sync failure notification to Discord")
webhook = DiscordWebhook(url=ctx.obj["notifications"]["discord"]["webhook"])
# Determine color and title based on failure type
if notification.get("type") == "sync_final_failure":
color = "ff0000" # Red for final failure
title = "🚨 Sync Final Failure"
description = (
f"Sync failed permanently after {notification['retry_count']} attempts"
)
else:
color = "ffaa00" # Orange for retry
title = "⚠️ Sync Failure"
description = f"Sync failed (attempt {notification['retry_count']}/{notification['max_retries']}). Will retry automatically..."
color = "ffaa00" # Orange for sync failure
title = "⚠️ Sync Failure"
# Build description with account info if available
description = "Account sync failed"
if notification.get("account_id"):
description = f"Account {notification['account_id']} sync failed"
embed = DiscordEmbed(
title=title,

View File

@@ -87,19 +87,14 @@ def send_sync_failure_notification(ctx: click.Context, notification: dict):
bot_url = f"https://api.telegram.org/bot{token}/sendMessage"
info("Sending sync failure notification to Telegram")
message = "*🚨 [Leggen](https://github.com/elisiariocouto/leggen)*\n"
message = "*⚠️ [Leggen](https://github.com/elisiariocouto/leggen)*\n"
message += "*Sync Failed*\n\n"
message += escape_markdown(f"Error: {notification['error']}\n")
if notification.get("type") == "sync_final_failure":
message += escape_markdown(
f"❌ Final failure after {notification['retry_count']} attempts\n"
)
else:
message += escape_markdown(
f"🔄 Attempt {notification['retry_count']}/{notification['max_retries']}\n"
)
message += escape_markdown("Will retry automatically...\n")
# Add account info if available
if notification.get("account_id"):
message += escape_markdown(f"Account: {notification['account_id']}\n")
message += escape_markdown(f"Error: {notification['error']}\n")
res = requests.post(
bot_url,

View File

@@ -0,0 +1,13 @@
from leggen.repositories.account_repository import AccountRepository
from leggen.repositories.balance_repository import BalanceRepository
from leggen.repositories.migration_repository import MigrationRepository
from leggen.repositories.sync_repository import SyncRepository
from leggen.repositories.transaction_repository import TransactionRepository
__all__ = [
"AccountRepository",
"BalanceRepository",
"MigrationRepository",
"SyncRepository",
"TransactionRepository",
]

View File

@@ -0,0 +1,128 @@
from typing import Any, Dict, List, Optional
from leggen.repositories.base_repository import BaseRepository
class AccountRepository(BaseRepository):
"""Repository for account data operations"""
def create_table(self):
"""Create accounts table with indexes"""
with self._get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"""CREATE TABLE IF NOT EXISTS accounts (
id TEXT PRIMARY KEY,
institution_id TEXT,
status TEXT,
iban TEXT,
name TEXT,
currency TEXT,
created DATETIME,
last_accessed DATETIME,
last_updated DATETIME,
display_name TEXT,
logo TEXT
)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_accounts_institution_id
ON accounts(institution_id)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_accounts_status
ON accounts(status)"""
)
conn.commit()
def persist(self, account_data: Dict[str, Any]) -> Dict[str, Any]:
"""Persist account details to database"""
self.create_table()
with self._get_db_connection() as conn:
cursor = conn.cursor()
# Check if account exists and preserve display_name
cursor.execute(
"SELECT display_name FROM accounts WHERE id = ?", (account_data["id"],)
)
existing_row = cursor.fetchone()
existing_display_name = existing_row[0] if existing_row else None
# Use existing display_name if not provided in account_data
display_name = account_data.get("display_name", existing_display_name)
cursor.execute(
"""INSERT OR REPLACE INTO accounts (
id,
institution_id,
status,
iban,
name,
currency,
created,
last_accessed,
last_updated,
display_name,
logo
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
(
account_data["id"],
account_data["institution_id"],
account_data["status"],
account_data.get("iban"),
account_data.get("name"),
account_data.get("currency"),
account_data["created"],
account_data.get("last_accessed"),
account_data.get("last_updated", account_data["created"]),
display_name,
account_data.get("logo"),
),
)
conn.commit()
return account_data
def get_accounts(
self, account_ids: Optional[List[str]] = None
) -> List[Dict[str, Any]]:
"""Get account details from database"""
if not self._db_exists():
return []
with self._get_db_connection(row_factory=True) as conn:
cursor = conn.cursor()
query = "SELECT * FROM accounts"
params = []
if account_ids:
placeholders = ",".join("?" * len(account_ids))
query += f" WHERE id IN ({placeholders})"
params.extend(account_ids)
query += " ORDER BY created DESC"
cursor.execute(query, params)
rows = cursor.fetchall()
return [dict(row) for row in rows]
def get_account(self, account_id: str) -> Optional[Dict[str, Any]]:
"""Get specific account details from database"""
if not self._db_exists():
return None
with self._get_db_connection(row_factory=True) as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM accounts WHERE id = ?", (account_id,))
row = cursor.fetchone()
if row:
return dict(row)
return None

View File

@@ -0,0 +1,107 @@
import sqlite3
from typing import Any, Dict, List, Optional
from loguru import logger
from leggen.repositories.base_repository import BaseRepository
class BalanceRepository(BaseRepository):
"""Repository for balance data operations"""
def create_table(self):
"""Create balances table with indexes"""
with self._get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"""CREATE TABLE IF NOT EXISTS balances (
id INTEGER PRIMARY KEY AUTOINCREMENT,
account_id TEXT,
bank TEXT,
status TEXT,
iban TEXT,
amount REAL,
currency TEXT,
type TEXT,
timestamp DATETIME
)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_balances_account_id
ON balances(account_id)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_balances_timestamp
ON balances(timestamp)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_balances_account_type_timestamp
ON balances(account_id, type, timestamp)"""
)
conn.commit()
def persist(self, account_id: str, balance_rows: List[tuple]) -> None:
"""Persist balance rows to database"""
try:
self.create_table()
with self._get_db_connection() as conn:
cursor = conn.cursor()
for row in balance_rows:
try:
cursor.execute(
"""INSERT INTO balances (
account_id,
bank,
status,
iban,
amount,
currency,
type,
timestamp
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
row,
)
except sqlite3.IntegrityError:
logger.warning(f"Skipped duplicate balance for {account_id}")
conn.commit()
logger.info(f"Persisted balances for account {account_id}")
except Exception as e:
logger.error(f"Failed to persist balances: {e}")
raise
def get_balances(self, account_id: Optional[str] = None) -> List[Dict[str, Any]]:
"""Get latest balances from database"""
if not self._db_exists():
return []
with self._get_db_connection(row_factory=True) as conn:
cursor = conn.cursor()
# Get latest balance for each account_id and type combination
query = """
SELECT * FROM balances b1
WHERE b1.timestamp = (
SELECT MAX(b2.timestamp)
FROM balances b2
WHERE b2.account_id = b1.account_id AND b2.type = b1.type
)
"""
params = []
if account_id:
query += " AND b1.account_id = ?"
params.append(account_id)
query += " ORDER BY b1.account_id, b1.type"
cursor.execute(query, params)
rows = cursor.fetchall()
return [dict(row) for row in rows]

View File

@@ -0,0 +1,28 @@
import sqlite3
from contextlib import contextmanager
from leggen.utils.paths import path_manager
class BaseRepository:
"""Base repository with shared database connection logic"""
@contextmanager
def _get_db_connection(self, row_factory: bool = False):
"""Context manager for database connections with proper cleanup"""
db_path = path_manager.get_database_path()
conn = sqlite3.connect(str(db_path))
if row_factory:
conn.row_factory = sqlite3.Row
try:
yield conn
except Exception as e:
conn.rollback()
raise e
finally:
conn.close()
def _db_exists(self) -> bool:
"""Check if database file exists"""
db_path = path_manager.get_database_path()
return db_path.exists()

View File

@@ -0,0 +1,626 @@
import sqlite3
import uuid
from datetime import datetime
from loguru import logger
from leggen.repositories.base_repository import BaseRepository
from leggen.utils.paths import path_manager
class MigrationRepository(BaseRepository):
"""Repository for database migrations"""
async def run_all_migrations(self):
"""Run all necessary database migrations"""
await self.migrate_balance_timestamps_if_needed()
await self.migrate_null_transaction_ids_if_needed()
await self.migrate_to_composite_key_if_needed()
await self.migrate_add_display_name_if_needed()
await self.migrate_add_sync_operations_if_needed()
await self.migrate_add_logo_if_needed()
# Balance timestamp migration methods
async def migrate_balance_timestamps_if_needed(self):
"""Check and migrate balance timestamps if needed"""
try:
if await self._check_balance_timestamp_migration_needed():
logger.info("Balance timestamp migration needed, starting...")
await self._migrate_balance_timestamps()
logger.info("Balance timestamp migration completed")
else:
logger.info("Balance timestamps are already consistent")
except Exception as e:
logger.error(f"Balance timestamp migration failed: {e}")
raise
async def _check_balance_timestamp_migration_needed(self) -> bool:
"""Check if balance timestamps need migration"""
db_path = path_manager.get_database_path()
if not db_path.exists():
return False
try:
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
cursor.execute("""
SELECT typeof(timestamp) as type, COUNT(*) as count
FROM balances
GROUP BY typeof(timestamp)
""")
types = cursor.fetchall()
conn.close()
type_names = [row[0] for row in types]
return "real" in type_names and "text" in type_names
except Exception as e:
logger.error(f"Failed to check migration status: {e}")
return False
async def _migrate_balance_timestamps(self):
"""Convert all Unix timestamps to datetime strings"""
db_path = path_manager.get_database_path()
if not db_path.exists():
logger.warning("Database file not found, skipping migration")
return
try:
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
cursor.execute("""
SELECT id, timestamp
FROM balances
WHERE typeof(timestamp) = 'real'
ORDER BY id
""")
unix_records = cursor.fetchall()
total_records = len(unix_records)
if total_records == 0:
logger.info("No Unix timestamps found to migrate")
conn.close()
return
logger.info(
f"Migrating {total_records} balance records from Unix to datetime format"
)
batch_size = 100
migrated_count = 0
for i in range(0, total_records, batch_size):
batch = unix_records[i : i + batch_size]
for record_id, unix_timestamp in batch:
try:
dt_string = self._unix_to_datetime_string(float(unix_timestamp))
cursor.execute(
"""
UPDATE balances
SET timestamp = ?
WHERE id = ?
""",
(dt_string, record_id),
)
migrated_count += 1
if migrated_count % 100 == 0:
logger.info(
f"Migrated {migrated_count}/{total_records} balance records"
)
except Exception as e:
logger.error(f"Failed to migrate record {record_id}: {e}")
continue
conn.commit()
conn.close()
logger.info(f"Successfully migrated {migrated_count} balance records")
except Exception as e:
logger.error(f"Balance timestamp migration failed: {e}")
raise
def _unix_to_datetime_string(self, unix_timestamp: float) -> str:
"""Convert Unix timestamp to datetime string"""
dt = datetime.fromtimestamp(unix_timestamp)
return dt.isoformat()
# Null transaction IDs migration methods
async def migrate_null_transaction_ids_if_needed(self):
"""Check and migrate null transaction IDs if needed"""
try:
if await self._check_null_transaction_ids_migration_needed():
logger.info("Null transaction IDs migration needed, starting...")
await self._migrate_null_transaction_ids()
logger.info("Null transaction IDs migration completed")
else:
logger.info("No null transaction IDs found to migrate")
except Exception as e:
logger.error(f"Null transaction IDs migration failed: {e}")
raise
async def _check_null_transaction_ids_migration_needed(self) -> bool:
"""Check if null transaction IDs need migration"""
db_path = path_manager.get_database_path()
if not db_path.exists():
return False
try:
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
cursor.execute("""
SELECT COUNT(*)
FROM transactions
WHERE (internalTransactionId IS NULL OR internalTransactionId = '')
AND json_extract(rawTransaction, '$.transactionId') IS NOT NULL
""")
count = cursor.fetchone()[0]
conn.close()
return count > 0
except Exception as e:
logger.error(f"Failed to check null transaction IDs migration status: {e}")
return False
async def _migrate_null_transaction_ids(self):
"""Populate null internalTransactionId fields using transactionId from raw data"""
db_path = path_manager.get_database_path()
if not db_path.exists():
logger.warning("Database file not found, skipping migration")
return
try:
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
cursor.execute("""
SELECT rowid, json_extract(rawTransaction, '$.transactionId') as transactionId
FROM transactions
WHERE (internalTransactionId IS NULL OR internalTransactionId = '')
AND json_extract(rawTransaction, '$.transactionId') IS NOT NULL
ORDER BY rowid
""")
null_records = cursor.fetchall()
total_records = len(null_records)
if total_records == 0:
logger.info("No null transaction IDs found to migrate")
conn.close()
return
logger.info(
f"Migrating {total_records} transaction records with null internalTransactionId"
)
batch_size = 100
migrated_count = 0
for i in range(0, total_records, batch_size):
batch = null_records[i : i + batch_size]
for rowid, transaction_id in batch:
try:
cursor.execute(
"SELECT COUNT(*) FROM transactions WHERE internalTransactionId = ?",
(str(transaction_id),),
)
existing_count = cursor.fetchone()[0]
if existing_count > 0:
unique_id = f"{str(transaction_id)}_{uuid.uuid4().hex[:8]}"
logger.debug(
f"Generated unique ID for duplicate transactionId: {unique_id}"
)
else:
unique_id = str(transaction_id)
cursor.execute(
"""
UPDATE transactions
SET internalTransactionId = ?
WHERE rowid = ?
""",
(unique_id, rowid),
)
migrated_count += 1
if migrated_count % 100 == 0:
logger.info(
f"Migrated {migrated_count}/{total_records} transaction records"
)
except Exception as e:
logger.error(f"Failed to migrate record {rowid}: {e}")
continue
conn.commit()
conn.close()
logger.info(f"Successfully migrated {migrated_count} transaction records")
except Exception as e:
logger.error(f"Null transaction IDs migration failed: {e}")
raise
# Composite key migration methods
async def migrate_to_composite_key_if_needed(self):
"""Check and migrate to composite primary key if needed"""
try:
if await self._check_composite_key_migration_needed():
logger.info("Composite key migration needed, starting...")
await self._migrate_to_composite_key()
logger.info("Composite key migration completed")
else:
logger.info("Composite key migration not needed")
except Exception as e:
logger.error(f"Composite key migration failed: {e}")
raise
async def _check_composite_key_migration_needed(self) -> bool:
"""Check if composite key migration is needed"""
db_path = path_manager.get_database_path()
if not db_path.exists():
return False
try:
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='transactions'"
)
if not cursor.fetchone():
conn.close()
return False
cursor.execute("PRAGMA table_info(transactions)")
columns = cursor.fetchall()
internal_transaction_id_is_pk = any(
col[1] == "internalTransactionId" and col[5] == 1 for col in columns
)
has_composite_key = any(
col[1] in ["accountId", "transactionId"] and col[5] == 1
for col in columns
)
conn.close()
return internal_transaction_id_is_pk or not has_composite_key
except Exception as e:
logger.error(f"Failed to check composite key migration status: {e}")
return False
async def _migrate_to_composite_key(self):
"""Migrate transactions table to use composite primary key (accountId, transactionId)"""
db_path = path_manager.get_database_path()
if not db_path.exists():
logger.warning("Database file not found, skipping migration")
return
try:
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
logger.info("Starting composite key migration...")
logger.info("Creating temporary table with composite primary key...")
cursor.execute("DROP TABLE IF EXISTS transactions_temp")
cursor.execute("""
CREATE TABLE transactions_temp (
accountId TEXT NOT NULL,
transactionId TEXT NOT NULL,
internalTransactionId TEXT,
institutionId TEXT,
iban TEXT,
transactionDate DATETIME,
description TEXT,
transactionValue REAL,
transactionCurrency TEXT,
transactionStatus TEXT,
rawTransaction JSON,
PRIMARY KEY (accountId, transactionId)
)
""")
logger.info("Inserting deduplicated data...")
cursor.execute("""
INSERT INTO transactions_temp
SELECT
accountId,
json_extract(rawTransaction, '$.transactionId') as transactionId,
internalTransactionId,
institutionId,
iban,
transactionDate,
description,
transactionValue,
transactionCurrency,
transactionStatus,
rawTransaction
FROM (
SELECT *,
ROW_NUMBER() OVER (
PARTITION BY accountId, json_extract(rawTransaction, '$.transactionId')
ORDER BY transactionDate DESC
) as rn
FROM transactions
WHERE json_extract(rawTransaction, '$.transactionId') IS NOT NULL
)
WHERE rn = 1
""")
rows_migrated = cursor.rowcount
logger.info(f"Migrated {rows_migrated} unique transactions")
logger.info("Replacing old table...")
cursor.execute("DROP TABLE transactions")
cursor.execute("ALTER TABLE transactions_temp RENAME TO transactions")
logger.info("Recreating indexes...")
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_transactions_internal_id
ON transactions(internalTransactionId)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_transactions_date
ON transactions(transactionDate)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_transactions_account_date
ON transactions(accountId, transactionDate)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_transactions_amount
ON transactions(transactionValue)"""
)
conn.commit()
conn.close()
logger.info("Composite key migration completed successfully")
except Exception as e:
logger.error(f"Composite key migration failed: {e}")
raise
# Display name migration methods
async def migrate_add_display_name_if_needed(self):
"""Check and add display_name column if needed"""
try:
if await self._check_display_name_migration_needed():
logger.info("Display name column migration needed, starting...")
await self._migrate_add_display_name()
logger.info("Display name column migration completed")
else:
logger.info("Display name column already exists")
except Exception as e:
logger.error(f"Display name column migration failed: {e}")
raise
async def _check_display_name_migration_needed(self) -> bool:
"""Check if display_name column needs to be added"""
db_path = path_manager.get_database_path()
if not db_path.exists():
return False
try:
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='accounts'"
)
if not cursor.fetchone():
conn.close()
return False
cursor.execute("PRAGMA table_info(accounts)")
columns = cursor.fetchall()
has_display_name = any(col[1] == "display_name" for col in columns)
conn.close()
return not has_display_name
except Exception as e:
logger.error(f"Failed to check display_name migration status: {e}")
return False
async def _migrate_add_display_name(self):
"""Add display_name column to accounts table"""
db_path = path_manager.get_database_path()
if not db_path.exists():
logger.warning("Database file not found, skipping migration")
return
try:
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
logger.info("Adding display_name column to accounts table...")
cursor.execute("""
ALTER TABLE accounts
ADD COLUMN display_name TEXT
""")
conn.commit()
conn.close()
logger.info("Display name column migration completed successfully")
except Exception as e:
logger.error(f"Display name column migration failed: {e}")
raise
# Sync operations migration methods
async def migrate_add_sync_operations_if_needed(self):
"""Check and add sync_operations table if needed"""
try:
if await self._check_sync_operations_migration_needed():
logger.info("Sync operations table migration needed, starting...")
await self._migrate_add_sync_operations()
logger.info("Sync operations table migration completed")
else:
logger.info("Sync operations table already exists")
except Exception as e:
logger.error(f"Sync operations table migration failed: {e}")
raise
async def _check_sync_operations_migration_needed(self) -> bool:
"""Check if sync_operations table needs to be created"""
db_path = path_manager.get_database_path()
if not db_path.exists():
return False
try:
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='sync_operations'"
)
table_exists = cursor.fetchone() is not None
conn.close()
return not table_exists
except Exception as e:
logger.error(f"Failed to check sync_operations migration status: {e}")
return False
async def _migrate_add_sync_operations(self):
"""Add sync_operations table"""
db_path = path_manager.get_database_path()
if not db_path.exists():
logger.warning("Database file not found, skipping migration")
return
try:
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
logger.info("Creating sync_operations table...")
cursor.execute("""
CREATE TABLE sync_operations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
started_at DATETIME NOT NULL,
completed_at DATETIME,
success BOOLEAN,
accounts_processed INTEGER DEFAULT 0,
transactions_added INTEGER DEFAULT 0,
transactions_updated INTEGER DEFAULT 0,
balances_updated INTEGER DEFAULT 0,
duration_seconds REAL,
errors TEXT,
logs TEXT,
trigger_type TEXT DEFAULT 'manual'
)
""")
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_sync_operations_started_at ON sync_operations(started_at)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_sync_operations_success ON sync_operations(success)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_sync_operations_trigger_type ON sync_operations(trigger_type)"
)
conn.commit()
conn.close()
logger.info("Sync operations table migration completed successfully")
except Exception as e:
logger.error(f"Sync operations table migration failed: {e}")
raise
# Logo migration methods
async def migrate_add_logo_if_needed(self):
"""Check and add logo column to accounts table if needed"""
try:
if await self._check_logo_migration_needed():
logger.info("Logo column migration needed, starting...")
await self._migrate_add_logo()
logger.info("Logo column migration completed")
else:
logger.info("Logo column already exists")
except Exception as e:
logger.error(f"Logo column migration failed: {e}")
raise
async def _check_logo_migration_needed(self) -> bool:
"""Check if logo column needs to be added to accounts table"""
db_path = path_manager.get_database_path()
if not db_path.exists():
return False
try:
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='accounts'"
)
if not cursor.fetchone():
conn.close()
return False
cursor.execute("PRAGMA table_info(accounts)")
columns = cursor.fetchall()
has_logo = any(col[1] == "logo" for col in columns)
conn.close()
return not has_logo
except Exception as e:
logger.error(f"Failed to check logo migration status: {e}")
return False
async def _migrate_add_logo(self):
"""Add logo column to accounts table"""
db_path = path_manager.get_database_path()
if not db_path.exists():
logger.warning("Database file not found, skipping migration")
return
try:
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
logger.info("Adding logo column to accounts table...")
cursor.execute("""
ALTER TABLE accounts
ADD COLUMN logo TEXT
""")
conn.commit()
conn.close()
logger.info("Logo column migration completed successfully")
except Exception as e:
logger.error(f"Logo column migration failed: {e}")
raise

View File

@@ -0,0 +1,132 @@
import json
import sqlite3
from typing import Any, Dict, List
from loguru import logger
from leggen.repositories.base_repository import BaseRepository
from leggen.utils.paths import path_manager
class SyncRepository(BaseRepository):
"""Repository for sync operation data"""
def create_table(self):
"""Create sync_operations table with indexes"""
with self._get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS sync_operations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
started_at DATETIME NOT NULL,
completed_at DATETIME,
success BOOLEAN,
accounts_processed INTEGER DEFAULT 0,
transactions_added INTEGER DEFAULT 0,
transactions_updated INTEGER DEFAULT 0,
balances_updated INTEGER DEFAULT 0,
duration_seconds REAL,
errors TEXT,
logs TEXT,
trigger_type TEXT DEFAULT 'manual'
)
""")
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_sync_operations_started_at ON sync_operations(started_at)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_sync_operations_success ON sync_operations(success)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_sync_operations_trigger_type ON sync_operations(trigger_type)"
)
conn.commit()
def persist(self, sync_operation: Dict[str, Any]) -> int:
"""Persist sync operation to database and return the ID"""
try:
self.create_table()
db_path = path_manager.get_database_path()
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
cursor.execute(
"""INSERT INTO sync_operations (
started_at, completed_at, success, accounts_processed,
transactions_added, transactions_updated, balances_updated,
duration_seconds, errors, logs, trigger_type
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
(
sync_operation.get("started_at"),
sync_operation.get("completed_at"),
sync_operation.get("success"),
sync_operation.get("accounts_processed", 0),
sync_operation.get("transactions_added", 0),
sync_operation.get("transactions_updated", 0),
sync_operation.get("balances_updated", 0),
sync_operation.get("duration_seconds"),
json.dumps(sync_operation.get("errors", [])),
json.dumps(sync_operation.get("logs", [])),
sync_operation.get("trigger_type", "manual"),
),
)
operation_id = cursor.lastrowid
if operation_id is None:
raise ValueError("Failed to get operation ID after insert")
conn.commit()
conn.close()
logger.debug(f"Persisted sync operation with ID: {operation_id}")
return operation_id
except Exception as e:
logger.error(f"Failed to persist sync operation: {e}")
raise
def get_operations(self, limit: int = 50, offset: int = 0) -> List[Dict[str, Any]]:
"""Get sync operations from database"""
try:
db_path = path_manager.get_database_path()
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
cursor.execute(
"""SELECT id, started_at, completed_at, success, accounts_processed,
transactions_added, transactions_updated, balances_updated,
duration_seconds, errors, logs, trigger_type
FROM sync_operations
ORDER BY started_at DESC
LIMIT ? OFFSET ?""",
(limit, offset),
)
operations = []
for row in cursor.fetchall():
operation = {
"id": row[0],
"started_at": row[1],
"completed_at": row[2],
"success": bool(row[3]) if row[3] is not None else None,
"accounts_processed": row[4],
"transactions_added": row[5],
"transactions_updated": row[6],
"balances_updated": row[7],
"duration_seconds": row[8],
"errors": json.loads(row[9]) if row[9] else [],
"logs": json.loads(row[10]) if row[10] else [],
"trigger_type": row[11],
}
operations.append(operation)
conn.close()
return operations
except Exception as e:
logger.error(f"Failed to get sync operations: {e}")
return []

View File

@@ -0,0 +1,264 @@
import json
import sqlite3
from typing import Any, Dict, List, Optional, Union
from loguru import logger
from leggen.repositories.base_repository import BaseRepository
class TransactionRepository(BaseRepository):
"""Repository for transaction data operations"""
def create_table(self):
"""Create transactions table with indexes"""
with self._get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"""CREATE TABLE IF NOT EXISTS transactions (
accountId TEXT NOT NULL,
transactionId TEXT NOT NULL,
internalTransactionId TEXT,
institutionId TEXT,
iban TEXT,
transactionDate DATETIME,
description TEXT,
transactionValue REAL,
transactionCurrency TEXT,
transactionStatus TEXT,
rawTransaction JSON,
PRIMARY KEY (accountId, transactionId)
)"""
)
# Create indexes for better performance
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_transactions_internal_id
ON transactions(internalTransactionId)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_transactions_date
ON transactions(transactionDate)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_transactions_account_date
ON transactions(accountId, transactionDate)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_transactions_amount
ON transactions(transactionValue)"""
)
conn.commit()
def persist(
self, account_id: str, transactions: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Persist transactions to database, return new ones"""
try:
self.create_table()
with self._get_db_connection() as conn:
cursor = conn.cursor()
insert_sql = """INSERT OR REPLACE INTO transactions (
accountId,
transactionId,
internalTransactionId,
institutionId,
iban,
transactionDate,
description,
transactionValue,
transactionCurrency,
transactionStatus,
rawTransaction
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"""
new_transactions = []
for transaction in transactions:
try:
# Check if transaction already exists
cursor.execute(
"""SELECT COUNT(*) FROM transactions
WHERE accountId = ? AND transactionId = ?""",
(transaction["accountId"], transaction["transactionId"]),
)
exists = cursor.fetchone()[0] > 0
cursor.execute(
insert_sql,
(
transaction["accountId"],
transaction["transactionId"],
transaction.get("internalTransactionId"),
transaction["institutionId"],
transaction["iban"],
transaction["transactionDate"],
transaction["description"],
transaction["transactionValue"],
transaction["transactionCurrency"],
transaction["transactionStatus"],
json.dumps(transaction["rawTransaction"]),
),
)
if not exists:
new_transactions.append(transaction)
except sqlite3.IntegrityError as e:
logger.warning(
f"Failed to insert transaction {transaction.get('transactionId')}: {e}"
)
continue
conn.commit()
logger.info(
f"Persisted {len(new_transactions)} new transactions for account {account_id}"
)
return new_transactions
except Exception as e:
logger.error(f"Failed to persist transactions: {e}")
raise
def get_transactions(
self,
account_id: Optional[str] = None,
limit: Optional[int] = 100,
offset: int = 0,
date_from: Optional[str] = None,
date_to: Optional[str] = None,
min_amount: Optional[float] = None,
max_amount: Optional[float] = None,
search: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""Get transactions with optional filtering"""
if not self._db_exists():
return []
with self._get_db_connection(row_factory=True) as conn:
cursor = conn.cursor()
query = "SELECT * FROM transactions WHERE 1=1"
params: List[Union[str, int, float]] = []
if account_id:
query += " AND accountId = ?"
params.append(account_id)
if date_from:
query += " AND transactionDate >= ?"
params.append(date_from)
if date_to:
query += " AND transactionDate <= ?"
params.append(date_to)
if min_amount is not None:
query += " AND transactionValue >= ?"
params.append(min_amount)
if max_amount is not None:
query += " AND transactionValue <= ?"
params.append(max_amount)
if search:
query += " AND description LIKE ?"
params.append(f"%{search}%")
query += " ORDER BY transactionDate DESC"
if limit:
query += " LIMIT ?"
params.append(limit)
if offset:
query += " OFFSET ?"
params.append(offset)
cursor.execute(query, params)
rows = cursor.fetchall()
transactions = []
for row in rows:
transaction = dict(row)
if transaction["rawTransaction"]:
transaction["rawTransaction"] = json.loads(
transaction["rawTransaction"]
)
transactions.append(transaction)
return transactions
def get_count(
self,
account_id: Optional[str] = None,
date_from: Optional[str] = None,
date_to: Optional[str] = None,
min_amount: Optional[float] = None,
max_amount: Optional[float] = None,
search: Optional[str] = None,
) -> int:
"""Get total count of transactions matching filters"""
if not self._db_exists():
return 0
with self._get_db_connection() as conn:
cursor = conn.cursor()
query = "SELECT COUNT(*) FROM transactions WHERE 1=1"
params: List[Union[str, float]] = []
if account_id:
query += " AND accountId = ?"
params.append(account_id)
if date_from:
query += " AND transactionDate >= ?"
params.append(date_from)
if date_to:
query += " AND transactionDate <= ?"
params.append(date_to)
if min_amount is not None:
query += " AND transactionValue >= ?"
params.append(min_amount)
if max_amount is not None:
query += " AND transactionValue <= ?"
params.append(max_amount)
if search:
query += " AND description LIKE ?"
params.append(f"%{search}%")
cursor.execute(query, params)
return cursor.fetchone()[0]
def get_account_summary(self, account_id: str) -> Optional[Dict[str, Any]]:
"""Get basic account info from transactions table"""
if not self._db_exists():
return None
with self._get_db_connection(row_factory=True) as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT DISTINCT accountId, institutionId, iban
FROM transactions
WHERE accountId = ?
ORDER BY transactionDate DESC
LIMIT 1
""",
(account_id,),
)
row = cursor.fetchone()
if row:
return dict(row)
return None

View File

@@ -0,0 +1,13 @@
"""Data processing layer for all transformation logic."""
from leggen.services.data_processors.account_enricher import AccountEnricher
from leggen.services.data_processors.analytics_processor import AnalyticsProcessor
from leggen.services.data_processors.balance_transformer import BalanceTransformer
from leggen.services.data_processors.transaction_processor import TransactionProcessor
__all__ = [
"AccountEnricher",
"AnalyticsProcessor",
"BalanceTransformer",
"TransactionProcessor",
]

View File

@@ -0,0 +1,71 @@
"""Account enrichment processor for adding currency, logos, and metadata."""
from typing import Any, Dict
from loguru import logger
from leggen.services.gocardless_service import GoCardlessService
class AccountEnricher:
"""Enriches account details with currency and institution information."""
def __init__(self):
self.gocardless = GoCardlessService()
async def enrich_account_details(
self,
account_details: Dict[str, Any],
balances: Dict[str, Any],
) -> Dict[str, Any]:
"""
Enrich account details with currency from balances and institution logo.
Args:
account_details: Raw account details from GoCardless
balances: Balance data containing currency information
Returns:
Enriched account details with currency and logo added
"""
enriched_account = account_details.copy()
# Extract currency from first balance
currency = self._extract_currency_from_balances(balances)
if currency:
enriched_account["currency"] = currency
# Fetch and add institution logo
institution_id = enriched_account.get("institution_id")
if institution_id:
logo = await self._fetch_institution_logo(institution_id)
if logo:
enriched_account["logo"] = logo
return enriched_account
def _extract_currency_from_balances(self, balances: Dict[str, Any]) -> str | None:
"""Extract currency from the first balance in the balances data."""
balances_list = balances.get("balances", [])
if not balances_list:
return None
first_balance = balances_list[0]
balance_amount = first_balance.get("balanceAmount", {})
return balance_amount.get("currency")
async def _fetch_institution_logo(self, institution_id: str) -> str | None:
"""Fetch institution logo from GoCardless API."""
try:
institution_details = await self.gocardless.get_institution_details(
institution_id
)
logo = institution_details.get("logo", "")
if logo:
logger.info(f"Fetched logo for institution {institution_id}: {logo}")
return logo
except Exception as e:
logger.warning(
f"Failed to fetch institution details for {institution_id}: {e}"
)
return None

View File

@@ -0,0 +1,201 @@
"""Analytics processor for calculating historical balances and statistics."""
import sqlite3
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Dict, List, Optional
from loguru import logger
class AnalyticsProcessor:
"""Calculates historical balances and transaction statistics from database data."""
def calculate_historical_balances(
self,
db_path: Path,
account_id: Optional[str] = None,
days: int = 365,
) -> List[Dict[str, Any]]:
"""
Generate historical balance progression based on transaction history.
Uses current balances and subtracts future transactions to calculate
balance at each historical point in time.
Args:
db_path: Path to SQLite database
account_id: Optional account ID to filter by
days: Number of days to look back (default 365)
Returns:
List of historical balance data points
"""
if not db_path.exists():
return []
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
try:
cutoff_date = (datetime.now() - timedelta(days=days)).date().isoformat()
today_date = datetime.now().date().isoformat()
# Single SQL query to generate historical balances using window functions
query = """
WITH RECURSIVE date_series AS (
-- Generate weekly dates from cutoff_date to today
SELECT date(?) as ref_date
UNION ALL
SELECT date(ref_date, '+7 days')
FROM date_series
WHERE ref_date < date(?)
),
current_balances AS (
-- Get current balance for each account/type
SELECT account_id, type, amount, currency
FROM balances b1
WHERE b1.timestamp = (
SELECT MAX(b2.timestamp)
FROM balances b2
WHERE b2.account_id = b1.account_id AND b2.type = b1.type
)
{account_filter}
AND b1.type = 'closingBooked' -- Focus on closingBooked for charts
),
historical_points AS (
-- Calculate balance at each weekly point by subtracting future transactions
SELECT
cb.account_id,
cb.type as balance_type,
cb.currency,
ds.ref_date,
cb.amount - COALESCE(
(SELECT SUM(t.transactionValue)
FROM transactions t
WHERE t.accountId = cb.account_id
AND date(t.transactionDate) > ds.ref_date), 0
) as balance_amount
FROM current_balances cb
CROSS JOIN date_series ds
)
SELECT
account_id || '_' || balance_type || '_' || ref_date as id,
account_id,
balance_amount,
balance_type,
currency,
ref_date as reference_date
FROM historical_points
ORDER BY account_id, ref_date
"""
# Build parameters and account filter
params = [cutoff_date, today_date]
if account_id:
account_filter = "AND b1.account_id = ?"
params.append(account_id)
else:
account_filter = ""
# Format the query with conditional filter
formatted_query = query.format(account_filter=account_filter)
cursor.execute(formatted_query, params)
rows = cursor.fetchall()
conn.close()
return [dict(row) for row in rows]
except Exception as e:
conn.close()
logger.error(f"Failed to calculate historical balances: {e}")
raise
def calculate_monthly_stats(
self,
db_path: Path,
account_id: Optional[str] = None,
date_from: Optional[str] = None,
date_to: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""
Calculate monthly transaction statistics aggregated from database.
Sums transactions by month and calculates income, expenses, and net values.
Args:
db_path: Path to SQLite database
account_id: Optional account ID to filter by
date_from: Optional start date (ISO format)
date_to: Optional end date (ISO format)
Returns:
List of monthly statistics with income, expenses, and net totals
"""
if not db_path.exists():
return []
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
try:
# SQL query to aggregate transactions by month
query = """
SELECT
strftime('%Y-%m', transactionDate) as month,
COALESCE(SUM(CASE WHEN transactionValue > 0 THEN transactionValue ELSE 0 END), 0) as income,
COALESCE(SUM(CASE WHEN transactionValue < 0 THEN ABS(transactionValue) ELSE 0 END), 0) as expenses,
COALESCE(SUM(transactionValue), 0) as net
FROM transactions
WHERE 1=1
"""
params = []
if account_id:
query += " AND accountId = ?"
params.append(account_id)
if date_from:
query += " AND transactionDate >= ?"
params.append(date_from)
if date_to:
query += " AND transactionDate <= ?"
params.append(date_to)
query += """
GROUP BY strftime('%Y-%m', transactionDate)
ORDER BY month ASC
"""
cursor.execute(query, params)
rows = cursor.fetchall()
# Convert to desired format with proper month display
monthly_stats = []
for row in rows:
# Convert YYYY-MM to display format like "Mar 2024"
year, month_num = row["month"].split("-")
month_date = datetime.strptime(f"{year}-{month_num}-01", "%Y-%m-%d")
display_month = month_date.strftime("%b %Y")
monthly_stats.append(
{
"month": display_month,
"income": round(row["income"], 2),
"expenses": round(row["expenses"], 2),
"net": round(row["net"], 2),
}
)
conn.close()
return monthly_stats
except Exception as e:
conn.close()
logger.error(f"Failed to calculate monthly stats: {e}")
raise

View File

@@ -0,0 +1,69 @@
"""Balance data transformation processor for format conversions."""
from datetime import datetime
from typing import Any, Dict, List, Tuple
class BalanceTransformer:
"""Transforms balance data between GoCardless and internal database formats."""
def merge_account_metadata_into_balances(
self,
balances: Dict[str, Any],
account_details: Dict[str, Any],
) -> Dict[str, Any]:
"""
Merge account metadata into balance data for proper persistence.
This adds institution_id, iban, and account_status to the balances
so they can be persisted alongside the balance data.
Args:
balances: Raw balance data from GoCardless
account_details: Enriched account details containing metadata
Returns:
Balance data with account metadata merged in
"""
balances_with_metadata = balances.copy()
balances_with_metadata["institution_id"] = account_details.get("institution_id")
balances_with_metadata["iban"] = account_details.get("iban")
balances_with_metadata["account_status"] = account_details.get("status")
return balances_with_metadata
def transform_to_database_format(
self,
account_id: str,
balance_data: Dict[str, Any],
) -> List[Tuple[Any, ...]]:
"""
Transform GoCardless balance format to database row format.
Converts nested GoCardless balance structure into flat tuples
ready for SQLite insertion.
Args:
account_id: The account ID
balance_data: Balance data with merged account metadata
Returns:
List of tuples in database row format (account_id, bank, status, ...)
"""
rows = []
for balance in balance_data.get("balances", []):
balance_amount = balance.get("balanceAmount", {})
row = (
account_id,
balance_data.get("institution_id", "unknown"),
balance_data.get("account_status"),
balance_data.get("iban", "N/A"),
float(balance_amount.get("amount", 0)),
balance_amount.get("currency"),
balance.get("balanceType"),
datetime.now().isoformat(),
)
rows.append(row)
return rows

File diff suppressed because it is too large Load Diff

View File

@@ -52,11 +52,17 @@ class NotificationService:
async def send_expiry_notification(self, notification_data: Dict[str, Any]) -> None:
"""Send notification about account expiry"""
if self._is_discord_enabled():
await self._send_discord_expiry(notification_data)
try:
if self._is_discord_enabled():
await self._send_discord_expiry(notification_data)
except Exception as e:
logger.error(f"Failed to send Discord expiry notification: {e}")
if self._is_telegram_enabled():
await self._send_telegram_expiry(notification_data)
try:
if self._is_telegram_enabled():
await self._send_telegram_expiry(notification_data)
except Exception as e:
logger.error(f"Failed to send Telegram expiry notification: {e}")
def _filter_transactions(
self, transactions: List[Dict[str, Any]]
@@ -262,7 +268,6 @@ class NotificationService:
logger.info(f"Sent Discord expiry notification: {notification_data}")
except Exception as e:
logger.error(f"Failed to send Discord expiry notification: {e}")
raise
async def _send_telegram_expiry(self, notification_data: Dict[str, Any]) -> None:
"""Send Telegram expiry notification"""
@@ -288,17 +293,22 @@ class NotificationService:
logger.info(f"Sent Telegram expiry notification: {notification_data}")
except Exception as e:
logger.error(f"Failed to send Telegram expiry notification: {e}")
raise
async def send_sync_failure_notification(
self, notification_data: Dict[str, Any]
) -> None:
"""Send notification about sync failure"""
if self._is_discord_enabled():
await self._send_discord_sync_failure(notification_data)
try:
if self._is_discord_enabled():
await self._send_discord_sync_failure(notification_data)
except Exception as e:
logger.error(f"Failed to send Discord sync failure notification: {e}")
if self._is_telegram_enabled():
await self._send_telegram_sync_failure(notification_data)
try:
if self._is_telegram_enabled():
await self._send_telegram_sync_failure(notification_data)
except Exception as e:
logger.error(f"Failed to send Telegram sync failure notification: {e}")
async def _send_discord_sync_failure(
self, notification_data: Dict[str, Any]
@@ -326,7 +336,6 @@ class NotificationService:
logger.info(f"Sent Discord sync failure notification: {notification_data}")
except Exception as e:
logger.error(f"Failed to send Discord sync failure notification: {e}")
raise
async def _send_telegram_sync_failure(
self, notification_data: Dict[str, Any]
@@ -354,4 +363,3 @@ class NotificationService:
logger.info(f"Sent Telegram sync failure notification: {notification_data}")
except Exception as e:
logger.error(f"Failed to send Telegram sync failure notification: {e}")
raise

View File

@@ -4,18 +4,41 @@ from typing import List
from loguru import logger
from leggen.api.models.sync import SyncResult, SyncStatus
from leggen.services.database_service import DatabaseService
from leggen.repositories import (
AccountRepository,
BalanceRepository,
SyncRepository,
TransactionRepository,
)
from leggen.services.data_processors import (
AccountEnricher,
BalanceTransformer,
TransactionProcessor,
)
from leggen.services.gocardless_service import GoCardlessService
from leggen.services.notification_service import NotificationService
# Constants for notification
EXPIRED_DAYS_LEFT = 0
class SyncService:
def __init__(self):
self.gocardless = GoCardlessService()
self.database = DatabaseService()
self.notifications = NotificationService()
# Repositories
self.accounts = AccountRepository()
self.balances = BalanceRepository()
self.transactions = TransactionRepository()
self.sync = SyncRepository()
# Data processors
self.account_enricher = AccountEnricher()
self.balance_transformer = BalanceTransformer()
self.transaction_processor = TransactionProcessor()
self._sync_status = SyncStatus(is_running=False)
self._institution_logos = {} # Cache for institution logos
async def get_sync_status(self) -> SyncStatus:
"""Get current sync status"""
@@ -67,6 +90,9 @@ class SyncService:
self._sync_status.total_accounts = len(all_accounts)
logs.append(f"Found {len(all_accounts)} accounts to sync")
# Check for expired or expiring requisitions
await self._check_requisition_expiry(requisitions.get("results", []))
# Process each account
for account_id in all_accounts:
try:
@@ -78,72 +104,44 @@ class SyncService:
# Get balances to extract currency information
balances = await self.gocardless.get_account_balances(account_id)
# Enrich account details with currency and institution logo
# Enrich and persist account details
if account_details and balances:
enriched_account_details = account_details.copy()
# Extract currency from first balance
balances_list = balances.get("balances", [])
if balances_list:
first_balance = balances_list[0]
balance_amount = first_balance.get("balanceAmount", {})
currency = balance_amount.get("currency")
if currency:
enriched_account_details["currency"] = currency
# Get institution details to fetch logo
institution_id = enriched_account_details.get("institution_id")
if institution_id:
try:
institution_details = (
await self.gocardless.get_institution_details(
institution_id
)
)
enriched_account_details["logo"] = (
institution_details.get("logo", "")
)
logger.info(
f"Fetched logo for institution {institution_id}: {enriched_account_details.get('logo', 'No logo')}"
)
except Exception as e:
logger.warning(
f"Failed to fetch institution details for {institution_id}: {e}"
)
# Enrich account with currency and institution logo
enriched_account_details = (
await self.account_enricher.enrich_account_details(
account_details, balances
)
)
# Persist enriched account details to database
await self.database.persist_account_details(
enriched_account_details
)
self.accounts.persist(enriched_account_details)
# Merge account details into balances data for proper persistence
balances_with_account_info = balances.copy()
balances_with_account_info["institution_id"] = (
enriched_account_details.get("institution_id")
# Merge account metadata into balances for persistence
balances_with_account_info = self.balance_transformer.merge_account_metadata_into_balances(
balances, enriched_account_details
)
balances_with_account_info["iban"] = (
enriched_account_details.get("iban")
)
balances_with_account_info["account_status"] = (
enriched_account_details.get("status")
)
await self.database.persist_balance(
account_id, balances_with_account_info
balance_rows = (
self.balance_transformer.transform_to_database_format(
account_id, balances_with_account_info
)
)
self.balances.persist(account_id, balance_rows)
balances_updated += len(balances.get("balances", []))
elif account_details:
# Fallback: persist account details without currency if balances failed
await self.database.persist_account_details(account_details)
self.accounts.persist(account_details)
# Get and save transactions
transactions = await self.gocardless.get_account_transactions(
account_id
)
if transactions:
processed_transactions = self.database.process_transactions(
account_id, account_details, transactions
processed_transactions = (
self.transaction_processor.process_transactions(
account_id, account_details, transactions
)
)
new_transactions = await self.database.persist_transactions(
new_transactions = self.transactions.persist(
account_id, processed_transactions
)
transactions_added += len(new_transactions)
@@ -166,6 +164,15 @@ class SyncService:
logger.error(error_msg)
logs.append(error_msg)
# Send notification for account sync failure
await self.notifications.send_sync_failure_notification(
{
"account_id": account_id,
"error": error_msg,
"type": "account_sync_failure",
}
)
end_time = datetime.now()
duration = (end_time - start_time).total_seconds()
@@ -188,9 +195,7 @@ class SyncService:
# Persist sync operation to database
try:
operation_id = await self.database.persist_sync_operation(
sync_operation
)
operation_id = self.sync.persist(sync_operation)
logger.debug(f"Saved sync operation with ID: {operation_id}")
except Exception as e:
logger.error(f"Failed to persist sync operation: {e}")
@@ -239,9 +244,7 @@ class SyncService:
)
try:
operation_id = await self.database.persist_sync_operation(
sync_operation
)
operation_id = self.sync.persist(sync_operation)
logger.debug(f"Saved failed sync operation with ID: {operation_id}")
except Exception as persist_error:
logger.error(
@@ -252,6 +255,31 @@ class SyncService:
finally:
self._sync_status.is_running = False
async def _check_requisition_expiry(self, requisitions: List[dict]) -> None:
"""Check requisitions for expiry and send notifications.
Args:
requisitions: List of requisition dictionaries to check
"""
for req in requisitions:
requisition_id = req.get("id", "unknown")
institution_id = req.get("institution_id", "unknown")
status = req.get("status", "")
# Check if requisition is expired
if status == "EX":
logger.warning(
f"Requisition {requisition_id} for {institution_id} has expired"
)
await self.notifications.send_expiry_notification(
{
"bank": institution_id,
"requisition_id": requisition_id,
"status": "expired",
"days_left": EXPIRED_DAYS_LEFT,
}
)
async def sync_specific_accounts(
self, account_ids: List[str], force: bool = False, trigger_type: str = "manual"
) -> SyncResult:

View File

@@ -126,6 +126,38 @@ def api_client(fastapi_app):
return TestClient(fastapi_app)
@pytest.fixture
def mock_account_repo():
"""Create mock AccountRepository for testing."""
from unittest.mock import MagicMock
return MagicMock()
@pytest.fixture
def mock_balance_repo():
"""Create mock BalanceRepository for testing."""
from unittest.mock import MagicMock
return MagicMock()
@pytest.fixture
def mock_transaction_repo():
"""Create mock TransactionRepository for testing."""
from unittest.mock import MagicMock
return MagicMock()
@pytest.fixture
def mock_analytics_proc():
"""Create mock AnalyticsProcessor for testing."""
from unittest.mock import MagicMock
return MagicMock()
@pytest.fixture
def mock_db_path(temp_db_path):
"""Mock the database path to use temporary database for testing."""

View File

@@ -1,13 +1,13 @@
"""Tests for analytics fixes to ensure all transactions are used in statistics."""
from datetime import datetime, timedelta
from unittest.mock import AsyncMock, Mock, patch
from unittest.mock import Mock
import pytest
from fastapi.testclient import TestClient
from leggen.api.dependencies import get_transaction_repository
from leggen.commands.server import create_app
from leggen.services.database_service import DatabaseService
class TestAnalyticsFix:
@@ -19,11 +19,11 @@ class TestAnalyticsFix:
return TestClient(app)
@pytest.fixture
def mock_database_service(self):
return Mock(spec=DatabaseService)
def mock_transaction_repo(self):
return Mock()
@pytest.mark.asyncio
async def test_transaction_stats_uses_all_transactions(self, mock_database_service):
async def test_transaction_stats_uses_all_transactions(self, mock_transaction_repo):
"""Test that transaction stats endpoint uses all transactions (not limited to 100)"""
# Mock data for 600 transactions (simulating the issue)
mock_transactions = []
@@ -42,53 +42,50 @@ class TestAnalyticsFix:
}
)
mock_database_service.get_transactions_from_db = AsyncMock(
return_value=mock_transactions
mock_transaction_repo.get_transactions.return_value = mock_transactions
app = create_app()
app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
client = TestClient(app)
response = client.get("/api/v1/transactions/stats?days=365")
assert response.status_code == 200
data = response.json()
# Verify that limit=None was passed to get all transactions
mock_transaction_repo.get_transactions.assert_called_once()
call_args = mock_transaction_repo.get_transactions.call_args
assert call_args.kwargs.get("limit") is None, (
"Stats endpoint should pass limit=None to get all transactions"
)
# Test that the endpoint calls get_transactions_from_db with limit=None
with patch(
"leggen.api.routes.transactions.database_service", mock_database_service
):
app = create_app()
client = TestClient(app)
# Verify that the response contains stats for all 600 transactions
stats = data
assert stats["total_transactions"] == 600, (
"Should process all 600 transactions, not just 100"
)
response = client.get("/api/v1/transactions/stats?days=365")
# Verify calculations are correct for all transactions
expected_income = sum(
txn["transactionValue"]
for txn in mock_transactions
if txn["transactionValue"] > 0
)
expected_expenses = sum(
abs(txn["transactionValue"])
for txn in mock_transactions
if txn["transactionValue"] < 0
)
assert response.status_code == 200
data = response.json()
# Verify that limit=None was passed to get all transactions
mock_database_service.get_transactions_from_db.assert_called_once()
call_args = mock_database_service.get_transactions_from_db.call_args
assert call_args.kwargs.get("limit") is None, (
"Stats endpoint should pass limit=None to get all transactions"
)
# Verify that the response contains stats for all 600 transactions
stats = data
assert stats["total_transactions"] == 600, (
"Should process all 600 transactions, not just 100"
)
# Verify calculations are correct for all transactions
expected_income = sum(
txn["transactionValue"]
for txn in mock_transactions
if txn["transactionValue"] > 0
)
expected_expenses = sum(
abs(txn["transactionValue"])
for txn in mock_transactions
if txn["transactionValue"] < 0
)
assert stats["total_income"] == expected_income
assert stats["total_expenses"] == expected_expenses
assert stats["total_income"] == expected_income
assert stats["total_expenses"] == expected_expenses
@pytest.mark.asyncio
async def test_analytics_endpoint_returns_all_transactions(
self, mock_database_service
self, mock_transaction_repo
):
"""Test that the new analytics endpoint returns all transactions without pagination"""
# Mock data for 600 transactions
@@ -108,30 +105,28 @@ class TestAnalyticsFix:
}
)
mock_database_service.get_transactions_from_db = AsyncMock(
return_value=mock_transactions
mock_transaction_repo.get_transactions.return_value = mock_transactions
app = create_app()
app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
client = TestClient(app)
response = client.get("/api/v1/transactions/analytics?days=365")
assert response.status_code == 200
data = response.json()
# Verify that limit=None was passed to get all transactions
mock_transaction_repo.get_transactions.assert_called_once()
call_args = mock_transaction_repo.get_transactions.call_args
assert call_args.kwargs.get("limit") is None, (
"Analytics endpoint should pass limit=None"
)
with patch(
"leggen.api.routes.transactions.database_service", mock_database_service
):
app = create_app()
client = TestClient(app)
response = client.get("/api/v1/transactions/analytics?days=365")
assert response.status_code == 200
data = response.json()
# Verify that limit=None was passed to get all transactions
mock_database_service.get_transactions_from_db.assert_called_once()
call_args = mock_database_service.get_transactions_from_db.call_args
assert call_args.kwargs.get("limit") is None, (
"Analytics endpoint should pass limit=None"
)
# Verify that all 600 transactions are returned
transactions_data = data
assert len(transactions_data) == 600, (
"Analytics endpoint should return all 600 transactions"
)
# Verify that all 600 transactions are returned
transactions_data = data
assert len(transactions_data) == 600, (
"Analytics endpoint should return all 600 transactions"
)

View File

@@ -4,6 +4,12 @@ from unittest.mock import patch
import pytest
from leggen.api.dependencies import (
get_account_repository,
get_balance_repository,
get_transaction_repository,
)
@pytest.mark.api
class TestAccountsAPI:
@@ -11,11 +17,14 @@ class TestAccountsAPI:
def test_get_all_accounts_success(
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
sample_account_data,
mock_db_path,
mock_account_repo,
mock_balance_repo,
):
"""Test successful retrieval of all accounts from database."""
mock_accounts = [
@@ -45,19 +54,21 @@ class TestAccountsAPI:
}
]
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.accounts.database_service.get_accounts_from_db",
return_value=mock_accounts,
),
patch(
"leggen.api.routes.accounts.database_service.get_balances_from_db",
return_value=mock_balances,
),
):
mock_account_repo.get_accounts.return_value = mock_accounts
mock_balance_repo.get_balances.return_value = mock_balances
fastapi_app.dependency_overrides[get_account_repository] = (
lambda: mock_account_repo
)
fastapi_app.dependency_overrides[get_balance_repository] = (
lambda: mock_balance_repo
)
with patch("leggen.utils.config.config", mock_config):
response = api_client.get("/api/v1/accounts")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200
data = response.json()
assert len(data) == 1
@@ -69,11 +80,14 @@ class TestAccountsAPI:
def test_get_account_details_success(
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
sample_account_data,
mock_db_path,
mock_account_repo,
mock_balance_repo,
):
"""Test successful retrieval of specific account details from database."""
mock_account = {
@@ -101,19 +115,21 @@ class TestAccountsAPI:
}
]
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.accounts.database_service.get_account_details_from_db",
return_value=mock_account,
),
patch(
"leggen.api.routes.accounts.database_service.get_balances_from_db",
return_value=mock_balances,
),
):
mock_account_repo.get_account.return_value = mock_account
mock_balance_repo.get_balances.return_value = mock_balances
fastapi_app.dependency_overrides[get_account_repository] = (
lambda: mock_account_repo
)
fastapi_app.dependency_overrides[get_balance_repository] = (
lambda: mock_balance_repo
)
with patch("leggen.utils.config.config", mock_config):
response = api_client.get("/api/v1/accounts/test-account-123")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200
data = response.json()
assert data["id"] == "test-account-123"
@@ -121,7 +137,13 @@ class TestAccountsAPI:
assert len(data["balances"]) == 1
def test_get_account_balances_success(
self, api_client, mock_config, mock_auth_token, mock_db_path
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_db_path,
mock_balance_repo,
):
"""Test successful retrieval of account balances from database."""
mock_balances = [
@@ -149,15 +171,17 @@ class TestAccountsAPI:
},
]
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.accounts.database_service.get_balances_from_db",
return_value=mock_balances,
),
):
mock_balance_repo.get_balances.return_value = mock_balances
fastapi_app.dependency_overrides[get_balance_repository] = (
lambda: mock_balance_repo
)
with patch("leggen.utils.config.config", mock_config):
response = api_client.get("/api/v1/accounts/test-account-123/balances")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200
data = response.json()
assert len(data) == 2
@@ -167,12 +191,14 @@ class TestAccountsAPI:
def test_get_account_transactions_success(
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
sample_account_data,
sample_transaction_data,
mock_db_path,
mock_transaction_repo,
):
"""Test successful retrieval of account transactions from database."""
mock_transactions = [
@@ -191,21 +217,19 @@ class TestAccountsAPI:
}
]
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.accounts.database_service.get_transactions_from_db",
return_value=mock_transactions,
),
patch(
"leggen.api.routes.accounts.database_service.get_transaction_count_from_db",
return_value=1,
),
):
mock_transaction_repo.get_transactions.return_value = mock_transactions
fastapi_app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
with patch("leggen.utils.config.config", mock_config):
response = api_client.get(
"/api/v1/accounts/test-account-123/transactions?summary_only=true"
)
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200
data = response.json()
assert len(data) == 1
@@ -218,12 +242,14 @@ class TestAccountsAPI:
def test_get_account_transactions_full_details(
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
sample_account_data,
sample_transaction_data,
mock_db_path,
mock_transaction_repo,
):
"""Test retrieval of full transaction details from database."""
mock_transactions = [
@@ -242,21 +268,19 @@ class TestAccountsAPI:
}
]
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.accounts.database_service.get_transactions_from_db",
return_value=mock_transactions,
),
patch(
"leggen.api.routes.accounts.database_service.get_transaction_count_from_db",
return_value=1,
),
):
mock_transaction_repo.get_transactions.return_value = mock_transactions
fastapi_app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
with patch("leggen.utils.config.config", mock_config):
response = api_client.get(
"/api/v1/accounts/test-account-123/transactions?summary_only=false"
)
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200
data = response.json()
assert len(data) == 1
@@ -268,22 +292,36 @@ class TestAccountsAPI:
assert "raw_transaction" in transaction
def test_get_account_not_found(
self, api_client, mock_config, mock_auth_token, mock_db_path
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_db_path,
mock_account_repo,
):
"""Test handling of non-existent account."""
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.accounts.database_service.get_account_details_from_db",
return_value=None,
),
):
mock_account_repo.get_account.return_value = None
fastapi_app.dependency_overrides[get_account_repository] = (
lambda: mock_account_repo
)
with patch("leggen.utils.config.config", mock_config):
response = api_client.get("/api/v1/accounts/nonexistent")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 404
def test_update_account_display_name_success(
self, api_client, mock_config, mock_auth_token, mock_db_path
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_db_path,
mock_account_repo,
):
"""Test successful update of account display name."""
mock_account = {
@@ -297,41 +335,48 @@ class TestAccountsAPI:
"last_accessed": "2025-09-01T09:30:00Z",
}
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.accounts.database_service.get_account_details_from_db",
return_value=mock_account,
),
patch(
"leggen.api.routes.accounts.database_service.persist_account_details",
return_value=None,
),
):
mock_account_repo.get_account.return_value = mock_account
mock_account_repo.persist.return_value = mock_account
fastapi_app.dependency_overrides[get_account_repository] = (
lambda: mock_account_repo
)
with patch("leggen.utils.config.config", mock_config):
response = api_client.put(
"/api/v1/accounts/test-account-123",
json={"display_name": "My Custom Account Name"},
)
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200
data = response.json()
assert data["id"] == "test-account-123"
assert data["display_name"] == "My Custom Account Name"
def test_update_account_not_found(
self, api_client, mock_config, mock_auth_token, mock_db_path
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_db_path,
mock_account_repo,
):
"""Test updating non-existent account."""
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.accounts.database_service.get_account_details_from_db",
return_value=None,
),
):
mock_account_repo.get_account.return_value = None
fastapi_app.dependency_overrides[get_account_repository] = (
lambda: mock_account_repo
)
with patch("leggen.utils.config.config", mock_config):
response = api_client.put(
"/api/v1/accounts/nonexistent",
json={"display_name": "New Name"},
)
fastapi_app.dependency_overrides.clear()
assert response.status_code == 404

View File

@@ -5,13 +5,20 @@ from unittest.mock import patch
import pytest
from leggen.api.dependencies import get_transaction_repository
@pytest.mark.api
class TestTransactionsAPI:
"""Test transaction-related API endpoints."""
def test_get_all_transactions_success(
self, api_client, mock_config, mock_auth_token
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
):
"""Test successful retrieval of all transactions from database."""
mock_transactions = [
@@ -43,19 +50,17 @@ class TestTransactionsAPI:
},
]
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.transactions.database_service.get_transactions_from_db",
return_value=mock_transactions,
),
patch(
"leggen.api.routes.transactions.database_service.get_transaction_count_from_db",
return_value=2,
),
):
mock_transaction_repo.get_transactions.return_value = mock_transactions
mock_transaction_repo.get_count.return_value = len(mock_transactions)
fastapi_app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
with patch("leggen.utils.config.config", mock_config):
response = api_client.get("/api/v1/transactions?summary_only=true")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200
data = response.json()
assert len(data["data"]) == 2
@@ -70,7 +75,12 @@ class TestTransactionsAPI:
assert transaction["account_id"] == "test-account-123"
def test_get_all_transactions_full_details(
self, api_client, mock_config, mock_auth_token
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
):
"""Test retrieval of full transaction details from database."""
mock_transactions = [
@@ -89,19 +99,17 @@ class TestTransactionsAPI:
}
]
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.transactions.database_service.get_transactions_from_db",
return_value=mock_transactions,
),
patch(
"leggen.api.routes.transactions.database_service.get_transaction_count_from_db",
return_value=1,
),
):
mock_transaction_repo.get_transactions.return_value = mock_transactions
mock_transaction_repo.get_count.return_value = len(mock_transactions)
fastapi_app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
with patch("leggen.utils.config.config", mock_config):
response = api_client.get("/api/v1/transactions?summary_only=false")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200
data = response.json()
assert len(data["data"]) == 1
@@ -114,7 +122,12 @@ class TestTransactionsAPI:
assert "raw_transaction" in transaction
def test_get_transactions_with_filters(
self, api_client, mock_config, mock_auth_token
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
):
"""Test getting transactions with various filters."""
mock_transactions = [
@@ -133,17 +146,14 @@ class TestTransactionsAPI:
}
]
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.transactions.database_service.get_transactions_from_db",
return_value=mock_transactions,
) as mock_get_transactions,
patch(
"leggen.api.routes.transactions.database_service.get_transaction_count_from_db",
return_value=1,
),
):
mock_transaction_repo.get_transactions.return_value = mock_transactions
mock_transaction_repo.get_count.return_value = 1
fastapi_app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
with patch("leggen.utils.config.config", mock_config):
response = api_client.get(
"/api/v1/transactions?"
"account_id=test-account-123&"
@@ -156,10 +166,12 @@ class TestTransactionsAPI:
"per_page=10"
)
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200
# Verify the database service was called with correct filters
mock_get_transactions.assert_called_once_with(
# Verify the repository was called with correct filters
mock_transaction_repo.get_transactions.assert_called_once_with(
account_id="test-account-123",
limit=10,
offset=10, # (page-1) * per_page = (2-1) * 10 = 10
@@ -171,22 +183,26 @@ class TestTransactionsAPI:
)
def test_get_transactions_empty_result(
self, api_client, mock_config, mock_auth_token
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
):
"""Test getting transactions when database returns empty result."""
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.transactions.database_service.get_transactions_from_db",
return_value=[],
),
patch(
"leggen.api.routes.transactions.database_service.get_transaction_count_from_db",
return_value=0,
),
):
mock_transaction_repo.get_transactions.return_value = []
mock_transaction_repo.get_count.return_value = 0
fastapi_app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
with patch("leggen.utils.config.config", mock_config):
response = api_client.get("/api/v1/transactions")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200
data = response.json()
assert len(data["data"]) == 0
@@ -195,23 +211,37 @@ class TestTransactionsAPI:
assert data["total_pages"] == 0
def test_get_transactions_database_error(
self, api_client, mock_config, mock_auth_token
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
):
"""Test handling database error when getting transactions."""
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.transactions.database_service.get_transactions_from_db",
side_effect=Exception("Database connection failed"),
),
):
mock_transaction_repo.get_transactions.side_effect = Exception(
"Database connection failed"
)
fastapi_app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
with patch("leggen.utils.config.config", mock_config):
response = api_client.get("/api/v1/transactions")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 500
assert "Failed to get transactions" in response.json()["detail"]
def test_get_transaction_stats_success(
self, api_client, mock_config, mock_auth_token
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
):
"""Test successful retrieval of transaction statistics from database."""
mock_transactions = [
@@ -238,15 +268,16 @@ class TestTransactionsAPI:
},
]
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.transactions.database_service.get_transactions_from_db",
return_value=mock_transactions,
),
):
mock_transaction_repo.get_transactions.return_value = mock_transactions
fastapi_app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
with patch("leggen.utils.config.config", mock_config):
response = api_client.get("/api/v1/transactions/stats?days=30")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200
data = response.json()
@@ -264,7 +295,12 @@ class TestTransactionsAPI:
assert data["average_transaction"] == expected_avg
def test_get_transaction_stats_with_account_filter(
self, api_client, mock_config, mock_auth_token
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
):
"""Test getting transaction stats filtered by account."""
mock_transactions = [
@@ -277,37 +313,46 @@ class TestTransactionsAPI:
}
]
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.transactions.database_service.get_transactions_from_db",
return_value=mock_transactions,
) as mock_get_transactions,
):
mock_transaction_repo.get_transactions.return_value = mock_transactions
fastapi_app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
with patch("leggen.utils.config.config", mock_config):
response = api_client.get(
"/api/v1/transactions/stats?account_id=test-account-123"
)
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200
# Verify the database service was called with account filter
mock_get_transactions.assert_called_once()
call_kwargs = mock_get_transactions.call_args.kwargs
# Verify the repository was called with account filter
mock_transaction_repo.get_transactions.assert_called_once()
call_kwargs = mock_transaction_repo.get_transactions.call_args.kwargs
assert call_kwargs["account_id"] == "test-account-123"
def test_get_transaction_stats_empty_result(
self, api_client, mock_config, mock_auth_token
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
):
"""Test getting stats when no transactions match criteria."""
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.transactions.database_service.get_transactions_from_db",
return_value=[],
),
):
mock_transaction_repo.get_transactions.return_value = []
fastapi_app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
with patch("leggen.utils.config.config", mock_config):
response = api_client.get("/api/v1/transactions/stats")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200
data = response.json()
@@ -319,23 +364,37 @@ class TestTransactionsAPI:
assert data["accounts_included"] == 0
def test_get_transaction_stats_database_error(
self, api_client, mock_config, mock_auth_token
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
):
"""Test handling database error when getting stats."""
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.transactions.database_service.get_transactions_from_db",
side_effect=Exception("Database connection failed"),
),
):
mock_transaction_repo.get_transactions.side_effect = Exception(
"Database connection failed"
)
fastapi_app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
with patch("leggen.utils.config.config", mock_config):
response = api_client.get("/api/v1/transactions/stats")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 500
assert "Failed to get transaction stats" in response.json()["detail"]
def test_get_transaction_stats_custom_period(
self, api_client, mock_config, mock_auth_token
self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
):
"""Test getting transaction stats for custom time period."""
mock_transactions = [
@@ -348,21 +407,23 @@ class TestTransactionsAPI:
}
]
with (
patch("leggen.utils.config.config", mock_config),
patch(
"leggen.api.routes.transactions.database_service.get_transactions_from_db",
return_value=mock_transactions,
) as mock_get_transactions,
):
mock_transaction_repo.get_transactions.return_value = mock_transactions
fastapi_app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
with patch("leggen.utils.config.config", mock_config):
response = api_client.get("/api/v1/transactions/stats?days=7")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200
data = response.json()
assert data["period_days"] == 7
# Verify the date range was calculated correctly for 7 days
mock_get_transactions.assert_called_once()
call_kwargs = mock_get_transactions.call_args.kwargs
mock_transaction_repo.get_transactions.assert_called_once()
call_kwargs = mock_transaction_repo.get_transactions.call_args.kwargs
assert "date_from" in call_kwargs
assert "date_to" in call_kwargs

View File

@@ -120,12 +120,10 @@ class TestConfigurablePaths:
"iban": "TEST_IBAN",
}
# Use the internal balance persistence method since the test needs direct database access
# Use the public balance persistence method
import asyncio
asyncio.run(
database_service._persist_balance_sqlite("test-account", balance_data)
)
asyncio.run(database_service.persist_balance("test-account", balance_data))
# Retrieve balances
balances = asyncio.run(

View File

@@ -85,7 +85,7 @@ class TestDatabaseService:
):
"""Test successful retrieval of transactions from database."""
with patch.object(
database_service, "_get_transactions"
database_service.transactions, "get_transactions"
) as mock_get_transactions:
mock_get_transactions.return_value = sample_transactions_db_format
@@ -111,7 +111,7 @@ class TestDatabaseService:
):
"""Test retrieving transactions with filters."""
with patch.object(
database_service, "_get_transactions"
database_service.transactions, "get_transactions"
) as mock_get_transactions:
mock_get_transactions.return_value = sample_transactions_db_format
@@ -149,7 +149,7 @@ class TestDatabaseService:
async def test_get_transactions_from_db_error(self, database_service):
"""Test handling error when getting transactions."""
with patch.object(
database_service, "_get_transactions"
database_service.transactions, "get_transactions"
) as mock_get_transactions:
mock_get_transactions.side_effect = Exception("Database error")
@@ -159,7 +159,7 @@ class TestDatabaseService:
async def test_get_transaction_count_from_db_success(self, database_service):
"""Test successful retrieval of transaction count."""
with patch.object(database_service, "_get_transaction_count") as mock_get_count:
with patch.object(database_service.transactions, "get_count") as mock_get_count:
mock_get_count.return_value = 42
result = await database_service.get_transaction_count_from_db(
@@ -167,11 +167,18 @@ class TestDatabaseService:
)
assert result == 42
mock_get_count.assert_called_once_with(account_id="test-account-123")
mock_get_count.assert_called_once_with(
account_id="test-account-123",
date_from=None,
date_to=None,
min_amount=None,
max_amount=None,
search=None,
)
async def test_get_transaction_count_from_db_with_filters(self, database_service):
"""Test getting transaction count with filters."""
with patch.object(database_service, "_get_transaction_count") as mock_get_count:
with patch.object(database_service.transactions, "get_count") as mock_get_count:
mock_get_count.return_value = 15
result = await database_service.get_transaction_count_from_db(
@@ -185,7 +192,9 @@ class TestDatabaseService:
mock_get_count.assert_called_once_with(
account_id="test-account-123",
date_from="2025-09-01",
date_to=None,
min_amount=-100.0,
max_amount=None,
search="Coffee",
)
@@ -201,7 +210,7 @@ class TestDatabaseService:
async def test_get_transaction_count_from_db_error(self, database_service):
"""Test handling error when getting count."""
with patch.object(database_service, "_get_transaction_count") as mock_get_count:
with patch.object(database_service.transactions, "get_count") as mock_get_count:
mock_get_count.side_effect = Exception("Database error")
result = await database_service.get_transaction_count_from_db()
@@ -212,7 +221,9 @@ class TestDatabaseService:
self, database_service, sample_balances_db_format
):
"""Test successful retrieval of balances from database."""
with patch.object(database_service, "_get_balances") as mock_get_balances:
with patch.object(
database_service.balances, "get_balances"
) as mock_get_balances:
mock_get_balances.return_value = sample_balances_db_format
result = await database_service.get_balances_from_db(
@@ -234,7 +245,9 @@ class TestDatabaseService:
async def test_get_balances_from_db_error(self, database_service):
"""Test handling error when getting balances."""
with patch.object(database_service, "_get_balances") as mock_get_balances:
with patch.object(
database_service.balances, "get_balances"
) as mock_get_balances:
mock_get_balances.side_effect = Exception("Database error")
result = await database_service.get_balances_from_db()
@@ -249,7 +262,9 @@ class TestDatabaseService:
"iban": "LT313250081177977789",
}
with patch.object(database_service, "_get_account_summary") as mock_get_summary:
with patch.object(
database_service.transactions, "get_account_summary"
) as mock_get_summary:
mock_get_summary.return_value = mock_summary
result = await database_service.get_account_summary_from_db(
@@ -269,7 +284,9 @@ class TestDatabaseService:
async def test_get_account_summary_from_db_error(self, database_service):
"""Test handling error when getting summary."""
with patch.object(database_service, "_get_account_summary") as mock_get_summary:
with patch.object(
database_service.transactions, "get_account_summary"
) as mock_get_summary:
mock_get_summary.side_effect = Exception("Database error")
result = await database_service.get_account_summary_from_db(
@@ -291,87 +308,87 @@ class TestDatabaseService:
],
}
with patch("sqlite3.connect") as mock_connect:
mock_conn = mock_connect.return_value
mock_cursor = mock_conn.cursor.return_value
with (
patch.object(database_service.balances, "persist") as mock_persist,
patch.object(
database_service.balance_transformer, "transform_to_database_format"
) as mock_transform,
):
mock_transform.return_value = [
(
"test-account-123",
"REVOLUT_REVOLT21",
"active",
"LT313250081177977789",
1000.0,
"EUR",
"interimAvailable",
"2025-09-01T10:00:00",
)
]
await database_service._persist_balance_sqlite(
"test-account-123", balance_data
)
await database_service.persist_balance("test-account-123", balance_data)
# Verify database operations
mock_connect.assert_called()
mock_cursor.execute.assert_called() # Table creation and insert
mock_conn.commit.assert_called_once()
mock_conn.close.assert_called_once()
# Verify transformation and persistence were called
mock_transform.assert_called_once_with("test-account-123", balance_data)
mock_persist.assert_called_once()
async def test_persist_balance_sqlite_error(self, database_service):
"""Test handling error during balance persistence."""
balance_data = {"balances": []}
with patch("sqlite3.connect") as mock_connect:
mock_connect.side_effect = Exception("Database error")
with (
patch.object(database_service.balances, "persist") as mock_persist,
patch.object(
database_service.balance_transformer, "transform_to_database_format"
) as mock_transform,
):
mock_persist.side_effect = Exception("Database error")
mock_transform.return_value = []
with pytest.raises(Exception, match="Database error"):
await database_service._persist_balance_sqlite(
"test-account-123", balance_data
)
await database_service.persist_balance("test-account-123", balance_data)
async def test_persist_transactions_sqlite_success(
self, database_service, sample_transactions_db_format
):
"""Test successful transaction persistence."""
with patch("sqlite3.connect") as mock_connect:
mock_conn = mock_connect.return_value
mock_cursor = mock_conn.cursor.return_value
# Mock fetchone to return (0,) indicating transaction doesn't exist yet
mock_cursor.fetchone.return_value = (0,)
with patch.object(database_service.transactions, "persist") as mock_persist:
mock_persist.return_value = sample_transactions_db_format
result = await database_service._persist_transactions_sqlite(
result = await database_service.persist_transactions(
"test-account-123", sample_transactions_db_format
)
# Should return the transactions (assuming no duplicates)
assert len(result) >= 0 # Could be empty if all are duplicates
# Verify database operations
mock_connect.assert_called()
mock_cursor.execute.assert_called()
mock_conn.commit.assert_called_once()
mock_conn.close.assert_called_once()
# Should return the new transactions
assert len(result) == 2
mock_persist.assert_called_once_with(
"test-account-123", sample_transactions_db_format
)
async def test_persist_transactions_sqlite_duplicate_detection(
self, database_service, sample_transactions_db_format
):
"""Test that existing transactions are not returned as new."""
with patch("sqlite3.connect") as mock_connect:
mock_conn = mock_connect.return_value
mock_cursor = mock_conn.cursor.return_value
# Mock fetchone to return (1,) indicating transaction already exists
mock_cursor.fetchone.return_value = (1,)
with patch.object(database_service.transactions, "persist") as mock_persist:
# Return empty list indicating all were duplicates
mock_persist.return_value = []
result = await database_service._persist_transactions_sqlite(
result = await database_service.persist_transactions(
"test-account-123", sample_transactions_db_format
)
# Should return empty list since all transactions already exist
assert len(result) == 0
# Verify database operations still happened (INSERT OR REPLACE executed)
mock_connect.assert_called()
mock_cursor.execute.assert_called()
mock_conn.commit.assert_called_once()
mock_conn.close.assert_called_once()
mock_persist.assert_called_once()
async def test_persist_transactions_sqlite_error(self, database_service):
"""Test handling error during transaction persistence."""
with patch("sqlite3.connect") as mock_connect:
mock_connect.side_effect = Exception("Database error")
with patch.object(database_service.transactions, "persist") as mock_persist:
mock_persist.side_effect = Exception("Database error")
with pytest.raises(Exception, match="Database error"):
await database_service._persist_transactions_sqlite(
"test-account-123", []
)
await database_service.persist_transactions("test-account-123", [])
async def test_process_transactions_booked_and_pending(self, database_service):
"""Test processing transactions with both booked and pending."""

View File

@@ -0,0 +1,244 @@
"""Tests for sync service notification functionality."""
from unittest.mock import patch
import pytest
from leggen.services.sync_service import SyncService
@pytest.mark.unit
class TestSyncNotifications:
"""Test sync service notification functionality."""
@pytest.mark.asyncio
async def test_sync_failure_sends_notification(self):
"""Test that sync failures trigger notifications."""
sync_service = SyncService()
# Mock the dependencies
with (
patch.object(
sync_service.gocardless, "get_requisitions"
) as mock_get_requisitions,
patch.object(
sync_service.gocardless, "get_account_details"
) as mock_get_details,
patch.object(
sync_service.notifications, "send_sync_failure_notification"
) as mock_send_notification,
patch.object(sync_service.sync, "persist", return_value=1),
):
# Setup: One requisition with one account that will fail
mock_get_requisitions.return_value = {
"results": [
{
"id": "req-123",
"institution_id": "TEST_BANK",
"status": "LN",
"accounts": ["account-1"],
}
]
}
# Make account details fail
mock_get_details.side_effect = Exception("API Error")
# Execute: Run sync which should fail for the account
await sync_service.sync_all_accounts()
# Assert: Notification should be sent for the failed account
mock_send_notification.assert_called_once()
call_args = mock_send_notification.call_args[0][0]
assert call_args["account_id"] == "account-1"
assert "API Error" in call_args["error"]
assert call_args["type"] == "account_sync_failure"
@pytest.mark.asyncio
async def test_expired_requisition_sends_notification(self):
"""Test that expired requisitions trigger expiry notifications."""
sync_service = SyncService()
# Mock the dependencies
with (
patch.object(
sync_service.gocardless, "get_requisitions"
) as mock_get_requisitions,
patch.object(
sync_service.notifications, "send_expiry_notification"
) as mock_send_expiry,
patch.object(sync_service.sync, "persist", return_value=1),
):
# Setup: One expired requisition
mock_get_requisitions.return_value = {
"results": [
{
"id": "req-expired",
"institution_id": "EXPIRED_BANK",
"status": "EX",
"accounts": [],
}
]
}
# Execute: Run sync
await sync_service.sync_all_accounts()
# Assert: Expiry notification should be sent
mock_send_expiry.assert_called_once()
call_args = mock_send_expiry.call_args[0][0]
assert call_args["requisition_id"] == "req-expired"
assert call_args["bank"] == "EXPIRED_BANK"
assert call_args["status"] == "expired"
assert call_args["days_left"] == 0
@pytest.mark.asyncio
async def test_multiple_failures_send_multiple_notifications(self):
"""Test that multiple account failures send multiple notifications."""
sync_service = SyncService()
# Mock the dependencies
with (
patch.object(
sync_service.gocardless, "get_requisitions"
) as mock_get_requisitions,
patch.object(
sync_service.gocardless, "get_account_details"
) as mock_get_details,
patch.object(
sync_service.notifications, "send_sync_failure_notification"
) as mock_send_notification,
patch.object(sync_service.sync, "persist", return_value=1),
):
# Setup: One requisition with two accounts that will fail
mock_get_requisitions.return_value = {
"results": [
{
"id": "req-123",
"institution_id": "TEST_BANK",
"status": "LN",
"accounts": ["account-1", "account-2"],
}
]
}
# Make all account details fail
mock_get_details.side_effect = Exception("API Error")
# Execute: Run sync
await sync_service.sync_all_accounts()
# Assert: Two notifications should be sent
assert mock_send_notification.call_count == 2
@pytest.mark.asyncio
async def test_successful_sync_no_failure_notification(self):
"""Test that successful syncs don't send failure notifications."""
sync_service = SyncService()
# Mock the dependencies
with (
patch.object(
sync_service.gocardless, "get_requisitions"
) as mock_get_requisitions,
patch.object(
sync_service.gocardless, "get_account_details"
) as mock_get_details,
patch.object(
sync_service.gocardless, "get_account_balances"
) as mock_get_balances,
patch.object(
sync_service.gocardless, "get_account_transactions"
) as mock_get_transactions,
patch.object(
sync_service.notifications, "send_sync_failure_notification"
) as mock_send_notification,
patch.object(sync_service.notifications, "send_transaction_notifications"),
patch.object(sync_service.accounts, "persist"),
patch.object(sync_service.balances, "persist"),
patch.object(
sync_service.transaction_processor,
"process_transactions",
return_value=[],
),
patch.object(sync_service.transactions, "persist", return_value=[]),
patch.object(sync_service.sync, "persist", return_value=1),
):
# Setup: One requisition with one account that succeeds
mock_get_requisitions.return_value = {
"results": [
{
"id": "req-123",
"institution_id": "TEST_BANK",
"status": "LN",
"accounts": ["account-1"],
}
]
}
mock_get_details.return_value = {
"id": "account-1",
"institution_id": "TEST_BANK",
"status": "READY",
"iban": "TEST123",
}
mock_get_balances.return_value = {
"balances": [{"balanceAmount": {"amount": "100", "currency": "EUR"}}]
}
mock_get_transactions.return_value = {"transactions": {"booked": []}}
# Execute: Run sync
await sync_service.sync_all_accounts()
# Assert: No failure notification should be sent
mock_send_notification.assert_not_called()
@pytest.mark.asyncio
async def test_notification_failure_does_not_stop_sync(self):
"""Test that notification failures don't stop the sync process."""
sync_service = SyncService()
# Mock the dependencies
with (
patch.object(
sync_service.gocardless, "get_requisitions"
) as mock_get_requisitions,
patch.object(
sync_service.gocardless, "get_account_details"
) as mock_get_details,
patch.object(
sync_service.notifications, "_send_discord_sync_failure"
) as mock_discord_notification,
patch.object(
sync_service.notifications, "_send_telegram_sync_failure"
) as mock_telegram_notification,
patch.object(sync_service.sync, "persist", return_value=1),
):
# Setup: One requisition with one account that will fail
mock_get_requisitions.return_value = {
"results": [
{
"id": "req-123",
"institution_id": "TEST_BANK",
"status": "LN",
"accounts": ["account-1"],
}
]
}
# Make account details fail
mock_get_details.side_effect = Exception("API Error")
# Make both notification methods fail
mock_discord_notification.side_effect = Exception("Discord Error")
mock_telegram_notification.side_effect = Exception("Telegram Error")
# Execute: Run sync - should not raise exception from notification
result = await sync_service.sync_all_accounts()
# The sync should complete with errors but not crash from notifications
assert result.success is False
assert len(result.errors) > 0
assert "API Error" in result.errors[0]