Compare commits

..

21 Commits

Author SHA1 Message Date
Elisiário Couto
9e9b1cf15f refactor(api): Update all modified files with dependency injection changes. 2025-12-10 00:18:53 +00:00
Elisiário Couto
9dc6357905 refactor(api): Remove DatabaseService layer and implement dependency injection.
- Remove DatabaseService abstraction layer from API routes
- Implement FastAPI dependency injection for repositories
- Create leggen/api/dependencies.py with factory functions
- Update routes to use AccountRepo, BalanceRepo, TransactionRepo directly
- Refactor SyncService to use repositories instead of DatabaseService
- Deprecate DatabaseService with warnings for backward compatibility
- Update all tests to use FastAPI dependency overrides pattern
- Fix mypy type errors in routes

Benefits:
- Simpler architecture with one less abstraction layer
- More explicit dependencies via function signatures
- Better testability with FastAPI's app.dependency_overrides
- Clearer separation of concerns

All 114 tests passing, mypy clean.
2025-12-10 00:18:53 +00:00
Elisiário Couto
5f87991076 refactor(api): Split DatabaseService into repository pattern.
Split the monolithic DatabaseService (1,492 lines) into focused repository
modules using the repository pattern for better maintainability and
separation of concerns.

Changes:
- Create new repositories/ directory with 5 focused repositories:
  - TransactionRepository: transaction data operations (264 lines)
  - AccountRepository: account data operations (128 lines)
  - BalanceRepository: balance data operations (107 lines)
  - MigrationRepository: all database migrations (629 lines)
  - SyncRepository: sync operation tracking (132 lines)
  - BaseRepository: shared database connection logic (28 lines)

- Refactor DatabaseService into a clean facade (287 lines):
  - Delegates data access to repositories
  - Maintains public API (no breaking changes)
  - Keeps data processors in service layer
  - Preserves require_sqlite decorator

- Update tests to mock repository methods instead of private methods
- Fix test references to internal methods (_persist_*, _get_*)

Benefits:
- Clear separation of concerns (one repository per domain)
- Easier maintenance (changes isolated to specific repositories)
- Better testability (repositories can be mocked individually)
- Improved code organization (from 1 file to 7 focused files)

All 114 tests passing.
2025-12-08 23:21:55 +00:00
Elisiário Couto
267db8ac63 refactor(api): Improve database connection management and reduce boilerplate.
- Add context manager for database connections with proper cleanup
- Add @require_sqlite decorator to eliminate duplicate checks
- Refactor 9 core CRUD methods to use managed connections
- Reduce code by 50 lines while improving resource management
- All 114 tests passing
2025-12-08 22:54:57 +00:00
Elisiário Couto
7007043521 Reformat. 2025-12-08 22:05:04 +00:00
Elisiário Couto
fbb3eb9e64 refactor: Consolidate service layer with dedicated data processors.
Introduces a DataProcessor layer to separate transformation logic from orchestration and persistence concerns:

- Created data_processors/ directory with AccountEnricher, BalanceTransformer, AnalyticsProcessor, and moved TransactionProcessor
- Refactored SyncService to pure orchestrator, removing account/balance enrichment logic
- Refactored DatabaseService to pure CRUD, removing analytics and transformation logic
- Extracted 90+ lines of analytics SQL from DatabaseService to AnalyticsProcessor
- Extracted 80+ lines of balance transformation logic to BalanceTransformer
- Maintained backward compatibility - all 109 tests pass
- No API contract changes

This improves code clarity, testability, and maintainability while maintaining the existing API surface.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Haiku 4.5 <noreply@anthropic.com>
2025-12-08 22:04:00 +00:00
Elisiário Couto
3d5994bf30 Fix lint issues. 2025-12-08 21:59:54 +00:00
Elisiário Couto
edbc1cb39e Update frontend/src/components/TransactionsTable.tsx
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-12-08 21:59:54 +00:00
Elisiário Couto
504f78aa85 Update frontend/src/components/filters/FilterBar.tsx
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-12-08 21:59:54 +00:00
Elisiário Couto
cbbc316537 chore(ci): Fix workflow permissions. 2025-12-08 21:59:54 +00:00
Elisiário Couto
18ee52bdff fix(frontend): Prevent full transactions page reload on search. 2025-12-08 21:59:54 +00:00
Elisiário Couto
07edfeaf25 fix(frontend): Blur balances in transactions page cards. 2025-12-08 21:59:54 +00:00
copilot-swe-agent[bot]
c8b161e7f2 refactor(frontend): Address code review feedback on focus and currency handling.
Co-authored-by: elisiariocouto <818914+elisiariocouto@users.noreply.github.com>
2025-12-08 21:59:54 +00:00
copilot-swe-agent[bot]
2c85722fd0 feat(frontend): Fix search focus issue and add transaction statistics.
Co-authored-by: elisiariocouto <818914+elisiariocouto@users.noreply.github.com>
2025-12-08 21:59:54 +00:00
copilot-swe-agent[bot]
88037f328d fix: Address code review feedback on notification error handling.
Co-authored-by: elisiariocouto <818914+elisiariocouto@users.noreply.github.com>
2025-12-07 19:05:28 +00:00
copilot-swe-agent[bot]
d58894d07c refactor: Replace magic numbers with named constants.
Co-authored-by: elisiariocouto <818914+elisiariocouto@users.noreply.github.com>
2025-12-07 19:05:28 +00:00
copilot-swe-agent[bot]
1a2ec45f89 feat: Add sync error and account expiry notifications.
Co-authored-by: elisiariocouto <818914+elisiariocouto@users.noreply.github.com>
2025-12-07 19:05:28 +00:00
Elisiário Couto
5de9badfde fix(frontend): Blur balances in Account Management page. 2025-12-07 12:00:23 +00:00
Elisiário Couto
159cba508e fix: Resolve all lint warnings and type errors across frontend and backend.
Frontend:
- Memoize pagination object in TransactionsTable to prevent unnecessary re-renders and fix exhaustive-deps warning
- Add optional success and message fields to backup API response types for proper error handling

Backend:
- Add TypedDict for transaction type configuration to improve type safety in generate_sample_db
- Fix unpacking of amount_range with explicit float type hints
- Add explicit type hints for descriptions dictionary and specific_descriptions variable
- Fix sync endpoint return types: get_sync_status returns SyncStatus and sync_now returns SyncResult
- Fix transactions endpoint data type declaration to properly support Union types in PaginatedResponse

All checks now pass:
- Frontend: npm lint and npm build ✓
- Backend: mypy type checking ✓
- Backend: ruff lint on modified files ✓

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Haiku 4.5 <noreply@anthropic.com>
2025-12-07 12:00:23 +00:00
copilot-swe-agent[bot]
966440006a fix(frontend): Remove unused import in TransactionDistribution
Co-authored-by: elisiariocouto <818914+elisiariocouto@users.noreply.github.com>
2025-12-07 12:00:23 +00:00
copilot-swe-agent[bot]
a592b827aa feat(frontend): Add balance visibility toggle with blur effect
Co-authored-by: elisiariocouto <818914+elisiariocouto@users.noreply.github.com>
2025-12-07 12:00:23 +00:00
52 changed files with 3269 additions and 2059 deletions

View File

@@ -6,6 +6,9 @@ on:
pull_request: pull_request:
branches: ["main", "dev"] branches: ["main", "dev"]
permissions:
contents: read
jobs: jobs:
test-python: test-python:
name: Test Python name: Test Python

View File

@@ -5,6 +5,11 @@ on:
tags: tags:
- "**" - "**"
permissions:
contents: write
packages: write
id-token: write
jobs: jobs:
build: build:
runs-on: ubuntu-latest runs-on: ubuntu-latest
@@ -44,6 +49,9 @@ jobs:
push-docker-backend: push-docker-backend:
runs-on: ubuntu-latest runs-on: ubuntu-latest
permissions:
contents: read
packages: write
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
@@ -90,6 +98,9 @@ jobs:
push-docker-frontend: push-docker-frontend:
runs-on: ubuntu-latest runs-on: ubuntu-latest
permissions:
contents: read
packages: write
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
@@ -137,6 +148,8 @@ jobs:
create-github-release: create-github-release:
name: Create GitHub Release name: Create GitHub Release
runs-on: ubuntu-latest runs-on: ubuntu-latest
permissions:
contents: write
needs: [build, publish-to-pypi, push-docker-backend, push-docker-frontend] needs: [build, publish-to-pypi, push-docker-backend, push-docker-frontend]
steps: steps:
- name: Checkout - name: Checkout

127
REFACTORING_SUMMARY.md Normal file
View File

@@ -0,0 +1,127 @@
# Backend Refactoring Summary
## What Was Accomplished ✅
### 1. Removed DatabaseService Layer from Production Code
- **Removed**: The `DatabaseService` class is no longer used in production API routes
- **Replaced with**: Direct repository usage via FastAPI dependency injection
- **Files changed**:
- `leggen/api/routes/accounts.py` - Now uses `AccountRepo`, `BalanceRepo`, `TransactionRepo`, `AnalyticsProc`
- `leggen/api/routes/transactions.py` - Now uses `TransactionRepo`, `AnalyticsProc`
- `leggen/services/sync_service.py` - Now uses repositories directly
- `leggen/commands/server.py` - Now uses `MigrationRepository` directly
### 2. Created Dependency Injection System
- **New file**: `leggen/api/dependencies.py`
- **Provides**: Centralized dependency injection setup for FastAPI
- **Includes**: Factory functions for all repositories and data processors
- **Type annotations**: `AccountRepo`, `BalanceRepo`, `TransactionRepo`, etc.
### 3. Simplified Code Architecture
- **Before**: Routes → DatabaseService → Repositories
- **After**: Routes → Repositories (via DI)
- **Benefits**:
- One less layer of indirection
- Clearer dependencies
- Easier to test with FastAPI's `app.dependency_overrides`
- Better separation of concerns
### 4. Maintained Backward Compatibility
- **DatabaseService** is kept but deprecated for test compatibility
- Added deprecation warning when instantiated
- Tests continue to work without immediate changes required
## Code Statistics
- **Lines removed from API layer**: ~16 imports of DatabaseService
- **New dependency injection file**: 80 lines
- **Files refactored**: 4 main files
## Benefits Achieved
1. **Cleaner Architecture**: Removed unnecessary abstraction layer
2. **Better Testability**: FastAPI dependency overrides are cleaner than mocking
3. **More Explicit Dependencies**: Function signatures show exactly what's needed
4. **Easier to Maintain**: Less indirection makes code easier to follow
5. **Performance**: Slightly fewer object instantiations per request
## What Still Needs Work
### Tests Need Updating
The test files still patch `database_service` which no longer exists in routes:
```python
# Old test pattern (needs updating):
patch("leggen.api.routes.accounts.database_service.get_accounts_from_db")
# New pattern (should use):
app.dependency_overrides[get_account_repository] = lambda: mock_repo
```
**Files needing test updates**:
- `tests/unit/test_api_accounts.py` (7 tests failing)
- `tests/unit/test_api_transactions.py` (10 tests failing)
- `tests/unit/test_analytics_fix.py` (2 tests failing)
### Test Update Strategy
**Option 1 - Quick Fix (Recommended for now)**:
Keep `DatabaseService` and have routes import it again temporarily, update tests at leisure.
**Option 2 - Proper Fix**:
Update all tests to use FastAPI dependency overrides pattern:
```python
def test_get_accounts(fastapi_app, api_client, mock_account_repo):
mock_account_repo.get_accounts.return_value = [...]
fastapi_app.dependency_overrides[get_account_repository] = lambda: mock_account_repo
response = api_client.get("/api/v1/accounts")
fastapi_app.dependency_overrides.clear()
```
## Migration Path Forward
1.**Phase 1**: Refactor production code (DONE)
2. 🔄 **Phase 2**: Update tests to use dependency overrides (TODO)
3. 🔄 **Phase 3**: Remove deprecated `DatabaseService` completely (TODO)
4. 🔄 **Phase 4**: Consider extracting analytics logic to separate service (TODO)
## How to Use the New System
### In API Routes
```python
from leggen.api.dependencies import AccountRepo, BalanceRepo
@router.get("/accounts")
async def get_accounts(
account_repo: AccountRepo, # Injected automatically
balance_repo: BalanceRepo, # Injected automatically
) -> List[AccountDetails]:
accounts = account_repo.get_accounts()
# ...
```
### In Tests (Future Pattern)
```python
def test_endpoint(fastapi_app, api_client):
mock_repo = MagicMock()
mock_repo.get_accounts.return_value = [...]
fastapi_app.dependency_overrides[get_account_repository] = lambda: mock_repo
response = api_client.get("/api/v1/accounts")
# assertions...
```
## Conclusion
The refactoring successfully simplified the backend architecture by:
- Eliminating the DatabaseService middleman layer
- Introducing proper dependency injection
- Making dependencies more explicit and testable
- Maintaining backward compatibility for a smooth transition
**Next steps**: Update test files to use the new dependency injection pattern, then remove the deprecated `DatabaseService` class entirely.

View File

@@ -23,6 +23,7 @@ import {
import { Button } from "./ui/button"; import { Button } from "./ui/button";
import { Alert, AlertDescription, AlertTitle } from "./ui/alert"; import { Alert, AlertDescription, AlertTitle } from "./ui/alert";
import AccountsSkeleton from "./AccountsSkeleton"; import AccountsSkeleton from "./AccountsSkeleton";
import { BlurredValue } from "./ui/blurred-value";
import type { Account, Balance } from "../types/api"; import type { Account, Balance } from "../types/api";
// Helper function to get status indicator color and styles // Helper function to get status indicator color and styles
@@ -158,7 +159,7 @@ export default function AccountsOverview() {
Total Balance Total Balance
</p> </p>
<p className="text-2xl font-bold text-foreground"> <p className="text-2xl font-bold text-foreground">
{formatCurrency(totalBalance)} <BlurredValue>{formatCurrency(totalBalance)}</BlurredValue>
</p> </p>
</div> </div>
<div className="p-3 bg-green-100 dark:bg-green-900/20 rounded-full"> <div className="p-3 bg-green-100 dark:bg-green-900/20 rounded-full">
@@ -369,7 +370,9 @@ export default function AccountsOverview() {
isPositive ? "text-green-600" : "text-red-600" isPositive ? "text-green-600" : "text-red-600"
}`} }`}
> >
{formatCurrency(balance, currency)} <BlurredValue>
{formatCurrency(balance, currency)}
</BlurredValue>
</p> </p>
</div> </div>
</div> </div>

View File

@@ -16,6 +16,7 @@ import { apiClient } from "../lib/api";
import { formatCurrency } from "../lib/utils"; import { formatCurrency } from "../lib/utils";
import { useState } from "react"; import { useState } from "react";
import type { Account } from "../types/api"; import type { Account } from "../types/api";
import { BlurredValue } from "./ui/blurred-value";
import { import {
Sidebar, Sidebar,
SidebarContent, SidebarContent,
@@ -130,7 +131,7 @@ export function AppSidebar({ ...props }: React.ComponentProps<typeof Sidebar>) {
<div className="px-3 pb-2"> <div className="px-3 pb-2">
<p className="text-xl font-bold text-foreground"> <p className="text-xl font-bold text-foreground">
{formatCurrency(totalBalance)} <BlurredValue>{formatCurrency(totalBalance)}</BlurredValue>
</p> </p>
<p className="text-sm text-muted-foreground"> <p className="text-sm text-muted-foreground">
{accounts?.length || 0} accounts {accounts?.length || 0} accounts
@@ -163,7 +164,9 @@ export function AppSidebar({ ...props }: React.ComponentProps<typeof Sidebar>) {
"Unnamed Account"} "Unnamed Account"}
</p> </p>
<p className="text-xs font-semibold text-foreground"> <p className="text-xs font-semibold text-foreground">
{formatCurrency(primaryBalance, currency)} <BlurredValue>
{formatCurrency(primaryBalance, currency)}
</BlurredValue>
</p> </p>
</div> </div>
</div> </div>

View File

@@ -34,6 +34,7 @@ import { Button } from "./ui/button";
import { Alert, AlertDescription, AlertTitle } from "./ui/alert"; import { Alert, AlertDescription, AlertTitle } from "./ui/alert";
import { Label } from "./ui/label"; import { Label } from "./ui/label";
import { Tabs, TabsContent, TabsList, TabsTrigger } from "./ui/tabs"; import { Tabs, TabsContent, TabsList, TabsTrigger } from "./ui/tabs";
import { BlurredValue } from "./ui/blurred-value";
import AccountsSkeleton from "./AccountsSkeleton"; import AccountsSkeleton from "./AccountsSkeleton";
import NotificationFiltersDrawer from "./NotificationFiltersDrawer"; import NotificationFiltersDrawer from "./NotificationFiltersDrawer";
import DiscordConfigDrawer from "./DiscordConfigDrawer"; import DiscordConfigDrawer from "./DiscordConfigDrawer";
@@ -491,13 +492,13 @@ export default function Settings() {
) : ( ) : (
<TrendingDown className="h-4 w-4 text-red-500" /> <TrendingDown className="h-4 w-4 text-red-500" />
)} )}
<p <BlurredValue
className={`text-base sm:text-lg font-semibold ${ className={`text-base sm:text-lg font-semibold ${
isPositive ? "text-green-600" : "text-red-600" isPositive ? "text-green-600" : "text-red-600"
}`} }`}
> >
{formatCurrency(balance, currency)} {formatCurrency(balance, currency)}
</p> </BlurredValue>
</div> </div>
</div> </div>
</div> </div>

View File

@@ -3,6 +3,7 @@ import { Activity, Wifi, WifiOff } from "lucide-react";
import { useQuery } from "@tanstack/react-query"; import { useQuery } from "@tanstack/react-query";
import { apiClient } from "../lib/api"; import { apiClient } from "../lib/api";
import { ThemeToggle } from "./ui/theme-toggle"; import { ThemeToggle } from "./ui/theme-toggle";
import { BalanceToggle } from "./ui/balance-toggle";
import { Separator } from "./ui/separator"; import { Separator } from "./ui/separator";
import { SidebarTrigger } from "./ui/sidebar"; import { SidebarTrigger } from "./ui/sidebar";
@@ -77,6 +78,7 @@ export function SiteHeader() {
</> </>
)} )}
</div> </div>
<BalanceToggle />
<ThemeToggle /> <ThemeToggle />
</div> </div>
</div> </div>

View File

@@ -1,4 +1,4 @@
import { useState, useEffect } from "react"; import { useState, useEffect, useMemo } from "react";
import { useQuery } from "@tanstack/react-query"; import { useQuery } from "@tanstack/react-query";
import { import {
useReactTable, useReactTable,
@@ -31,7 +31,8 @@ import { DataTablePagination } from "./ui/data-table-pagination";
import { Card } from "./ui/card"; import { Card } from "./ui/card";
import { Alert, AlertDescription, AlertTitle } from "./ui/alert"; import { Alert, AlertDescription, AlertTitle } from "./ui/alert";
import { Button } from "./ui/button"; import { Button } from "./ui/button";
import type { Account, Transaction, ApiResponse } from "../types/api"; import { BlurredValue } from "./ui/blurred-value";
import type { Account, Transaction, PaginatedResponse } from "../types/api";
export default function TransactionsTable() { export default function TransactionsTable() {
// Filter state consolidated into a single object // Filter state consolidated into a single object
@@ -102,7 +103,7 @@ export default function TransactionsTable() {
isLoading: transactionsLoading, isLoading: transactionsLoading,
error: transactionsError, error: transactionsError,
refetch: refetchTransactions, refetch: refetchTransactions,
} = useQuery<ApiResponse<Transaction[]>>({ } = useQuery<PaginatedResponse<Transaction>>({
queryKey: [ queryKey: [
"transactions", "transactions",
filterState.selectedAccount, filterState.selectedAccount,
@@ -122,10 +123,52 @@ export default function TransactionsTable() {
search: debouncedSearchTerm || undefined, search: debouncedSearchTerm || undefined,
summaryOnly: false, summaryOnly: false,
}), }),
placeholderData: (previousData) => previousData,
}); });
const transactions = transactionsResponse?.data || []; const transactions = useMemo(
const pagination = transactionsResponse?.pagination; () => transactionsResponse?.data || [],
[transactionsResponse],
);
const pagination = useMemo(
() =>
transactionsResponse
? {
page: transactionsResponse.page,
total_pages: transactionsResponse.total_pages,
per_page: transactionsResponse.per_page,
total: transactionsResponse.total,
has_next: transactionsResponse.has_next,
has_prev: transactionsResponse.has_prev,
}
: undefined,
[transactionsResponse],
);
// Calculate stats from current page transactions, memoized for performance
const stats = useMemo(() => {
const totalIncome = transactions
.filter((t: Transaction) => t.transaction_value > 0)
.reduce((sum: number, t: Transaction) => sum + t.transaction_value, 0);
const totalExpenses = Math.abs(
transactions
.filter((t: Transaction) => t.transaction_value < 0)
.reduce((sum: number, t: Transaction) => sum + t.transaction_value, 0)
);
// Get currency from first transaction, fallback to EUR
const displayCurrency = transactions.length > 0 ? transactions[0].transaction_currency : "EUR";
return {
totalCount: pagination?.total || 0,
pageCount: transactions.length,
totalIncome,
totalExpenses,
netChange: totalIncome - totalExpenses,
displayCurrency,
};
}, [transactions, pagination]);
// Check if search is currently debouncing // Check if search is currently debouncing
const isSearchLoading = filterState.searchTerm !== debouncedSearchTerm; const isSearchLoading = filterState.searchTerm !== debouncedSearchTerm;
@@ -221,11 +264,13 @@ export default function TransactionsTable() {
isPositive ? "text-green-600" : "text-red-600" isPositive ? "text-green-600" : "text-red-600"
}`} }`}
> >
{isPositive ? "+" : ""} <BlurredValue>
{formatCurrency( {isPositive ? "+" : ""}
transaction.transaction_value, {formatCurrency(
transaction.transaction_currency, transaction.transaction_value,
)} transaction.transaction_currency,
)}
</BlurredValue>
</p> </p>
</div> </div>
); );
@@ -350,6 +395,78 @@ export default function TransactionsTable() {
isSearchLoading={isSearchLoading} isSearchLoading={isSearchLoading}
/> />
{/* Transaction Statistics */}
{transactions.length > 0 && (
<div className="grid grid-cols-1 md:grid-cols-4 gap-4">
<Card className="p-4">
<div className="flex items-center justify-between">
<div>
<p className="text-xs text-muted-foreground uppercase tracking-wider">
Showing
</p>
<p className="text-2xl font-bold text-foreground mt-1">
{stats.pageCount}
</p>
<p className="text-xs text-muted-foreground mt-1">
of {stats.totalCount} total
</p>
</div>
</div>
</Card>
<Card className="p-4">
<div className="flex items-center justify-between">
<div>
<p className="text-xs text-muted-foreground uppercase tracking-wider">
Income
</p>
<BlurredValue className="text-2xl font-bold text-green-600 mt-1 block">
+{formatCurrency(stats.totalIncome, stats.displayCurrency)}
</BlurredValue>
</div>
<TrendingUp className="h-8 w-8 text-green-600 opacity-50" />
</div>
</Card>
<Card className="p-4">
<div className="flex items-center justify-between">
<div>
<p className="text-xs text-muted-foreground uppercase tracking-wider">
Expenses
</p>
<BlurredValue className="text-2xl font-bold text-red-600 mt-1 block">
-{formatCurrency(stats.totalExpenses, stats.displayCurrency)}
</BlurredValue>
</div>
<TrendingDown className="h-8 w-8 text-red-600 opacity-50" />
</div>
</Card>
<Card className="p-4">
<div className="flex items-center justify-between">
<div>
<p className="text-xs text-muted-foreground uppercase tracking-wider">
Net Change
</p>
<BlurredValue
className={`text-2xl font-bold mt-1 block ${
stats.netChange >= 0 ? "text-green-600" : "text-red-600"
}`}
>
{stats.netChange >= 0 ? "+" : ""}
{formatCurrency(stats.netChange, stats.displayCurrency)}
</BlurredValue>
</div>
{stats.netChange >= 0 ? (
<TrendingUp className="h-8 w-8 text-green-600 opacity-50" />
) : (
<TrendingDown className="h-8 w-8 text-red-600 opacity-50" />
)}
</div>
</Card>
</div>
)}
{/* Responsive Table/Cards */} {/* Responsive Table/Cards */}
<Card> <Card>
{/* Desktop Table View (hidden on mobile) */} {/* Desktop Table View (hidden on mobile) */}
@@ -525,11 +642,13 @@ export default function TransactionsTable() {
isPositive ? "text-green-600" : "text-red-600" isPositive ? "text-green-600" : "text-red-600"
}`} }`}
> >
{isPositive ? "+" : ""} <BlurredValue>
{formatCurrency( {isPositive ? "+" : ""}
transaction.transaction_value, {formatCurrency(
transaction.transaction_currency, transaction.transaction_value,
)} transaction.transaction_currency,
)}
</BlurredValue>
</p> </p>
<Button <Button
onClick={() => handleViewRaw(transaction)} onClick={() => handleViewRaw(transaction)}

View File

@@ -8,6 +8,8 @@ import {
ResponsiveContainer, ResponsiveContainer,
Legend, Legend,
} from "recharts"; } from "recharts";
import { useBalanceVisibility } from "../../contexts/BalanceVisibilityContext";
import { cn } from "../../lib/utils";
import type { Balance, Account } from "../../types/api"; import type { Balance, Account } from "../../types/api";
interface BalanceChartProps { interface BalanceChartProps {
@@ -42,6 +44,8 @@ export default function BalanceChart({
accounts, accounts,
className, className,
}: BalanceChartProps) { }: BalanceChartProps) {
const { isBalanceVisible } = useBalanceVisibility();
// Create a lookup map for account info // Create a lookup map for account info
const accountMap = accounts.reduce( const accountMap = accounts.reduce(
(map, account) => { (map, account) => {
@@ -149,7 +153,7 @@ export default function BalanceChart({
<h3 className="text-lg font-medium text-foreground mb-4"> <h3 className="text-lg font-medium text-foreground mb-4">
Balance Progress Over Time Balance Progress Over Time
</h3> </h3>
<div className="h-80"> <div className={cn("h-80", !isBalanceVisible && "blur-md select-none")}>
<ResponsiveContainer width="100%" height="100%"> <ResponsiveContainer width="100%" height="100%">
<AreaChart data={finalData}> <AreaChart data={finalData}>
<CartesianGrid strokeDasharray="3 3" /> <CartesianGrid strokeDasharray="3 3" />

View File

@@ -8,6 +8,8 @@ import {
ResponsiveContainer, ResponsiveContainer,
} from "recharts"; } from "recharts";
import { useQuery } from "@tanstack/react-query"; import { useQuery } from "@tanstack/react-query";
import { useBalanceVisibility } from "../../contexts/BalanceVisibilityContext";
import { cn } from "../../lib/utils";
import apiClient from "../../lib/api"; import apiClient from "../../lib/api";
interface MonthlyTrendsProps { interface MonthlyTrendsProps {
@@ -29,6 +31,8 @@ export default function MonthlyTrends({
className, className,
days = 365, days = 365,
}: MonthlyTrendsProps) { }: MonthlyTrendsProps) {
const { isBalanceVisible } = useBalanceVisibility();
// Get pre-calculated monthly stats from the new endpoint // Get pre-calculated monthly stats from the new endpoint
const { data: monthlyData, isLoading } = useQuery({ const { data: monthlyData, isLoading } = useQuery({
queryKey: ["monthly-stats", days], queryKey: ["monthly-stats", days],
@@ -103,7 +107,7 @@ export default function MonthlyTrends({
<h3 className="text-lg font-medium text-foreground mb-4"> <h3 className="text-lg font-medium text-foreground mb-4">
{getTitle(days)} {getTitle(days)}
</h3> </h3>
<div className="h-80"> <div className={cn("h-80", !isBalanceVisible && "blur-md select-none")}>
<ResponsiveContainer width="100%" height="100%"> <ResponsiveContainer width="100%" height="100%">
<BarChart <BarChart
data={displayData} data={displayData}

View File

@@ -1,5 +1,6 @@
import type { LucideIcon } from "lucide-react"; import type { LucideIcon } from "lucide-react";
import { Card, CardContent } from "../ui/card"; import { Card, CardContent } from "../ui/card";
import { BlurredValue } from "../ui/blurred-value";
import { cn } from "../../lib/utils"; import { cn } from "../../lib/utils";
interface StatCardProps { interface StatCardProps {
@@ -13,6 +14,7 @@ interface StatCardProps {
}; };
className?: string; className?: string;
iconColor?: "green" | "blue" | "red" | "purple" | "orange" | "default"; iconColor?: "green" | "blue" | "red" | "purple" | "orange" | "default";
shouldBlur?: boolean;
} }
export default function StatCard({ export default function StatCard({
@@ -23,6 +25,7 @@ export default function StatCard({
trend, trend,
className, className,
iconColor = "default", iconColor = "default",
shouldBlur = false,
}: StatCardProps) { }: StatCardProps) {
return ( return (
<Card className={cn(className)}> <Card className={cn(className)}>
@@ -31,7 +34,9 @@ export default function StatCard({
<div> <div>
<p className="text-sm font-medium text-muted-foreground">{title}</p> <p className="text-sm font-medium text-muted-foreground">{title}</p>
<div className="flex items-baseline"> <div className="flex items-baseline">
<p className="text-2xl font-bold text-foreground">{value}</p> <p className="text-2xl font-bold text-foreground">
{shouldBlur ? <BlurredValue>{value}</BlurredValue> : value}
</p>
{trend && ( {trend && (
<div <div
className={cn( className={cn(

View File

@@ -6,6 +6,7 @@ import {
Tooltip, Tooltip,
Legend, Legend,
} from "recharts"; } from "recharts";
import { BlurredValue } from "../ui/blurred-value";
import type { Account } from "../../types/api"; import type { Account } from "../../types/api";
interface TransactionDistributionProps { interface TransactionDistributionProps {
@@ -85,7 +86,8 @@ export default function TransactionDistribution({
<div className="bg-card p-3 border rounded shadow-lg"> <div className="bg-card p-3 border rounded shadow-lg">
<p className="font-medium text-foreground">{data.name}</p> <p className="font-medium text-foreground">{data.name}</p>
<p className="text-primary"> <p className="text-primary">
Balance: {data.value.toLocaleString()} Balance:{" "}
<BlurredValue>{data.value.toLocaleString()}</BlurredValue>
</p> </p>
<p className="text-muted-foreground">{percentage}% of total</p> <p className="text-muted-foreground">{percentage}% of total</p>
</div> </div>
@@ -138,7 +140,7 @@ export default function TransactionDistribution({
<span className="text-foreground">{item.name}</span> <span className="text-foreground">{item.name}</span>
</div> </div>
<span className="font-medium text-foreground"> <span className="font-medium text-foreground">
{item.value.toLocaleString()} <BlurredValue>{item.value.toLocaleString()}</BlurredValue>
</span> </span>
</div> </div>
))} ))}

View File

@@ -1,3 +1,4 @@
import { useRef, useEffect } from "react";
import { Search } from "lucide-react"; import { Search } from "lucide-react";
import { Input } from "@/components/ui/input"; import { Input } from "@/components/ui/input";
import { cn } from "@/lib/utils"; import { cn } from "@/lib/utils";
@@ -30,6 +31,21 @@ export function FilterBar({
isSearchLoading = false, isSearchLoading = false,
className, className,
}: FilterBarProps) { }: FilterBarProps) {
const searchInputRef = useRef<HTMLInputElement>(null);
const cursorPositionRef = useRef<number | null>(null);
// Maintain focus and cursor position on search input during re-renders
useEffect(() => {
const currentInput = searchInputRef.current;
if (!currentInput) return;
// Restore focus and cursor position after data fetches complete
if (cursorPositionRef.current !== null && document.activeElement !== currentInput) {
currentInput.focus();
currentInput.setSelectionRange(cursorPositionRef.current, cursorPositionRef.current);
}
}, [isSearchLoading]);
const hasActiveFilters = const hasActiveFilters =
filterState.searchTerm || filterState.searchTerm ||
filterState.selectedAccount || filterState.selectedAccount ||
@@ -61,9 +77,19 @@ export function FilterBar({
<div className="relative w-[200px]"> <div className="relative w-[200px]">
<Search className="absolute left-3 top-1/2 transform -translate-y-1/2 h-4 w-4 text-muted-foreground" /> <Search className="absolute left-3 top-1/2 transform -translate-y-1/2 h-4 w-4 text-muted-foreground" />
<Input <Input
ref={searchInputRef}
placeholder="Search transactions..." placeholder="Search transactions..."
value={filterState.searchTerm} value={filterState.searchTerm}
onChange={(e) => onFilterChange("searchTerm", e.target.value)} onChange={(e) => {
cursorPositionRef.current = e.target.selectionStart;
onFilterChange("searchTerm", e.target.value);
}}
onFocus={() => {
cursorPositionRef.current = searchInputRef.current?.selectionStart ?? null;
}}
onBlur={() => {
cursorPositionRef.current = null;
}}
className="pl-9 pr-8 bg-background" className="pl-9 pr-8 bg-background"
/> />
{isSearchLoading && ( {isSearchLoading && (
@@ -99,9 +125,19 @@ export function FilterBar({
<div className="relative"> <div className="relative">
<Search className="absolute left-3 top-1/2 transform -translate-y-1/2 h-4 w-4 text-muted-foreground" /> <Search className="absolute left-3 top-1/2 transform -translate-y-1/2 h-4 w-4 text-muted-foreground" />
<Input <Input
ref={searchInputRef}
placeholder="Search..." placeholder="Search..."
value={filterState.searchTerm} value={filterState.searchTerm}
onChange={(e) => onFilterChange("searchTerm", e.target.value)} onChange={(e) => {
cursorPositionRef.current = e.target.selectionStart;
onFilterChange("searchTerm", e.target.value);
}}
onFocus={() => {
cursorPositionRef.current = searchInputRef.current?.selectionStart ?? null;
}}
onBlur={() => {
cursorPositionRef.current = null;
}}
className="pl-9 pr-8 bg-background w-full" className="pl-9 pr-8 bg-background w-full"
/> />
{isSearchLoading && ( {isSearchLoading && (

View File

@@ -0,0 +1,26 @@
import { Eye, EyeOff } from "lucide-react";
import { Button } from "./button";
import { useBalanceVisibility } from "../../contexts/BalanceVisibilityContext";
export function BalanceToggle() {
const { isBalanceVisible, toggleBalanceVisibility } = useBalanceVisibility();
return (
<Button
variant="outline"
size="icon"
onClick={toggleBalanceVisibility}
className="h-8 w-8"
title={isBalanceVisible ? "Hide balances" : "Show balances"}
>
{isBalanceVisible ? (
<Eye className="h-4 w-4" />
) : (
<EyeOff className="h-4 w-4" />
)}
<span className="sr-only">
{isBalanceVisible ? "Hide balances" : "Show balances"}
</span>
</Button>
);
}

View File

@@ -0,0 +1,23 @@
import React from "react";
import { useBalanceVisibility } from "../../contexts/BalanceVisibilityContext";
import { cn } from "../../lib/utils";
interface BlurredValueProps {
children: React.ReactNode;
className?: string;
}
export function BlurredValue({ children, className }: BlurredValueProps) {
const { isBalanceVisible } = useBalanceVisibility();
return (
<span
className={cn(
isBalanceVisible ? "" : "blur-md select-none",
className,
)}
>
{children}
</span>
);
}

View File

@@ -0,0 +1,48 @@
import React, { createContext, useContext, useEffect, useState } from "react";
interface BalanceVisibilityContextType {
isBalanceVisible: boolean;
toggleBalanceVisibility: () => void;
}
const BalanceVisibilityContext = createContext<
BalanceVisibilityContextType | undefined
>(undefined);
export function BalanceVisibilityProvider({
children,
}: {
children: React.ReactNode;
}) {
const [isBalanceVisible, setIsBalanceVisible] = useState<boolean>(() => {
const stored = localStorage.getItem("balanceVisible");
// Default to true (visible) if not set
return stored === null ? true : stored === "true";
});
useEffect(() => {
localStorage.setItem("balanceVisible", String(isBalanceVisible));
}, [isBalanceVisible]);
const toggleBalanceVisibility = () => {
setIsBalanceVisible((prev) => !prev);
};
return (
<BalanceVisibilityContext.Provider
value={{ isBalanceVisible, toggleBalanceVisibility }}
>
{children}
</BalanceVisibilityContext.Provider>
);
}
export function useBalanceVisibility() {
const context = useContext(BalanceVisibilityContext);
if (context === undefined) {
throw new Error(
"useBalanceVisibility must be used within a BalanceVisibilityProvider",
);
}
return context;
}

View File

@@ -288,11 +288,14 @@ export const apiClient = {
return response.data; return response.data;
}, },
testBackupConnection: async (test: BackupTest): Promise<{ connected?: boolean }> => { testBackupConnection: async (
const response = await api.post<{ connected?: boolean }>( test: BackupTest,
"/backup/test", ): Promise<{ connected?: boolean; success?: boolean; message?: string }> => {
test, const response = await api.post<{
); connected?: boolean;
success?: boolean;
message?: string;
}>("/backup/test", test);
return response.data; return response.data;
}, },
@@ -301,11 +304,20 @@ export const apiClient = {
return response.data; return response.data;
}, },
performBackupOperation: async (operation: BackupOperation): Promise<{ operation: string; completed: boolean }> => { performBackupOperation: async (
const response = await api.post<{ operation: string; completed: boolean }>( operation: BackupOperation,
"/backup/operation", ): Promise<{
operation, operation: string;
); completed: boolean;
success?: boolean;
message?: string;
}> => {
const response = await api.post<{
operation: string;
completed: boolean;
success?: boolean;
message?: string;
}>("/backup/operation", operation);
return response.data; return response.data;
}, },
}; };

View File

@@ -3,6 +3,7 @@ import { createRoot } from "react-dom/client";
import { createRouter, RouterProvider } from "@tanstack/react-router"; import { createRouter, RouterProvider } from "@tanstack/react-router";
import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
import { ThemeProvider } from "./contexts/ThemeContext"; import { ThemeProvider } from "./contexts/ThemeContext";
import { BalanceVisibilityProvider } from "./contexts/BalanceVisibilityContext";
import "./index.css"; import "./index.css";
import { routeTree } from "./routeTree.gen"; import { routeTree } from "./routeTree.gen";
import { registerSW } from "virtual:pwa-register"; import { registerSW } from "virtual:pwa-register";
@@ -73,7 +74,9 @@ createRoot(document.getElementById("root")!).render(
<StrictMode> <StrictMode>
<QueryClientProvider client={queryClient}> <QueryClientProvider client={queryClient}>
<ThemeProvider> <ThemeProvider>
<RouterProvider router={router} /> <BalanceVisibilityProvider>
<RouterProvider router={router} />
</BalanceVisibilityProvider>
</ThemeProvider> </ThemeProvider>
</QueryClientProvider> </QueryClientProvider>
</StrictMode>, </StrictMode>,

View File

@@ -88,6 +88,7 @@ function AnalyticsDashboard() {
subtitle="Inflows this period" subtitle="Inflows this period"
icon={TrendingUp} icon={TrendingUp}
iconColor="green" iconColor="green"
shouldBlur={true}
/> />
<StatCard <StatCard
title="Total Expenses" title="Total Expenses"
@@ -95,6 +96,7 @@ function AnalyticsDashboard() {
subtitle="Outflows this period" subtitle="Outflows this period"
icon={TrendingDown} icon={TrendingDown}
iconColor="red" iconColor="red"
shouldBlur={true}
/> />
</div> </div>
@@ -106,6 +108,7 @@ function AnalyticsDashboard() {
subtitle="Income minus expenses" subtitle="Income minus expenses"
icon={CreditCard} icon={CreditCard}
iconColor={(stats?.net_change || 0) >= 0 ? "green" : "red"} iconColor={(stats?.net_change || 0) >= 0 ? "green" : "red"}
shouldBlur={true}
/> />
<StatCard <StatCard
title="Average Transaction" title="Average Transaction"
@@ -113,6 +116,7 @@ function AnalyticsDashboard() {
subtitle="Per transaction" subtitle="Per transaction"
icon={Activity} icon={Activity}
iconColor="purple" iconColor="purple"
shouldBlur={true}
/> />
<StatCard <StatCard
title="Active Accounts" title="Active Accounts"

View File

@@ -0,0 +1,75 @@
"""FastAPI dependency injection setup for repositories and services."""
from typing import Annotated
from fastapi import Depends
from leggen.repositories import (
AccountRepository,
BalanceRepository,
MigrationRepository,
SyncRepository,
TransactionRepository,
)
from leggen.services.data_processors import (
AnalyticsProcessor,
BalanceTransformer,
TransactionProcessor,
)
from leggen.utils.config import config
def get_account_repository() -> AccountRepository:
"""Get account repository instance."""
return AccountRepository()
def get_balance_repository() -> BalanceRepository:
"""Get balance repository instance."""
return BalanceRepository()
def get_transaction_repository() -> TransactionRepository:
"""Get transaction repository instance."""
return TransactionRepository()
def get_sync_repository() -> SyncRepository:
"""Get sync repository instance."""
return SyncRepository()
def get_migration_repository() -> MigrationRepository:
"""Get migration repository instance."""
return MigrationRepository()
def get_transaction_processor() -> TransactionProcessor:
"""Get transaction processor instance."""
return TransactionProcessor()
def get_balance_transformer() -> BalanceTransformer:
"""Get balance transformer instance."""
return BalanceTransformer()
def get_analytics_processor() -> AnalyticsProcessor:
"""Get analytics processor instance."""
return AnalyticsProcessor()
def is_sqlite_enabled() -> bool:
"""Check if SQLite is enabled in configuration."""
return config.database_config.get("sqlite", True)
# Type annotations for dependency injection
AccountRepo = Annotated[AccountRepository, Depends(get_account_repository)]
BalanceRepo = Annotated[BalanceRepository, Depends(get_balance_repository)]
TransactionRepo = Annotated[TransactionRepository, Depends(get_transaction_repository)]
SyncRepo = Annotated[SyncRepository, Depends(get_sync_repository)]
MigrationRepo = Annotated[MigrationRepository, Depends(get_migration_repository)]
TransactionProc = Annotated[TransactionProcessor, Depends(get_transaction_processor)]
BalanceTransform = Annotated[BalanceTransformer, Depends(get_balance_transformer)]
AnalyticsProc = Annotated[AnalyticsProcessor, Depends(get_analytics_processor)]

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, Generic, List, TypeVar from typing import Generic, List, TypeVar
from pydantic import BaseModel from pydantic import BaseModel

View File

@@ -3,6 +3,12 @@ from typing import List, Optional, Union
from fastapi import APIRouter, HTTPException, Query from fastapi import APIRouter, HTTPException, Query
from loguru import logger from loguru import logger
from leggen.api.dependencies import (
AccountRepo,
AnalyticsProc,
BalanceRepo,
TransactionRepo,
)
from leggen.api.models.accounts import ( from leggen.api.models.accounts import (
AccountBalance, AccountBalance,
AccountDetails, AccountDetails,
@@ -10,28 +16,27 @@ from leggen.api.models.accounts import (
Transaction, Transaction,
TransactionSummary, TransactionSummary,
) )
from leggen.services.database_service import DatabaseService
router = APIRouter() router = APIRouter()
database_service = DatabaseService()
@router.get("/accounts") @router.get("/accounts")
async def get_all_accounts() -> List[AccountDetails]: async def get_all_accounts(
account_repo: AccountRepo,
balance_repo: BalanceRepo,
) -> List[AccountDetails]:
"""Get all connected accounts from database""" """Get all connected accounts from database"""
try: try:
accounts = [] accounts = []
# Get all account details from database # Get all account details from database
db_accounts = await database_service.get_accounts_from_db() db_accounts = account_repo.get_accounts()
# Process accounts found in database # Process accounts found in database
for db_account in db_accounts: for db_account in db_accounts:
try: try:
# Get latest balances from database for this account # Get latest balances from database for this account
balances_data = await database_service.get_balances_from_db( balances_data = balance_repo.get_balances(db_account["id"])
db_account["id"]
)
# Process balances # Process balances
balances = [] balances = []
@@ -77,11 +82,15 @@ async def get_all_accounts() -> List[AccountDetails]:
@router.get("/accounts/{account_id}") @router.get("/accounts/{account_id}")
async def get_account_details(account_id: str) -> AccountDetails: async def get_account_details(
account_id: str,
account_repo: AccountRepo,
balance_repo: BalanceRepo,
) -> AccountDetails:
"""Get details for a specific account from database""" """Get details for a specific account from database"""
try: try:
# Get account details from database # Get account details from database
db_account = await database_service.get_account_details_from_db(account_id) db_account = account_repo.get_account(account_id)
if not db_account: if not db_account:
raise HTTPException( raise HTTPException(
@@ -89,7 +98,7 @@ async def get_account_details(account_id: str) -> AccountDetails:
) )
# Get latest balances from database for this account # Get latest balances from database for this account
balances_data = await database_service.get_balances_from_db(account_id) balances_data = balance_repo.get_balances(account_id)
# Process balances # Process balances
balances = [] balances = []
@@ -129,11 +138,14 @@ async def get_account_details(account_id: str) -> AccountDetails:
@router.get("/accounts/{account_id}/balances") @router.get("/accounts/{account_id}/balances")
async def get_account_balances(account_id: str) -> List[AccountBalance]: async def get_account_balances(
account_id: str,
balance_repo: BalanceRepo,
) -> List[AccountBalance]:
"""Get balances for a specific account from database""" """Get balances for a specific account from database"""
try: try:
# Get balances from database instead of GoCardless API # Get balances from database instead of GoCardless API
db_balances = await database_service.get_balances_from_db(account_id=account_id) db_balances = balance_repo.get_balances(account_id=account_id)
balances = [] balances = []
for balance in db_balances: for balance in db_balances:
@@ -158,19 +170,20 @@ async def get_account_balances(account_id: str) -> List[AccountBalance]:
@router.get("/balances") @router.get("/balances")
async def get_all_balances() -> List[dict]: async def get_all_balances(
account_repo: AccountRepo,
balance_repo: BalanceRepo,
) -> List[dict]:
"""Get all balances from all accounts in database""" """Get all balances from all accounts in database"""
try: try:
# Get all accounts first to iterate through them # Get all accounts first to iterate through them
db_accounts = await database_service.get_accounts_from_db() db_accounts = account_repo.get_accounts()
all_balances = [] all_balances = []
for db_account in db_accounts: for db_account in db_accounts:
try: try:
# Get balances for this account # Get balances for this account
db_balances = await database_service.get_balances_from_db( db_balances = balance_repo.get_balances(account_id=db_account["id"])
account_id=db_account["id"]
)
# Process balances and add account info # Process balances and add account info
for balance in db_balances: for balance in db_balances:
@@ -205,6 +218,7 @@ async def get_all_balances() -> List[dict]:
@router.get("/balances/history") @router.get("/balances/history")
async def get_historical_balances( async def get_historical_balances(
analytics_proc: AnalyticsProc,
days: Optional[int] = Query( days: Optional[int] = Query(
default=365, le=1095, ge=1, description="Number of days of history to retrieve" default=365, le=1095, ge=1, description="Number of days of history to retrieve"
), ),
@@ -214,9 +228,12 @@ async def get_historical_balances(
) -> List[dict]: ) -> List[dict]:
"""Get historical balance progression calculated from transaction history""" """Get historical balance progression calculated from transaction history"""
try: try:
from leggen.utils.paths import path_manager
# Get historical balances from database # Get historical balances from database
historical_balances = await database_service.get_historical_balances_from_db( db_path = path_manager.get_database_path()
account_id=account_id, days=days or 365 historical_balances = analytics_proc.calculate_historical_balances(
db_path, account_id=account_id, days=days or 365
) )
return historical_balances return historical_balances
@@ -231,6 +248,7 @@ async def get_historical_balances(
@router.get("/accounts/{account_id}/transactions") @router.get("/accounts/{account_id}/transactions")
async def get_account_transactions( async def get_account_transactions(
account_id: str, account_id: str,
transaction_repo: TransactionRepo,
limit: Optional[int] = Query(default=100, le=500), limit: Optional[int] = Query(default=100, le=500),
offset: Optional[int] = Query(default=0, ge=0), offset: Optional[int] = Query(default=0, ge=0),
summary_only: bool = Query( summary_only: bool = Query(
@@ -240,15 +258,10 @@ async def get_account_transactions(
"""Get transactions for a specific account from database""" """Get transactions for a specific account from database"""
try: try:
# Get transactions from database instead of GoCardless API # Get transactions from database instead of GoCardless API
db_transactions = await database_service.get_transactions_from_db( db_transactions = transaction_repo.get_transactions(
account_id=account_id, account_id=account_id,
limit=limit, limit=limit,
offset=offset, offset=offset or 0,
)
# Get total count for pagination info
total_transactions = await database_service.get_transaction_count_from_db(
account_id=account_id,
) )
data: Union[List[TransactionSummary], List[Transaction]] data: Union[List[TransactionSummary], List[Transaction]]
@@ -300,12 +313,14 @@ async def get_account_transactions(
@router.put("/accounts/{account_id}") @router.put("/accounts/{account_id}")
async def update_account_details( async def update_account_details(
account_id: str, update_data: AccountUpdate account_id: str,
update_data: AccountUpdate,
account_repo: AccountRepo,
) -> dict: ) -> dict:
"""Update account details (currently only display_name)""" """Update account details (currently only display_name)"""
try: try:
# Get current account details # Get current account details
current_account = await database_service.get_account_details_from_db(account_id) current_account = account_repo.get_account(account_id)
if not current_account: if not current_account:
raise HTTPException( raise HTTPException(
@@ -318,7 +333,7 @@ async def update_account_details(
updated_account_data["display_name"] = update_data.display_name updated_account_data["display_name"] = update_data.display_name
# Persist updated account details # Persist updated account details
await database_service.persist_account_details(updated_account_data) account_repo.persist(updated_account_data)
return {"id": account_id, "display_name": update_data.display_name} return {"id": account_id, "display_name": update_data.display_name}

View File

@@ -129,9 +129,7 @@ async def test_backup_connection(test_request: BackupTest) -> dict:
success = await backup_service.test_connection(s3_config) success = await backup_service.test_connection(s3_config)
if not success: if not success:
raise HTTPException( raise HTTPException(status_code=400, detail="S3 connection test failed")
status_code=400, detail="S3 connection test failed"
)
return {"connected": True} return {"connected": True}
@@ -193,9 +191,7 @@ async def backup_operation(operation_request: BackupOperation) -> dict:
success = await backup_service.backup_database(database_path) success = await backup_service.backup_database(database_path)
if not success: if not success:
raise HTTPException( raise HTTPException(status_code=500, detail="Database backup failed")
status_code=500, detail="Database backup failed"
)
return {"operation": "backup", "completed": True} return {"operation": "backup", "completed": True}
@@ -213,9 +209,7 @@ async def backup_operation(operation_request: BackupOperation) -> dict:
) )
if not success: if not success:
raise HTTPException( raise HTTPException(status_code=500, detail="Database restore failed")
status_code=500, detail="Database restore failed"
)
return {"operation": "restore", "completed": True} return {"operation": "restore", "completed": True}
else: else:

View File

@@ -3,7 +3,7 @@ from typing import Optional
from fastapi import APIRouter, BackgroundTasks, HTTPException from fastapi import APIRouter, BackgroundTasks, HTTPException
from loguru import logger from loguru import logger
from leggen.api.models.sync import SchedulerConfig, SyncRequest from leggen.api.models.sync import SchedulerConfig, SyncRequest, SyncResult, SyncStatus
from leggen.background.scheduler import scheduler from leggen.background.scheduler import scheduler
from leggen.services.sync_service import SyncService from leggen.services.sync_service import SyncService
from leggen.utils.config import config from leggen.utils.config import config
@@ -13,7 +13,7 @@ sync_service = SyncService()
@router.get("/sync/status") @router.get("/sync/status")
async def get_sync_status() -> dict: async def get_sync_status() -> SyncStatus:
"""Get current sync status""" """Get current sync status"""
try: try:
status = await sync_service.get_sync_status() status = await sync_service.get_sync_status()
@@ -78,7 +78,7 @@ async def trigger_sync(
@router.post("/sync/now") @router.post("/sync/now")
async def sync_now(sync_request: Optional[SyncRequest] = None) -> dict: async def sync_now(sync_request: Optional[SyncRequest] = None) -> SyncResult:
"""Run sync synchronously and return results (slower, for testing)""" """Run sync synchronously and return results (slower, for testing)"""
try: try:
if sync_request and sync_request.account_ids: if sync_request and sync_request.account_ids:
@@ -198,9 +198,10 @@ async def stop_scheduler() -> dict:
async def get_sync_operations(limit: int = 50, offset: int = 0) -> dict: async def get_sync_operations(limit: int = 50, offset: int = 0) -> dict:
"""Get sync operations history""" """Get sync operations history"""
try: try:
operations = await sync_service.database.get_sync_operations( from leggen.repositories import SyncRepository
limit=limit, offset=offset
) sync_repo = SyncRepository()
operations = sync_repo.get_operations(limit=limit, offset=offset)
return {"operations": operations, "count": len(operations)} return {"operations": operations, "count": len(operations)}

View File

@@ -4,16 +4,16 @@ from typing import List, Optional, Union
from fastapi import APIRouter, HTTPException, Query from fastapi import APIRouter, HTTPException, Query
from loguru import logger from loguru import logger
from leggen.api.dependencies import AnalyticsProc, TransactionRepo
from leggen.api.models.accounts import Transaction, TransactionSummary from leggen.api.models.accounts import Transaction, TransactionSummary
from leggen.api.models.common import PaginatedResponse from leggen.api.models.common import PaginatedResponse
from leggen.services.database_service import DatabaseService
router = APIRouter() router = APIRouter()
database_service = DatabaseService()
@router.get("/transactions") @router.get("/transactions")
async def get_all_transactions( async def get_all_transactions(
transaction_repo: TransactionRepo,
page: int = Query(default=1, ge=1, description="Page number (1-based)"), page: int = Query(default=1, ge=1, description="Page number (1-based)"),
per_page: int = Query(default=50, le=500, description="Items per page"), per_page: int = Query(default=50, le=500, description="Items per page"),
summary_only: bool = Query( summary_only: bool = Query(
@@ -43,7 +43,7 @@ async def get_all_transactions(
limit = per_page limit = per_page
# Get transactions from database instead of GoCardless API # Get transactions from database instead of GoCardless API
db_transactions = await database_service.get_transactions_from_db( db_transactions = transaction_repo.get_transactions(
account_id=account_id, account_id=account_id,
limit=limit, limit=limit,
offset=offset, offset=offset,
@@ -55,7 +55,7 @@ async def get_all_transactions(
) )
# Get total count for pagination info (respecting the same filters) # Get total count for pagination info (respecting the same filters)
total_transactions = await database_service.get_transaction_count_from_db( total_transactions = transaction_repo.get_count(
account_id=account_id, account_id=account_id,
date_from=date_from, date_from=date_from,
date_to=date_to, date_to=date_to,
@@ -64,11 +64,9 @@ async def get_all_transactions(
search=search, search=search,
) )
data: Union[List[TransactionSummary], List[Transaction]]
if summary_only: if summary_only:
# Return simplified transaction summaries # Return simplified transaction summaries
data = [ data: list[TransactionSummary | Transaction] = [
TransactionSummary( TransactionSummary(
transaction_id=txn["transactionId"], # NEW: stable bank-provided ID transaction_id=txn["transactionId"], # NEW: stable bank-provided ID
internal_transaction_id=txn.get("internalTransactionId"), internal_transaction_id=txn.get("internalTransactionId"),
@@ -121,6 +119,7 @@ async def get_all_transactions(
@router.get("/transactions/stats") @router.get("/transactions/stats")
async def get_transaction_stats( async def get_transaction_stats(
transaction_repo: TransactionRepo,
days: int = Query(default=30, description="Number of days to include in stats"), days: int = Query(default=30, description="Number of days to include in stats"),
account_id: Optional[str] = Query(default=None, description="Filter by account ID"), account_id: Optional[str] = Query(default=None, description="Filter by account ID"),
) -> dict: ) -> dict:
@@ -135,7 +134,7 @@ async def get_transaction_stats(
date_to = end_date.isoformat() date_to = end_date.isoformat()
# Get transactions from database # Get transactions from database
recent_transactions = await database_service.get_transactions_from_db( recent_transactions = transaction_repo.get_transactions(
account_id=account_id, account_id=account_id,
date_from=date_from, date_from=date_from,
date_to=date_to, date_to=date_to,
@@ -200,6 +199,7 @@ async def get_transaction_stats(
@router.get("/transactions/analytics") @router.get("/transactions/analytics")
async def get_transactions_for_analytics( async def get_transactions_for_analytics(
transaction_repo: TransactionRepo,
days: int = Query(default=365, description="Number of days to include"), days: int = Query(default=365, description="Number of days to include"),
account_id: Optional[str] = Query(default=None, description="Filter by account ID"), account_id: Optional[str] = Query(default=None, description="Filter by account ID"),
) -> List[dict]: ) -> List[dict]:
@@ -214,7 +214,7 @@ async def get_transactions_for_analytics(
date_to = end_date.isoformat() date_to = end_date.isoformat()
# Get ALL transactions from database (no limit for analytics) # Get ALL transactions from database (no limit for analytics)
transactions = await database_service.get_transactions_from_db( transactions = transaction_repo.get_transactions(
account_id=account_id, account_id=account_id,
date_from=date_from, date_from=date_from,
date_to=date_to, date_to=date_to,
@@ -246,11 +246,14 @@ async def get_transactions_for_analytics(
@router.get("/transactions/monthly-stats") @router.get("/transactions/monthly-stats")
async def get_monthly_transaction_stats( async def get_monthly_transaction_stats(
analytics_proc: AnalyticsProc,
days: int = Query(default=365, description="Number of days to include"), days: int = Query(default=365, description="Number of days to include"),
account_id: Optional[str] = Query(default=None, description="Filter by account ID"), account_id: Optional[str] = Query(default=None, description="Filter by account ID"),
) -> List[dict]: ) -> List[dict]:
"""Get monthly transaction statistics aggregated by the database""" """Get monthly transaction statistics aggregated by the database"""
try: try:
from leggen.utils.paths import path_manager
# Date range for monthly stats # Date range for monthly stats
end_date = datetime.now() end_date = datetime.now()
start_date = end_date - timedelta(days=days) start_date = end_date - timedelta(days=days)
@@ -260,10 +263,9 @@ async def get_monthly_transaction_stats(
date_to = end_date.isoformat() date_to = end_date.isoformat()
# Get monthly aggregated stats from database # Get monthly aggregated stats from database
monthly_stats = await database_service.get_monthly_transaction_stats_from_db( db_path = path_manager.get_database_path()
account_id=account_id, monthly_stats = analytics_proc.calculate_monthly_stats(
date_from=date_from, db_path, account_id=account_id, date_from=date_from, date_to=date_to
date_to=date_to,
) )
return monthly_stats return monthly_stats

View File

@@ -5,11 +5,19 @@ import random
import sqlite3 import sqlite3
from datetime import datetime, timedelta from datetime import datetime, timedelta
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any, TypedDict
import click import click
class TransactionType(TypedDict):
"""Type definition for transaction type configuration."""
description: str
amount_range: tuple[float, float]
frequency: float
class SampleDataGenerator: class SampleDataGenerator:
"""Generates realistic sample data for testing Leggen.""" """Generates realistic sample data for testing Leggen."""
@@ -42,7 +50,7 @@ class SampleDataGenerator:
}, },
] ]
self.transaction_types = [ self.transaction_types: list[TransactionType] = [
{ {
"description": "Grocery Store", "description": "Grocery Store",
"amount_range": (-150, -20), "amount_range": (-150, -20),
@@ -227,6 +235,8 @@ class SampleDataGenerator:
)[0] )[0]
# Generate transaction amount # Generate transaction amount
min_amount: float
max_amount: float
min_amount, max_amount = transaction_type["amount_range"] min_amount, max_amount = transaction_type["amount_range"]
amount = round(random.uniform(min_amount, max_amount), 2) amount = round(random.uniform(min_amount, max_amount), 2)
@@ -245,7 +255,7 @@ class SampleDataGenerator:
internal_transaction_id = f"int-txn-{random.randint(100000, 999999)}" internal_transaction_id = f"int-txn-{random.randint(100000, 999999)}"
# Create realistic descriptions # Create realistic descriptions
descriptions = { descriptions: dict[str, list[str]] = {
"Grocery Store": [ "Grocery Store": [
"TESCO", "TESCO",
"SAINSBURY'S", "SAINSBURY'S",
@@ -273,7 +283,7 @@ class SampleDataGenerator:
"Transfer to Savings": ["SAVINGS TRANSFER", "INVESTMENT TRANSFER"], "Transfer to Savings": ["SAVINGS TRANSFER", "INVESTMENT TRANSFER"],
} }
specific_descriptions = descriptions.get( specific_descriptions: list[str] = descriptions.get(
transaction_type["description"], [transaction_type["description"]] transaction_type["description"], [transaction_type["description"]]
) )
description = random.choice(specific_descriptions) description = random.choice(specific_descriptions)

View File

@@ -28,10 +28,10 @@ async def lifespan(app: FastAPI):
# Run database migrations # Run database migrations
try: try:
from leggen.services.database_service import DatabaseService from leggen.api.dependencies import get_migration_repository
db_service = DatabaseService() migrations = get_migration_repository()
await db_service.run_migrations_if_needed() await migrations.run_all_migrations()
logger.info("Database migrations completed") logger.info("Database migrations completed")
except Exception as e: except Exception as e:
logger.error(f"Database migration failed: {e}") logger.error(f"Database migration failed: {e}")

View File

@@ -61,17 +61,13 @@ def send_sync_failure_notification(ctx: click.Context, notification: dict):
info("Sending sync failure notification to Discord") info("Sending sync failure notification to Discord")
webhook = DiscordWebhook(url=ctx.obj["notifications"]["discord"]["webhook"]) webhook = DiscordWebhook(url=ctx.obj["notifications"]["discord"]["webhook"])
# Determine color and title based on failure type color = "ffaa00" # Orange for sync failure
if notification.get("type") == "sync_final_failure": title = "⚠️ Sync Failure"
color = "ff0000" # Red for final failure
title = "🚨 Sync Final Failure" # Build description with account info if available
description = ( description = "Account sync failed"
f"Sync failed permanently after {notification['retry_count']} attempts" if notification.get("account_id"):
) description = f"Account {notification['account_id']} sync failed"
else:
color = "ffaa00" # Orange for retry
title = "⚠️ Sync Failure"
description = f"Sync failed (attempt {notification['retry_count']}/{notification['max_retries']}). Will retry automatically..."
embed = DiscordEmbed( embed = DiscordEmbed(
title=title, title=title,

View File

@@ -87,19 +87,14 @@ def send_sync_failure_notification(ctx: click.Context, notification: dict):
bot_url = f"https://api.telegram.org/bot{token}/sendMessage" bot_url = f"https://api.telegram.org/bot{token}/sendMessage"
info("Sending sync failure notification to Telegram") info("Sending sync failure notification to Telegram")
message = "*🚨 [Leggen](https://github.com/elisiariocouto/leggen)*\n" message = "*⚠️ [Leggen](https://github.com/elisiariocouto/leggen)*\n"
message += "*Sync Failed*\n\n" message += "*Sync Failed*\n\n"
message += escape_markdown(f"Error: {notification['error']}\n")
if notification.get("type") == "sync_final_failure": # Add account info if available
message += escape_markdown( if notification.get("account_id"):
f"❌ Final failure after {notification['retry_count']} attempts\n" message += escape_markdown(f"Account: {notification['account_id']}\n")
)
else: message += escape_markdown(f"Error: {notification['error']}\n")
message += escape_markdown(
f"🔄 Attempt {notification['retry_count']}/{notification['max_retries']}\n"
)
message += escape_markdown("Will retry automatically...\n")
res = requests.post( res = requests.post(
bot_url, bot_url,

View File

@@ -0,0 +1,13 @@
from leggen.repositories.account_repository import AccountRepository
from leggen.repositories.balance_repository import BalanceRepository
from leggen.repositories.migration_repository import MigrationRepository
from leggen.repositories.sync_repository import SyncRepository
from leggen.repositories.transaction_repository import TransactionRepository
__all__ = [
"AccountRepository",
"BalanceRepository",
"MigrationRepository",
"SyncRepository",
"TransactionRepository",
]

View File

@@ -0,0 +1,128 @@
from typing import Any, Dict, List, Optional
from leggen.repositories.base_repository import BaseRepository
class AccountRepository(BaseRepository):
"""Repository for account data operations"""
def create_table(self):
"""Create accounts table with indexes"""
with self._get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"""CREATE TABLE IF NOT EXISTS accounts (
id TEXT PRIMARY KEY,
institution_id TEXT,
status TEXT,
iban TEXT,
name TEXT,
currency TEXT,
created DATETIME,
last_accessed DATETIME,
last_updated DATETIME,
display_name TEXT,
logo TEXT
)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_accounts_institution_id
ON accounts(institution_id)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_accounts_status
ON accounts(status)"""
)
conn.commit()
def persist(self, account_data: Dict[str, Any]) -> Dict[str, Any]:
"""Persist account details to database"""
self.create_table()
with self._get_db_connection() as conn:
cursor = conn.cursor()
# Check if account exists and preserve display_name
cursor.execute(
"SELECT display_name FROM accounts WHERE id = ?", (account_data["id"],)
)
existing_row = cursor.fetchone()
existing_display_name = existing_row[0] if existing_row else None
# Use existing display_name if not provided in account_data
display_name = account_data.get("display_name", existing_display_name)
cursor.execute(
"""INSERT OR REPLACE INTO accounts (
id,
institution_id,
status,
iban,
name,
currency,
created,
last_accessed,
last_updated,
display_name,
logo
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
(
account_data["id"],
account_data["institution_id"],
account_data["status"],
account_data.get("iban"),
account_data.get("name"),
account_data.get("currency"),
account_data["created"],
account_data.get("last_accessed"),
account_data.get("last_updated", account_data["created"]),
display_name,
account_data.get("logo"),
),
)
conn.commit()
return account_data
def get_accounts(
self, account_ids: Optional[List[str]] = None
) -> List[Dict[str, Any]]:
"""Get account details from database"""
if not self._db_exists():
return []
with self._get_db_connection(row_factory=True) as conn:
cursor = conn.cursor()
query = "SELECT * FROM accounts"
params = []
if account_ids:
placeholders = ",".join("?" * len(account_ids))
query += f" WHERE id IN ({placeholders})"
params.extend(account_ids)
query += " ORDER BY created DESC"
cursor.execute(query, params)
rows = cursor.fetchall()
return [dict(row) for row in rows]
def get_account(self, account_id: str) -> Optional[Dict[str, Any]]:
"""Get specific account details from database"""
if not self._db_exists():
return None
with self._get_db_connection(row_factory=True) as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM accounts WHERE id = ?", (account_id,))
row = cursor.fetchone()
if row:
return dict(row)
return None

View File

@@ -0,0 +1,107 @@
import sqlite3
from typing import Any, Dict, List, Optional
from loguru import logger
from leggen.repositories.base_repository import BaseRepository
class BalanceRepository(BaseRepository):
"""Repository for balance data operations"""
def create_table(self):
"""Create balances table with indexes"""
with self._get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"""CREATE TABLE IF NOT EXISTS balances (
id INTEGER PRIMARY KEY AUTOINCREMENT,
account_id TEXT,
bank TEXT,
status TEXT,
iban TEXT,
amount REAL,
currency TEXT,
type TEXT,
timestamp DATETIME
)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_balances_account_id
ON balances(account_id)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_balances_timestamp
ON balances(timestamp)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_balances_account_type_timestamp
ON balances(account_id, type, timestamp)"""
)
conn.commit()
def persist(self, account_id: str, balance_rows: List[tuple]) -> None:
"""Persist balance rows to database"""
try:
self.create_table()
with self._get_db_connection() as conn:
cursor = conn.cursor()
for row in balance_rows:
try:
cursor.execute(
"""INSERT INTO balances (
account_id,
bank,
status,
iban,
amount,
currency,
type,
timestamp
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
row,
)
except sqlite3.IntegrityError:
logger.warning(f"Skipped duplicate balance for {account_id}")
conn.commit()
logger.info(f"Persisted balances for account {account_id}")
except Exception as e:
logger.error(f"Failed to persist balances: {e}")
raise
def get_balances(self, account_id: Optional[str] = None) -> List[Dict[str, Any]]:
"""Get latest balances from database"""
if not self._db_exists():
return []
with self._get_db_connection(row_factory=True) as conn:
cursor = conn.cursor()
# Get latest balance for each account_id and type combination
query = """
SELECT * FROM balances b1
WHERE b1.timestamp = (
SELECT MAX(b2.timestamp)
FROM balances b2
WHERE b2.account_id = b1.account_id AND b2.type = b1.type
)
"""
params = []
if account_id:
query += " AND b1.account_id = ?"
params.append(account_id)
query += " ORDER BY b1.account_id, b1.type"
cursor.execute(query, params)
rows = cursor.fetchall()
return [dict(row) for row in rows]

View File

@@ -0,0 +1,28 @@
import sqlite3
from contextlib import contextmanager
from leggen.utils.paths import path_manager
class BaseRepository:
"""Base repository with shared database connection logic"""
@contextmanager
def _get_db_connection(self, row_factory: bool = False):
"""Context manager for database connections with proper cleanup"""
db_path = path_manager.get_database_path()
conn = sqlite3.connect(str(db_path))
if row_factory:
conn.row_factory = sqlite3.Row
try:
yield conn
except Exception as e:
conn.rollback()
raise e
finally:
conn.close()
def _db_exists(self) -> bool:
"""Check if database file exists"""
db_path = path_manager.get_database_path()
return db_path.exists()

View File

@@ -0,0 +1,626 @@
import sqlite3
import uuid
from datetime import datetime
from loguru import logger
from leggen.repositories.base_repository import BaseRepository
from leggen.utils.paths import path_manager
class MigrationRepository(BaseRepository):
"""Repository for database migrations"""
async def run_all_migrations(self):
"""Run all necessary database migrations"""
await self.migrate_balance_timestamps_if_needed()
await self.migrate_null_transaction_ids_if_needed()
await self.migrate_to_composite_key_if_needed()
await self.migrate_add_display_name_if_needed()
await self.migrate_add_sync_operations_if_needed()
await self.migrate_add_logo_if_needed()
# Balance timestamp migration methods
async def migrate_balance_timestamps_if_needed(self):
"""Check and migrate balance timestamps if needed"""
try:
if await self._check_balance_timestamp_migration_needed():
logger.info("Balance timestamp migration needed, starting...")
await self._migrate_balance_timestamps()
logger.info("Balance timestamp migration completed")
else:
logger.info("Balance timestamps are already consistent")
except Exception as e:
logger.error(f"Balance timestamp migration failed: {e}")
raise
async def _check_balance_timestamp_migration_needed(self) -> bool:
"""Check if balance timestamps need migration"""
db_path = path_manager.get_database_path()
if not db_path.exists():
return False
try:
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
cursor.execute("""
SELECT typeof(timestamp) as type, COUNT(*) as count
FROM balances
GROUP BY typeof(timestamp)
""")
types = cursor.fetchall()
conn.close()
type_names = [row[0] for row in types]
return "real" in type_names and "text" in type_names
except Exception as e:
logger.error(f"Failed to check migration status: {e}")
return False
async def _migrate_balance_timestamps(self):
"""Convert all Unix timestamps to datetime strings"""
db_path = path_manager.get_database_path()
if not db_path.exists():
logger.warning("Database file not found, skipping migration")
return
try:
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
cursor.execute("""
SELECT id, timestamp
FROM balances
WHERE typeof(timestamp) = 'real'
ORDER BY id
""")
unix_records = cursor.fetchall()
total_records = len(unix_records)
if total_records == 0:
logger.info("No Unix timestamps found to migrate")
conn.close()
return
logger.info(
f"Migrating {total_records} balance records from Unix to datetime format"
)
batch_size = 100
migrated_count = 0
for i in range(0, total_records, batch_size):
batch = unix_records[i : i + batch_size]
for record_id, unix_timestamp in batch:
try:
dt_string = self._unix_to_datetime_string(float(unix_timestamp))
cursor.execute(
"""
UPDATE balances
SET timestamp = ?
WHERE id = ?
""",
(dt_string, record_id),
)
migrated_count += 1
if migrated_count % 100 == 0:
logger.info(
f"Migrated {migrated_count}/{total_records} balance records"
)
except Exception as e:
logger.error(f"Failed to migrate record {record_id}: {e}")
continue
conn.commit()
conn.close()
logger.info(f"Successfully migrated {migrated_count} balance records")
except Exception as e:
logger.error(f"Balance timestamp migration failed: {e}")
raise
def _unix_to_datetime_string(self, unix_timestamp: float) -> str:
"""Convert Unix timestamp to datetime string"""
dt = datetime.fromtimestamp(unix_timestamp)
return dt.isoformat()
# Null transaction IDs migration methods
async def migrate_null_transaction_ids_if_needed(self):
"""Check and migrate null transaction IDs if needed"""
try:
if await self._check_null_transaction_ids_migration_needed():
logger.info("Null transaction IDs migration needed, starting...")
await self._migrate_null_transaction_ids()
logger.info("Null transaction IDs migration completed")
else:
logger.info("No null transaction IDs found to migrate")
except Exception as e:
logger.error(f"Null transaction IDs migration failed: {e}")
raise
async def _check_null_transaction_ids_migration_needed(self) -> bool:
"""Check if null transaction IDs need migration"""
db_path = path_manager.get_database_path()
if not db_path.exists():
return False
try:
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
cursor.execute("""
SELECT COUNT(*)
FROM transactions
WHERE (internalTransactionId IS NULL OR internalTransactionId = '')
AND json_extract(rawTransaction, '$.transactionId') IS NOT NULL
""")
count = cursor.fetchone()[0]
conn.close()
return count > 0
except Exception as e:
logger.error(f"Failed to check null transaction IDs migration status: {e}")
return False
async def _migrate_null_transaction_ids(self):
"""Populate null internalTransactionId fields using transactionId from raw data"""
db_path = path_manager.get_database_path()
if not db_path.exists():
logger.warning("Database file not found, skipping migration")
return
try:
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
cursor.execute("""
SELECT rowid, json_extract(rawTransaction, '$.transactionId') as transactionId
FROM transactions
WHERE (internalTransactionId IS NULL OR internalTransactionId = '')
AND json_extract(rawTransaction, '$.transactionId') IS NOT NULL
ORDER BY rowid
""")
null_records = cursor.fetchall()
total_records = len(null_records)
if total_records == 0:
logger.info("No null transaction IDs found to migrate")
conn.close()
return
logger.info(
f"Migrating {total_records} transaction records with null internalTransactionId"
)
batch_size = 100
migrated_count = 0
for i in range(0, total_records, batch_size):
batch = null_records[i : i + batch_size]
for rowid, transaction_id in batch:
try:
cursor.execute(
"SELECT COUNT(*) FROM transactions WHERE internalTransactionId = ?",
(str(transaction_id),),
)
existing_count = cursor.fetchone()[0]
if existing_count > 0:
unique_id = f"{str(transaction_id)}_{uuid.uuid4().hex[:8]}"
logger.debug(
f"Generated unique ID for duplicate transactionId: {unique_id}"
)
else:
unique_id = str(transaction_id)
cursor.execute(
"""
UPDATE transactions
SET internalTransactionId = ?
WHERE rowid = ?
""",
(unique_id, rowid),
)
migrated_count += 1
if migrated_count % 100 == 0:
logger.info(
f"Migrated {migrated_count}/{total_records} transaction records"
)
except Exception as e:
logger.error(f"Failed to migrate record {rowid}: {e}")
continue
conn.commit()
conn.close()
logger.info(f"Successfully migrated {migrated_count} transaction records")
except Exception as e:
logger.error(f"Null transaction IDs migration failed: {e}")
raise
# Composite key migration methods
async def migrate_to_composite_key_if_needed(self):
"""Check and migrate to composite primary key if needed"""
try:
if await self._check_composite_key_migration_needed():
logger.info("Composite key migration needed, starting...")
await self._migrate_to_composite_key()
logger.info("Composite key migration completed")
else:
logger.info("Composite key migration not needed")
except Exception as e:
logger.error(f"Composite key migration failed: {e}")
raise
async def _check_composite_key_migration_needed(self) -> bool:
"""Check if composite key migration is needed"""
db_path = path_manager.get_database_path()
if not db_path.exists():
return False
try:
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='transactions'"
)
if not cursor.fetchone():
conn.close()
return False
cursor.execute("PRAGMA table_info(transactions)")
columns = cursor.fetchall()
internal_transaction_id_is_pk = any(
col[1] == "internalTransactionId" and col[5] == 1 for col in columns
)
has_composite_key = any(
col[1] in ["accountId", "transactionId"] and col[5] == 1
for col in columns
)
conn.close()
return internal_transaction_id_is_pk or not has_composite_key
except Exception as e:
logger.error(f"Failed to check composite key migration status: {e}")
return False
async def _migrate_to_composite_key(self):
"""Migrate transactions table to use composite primary key (accountId, transactionId)"""
db_path = path_manager.get_database_path()
if not db_path.exists():
logger.warning("Database file not found, skipping migration")
return
try:
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
logger.info("Starting composite key migration...")
logger.info("Creating temporary table with composite primary key...")
cursor.execute("DROP TABLE IF EXISTS transactions_temp")
cursor.execute("""
CREATE TABLE transactions_temp (
accountId TEXT NOT NULL,
transactionId TEXT NOT NULL,
internalTransactionId TEXT,
institutionId TEXT,
iban TEXT,
transactionDate DATETIME,
description TEXT,
transactionValue REAL,
transactionCurrency TEXT,
transactionStatus TEXT,
rawTransaction JSON,
PRIMARY KEY (accountId, transactionId)
)
""")
logger.info("Inserting deduplicated data...")
cursor.execute("""
INSERT INTO transactions_temp
SELECT
accountId,
json_extract(rawTransaction, '$.transactionId') as transactionId,
internalTransactionId,
institutionId,
iban,
transactionDate,
description,
transactionValue,
transactionCurrency,
transactionStatus,
rawTransaction
FROM (
SELECT *,
ROW_NUMBER() OVER (
PARTITION BY accountId, json_extract(rawTransaction, '$.transactionId')
ORDER BY transactionDate DESC
) as rn
FROM transactions
WHERE json_extract(rawTransaction, '$.transactionId') IS NOT NULL
)
WHERE rn = 1
""")
rows_migrated = cursor.rowcount
logger.info(f"Migrated {rows_migrated} unique transactions")
logger.info("Replacing old table...")
cursor.execute("DROP TABLE transactions")
cursor.execute("ALTER TABLE transactions_temp RENAME TO transactions")
logger.info("Recreating indexes...")
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_transactions_internal_id
ON transactions(internalTransactionId)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_transactions_date
ON transactions(transactionDate)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_transactions_account_date
ON transactions(accountId, transactionDate)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_transactions_amount
ON transactions(transactionValue)"""
)
conn.commit()
conn.close()
logger.info("Composite key migration completed successfully")
except Exception as e:
logger.error(f"Composite key migration failed: {e}")
raise
# Display name migration methods
async def migrate_add_display_name_if_needed(self):
"""Check and add display_name column if needed"""
try:
if await self._check_display_name_migration_needed():
logger.info("Display name column migration needed, starting...")
await self._migrate_add_display_name()
logger.info("Display name column migration completed")
else:
logger.info("Display name column already exists")
except Exception as e:
logger.error(f"Display name column migration failed: {e}")
raise
async def _check_display_name_migration_needed(self) -> bool:
"""Check if display_name column needs to be added"""
db_path = path_manager.get_database_path()
if not db_path.exists():
return False
try:
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='accounts'"
)
if not cursor.fetchone():
conn.close()
return False
cursor.execute("PRAGMA table_info(accounts)")
columns = cursor.fetchall()
has_display_name = any(col[1] == "display_name" for col in columns)
conn.close()
return not has_display_name
except Exception as e:
logger.error(f"Failed to check display_name migration status: {e}")
return False
async def _migrate_add_display_name(self):
"""Add display_name column to accounts table"""
db_path = path_manager.get_database_path()
if not db_path.exists():
logger.warning("Database file not found, skipping migration")
return
try:
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
logger.info("Adding display_name column to accounts table...")
cursor.execute("""
ALTER TABLE accounts
ADD COLUMN display_name TEXT
""")
conn.commit()
conn.close()
logger.info("Display name column migration completed successfully")
except Exception as e:
logger.error(f"Display name column migration failed: {e}")
raise
# Sync operations migration methods
async def migrate_add_sync_operations_if_needed(self):
"""Check and add sync_operations table if needed"""
try:
if await self._check_sync_operations_migration_needed():
logger.info("Sync operations table migration needed, starting...")
await self._migrate_add_sync_operations()
logger.info("Sync operations table migration completed")
else:
logger.info("Sync operations table already exists")
except Exception as e:
logger.error(f"Sync operations table migration failed: {e}")
raise
async def _check_sync_operations_migration_needed(self) -> bool:
"""Check if sync_operations table needs to be created"""
db_path = path_manager.get_database_path()
if not db_path.exists():
return False
try:
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='sync_operations'"
)
table_exists = cursor.fetchone() is not None
conn.close()
return not table_exists
except Exception as e:
logger.error(f"Failed to check sync_operations migration status: {e}")
return False
async def _migrate_add_sync_operations(self):
"""Add sync_operations table"""
db_path = path_manager.get_database_path()
if not db_path.exists():
logger.warning("Database file not found, skipping migration")
return
try:
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
logger.info("Creating sync_operations table...")
cursor.execute("""
CREATE TABLE sync_operations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
started_at DATETIME NOT NULL,
completed_at DATETIME,
success BOOLEAN,
accounts_processed INTEGER DEFAULT 0,
transactions_added INTEGER DEFAULT 0,
transactions_updated INTEGER DEFAULT 0,
balances_updated INTEGER DEFAULT 0,
duration_seconds REAL,
errors TEXT,
logs TEXT,
trigger_type TEXT DEFAULT 'manual'
)
""")
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_sync_operations_started_at ON sync_operations(started_at)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_sync_operations_success ON sync_operations(success)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_sync_operations_trigger_type ON sync_operations(trigger_type)"
)
conn.commit()
conn.close()
logger.info("Sync operations table migration completed successfully")
except Exception as e:
logger.error(f"Sync operations table migration failed: {e}")
raise
# Logo migration methods
async def migrate_add_logo_if_needed(self):
"""Check and add logo column to accounts table if needed"""
try:
if await self._check_logo_migration_needed():
logger.info("Logo column migration needed, starting...")
await self._migrate_add_logo()
logger.info("Logo column migration completed")
else:
logger.info("Logo column already exists")
except Exception as e:
logger.error(f"Logo column migration failed: {e}")
raise
async def _check_logo_migration_needed(self) -> bool:
"""Check if logo column needs to be added to accounts table"""
db_path = path_manager.get_database_path()
if not db_path.exists():
return False
try:
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='accounts'"
)
if not cursor.fetchone():
conn.close()
return False
cursor.execute("PRAGMA table_info(accounts)")
columns = cursor.fetchall()
has_logo = any(col[1] == "logo" for col in columns)
conn.close()
return not has_logo
except Exception as e:
logger.error(f"Failed to check logo migration status: {e}")
return False
async def _migrate_add_logo(self):
"""Add logo column to accounts table"""
db_path = path_manager.get_database_path()
if not db_path.exists():
logger.warning("Database file not found, skipping migration")
return
try:
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
logger.info("Adding logo column to accounts table...")
cursor.execute("""
ALTER TABLE accounts
ADD COLUMN logo TEXT
""")
conn.commit()
conn.close()
logger.info("Logo column migration completed successfully")
except Exception as e:
logger.error(f"Logo column migration failed: {e}")
raise

View File

@@ -0,0 +1,132 @@
import json
import sqlite3
from typing import Any, Dict, List
from loguru import logger
from leggen.repositories.base_repository import BaseRepository
from leggen.utils.paths import path_manager
class SyncRepository(BaseRepository):
"""Repository for sync operation data"""
def create_table(self):
"""Create sync_operations table with indexes"""
with self._get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS sync_operations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
started_at DATETIME NOT NULL,
completed_at DATETIME,
success BOOLEAN,
accounts_processed INTEGER DEFAULT 0,
transactions_added INTEGER DEFAULT 0,
transactions_updated INTEGER DEFAULT 0,
balances_updated INTEGER DEFAULT 0,
duration_seconds REAL,
errors TEXT,
logs TEXT,
trigger_type TEXT DEFAULT 'manual'
)
""")
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_sync_operations_started_at ON sync_operations(started_at)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_sync_operations_success ON sync_operations(success)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_sync_operations_trigger_type ON sync_operations(trigger_type)"
)
conn.commit()
def persist(self, sync_operation: Dict[str, Any]) -> int:
"""Persist sync operation to database and return the ID"""
try:
self.create_table()
db_path = path_manager.get_database_path()
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
cursor.execute(
"""INSERT INTO sync_operations (
started_at, completed_at, success, accounts_processed,
transactions_added, transactions_updated, balances_updated,
duration_seconds, errors, logs, trigger_type
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
(
sync_operation.get("started_at"),
sync_operation.get("completed_at"),
sync_operation.get("success"),
sync_operation.get("accounts_processed", 0),
sync_operation.get("transactions_added", 0),
sync_operation.get("transactions_updated", 0),
sync_operation.get("balances_updated", 0),
sync_operation.get("duration_seconds"),
json.dumps(sync_operation.get("errors", [])),
json.dumps(sync_operation.get("logs", [])),
sync_operation.get("trigger_type", "manual"),
),
)
operation_id = cursor.lastrowid
if operation_id is None:
raise ValueError("Failed to get operation ID after insert")
conn.commit()
conn.close()
logger.debug(f"Persisted sync operation with ID: {operation_id}")
return operation_id
except Exception as e:
logger.error(f"Failed to persist sync operation: {e}")
raise
def get_operations(self, limit: int = 50, offset: int = 0) -> List[Dict[str, Any]]:
"""Get sync operations from database"""
try:
db_path = path_manager.get_database_path()
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
cursor.execute(
"""SELECT id, started_at, completed_at, success, accounts_processed,
transactions_added, transactions_updated, balances_updated,
duration_seconds, errors, logs, trigger_type
FROM sync_operations
ORDER BY started_at DESC
LIMIT ? OFFSET ?""",
(limit, offset),
)
operations = []
for row in cursor.fetchall():
operation = {
"id": row[0],
"started_at": row[1],
"completed_at": row[2],
"success": bool(row[3]) if row[3] is not None else None,
"accounts_processed": row[4],
"transactions_added": row[5],
"transactions_updated": row[6],
"balances_updated": row[7],
"duration_seconds": row[8],
"errors": json.loads(row[9]) if row[9] else [],
"logs": json.loads(row[10]) if row[10] else [],
"trigger_type": row[11],
}
operations.append(operation)
conn.close()
return operations
except Exception as e:
logger.error(f"Failed to get sync operations: {e}")
return []

View File

@@ -0,0 +1,264 @@
import json
import sqlite3
from typing import Any, Dict, List, Optional, Union
from loguru import logger
from leggen.repositories.base_repository import BaseRepository
class TransactionRepository(BaseRepository):
"""Repository for transaction data operations"""
def create_table(self):
"""Create transactions table with indexes"""
with self._get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"""CREATE TABLE IF NOT EXISTS transactions (
accountId TEXT NOT NULL,
transactionId TEXT NOT NULL,
internalTransactionId TEXT,
institutionId TEXT,
iban TEXT,
transactionDate DATETIME,
description TEXT,
transactionValue REAL,
transactionCurrency TEXT,
transactionStatus TEXT,
rawTransaction JSON,
PRIMARY KEY (accountId, transactionId)
)"""
)
# Create indexes for better performance
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_transactions_internal_id
ON transactions(internalTransactionId)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_transactions_date
ON transactions(transactionDate)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_transactions_account_date
ON transactions(accountId, transactionDate)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_transactions_amount
ON transactions(transactionValue)"""
)
conn.commit()
def persist(
self, account_id: str, transactions: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Persist transactions to database, return new ones"""
try:
self.create_table()
with self._get_db_connection() as conn:
cursor = conn.cursor()
insert_sql = """INSERT OR REPLACE INTO transactions (
accountId,
transactionId,
internalTransactionId,
institutionId,
iban,
transactionDate,
description,
transactionValue,
transactionCurrency,
transactionStatus,
rawTransaction
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"""
new_transactions = []
for transaction in transactions:
try:
# Check if transaction already exists
cursor.execute(
"""SELECT COUNT(*) FROM transactions
WHERE accountId = ? AND transactionId = ?""",
(transaction["accountId"], transaction["transactionId"]),
)
exists = cursor.fetchone()[0] > 0
cursor.execute(
insert_sql,
(
transaction["accountId"],
transaction["transactionId"],
transaction.get("internalTransactionId"),
transaction["institutionId"],
transaction["iban"],
transaction["transactionDate"],
transaction["description"],
transaction["transactionValue"],
transaction["transactionCurrency"],
transaction["transactionStatus"],
json.dumps(transaction["rawTransaction"]),
),
)
if not exists:
new_transactions.append(transaction)
except sqlite3.IntegrityError as e:
logger.warning(
f"Failed to insert transaction {transaction.get('transactionId')}: {e}"
)
continue
conn.commit()
logger.info(
f"Persisted {len(new_transactions)} new transactions for account {account_id}"
)
return new_transactions
except Exception as e:
logger.error(f"Failed to persist transactions: {e}")
raise
def get_transactions(
self,
account_id: Optional[str] = None,
limit: Optional[int] = 100,
offset: int = 0,
date_from: Optional[str] = None,
date_to: Optional[str] = None,
min_amount: Optional[float] = None,
max_amount: Optional[float] = None,
search: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""Get transactions with optional filtering"""
if not self._db_exists():
return []
with self._get_db_connection(row_factory=True) as conn:
cursor = conn.cursor()
query = "SELECT * FROM transactions WHERE 1=1"
params: List[Union[str, int, float]] = []
if account_id:
query += " AND accountId = ?"
params.append(account_id)
if date_from:
query += " AND transactionDate >= ?"
params.append(date_from)
if date_to:
query += " AND transactionDate <= ?"
params.append(date_to)
if min_amount is not None:
query += " AND transactionValue >= ?"
params.append(min_amount)
if max_amount is not None:
query += " AND transactionValue <= ?"
params.append(max_amount)
if search:
query += " AND description LIKE ?"
params.append(f"%{search}%")
query += " ORDER BY transactionDate DESC"
if limit:
query += " LIMIT ?"
params.append(limit)
if offset:
query += " OFFSET ?"
params.append(offset)
cursor.execute(query, params)
rows = cursor.fetchall()
transactions = []
for row in rows:
transaction = dict(row)
if transaction["rawTransaction"]:
transaction["rawTransaction"] = json.loads(
transaction["rawTransaction"]
)
transactions.append(transaction)
return transactions
def get_count(
self,
account_id: Optional[str] = None,
date_from: Optional[str] = None,
date_to: Optional[str] = None,
min_amount: Optional[float] = None,
max_amount: Optional[float] = None,
search: Optional[str] = None,
) -> int:
"""Get total count of transactions matching filters"""
if not self._db_exists():
return 0
with self._get_db_connection() as conn:
cursor = conn.cursor()
query = "SELECT COUNT(*) FROM transactions WHERE 1=1"
params: List[Union[str, float]] = []
if account_id:
query += " AND accountId = ?"
params.append(account_id)
if date_from:
query += " AND transactionDate >= ?"
params.append(date_from)
if date_to:
query += " AND transactionDate <= ?"
params.append(date_to)
if min_amount is not None:
query += " AND transactionValue >= ?"
params.append(min_amount)
if max_amount is not None:
query += " AND transactionValue <= ?"
params.append(max_amount)
if search:
query += " AND description LIKE ?"
params.append(f"%{search}%")
cursor.execute(query, params)
return cursor.fetchone()[0]
def get_account_summary(self, account_id: str) -> Optional[Dict[str, Any]]:
"""Get basic account info from transactions table"""
if not self._db_exists():
return None
with self._get_db_connection(row_factory=True) as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT DISTINCT accountId, institutionId, iban
FROM transactions
WHERE accountId = ?
ORDER BY transactionDate DESC
LIMIT 1
""",
(account_id,),
)
row = cursor.fetchone()
if row:
return dict(row)
return None

View File

@@ -0,0 +1,13 @@
"""Data processing layer for all transformation logic."""
from leggen.services.data_processors.account_enricher import AccountEnricher
from leggen.services.data_processors.analytics_processor import AnalyticsProcessor
from leggen.services.data_processors.balance_transformer import BalanceTransformer
from leggen.services.data_processors.transaction_processor import TransactionProcessor
__all__ = [
"AccountEnricher",
"AnalyticsProcessor",
"BalanceTransformer",
"TransactionProcessor",
]

View File

@@ -0,0 +1,71 @@
"""Account enrichment processor for adding currency, logos, and metadata."""
from typing import Any, Dict
from loguru import logger
from leggen.services.gocardless_service import GoCardlessService
class AccountEnricher:
"""Enriches account details with currency and institution information."""
def __init__(self):
self.gocardless = GoCardlessService()
async def enrich_account_details(
self,
account_details: Dict[str, Any],
balances: Dict[str, Any],
) -> Dict[str, Any]:
"""
Enrich account details with currency from balances and institution logo.
Args:
account_details: Raw account details from GoCardless
balances: Balance data containing currency information
Returns:
Enriched account details with currency and logo added
"""
enriched_account = account_details.copy()
# Extract currency from first balance
currency = self._extract_currency_from_balances(balances)
if currency:
enriched_account["currency"] = currency
# Fetch and add institution logo
institution_id = enriched_account.get("institution_id")
if institution_id:
logo = await self._fetch_institution_logo(institution_id)
if logo:
enriched_account["logo"] = logo
return enriched_account
def _extract_currency_from_balances(self, balances: Dict[str, Any]) -> str | None:
"""Extract currency from the first balance in the balances data."""
balances_list = balances.get("balances", [])
if not balances_list:
return None
first_balance = balances_list[0]
balance_amount = first_balance.get("balanceAmount", {})
return balance_amount.get("currency")
async def _fetch_institution_logo(self, institution_id: str) -> str | None:
"""Fetch institution logo from GoCardless API."""
try:
institution_details = await self.gocardless.get_institution_details(
institution_id
)
logo = institution_details.get("logo", "")
if logo:
logger.info(f"Fetched logo for institution {institution_id}: {logo}")
return logo
except Exception as e:
logger.warning(
f"Failed to fetch institution details for {institution_id}: {e}"
)
return None

View File

@@ -0,0 +1,201 @@
"""Analytics processor for calculating historical balances and statistics."""
import sqlite3
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Dict, List, Optional
from loguru import logger
class AnalyticsProcessor:
"""Calculates historical balances and transaction statistics from database data."""
def calculate_historical_balances(
self,
db_path: Path,
account_id: Optional[str] = None,
days: int = 365,
) -> List[Dict[str, Any]]:
"""
Generate historical balance progression based on transaction history.
Uses current balances and subtracts future transactions to calculate
balance at each historical point in time.
Args:
db_path: Path to SQLite database
account_id: Optional account ID to filter by
days: Number of days to look back (default 365)
Returns:
List of historical balance data points
"""
if not db_path.exists():
return []
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
try:
cutoff_date = (datetime.now() - timedelta(days=days)).date().isoformat()
today_date = datetime.now().date().isoformat()
# Single SQL query to generate historical balances using window functions
query = """
WITH RECURSIVE date_series AS (
-- Generate weekly dates from cutoff_date to today
SELECT date(?) as ref_date
UNION ALL
SELECT date(ref_date, '+7 days')
FROM date_series
WHERE ref_date < date(?)
),
current_balances AS (
-- Get current balance for each account/type
SELECT account_id, type, amount, currency
FROM balances b1
WHERE b1.timestamp = (
SELECT MAX(b2.timestamp)
FROM balances b2
WHERE b2.account_id = b1.account_id AND b2.type = b1.type
)
{account_filter}
AND b1.type = 'closingBooked' -- Focus on closingBooked for charts
),
historical_points AS (
-- Calculate balance at each weekly point by subtracting future transactions
SELECT
cb.account_id,
cb.type as balance_type,
cb.currency,
ds.ref_date,
cb.amount - COALESCE(
(SELECT SUM(t.transactionValue)
FROM transactions t
WHERE t.accountId = cb.account_id
AND date(t.transactionDate) > ds.ref_date), 0
) as balance_amount
FROM current_balances cb
CROSS JOIN date_series ds
)
SELECT
account_id || '_' || balance_type || '_' || ref_date as id,
account_id,
balance_amount,
balance_type,
currency,
ref_date as reference_date
FROM historical_points
ORDER BY account_id, ref_date
"""
# Build parameters and account filter
params = [cutoff_date, today_date]
if account_id:
account_filter = "AND b1.account_id = ?"
params.append(account_id)
else:
account_filter = ""
# Format the query with conditional filter
formatted_query = query.format(account_filter=account_filter)
cursor.execute(formatted_query, params)
rows = cursor.fetchall()
conn.close()
return [dict(row) for row in rows]
except Exception as e:
conn.close()
logger.error(f"Failed to calculate historical balances: {e}")
raise
def calculate_monthly_stats(
self,
db_path: Path,
account_id: Optional[str] = None,
date_from: Optional[str] = None,
date_to: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""
Calculate monthly transaction statistics aggregated from database.
Sums transactions by month and calculates income, expenses, and net values.
Args:
db_path: Path to SQLite database
account_id: Optional account ID to filter by
date_from: Optional start date (ISO format)
date_to: Optional end date (ISO format)
Returns:
List of monthly statistics with income, expenses, and net totals
"""
if not db_path.exists():
return []
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
try:
# SQL query to aggregate transactions by month
query = """
SELECT
strftime('%Y-%m', transactionDate) as month,
COALESCE(SUM(CASE WHEN transactionValue > 0 THEN transactionValue ELSE 0 END), 0) as income,
COALESCE(SUM(CASE WHEN transactionValue < 0 THEN ABS(transactionValue) ELSE 0 END), 0) as expenses,
COALESCE(SUM(transactionValue), 0) as net
FROM transactions
WHERE 1=1
"""
params = []
if account_id:
query += " AND accountId = ?"
params.append(account_id)
if date_from:
query += " AND transactionDate >= ?"
params.append(date_from)
if date_to:
query += " AND transactionDate <= ?"
params.append(date_to)
query += """
GROUP BY strftime('%Y-%m', transactionDate)
ORDER BY month ASC
"""
cursor.execute(query, params)
rows = cursor.fetchall()
# Convert to desired format with proper month display
monthly_stats = []
for row in rows:
# Convert YYYY-MM to display format like "Mar 2024"
year, month_num = row["month"].split("-")
month_date = datetime.strptime(f"{year}-{month_num}-01", "%Y-%m-%d")
display_month = month_date.strftime("%b %Y")
monthly_stats.append(
{
"month": display_month,
"income": round(row["income"], 2),
"expenses": round(row["expenses"], 2),
"net": round(row["net"], 2),
}
)
conn.close()
return monthly_stats
except Exception as e:
conn.close()
logger.error(f"Failed to calculate monthly stats: {e}")
raise

View File

@@ -0,0 +1,69 @@
"""Balance data transformation processor for format conversions."""
from datetime import datetime
from typing import Any, Dict, List, Tuple
class BalanceTransformer:
"""Transforms balance data between GoCardless and internal database formats."""
def merge_account_metadata_into_balances(
self,
balances: Dict[str, Any],
account_details: Dict[str, Any],
) -> Dict[str, Any]:
"""
Merge account metadata into balance data for proper persistence.
This adds institution_id, iban, and account_status to the balances
so they can be persisted alongside the balance data.
Args:
balances: Raw balance data from GoCardless
account_details: Enriched account details containing metadata
Returns:
Balance data with account metadata merged in
"""
balances_with_metadata = balances.copy()
balances_with_metadata["institution_id"] = account_details.get("institution_id")
balances_with_metadata["iban"] = account_details.get("iban")
balances_with_metadata["account_status"] = account_details.get("status")
return balances_with_metadata
def transform_to_database_format(
self,
account_id: str,
balance_data: Dict[str, Any],
) -> List[Tuple[Any, ...]]:
"""
Transform GoCardless balance format to database row format.
Converts nested GoCardless balance structure into flat tuples
ready for SQLite insertion.
Args:
account_id: The account ID
balance_data: Balance data with merged account metadata
Returns:
List of tuples in database row format (account_id, bank, status, ...)
"""
rows = []
for balance in balance_data.get("balances", []):
balance_amount = balance.get("balanceAmount", {})
row = (
account_id,
balance_data.get("institution_id", "unknown"),
balance_data.get("account_status"),
balance_data.get("iban", "N/A"),
float(balance_amount.get("amount", 0)),
balance_amount.get("currency"),
balance.get("balanceType"),
datetime.now().isoformat(),
)
rows.append(row)
return rows

File diff suppressed because it is too large Load Diff

View File

@@ -52,11 +52,17 @@ class NotificationService:
async def send_expiry_notification(self, notification_data: Dict[str, Any]) -> None: async def send_expiry_notification(self, notification_data: Dict[str, Any]) -> None:
"""Send notification about account expiry""" """Send notification about account expiry"""
if self._is_discord_enabled(): try:
await self._send_discord_expiry(notification_data) if self._is_discord_enabled():
await self._send_discord_expiry(notification_data)
except Exception as e:
logger.error(f"Failed to send Discord expiry notification: {e}")
if self._is_telegram_enabled(): try:
await self._send_telegram_expiry(notification_data) if self._is_telegram_enabled():
await self._send_telegram_expiry(notification_data)
except Exception as e:
logger.error(f"Failed to send Telegram expiry notification: {e}")
def _filter_transactions( def _filter_transactions(
self, transactions: List[Dict[str, Any]] self, transactions: List[Dict[str, Any]]
@@ -262,7 +268,6 @@ class NotificationService:
logger.info(f"Sent Discord expiry notification: {notification_data}") logger.info(f"Sent Discord expiry notification: {notification_data}")
except Exception as e: except Exception as e:
logger.error(f"Failed to send Discord expiry notification: {e}") logger.error(f"Failed to send Discord expiry notification: {e}")
raise
async def _send_telegram_expiry(self, notification_data: Dict[str, Any]) -> None: async def _send_telegram_expiry(self, notification_data: Dict[str, Any]) -> None:
"""Send Telegram expiry notification""" """Send Telegram expiry notification"""
@@ -288,17 +293,22 @@ class NotificationService:
logger.info(f"Sent Telegram expiry notification: {notification_data}") logger.info(f"Sent Telegram expiry notification: {notification_data}")
except Exception as e: except Exception as e:
logger.error(f"Failed to send Telegram expiry notification: {e}") logger.error(f"Failed to send Telegram expiry notification: {e}")
raise
async def send_sync_failure_notification( async def send_sync_failure_notification(
self, notification_data: Dict[str, Any] self, notification_data: Dict[str, Any]
) -> None: ) -> None:
"""Send notification about sync failure""" """Send notification about sync failure"""
if self._is_discord_enabled(): try:
await self._send_discord_sync_failure(notification_data) if self._is_discord_enabled():
await self._send_discord_sync_failure(notification_data)
except Exception as e:
logger.error(f"Failed to send Discord sync failure notification: {e}")
if self._is_telegram_enabled(): try:
await self._send_telegram_sync_failure(notification_data) if self._is_telegram_enabled():
await self._send_telegram_sync_failure(notification_data)
except Exception as e:
logger.error(f"Failed to send Telegram sync failure notification: {e}")
async def _send_discord_sync_failure( async def _send_discord_sync_failure(
self, notification_data: Dict[str, Any] self, notification_data: Dict[str, Any]
@@ -326,7 +336,6 @@ class NotificationService:
logger.info(f"Sent Discord sync failure notification: {notification_data}") logger.info(f"Sent Discord sync failure notification: {notification_data}")
except Exception as e: except Exception as e:
logger.error(f"Failed to send Discord sync failure notification: {e}") logger.error(f"Failed to send Discord sync failure notification: {e}")
raise
async def _send_telegram_sync_failure( async def _send_telegram_sync_failure(
self, notification_data: Dict[str, Any] self, notification_data: Dict[str, Any]
@@ -354,4 +363,3 @@ class NotificationService:
logger.info(f"Sent Telegram sync failure notification: {notification_data}") logger.info(f"Sent Telegram sync failure notification: {notification_data}")
except Exception as e: except Exception as e:
logger.error(f"Failed to send Telegram sync failure notification: {e}") logger.error(f"Failed to send Telegram sync failure notification: {e}")
raise

View File

@@ -4,18 +4,41 @@ from typing import List
from loguru import logger from loguru import logger
from leggen.api.models.sync import SyncResult, SyncStatus from leggen.api.models.sync import SyncResult, SyncStatus
from leggen.services.database_service import DatabaseService from leggen.repositories import (
AccountRepository,
BalanceRepository,
SyncRepository,
TransactionRepository,
)
from leggen.services.data_processors import (
AccountEnricher,
BalanceTransformer,
TransactionProcessor,
)
from leggen.services.gocardless_service import GoCardlessService from leggen.services.gocardless_service import GoCardlessService
from leggen.services.notification_service import NotificationService from leggen.services.notification_service import NotificationService
# Constants for notification
EXPIRED_DAYS_LEFT = 0
class SyncService: class SyncService:
def __init__(self): def __init__(self):
self.gocardless = GoCardlessService() self.gocardless = GoCardlessService()
self.database = DatabaseService()
self.notifications = NotificationService() self.notifications = NotificationService()
# Repositories
self.accounts = AccountRepository()
self.balances = BalanceRepository()
self.transactions = TransactionRepository()
self.sync = SyncRepository()
# Data processors
self.account_enricher = AccountEnricher()
self.balance_transformer = BalanceTransformer()
self.transaction_processor = TransactionProcessor()
self._sync_status = SyncStatus(is_running=False) self._sync_status = SyncStatus(is_running=False)
self._institution_logos = {} # Cache for institution logos
async def get_sync_status(self) -> SyncStatus: async def get_sync_status(self) -> SyncStatus:
"""Get current sync status""" """Get current sync status"""
@@ -67,6 +90,9 @@ class SyncService:
self._sync_status.total_accounts = len(all_accounts) self._sync_status.total_accounts = len(all_accounts)
logs.append(f"Found {len(all_accounts)} accounts to sync") logs.append(f"Found {len(all_accounts)} accounts to sync")
# Check for expired or expiring requisitions
await self._check_requisition_expiry(requisitions.get("results", []))
# Process each account # Process each account
for account_id in all_accounts: for account_id in all_accounts:
try: try:
@@ -78,72 +104,44 @@ class SyncService:
# Get balances to extract currency information # Get balances to extract currency information
balances = await self.gocardless.get_account_balances(account_id) balances = await self.gocardless.get_account_balances(account_id)
# Enrich account details with currency and institution logo # Enrich and persist account details
if account_details and balances: if account_details and balances:
enriched_account_details = account_details.copy() # Enrich account with currency and institution logo
enriched_account_details = (
# Extract currency from first balance await self.account_enricher.enrich_account_details(
balances_list = balances.get("balances", []) account_details, balances
if balances_list: )
first_balance = balances_list[0] )
balance_amount = first_balance.get("balanceAmount", {})
currency = balance_amount.get("currency")
if currency:
enriched_account_details["currency"] = currency
# Get institution details to fetch logo
institution_id = enriched_account_details.get("institution_id")
if institution_id:
try:
institution_details = (
await self.gocardless.get_institution_details(
institution_id
)
)
enriched_account_details["logo"] = (
institution_details.get("logo", "")
)
logger.info(
f"Fetched logo for institution {institution_id}: {enriched_account_details.get('logo', 'No logo')}"
)
except Exception as e:
logger.warning(
f"Failed to fetch institution details for {institution_id}: {e}"
)
# Persist enriched account details to database # Persist enriched account details to database
await self.database.persist_account_details( self.accounts.persist(enriched_account_details)
enriched_account_details
)
# Merge account details into balances data for proper persistence # Merge account metadata into balances for persistence
balances_with_account_info = balances.copy() balances_with_account_info = self.balance_transformer.merge_account_metadata_into_balances(
balances_with_account_info["institution_id"] = ( balances, enriched_account_details
enriched_account_details.get("institution_id")
) )
balances_with_account_info["iban"] = ( balance_rows = (
enriched_account_details.get("iban") self.balance_transformer.transform_to_database_format(
) account_id, balances_with_account_info
balances_with_account_info["account_status"] = ( )
enriched_account_details.get("status")
)
await self.database.persist_balance(
account_id, balances_with_account_info
) )
self.balances.persist(account_id, balance_rows)
balances_updated += len(balances.get("balances", [])) balances_updated += len(balances.get("balances", []))
elif account_details: elif account_details:
# Fallback: persist account details without currency if balances failed # Fallback: persist account details without currency if balances failed
await self.database.persist_account_details(account_details) self.accounts.persist(account_details)
# Get and save transactions # Get and save transactions
transactions = await self.gocardless.get_account_transactions( transactions = await self.gocardless.get_account_transactions(
account_id account_id
) )
if transactions: if transactions:
processed_transactions = self.database.process_transactions( processed_transactions = (
account_id, account_details, transactions self.transaction_processor.process_transactions(
account_id, account_details, transactions
)
) )
new_transactions = await self.database.persist_transactions( new_transactions = self.transactions.persist(
account_id, processed_transactions account_id, processed_transactions
) )
transactions_added += len(new_transactions) transactions_added += len(new_transactions)
@@ -166,6 +164,15 @@ class SyncService:
logger.error(error_msg) logger.error(error_msg)
logs.append(error_msg) logs.append(error_msg)
# Send notification for account sync failure
await self.notifications.send_sync_failure_notification(
{
"account_id": account_id,
"error": error_msg,
"type": "account_sync_failure",
}
)
end_time = datetime.now() end_time = datetime.now()
duration = (end_time - start_time).total_seconds() duration = (end_time - start_time).total_seconds()
@@ -188,9 +195,7 @@ class SyncService:
# Persist sync operation to database # Persist sync operation to database
try: try:
operation_id = await self.database.persist_sync_operation( operation_id = self.sync.persist(sync_operation)
sync_operation
)
logger.debug(f"Saved sync operation with ID: {operation_id}") logger.debug(f"Saved sync operation with ID: {operation_id}")
except Exception as e: except Exception as e:
logger.error(f"Failed to persist sync operation: {e}") logger.error(f"Failed to persist sync operation: {e}")
@@ -239,9 +244,7 @@ class SyncService:
) )
try: try:
operation_id = await self.database.persist_sync_operation( operation_id = self.sync.persist(sync_operation)
sync_operation
)
logger.debug(f"Saved failed sync operation with ID: {operation_id}") logger.debug(f"Saved failed sync operation with ID: {operation_id}")
except Exception as persist_error: except Exception as persist_error:
logger.error( logger.error(
@@ -252,6 +255,31 @@ class SyncService:
finally: finally:
self._sync_status.is_running = False self._sync_status.is_running = False
async def _check_requisition_expiry(self, requisitions: List[dict]) -> None:
"""Check requisitions for expiry and send notifications.
Args:
requisitions: List of requisition dictionaries to check
"""
for req in requisitions:
requisition_id = req.get("id", "unknown")
institution_id = req.get("institution_id", "unknown")
status = req.get("status", "")
# Check if requisition is expired
if status == "EX":
logger.warning(
f"Requisition {requisition_id} for {institution_id} has expired"
)
await self.notifications.send_expiry_notification(
{
"bank": institution_id,
"requisition_id": requisition_id,
"status": "expired",
"days_left": EXPIRED_DAYS_LEFT,
}
)
async def sync_specific_accounts( async def sync_specific_accounts(
self, account_ids: List[str], force: bool = False, trigger_type: str = "manual" self, account_ids: List[str], force: bool = False, trigger_type: str = "manual"
) -> SyncResult: ) -> SyncResult:

View File

@@ -8,8 +8,10 @@ from pathlib import Path
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
import tomli_w
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from leggen.commands.server import create_app
from leggen.utils.config import Config from leggen.utils.config import Config
# Create test config before any imports that might load it # Create test config before any imports that might load it
@@ -27,15 +29,12 @@ _config_data = {
"scheduler": {"sync": {"enabled": True, "hour": 3, "minute": 0}}, "scheduler": {"sync": {"enabled": True, "hour": 3, "minute": 0}},
} }
import tomli_w
with open(_test_config_path, "wb") as f: with open(_test_config_path, "wb") as f:
tomli_w.dump(_config_data, f) tomli_w.dump(_config_data, f)
# Set environment variables to point to test config BEFORE importing the app # Set environment variables to point to test config BEFORE importing the app
os.environ["LEGGEN_CONFIG_FILE"] = str(_test_config_path) os.environ["LEGGEN_CONFIG_FILE"] = str(_test_config_path)
from leggen.commands.server import create_app
def pytest_configure(config): def pytest_configure(config):
"""Pytest hook called before test collection.""" """Pytest hook called before test collection."""
@@ -114,7 +113,9 @@ def mock_auth_token(temp_config_dir):
def fastapi_app(mock_db_path): def fastapi_app(mock_db_path):
"""Create FastAPI test application.""" """Create FastAPI test application."""
# Patch the database path for the app # Patch the database path for the app
with patch("leggen.utils.paths.path_manager.get_database_path", return_value=mock_db_path): with patch(
"leggen.utils.paths.path_manager.get_database_path", return_value=mock_db_path
):
app = create_app() app = create_app()
yield app yield app
@@ -125,6 +126,38 @@ def api_client(fastapi_app):
return TestClient(fastapi_app) return TestClient(fastapi_app)
@pytest.fixture
def mock_account_repo():
"""Create mock AccountRepository for testing."""
from unittest.mock import MagicMock
return MagicMock()
@pytest.fixture
def mock_balance_repo():
"""Create mock BalanceRepository for testing."""
from unittest.mock import MagicMock
return MagicMock()
@pytest.fixture
def mock_transaction_repo():
"""Create mock TransactionRepository for testing."""
from unittest.mock import MagicMock
return MagicMock()
@pytest.fixture
def mock_analytics_proc():
"""Create mock AnalyticsProcessor for testing."""
from unittest.mock import MagicMock
return MagicMock()
@pytest.fixture @pytest.fixture
def mock_db_path(temp_db_path): def mock_db_path(temp_db_path):
"""Mock the database path to use temporary database for testing.""" """Mock the database path to use temporary database for testing."""

View File

@@ -1,13 +1,13 @@
"""Tests for analytics fixes to ensure all transactions are used in statistics.""" """Tests for analytics fixes to ensure all transactions are used in statistics."""
from datetime import datetime, timedelta from datetime import datetime, timedelta
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import Mock
import pytest import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from leggen.api.dependencies import get_transaction_repository
from leggen.commands.server import create_app from leggen.commands.server import create_app
from leggen.services.database_service import DatabaseService
class TestAnalyticsFix: class TestAnalyticsFix:
@@ -19,11 +19,11 @@ class TestAnalyticsFix:
return TestClient(app) return TestClient(app)
@pytest.fixture @pytest.fixture
def mock_database_service(self): def mock_transaction_repo(self):
return Mock(spec=DatabaseService) return Mock()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_transaction_stats_uses_all_transactions(self, mock_database_service): async def test_transaction_stats_uses_all_transactions(self, mock_transaction_repo):
"""Test that transaction stats endpoint uses all transactions (not limited to 100)""" """Test that transaction stats endpoint uses all transactions (not limited to 100)"""
# Mock data for 600 transactions (simulating the issue) # Mock data for 600 transactions (simulating the issue)
mock_transactions = [] mock_transactions = []
@@ -42,53 +42,50 @@ class TestAnalyticsFix:
} }
) )
mock_database_service.get_transactions_from_db = AsyncMock( mock_transaction_repo.get_transactions.return_value = mock_transactions
return_value=mock_transactions
app = create_app()
app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
client = TestClient(app)
response = client.get("/api/v1/transactions/stats?days=365")
assert response.status_code == 200
data = response.json()
# Verify that limit=None was passed to get all transactions
mock_transaction_repo.get_transactions.assert_called_once()
call_args = mock_transaction_repo.get_transactions.call_args
assert call_args.kwargs.get("limit") is None, (
"Stats endpoint should pass limit=None to get all transactions"
) )
# Test that the endpoint calls get_transactions_from_db with limit=None # Verify that the response contains stats for all 600 transactions
with patch( stats = data
"leggen.api.routes.transactions.database_service", mock_database_service assert stats["total_transactions"] == 600, (
): "Should process all 600 transactions, not just 100"
app = create_app() )
client = TestClient(app)
response = client.get("/api/v1/transactions/stats?days=365") # Verify calculations are correct for all transactions
expected_income = sum(
txn["transactionValue"]
for txn in mock_transactions
if txn["transactionValue"] > 0
)
expected_expenses = sum(
abs(txn["transactionValue"])
for txn in mock_transactions
if txn["transactionValue"] < 0
)
assert response.status_code == 200 assert stats["total_income"] == expected_income
data = response.json() assert stats["total_expenses"] == expected_expenses
# Verify that limit=None was passed to get all transactions
mock_database_service.get_transactions_from_db.assert_called_once()
call_args = mock_database_service.get_transactions_from_db.call_args
assert call_args.kwargs.get("limit") is None, (
"Stats endpoint should pass limit=None to get all transactions"
)
# Verify that the response contains stats for all 600 transactions
stats = data
assert stats["total_transactions"] == 600, (
"Should process all 600 transactions, not just 100"
)
# Verify calculations are correct for all transactions
expected_income = sum(
txn["transactionValue"]
for txn in mock_transactions
if txn["transactionValue"] > 0
)
expected_expenses = sum(
abs(txn["transactionValue"])
for txn in mock_transactions
if txn["transactionValue"] < 0
)
assert stats["total_income"] == expected_income
assert stats["total_expenses"] == expected_expenses
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_analytics_endpoint_returns_all_transactions( async def test_analytics_endpoint_returns_all_transactions(
self, mock_database_service self, mock_transaction_repo
): ):
"""Test that the new analytics endpoint returns all transactions without pagination""" """Test that the new analytics endpoint returns all transactions without pagination"""
# Mock data for 600 transactions # Mock data for 600 transactions
@@ -108,30 +105,28 @@ class TestAnalyticsFix:
} }
) )
mock_database_service.get_transactions_from_db = AsyncMock( mock_transaction_repo.get_transactions.return_value = mock_transactions
return_value=mock_transactions
app = create_app()
app.dependency_overrides[get_transaction_repository] = (
lambda: mock_transaction_repo
)
client = TestClient(app)
response = client.get("/api/v1/transactions/analytics?days=365")
assert response.status_code == 200
data = response.json()
# Verify that limit=None was passed to get all transactions
mock_transaction_repo.get_transactions.assert_called_once()
call_args = mock_transaction_repo.get_transactions.call_args
assert call_args.kwargs.get("limit") is None, (
"Analytics endpoint should pass limit=None"
) )
with patch( # Verify that all 600 transactions are returned
"leggen.api.routes.transactions.database_service", mock_database_service transactions_data = data
): assert len(transactions_data) == 600, (
app = create_app() "Analytics endpoint should return all 600 transactions"
client = TestClient(app) )
response = client.get("/api/v1/transactions/analytics?days=365")
assert response.status_code == 200
data = response.json()
# Verify that limit=None was passed to get all transactions
mock_database_service.get_transactions_from_db.assert_called_once()
call_args = mock_database_service.get_transactions_from_db.call_args
assert call_args.kwargs.get("limit") is None, (
"Analytics endpoint should pass limit=None"
)
# Verify that all 600 transactions are returned
transactions_data = data
assert len(transactions_data) == 600, (
"Analytics endpoint should return all 600 transactions"
)

View File

@@ -4,6 +4,12 @@ from unittest.mock import patch
import pytest import pytest
from leggen.api.dependencies import (
get_account_repository,
get_balance_repository,
get_transaction_repository,
)
@pytest.mark.api @pytest.mark.api
class TestAccountsAPI: class TestAccountsAPI:
@@ -11,11 +17,14 @@ class TestAccountsAPI:
def test_get_all_accounts_success( def test_get_all_accounts_success(
self, self,
fastapi_app,
api_client, api_client,
mock_config, mock_config,
mock_auth_token, mock_auth_token,
sample_account_data, sample_account_data,
mock_db_path, mock_db_path,
mock_account_repo,
mock_balance_repo,
): ):
"""Test successful retrieval of all accounts from database.""" """Test successful retrieval of all accounts from database."""
mock_accounts = [ mock_accounts = [
@@ -45,19 +54,21 @@ class TestAccountsAPI:
} }
] ]
with ( mock_account_repo.get_accounts.return_value = mock_accounts
patch("leggen.utils.config.config", mock_config), mock_balance_repo.get_balances.return_value = mock_balances
patch(
"leggen.api.routes.accounts.database_service.get_accounts_from_db", fastapi_app.dependency_overrides[get_account_repository] = (
return_value=mock_accounts, lambda: mock_account_repo
), )
patch( fastapi_app.dependency_overrides[get_balance_repository] = (
"leggen.api.routes.accounts.database_service.get_balances_from_db", lambda: mock_balance_repo
return_value=mock_balances, )
),
): with patch("leggen.utils.config.config", mock_config):
response = api_client.get("/api/v1/accounts") response = api_client.get("/api/v1/accounts")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert len(data) == 1 assert len(data) == 1
@@ -69,11 +80,14 @@ class TestAccountsAPI:
def test_get_account_details_success( def test_get_account_details_success(
self, self,
fastapi_app,
api_client, api_client,
mock_config, mock_config,
mock_auth_token, mock_auth_token,
sample_account_data, sample_account_data,
mock_db_path, mock_db_path,
mock_account_repo,
mock_balance_repo,
): ):
"""Test successful retrieval of specific account details from database.""" """Test successful retrieval of specific account details from database."""
mock_account = { mock_account = {
@@ -101,19 +115,21 @@ class TestAccountsAPI:
} }
] ]
with ( mock_account_repo.get_account.return_value = mock_account
patch("leggen.utils.config.config", mock_config), mock_balance_repo.get_balances.return_value = mock_balances
patch(
"leggen.api.routes.accounts.database_service.get_account_details_from_db", fastapi_app.dependency_overrides[get_account_repository] = (
return_value=mock_account, lambda: mock_account_repo
), )
patch( fastapi_app.dependency_overrides[get_balance_repository] = (
"leggen.api.routes.accounts.database_service.get_balances_from_db", lambda: mock_balance_repo
return_value=mock_balances, )
),
): with patch("leggen.utils.config.config", mock_config):
response = api_client.get("/api/v1/accounts/test-account-123") response = api_client.get("/api/v1/accounts/test-account-123")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["id"] == "test-account-123" assert data["id"] == "test-account-123"
@@ -121,7 +137,13 @@ class TestAccountsAPI:
assert len(data["balances"]) == 1 assert len(data["balances"]) == 1
def test_get_account_balances_success( def test_get_account_balances_success(
self, api_client, mock_config, mock_auth_token, mock_db_path self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_db_path,
mock_balance_repo,
): ):
"""Test successful retrieval of account balances from database.""" """Test successful retrieval of account balances from database."""
mock_balances = [ mock_balances = [
@@ -149,15 +171,17 @@ class TestAccountsAPI:
}, },
] ]
with ( mock_balance_repo.get_balances.return_value = mock_balances
patch("leggen.utils.config.config", mock_config),
patch( fastapi_app.dependency_overrides[get_balance_repository] = (
"leggen.api.routes.accounts.database_service.get_balances_from_db", lambda: mock_balance_repo
return_value=mock_balances, )
),
): with patch("leggen.utils.config.config", mock_config):
response = api_client.get("/api/v1/accounts/test-account-123/balances") response = api_client.get("/api/v1/accounts/test-account-123/balances")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert len(data) == 2 assert len(data) == 2
@@ -167,12 +191,14 @@ class TestAccountsAPI:
def test_get_account_transactions_success( def test_get_account_transactions_success(
self, self,
fastapi_app,
api_client, api_client,
mock_config, mock_config,
mock_auth_token, mock_auth_token,
sample_account_data, sample_account_data,
sample_transaction_data, sample_transaction_data,
mock_db_path, mock_db_path,
mock_transaction_repo,
): ):
"""Test successful retrieval of account transactions from database.""" """Test successful retrieval of account transactions from database."""
mock_transactions = [ mock_transactions = [
@@ -191,21 +217,19 @@ class TestAccountsAPI:
} }
] ]
with ( mock_transaction_repo.get_transactions.return_value = mock_transactions
patch("leggen.utils.config.config", mock_config),
patch( fastapi_app.dependency_overrides[get_transaction_repository] = (
"leggen.api.routes.accounts.database_service.get_transactions_from_db", lambda: mock_transaction_repo
return_value=mock_transactions, )
),
patch( with patch("leggen.utils.config.config", mock_config):
"leggen.api.routes.accounts.database_service.get_transaction_count_from_db",
return_value=1,
),
):
response = api_client.get( response = api_client.get(
"/api/v1/accounts/test-account-123/transactions?summary_only=true" "/api/v1/accounts/test-account-123/transactions?summary_only=true"
) )
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert len(data) == 1 assert len(data) == 1
@@ -218,12 +242,14 @@ class TestAccountsAPI:
def test_get_account_transactions_full_details( def test_get_account_transactions_full_details(
self, self,
fastapi_app,
api_client, api_client,
mock_config, mock_config,
mock_auth_token, mock_auth_token,
sample_account_data, sample_account_data,
sample_transaction_data, sample_transaction_data,
mock_db_path, mock_db_path,
mock_transaction_repo,
): ):
"""Test retrieval of full transaction details from database.""" """Test retrieval of full transaction details from database."""
mock_transactions = [ mock_transactions = [
@@ -242,21 +268,19 @@ class TestAccountsAPI:
} }
] ]
with ( mock_transaction_repo.get_transactions.return_value = mock_transactions
patch("leggen.utils.config.config", mock_config),
patch( fastapi_app.dependency_overrides[get_transaction_repository] = (
"leggen.api.routes.accounts.database_service.get_transactions_from_db", lambda: mock_transaction_repo
return_value=mock_transactions, )
),
patch( with patch("leggen.utils.config.config", mock_config):
"leggen.api.routes.accounts.database_service.get_transaction_count_from_db",
return_value=1,
),
):
response = api_client.get( response = api_client.get(
"/api/v1/accounts/test-account-123/transactions?summary_only=false" "/api/v1/accounts/test-account-123/transactions?summary_only=false"
) )
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert len(data) == 1 assert len(data) == 1
@@ -268,22 +292,36 @@ class TestAccountsAPI:
assert "raw_transaction" in transaction assert "raw_transaction" in transaction
def test_get_account_not_found( def test_get_account_not_found(
self, api_client, mock_config, mock_auth_token, mock_db_path self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_db_path,
mock_account_repo,
): ):
"""Test handling of non-existent account.""" """Test handling of non-existent account."""
with ( mock_account_repo.get_account.return_value = None
patch("leggen.utils.config.config", mock_config),
patch( fastapi_app.dependency_overrides[get_account_repository] = (
"leggen.api.routes.accounts.database_service.get_account_details_from_db", lambda: mock_account_repo
return_value=None, )
),
): with patch("leggen.utils.config.config", mock_config):
response = api_client.get("/api/v1/accounts/nonexistent") response = api_client.get("/api/v1/accounts/nonexistent")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 404 assert response.status_code == 404
def test_update_account_display_name_success( def test_update_account_display_name_success(
self, api_client, mock_config, mock_auth_token, mock_db_path self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_db_path,
mock_account_repo,
): ):
"""Test successful update of account display name.""" """Test successful update of account display name."""
mock_account = { mock_account = {
@@ -297,41 +335,48 @@ class TestAccountsAPI:
"last_accessed": "2025-09-01T09:30:00Z", "last_accessed": "2025-09-01T09:30:00Z",
} }
with ( mock_account_repo.get_account.return_value = mock_account
patch("leggen.utils.config.config", mock_config), mock_account_repo.persist.return_value = mock_account
patch(
"leggen.api.routes.accounts.database_service.get_account_details_from_db", fastapi_app.dependency_overrides[get_account_repository] = (
return_value=mock_account, lambda: mock_account_repo
), )
patch(
"leggen.api.routes.accounts.database_service.persist_account_details", with patch("leggen.utils.config.config", mock_config):
return_value=None,
),
):
response = api_client.put( response = api_client.put(
"/api/v1/accounts/test-account-123", "/api/v1/accounts/test-account-123",
json={"display_name": "My Custom Account Name"}, json={"display_name": "My Custom Account Name"},
) )
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["id"] == "test-account-123" assert data["id"] == "test-account-123"
assert data["display_name"] == "My Custom Account Name" assert data["display_name"] == "My Custom Account Name"
def test_update_account_not_found( def test_update_account_not_found(
self, api_client, mock_config, mock_auth_token, mock_db_path self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_db_path,
mock_account_repo,
): ):
"""Test updating non-existent account.""" """Test updating non-existent account."""
with ( mock_account_repo.get_account.return_value = None
patch("leggen.utils.config.config", mock_config),
patch( fastapi_app.dependency_overrides[get_account_repository] = (
"leggen.api.routes.accounts.database_service.get_account_details_from_db", lambda: mock_account_repo
return_value=None, )
),
): with patch("leggen.utils.config.config", mock_config):
response = api_client.put( response = api_client.put(
"/api/v1/accounts/nonexistent", "/api/v1/accounts/nonexistent",
json={"display_name": "New Name"}, json={"display_name": "New Name"},
) )
fastapi_app.dependency_overrides.clear()
assert response.status_code == 404 assert response.status_code == 404

View File

@@ -211,10 +211,7 @@ class TestBackupAPI:
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert len(data) == 2 assert len(data) == 2
assert ( assert data[0]["key"] == "leggen_backups/database_backup_20250101_120000.db"
data[0]["key"]
== "leggen_backups/database_backup_20250101_120000.db"
)
def test_list_backups_no_config(self, api_client, mock_config): def test_list_backups_no_config(self, api_client, mock_config):
"""Test backup listing with no configuration.""" """Test backup listing with no configuration."""

View File

@@ -5,13 +5,20 @@ from unittest.mock import patch
import pytest import pytest
from leggen.api.dependencies import get_transaction_repository
@pytest.mark.api @pytest.mark.api
class TestTransactionsAPI: class TestTransactionsAPI:
"""Test transaction-related API endpoints.""" """Test transaction-related API endpoints."""
def test_get_all_transactions_success( def test_get_all_transactions_success(
self, api_client, mock_config, mock_auth_token self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
): ):
"""Test successful retrieval of all transactions from database.""" """Test successful retrieval of all transactions from database."""
mock_transactions = [ mock_transactions = [
@@ -43,19 +50,17 @@ class TestTransactionsAPI:
}, },
] ]
with ( mock_transaction_repo.get_transactions.return_value = mock_transactions
patch("leggen.utils.config.config", mock_config), mock_transaction_repo.get_count.return_value = len(mock_transactions)
patch( fastapi_app.dependency_overrides[get_transaction_repository] = (
"leggen.api.routes.transactions.database_service.get_transactions_from_db", lambda: mock_transaction_repo
return_value=mock_transactions, )
),
patch( with patch("leggen.utils.config.config", mock_config):
"leggen.api.routes.transactions.database_service.get_transaction_count_from_db",
return_value=2,
),
):
response = api_client.get("/api/v1/transactions?summary_only=true") response = api_client.get("/api/v1/transactions?summary_only=true")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert len(data["data"]) == 2 assert len(data["data"]) == 2
@@ -70,7 +75,12 @@ class TestTransactionsAPI:
assert transaction["account_id"] == "test-account-123" assert transaction["account_id"] == "test-account-123"
def test_get_all_transactions_full_details( def test_get_all_transactions_full_details(
self, api_client, mock_config, mock_auth_token self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
): ):
"""Test retrieval of full transaction details from database.""" """Test retrieval of full transaction details from database."""
mock_transactions = [ mock_transactions = [
@@ -89,19 +99,17 @@ class TestTransactionsAPI:
} }
] ]
with ( mock_transaction_repo.get_transactions.return_value = mock_transactions
patch("leggen.utils.config.config", mock_config), mock_transaction_repo.get_count.return_value = len(mock_transactions)
patch( fastapi_app.dependency_overrides[get_transaction_repository] = (
"leggen.api.routes.transactions.database_service.get_transactions_from_db", lambda: mock_transaction_repo
return_value=mock_transactions, )
),
patch( with patch("leggen.utils.config.config", mock_config):
"leggen.api.routes.transactions.database_service.get_transaction_count_from_db",
return_value=1,
),
):
response = api_client.get("/api/v1/transactions?summary_only=false") response = api_client.get("/api/v1/transactions?summary_only=false")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert len(data["data"]) == 1 assert len(data["data"]) == 1
@@ -114,7 +122,12 @@ class TestTransactionsAPI:
assert "raw_transaction" in transaction assert "raw_transaction" in transaction
def test_get_transactions_with_filters( def test_get_transactions_with_filters(
self, api_client, mock_config, mock_auth_token self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
): ):
"""Test getting transactions with various filters.""" """Test getting transactions with various filters."""
mock_transactions = [ mock_transactions = [
@@ -133,17 +146,14 @@ class TestTransactionsAPI:
} }
] ]
with ( mock_transaction_repo.get_transactions.return_value = mock_transactions
patch("leggen.utils.config.config", mock_config), mock_transaction_repo.get_count.return_value = 1
patch(
"leggen.api.routes.transactions.database_service.get_transactions_from_db", fastapi_app.dependency_overrides[get_transaction_repository] = (
return_value=mock_transactions, lambda: mock_transaction_repo
) as mock_get_transactions, )
patch(
"leggen.api.routes.transactions.database_service.get_transaction_count_from_db", with patch("leggen.utils.config.config", mock_config):
return_value=1,
),
):
response = api_client.get( response = api_client.get(
"/api/v1/transactions?" "/api/v1/transactions?"
"account_id=test-account-123&" "account_id=test-account-123&"
@@ -156,11 +166,12 @@ class TestTransactionsAPI:
"per_page=10" "per_page=10"
) )
assert response.status_code == 200 fastapi_app.dependency_overrides.clear()
data = response.json()
# Verify the database service was called with correct filters assert response.status_code == 200
mock_get_transactions.assert_called_once_with(
# Verify the repository was called with correct filters
mock_transaction_repo.get_transactions.assert_called_once_with(
account_id="test-account-123", account_id="test-account-123",
limit=10, limit=10,
offset=10, # (page-1) * per_page = (2-1) * 10 = 10 offset=10, # (page-1) * per_page = (2-1) * 10 = 10
@@ -172,22 +183,26 @@ class TestTransactionsAPI:
) )
def test_get_transactions_empty_result( def test_get_transactions_empty_result(
self, api_client, mock_config, mock_auth_token self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
): ):
"""Test getting transactions when database returns empty result.""" """Test getting transactions when database returns empty result."""
with ( mock_transaction_repo.get_transactions.return_value = []
patch("leggen.utils.config.config", mock_config), mock_transaction_repo.get_count.return_value = 0
patch(
"leggen.api.routes.transactions.database_service.get_transactions_from_db", fastapi_app.dependency_overrides[get_transaction_repository] = (
return_value=[], lambda: mock_transaction_repo
), )
patch(
"leggen.api.routes.transactions.database_service.get_transaction_count_from_db", with patch("leggen.utils.config.config", mock_config):
return_value=0,
),
):
response = api_client.get("/api/v1/transactions") response = api_client.get("/api/v1/transactions")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert len(data["data"]) == 0 assert len(data["data"]) == 0
@@ -196,23 +211,37 @@ class TestTransactionsAPI:
assert data["total_pages"] == 0 assert data["total_pages"] == 0
def test_get_transactions_database_error( def test_get_transactions_database_error(
self, api_client, mock_config, mock_auth_token self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
): ):
"""Test handling database error when getting transactions.""" """Test handling database error when getting transactions."""
with ( mock_transaction_repo.get_transactions.side_effect = Exception(
patch("leggen.utils.config.config", mock_config), "Database connection failed"
patch( )
"leggen.api.routes.transactions.database_service.get_transactions_from_db",
side_effect=Exception("Database connection failed"), fastapi_app.dependency_overrides[get_transaction_repository] = (
), lambda: mock_transaction_repo
): )
with patch("leggen.utils.config.config", mock_config):
response = api_client.get("/api/v1/transactions") response = api_client.get("/api/v1/transactions")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 500 assert response.status_code == 500
assert "Failed to get transactions" in response.json()["detail"] assert "Failed to get transactions" in response.json()["detail"]
def test_get_transaction_stats_success( def test_get_transaction_stats_success(
self, api_client, mock_config, mock_auth_token self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
): ):
"""Test successful retrieval of transaction statistics from database.""" """Test successful retrieval of transaction statistics from database."""
mock_transactions = [ mock_transactions = [
@@ -239,15 +268,16 @@ class TestTransactionsAPI:
}, },
] ]
with ( mock_transaction_repo.get_transactions.return_value = mock_transactions
patch("leggen.utils.config.config", mock_config), fastapi_app.dependency_overrides[get_transaction_repository] = (
patch( lambda: mock_transaction_repo
"leggen.api.routes.transactions.database_service.get_transactions_from_db", )
return_value=mock_transactions,
), with patch("leggen.utils.config.config", mock_config):
):
response = api_client.get("/api/v1/transactions/stats?days=30") response = api_client.get("/api/v1/transactions/stats?days=30")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
@@ -265,7 +295,12 @@ class TestTransactionsAPI:
assert data["average_transaction"] == expected_avg assert data["average_transaction"] == expected_avg
def test_get_transaction_stats_with_account_filter( def test_get_transaction_stats_with_account_filter(
self, api_client, mock_config, mock_auth_token self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
): ):
"""Test getting transaction stats filtered by account.""" """Test getting transaction stats filtered by account."""
mock_transactions = [ mock_transactions = [
@@ -278,37 +313,46 @@ class TestTransactionsAPI:
} }
] ]
with ( mock_transaction_repo.get_transactions.return_value = mock_transactions
patch("leggen.utils.config.config", mock_config),
patch( fastapi_app.dependency_overrides[get_transaction_repository] = (
"leggen.api.routes.transactions.database_service.get_transactions_from_db", lambda: mock_transaction_repo
return_value=mock_transactions, )
) as mock_get_transactions,
): with patch("leggen.utils.config.config", mock_config):
response = api_client.get( response = api_client.get(
"/api/v1/transactions/stats?account_id=test-account-123" "/api/v1/transactions/stats?account_id=test-account-123"
) )
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200 assert response.status_code == 200
# Verify the database service was called with account filter # Verify the repository was called with account filter
mock_get_transactions.assert_called_once() mock_transaction_repo.get_transactions.assert_called_once()
call_kwargs = mock_get_transactions.call_args.kwargs call_kwargs = mock_transaction_repo.get_transactions.call_args.kwargs
assert call_kwargs["account_id"] == "test-account-123" assert call_kwargs["account_id"] == "test-account-123"
def test_get_transaction_stats_empty_result( def test_get_transaction_stats_empty_result(
self, api_client, mock_config, mock_auth_token self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
): ):
"""Test getting stats when no transactions match criteria.""" """Test getting stats when no transactions match criteria."""
with ( mock_transaction_repo.get_transactions.return_value = []
patch("leggen.utils.config.config", mock_config),
patch( fastapi_app.dependency_overrides[get_transaction_repository] = (
"leggen.api.routes.transactions.database_service.get_transactions_from_db", lambda: mock_transaction_repo
return_value=[], )
),
): with patch("leggen.utils.config.config", mock_config):
response = api_client.get("/api/v1/transactions/stats") response = api_client.get("/api/v1/transactions/stats")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
@@ -320,23 +364,37 @@ class TestTransactionsAPI:
assert data["accounts_included"] == 0 assert data["accounts_included"] == 0
def test_get_transaction_stats_database_error( def test_get_transaction_stats_database_error(
self, api_client, mock_config, mock_auth_token self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
): ):
"""Test handling database error when getting stats.""" """Test handling database error when getting stats."""
with ( mock_transaction_repo.get_transactions.side_effect = Exception(
patch("leggen.utils.config.config", mock_config), "Database connection failed"
patch( )
"leggen.api.routes.transactions.database_service.get_transactions_from_db",
side_effect=Exception("Database connection failed"), fastapi_app.dependency_overrides[get_transaction_repository] = (
), lambda: mock_transaction_repo
): )
with patch("leggen.utils.config.config", mock_config):
response = api_client.get("/api/v1/transactions/stats") response = api_client.get("/api/v1/transactions/stats")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 500 assert response.status_code == 500
assert "Failed to get transaction stats" in response.json()["detail"] assert "Failed to get transaction stats" in response.json()["detail"]
def test_get_transaction_stats_custom_period( def test_get_transaction_stats_custom_period(
self, api_client, mock_config, mock_auth_token self,
fastapi_app,
api_client,
mock_config,
mock_auth_token,
mock_transaction_repo,
): ):
"""Test getting transaction stats for custom time period.""" """Test getting transaction stats for custom time period."""
mock_transactions = [ mock_transactions = [
@@ -349,21 +407,23 @@ class TestTransactionsAPI:
} }
] ]
with ( mock_transaction_repo.get_transactions.return_value = mock_transactions
patch("leggen.utils.config.config", mock_config),
patch( fastapi_app.dependency_overrides[get_transaction_repository] = (
"leggen.api.routes.transactions.database_service.get_transactions_from_db", lambda: mock_transaction_repo
return_value=mock_transactions, )
) as mock_get_transactions,
): with patch("leggen.utils.config.config", mock_config):
response = api_client.get("/api/v1/transactions/stats?days=7") response = api_client.get("/api/v1/transactions/stats?days=7")
fastapi_app.dependency_overrides.clear()
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["period_days"] == 7 assert data["period_days"] == 7
# Verify the date range was calculated correctly for 7 days # Verify the date range was calculated correctly for 7 days
mock_get_transactions.assert_called_once() mock_transaction_repo.get_transactions.assert_called_once()
call_kwargs = mock_get_transactions.call_args.kwargs call_kwargs = mock_transaction_repo.get_transactions.call_args.kwargs
assert "date_from" in call_kwargs assert "date_from" in call_kwargs
assert "date_to" in call_kwargs assert "date_to" in call_kwargs

View File

@@ -120,12 +120,10 @@ class TestConfigurablePaths:
"iban": "TEST_IBAN", "iban": "TEST_IBAN",
} }
# Use the internal balance persistence method since the test needs direct database access # Use the public balance persistence method
import asyncio import asyncio
asyncio.run( asyncio.run(database_service.persist_balance("test-account", balance_data))
database_service._persist_balance_sqlite("test-account", balance_data)
)
# Retrieve balances # Retrieve balances
balances = asyncio.run( balances = asyncio.run(

View File

@@ -85,7 +85,7 @@ class TestDatabaseService:
): ):
"""Test successful retrieval of transactions from database.""" """Test successful retrieval of transactions from database."""
with patch.object( with patch.object(
database_service, "_get_transactions" database_service.transactions, "get_transactions"
) as mock_get_transactions: ) as mock_get_transactions:
mock_get_transactions.return_value = sample_transactions_db_format mock_get_transactions.return_value = sample_transactions_db_format
@@ -111,7 +111,7 @@ class TestDatabaseService:
): ):
"""Test retrieving transactions with filters.""" """Test retrieving transactions with filters."""
with patch.object( with patch.object(
database_service, "_get_transactions" database_service.transactions, "get_transactions"
) as mock_get_transactions: ) as mock_get_transactions:
mock_get_transactions.return_value = sample_transactions_db_format mock_get_transactions.return_value = sample_transactions_db_format
@@ -149,7 +149,7 @@ class TestDatabaseService:
async def test_get_transactions_from_db_error(self, database_service): async def test_get_transactions_from_db_error(self, database_service):
"""Test handling error when getting transactions.""" """Test handling error when getting transactions."""
with patch.object( with patch.object(
database_service, "_get_transactions" database_service.transactions, "get_transactions"
) as mock_get_transactions: ) as mock_get_transactions:
mock_get_transactions.side_effect = Exception("Database error") mock_get_transactions.side_effect = Exception("Database error")
@@ -159,7 +159,7 @@ class TestDatabaseService:
async def test_get_transaction_count_from_db_success(self, database_service): async def test_get_transaction_count_from_db_success(self, database_service):
"""Test successful retrieval of transaction count.""" """Test successful retrieval of transaction count."""
with patch.object(database_service, "_get_transaction_count") as mock_get_count: with patch.object(database_service.transactions, "get_count") as mock_get_count:
mock_get_count.return_value = 42 mock_get_count.return_value = 42
result = await database_service.get_transaction_count_from_db( result = await database_service.get_transaction_count_from_db(
@@ -167,11 +167,18 @@ class TestDatabaseService:
) )
assert result == 42 assert result == 42
mock_get_count.assert_called_once_with(account_id="test-account-123") mock_get_count.assert_called_once_with(
account_id="test-account-123",
date_from=None,
date_to=None,
min_amount=None,
max_amount=None,
search=None,
)
async def test_get_transaction_count_from_db_with_filters(self, database_service): async def test_get_transaction_count_from_db_with_filters(self, database_service):
"""Test getting transaction count with filters.""" """Test getting transaction count with filters."""
with patch.object(database_service, "_get_transaction_count") as mock_get_count: with patch.object(database_service.transactions, "get_count") as mock_get_count:
mock_get_count.return_value = 15 mock_get_count.return_value = 15
result = await database_service.get_transaction_count_from_db( result = await database_service.get_transaction_count_from_db(
@@ -185,7 +192,9 @@ class TestDatabaseService:
mock_get_count.assert_called_once_with( mock_get_count.assert_called_once_with(
account_id="test-account-123", account_id="test-account-123",
date_from="2025-09-01", date_from="2025-09-01",
date_to=None,
min_amount=-100.0, min_amount=-100.0,
max_amount=None,
search="Coffee", search="Coffee",
) )
@@ -201,7 +210,7 @@ class TestDatabaseService:
async def test_get_transaction_count_from_db_error(self, database_service): async def test_get_transaction_count_from_db_error(self, database_service):
"""Test handling error when getting count.""" """Test handling error when getting count."""
with patch.object(database_service, "_get_transaction_count") as mock_get_count: with patch.object(database_service.transactions, "get_count") as mock_get_count:
mock_get_count.side_effect = Exception("Database error") mock_get_count.side_effect = Exception("Database error")
result = await database_service.get_transaction_count_from_db() result = await database_service.get_transaction_count_from_db()
@@ -212,7 +221,9 @@ class TestDatabaseService:
self, database_service, sample_balances_db_format self, database_service, sample_balances_db_format
): ):
"""Test successful retrieval of balances from database.""" """Test successful retrieval of balances from database."""
with patch.object(database_service, "_get_balances") as mock_get_balances: with patch.object(
database_service.balances, "get_balances"
) as mock_get_balances:
mock_get_balances.return_value = sample_balances_db_format mock_get_balances.return_value = sample_balances_db_format
result = await database_service.get_balances_from_db( result = await database_service.get_balances_from_db(
@@ -234,7 +245,9 @@ class TestDatabaseService:
async def test_get_balances_from_db_error(self, database_service): async def test_get_balances_from_db_error(self, database_service):
"""Test handling error when getting balances.""" """Test handling error when getting balances."""
with patch.object(database_service, "_get_balances") as mock_get_balances: with patch.object(
database_service.balances, "get_balances"
) as mock_get_balances:
mock_get_balances.side_effect = Exception("Database error") mock_get_balances.side_effect = Exception("Database error")
result = await database_service.get_balances_from_db() result = await database_service.get_balances_from_db()
@@ -249,7 +262,9 @@ class TestDatabaseService:
"iban": "LT313250081177977789", "iban": "LT313250081177977789",
} }
with patch.object(database_service, "_get_account_summary") as mock_get_summary: with patch.object(
database_service.transactions, "get_account_summary"
) as mock_get_summary:
mock_get_summary.return_value = mock_summary mock_get_summary.return_value = mock_summary
result = await database_service.get_account_summary_from_db( result = await database_service.get_account_summary_from_db(
@@ -269,7 +284,9 @@ class TestDatabaseService:
async def test_get_account_summary_from_db_error(self, database_service): async def test_get_account_summary_from_db_error(self, database_service):
"""Test handling error when getting summary.""" """Test handling error when getting summary."""
with patch.object(database_service, "_get_account_summary") as mock_get_summary: with patch.object(
database_service.transactions, "get_account_summary"
) as mock_get_summary:
mock_get_summary.side_effect = Exception("Database error") mock_get_summary.side_effect = Exception("Database error")
result = await database_service.get_account_summary_from_db( result = await database_service.get_account_summary_from_db(
@@ -291,87 +308,87 @@ class TestDatabaseService:
], ],
} }
with patch("sqlite3.connect") as mock_connect: with (
mock_conn = mock_connect.return_value patch.object(database_service.balances, "persist") as mock_persist,
mock_cursor = mock_conn.cursor.return_value patch.object(
database_service.balance_transformer, "transform_to_database_format"
) as mock_transform,
):
mock_transform.return_value = [
(
"test-account-123",
"REVOLUT_REVOLT21",
"active",
"LT313250081177977789",
1000.0,
"EUR",
"interimAvailable",
"2025-09-01T10:00:00",
)
]
await database_service._persist_balance_sqlite( await database_service.persist_balance("test-account-123", balance_data)
"test-account-123", balance_data
)
# Verify database operations # Verify transformation and persistence were called
mock_connect.assert_called() mock_transform.assert_called_once_with("test-account-123", balance_data)
mock_cursor.execute.assert_called() # Table creation and insert mock_persist.assert_called_once()
mock_conn.commit.assert_called_once()
mock_conn.close.assert_called_once()
async def test_persist_balance_sqlite_error(self, database_service): async def test_persist_balance_sqlite_error(self, database_service):
"""Test handling error during balance persistence.""" """Test handling error during balance persistence."""
balance_data = {"balances": []} balance_data = {"balances": []}
with patch("sqlite3.connect") as mock_connect: with (
mock_connect.side_effect = Exception("Database error") patch.object(database_service.balances, "persist") as mock_persist,
patch.object(
database_service.balance_transformer, "transform_to_database_format"
) as mock_transform,
):
mock_persist.side_effect = Exception("Database error")
mock_transform.return_value = []
with pytest.raises(Exception, match="Database error"): with pytest.raises(Exception, match="Database error"):
await database_service._persist_balance_sqlite( await database_service.persist_balance("test-account-123", balance_data)
"test-account-123", balance_data
)
async def test_persist_transactions_sqlite_success( async def test_persist_transactions_sqlite_success(
self, database_service, sample_transactions_db_format self, database_service, sample_transactions_db_format
): ):
"""Test successful transaction persistence.""" """Test successful transaction persistence."""
with patch("sqlite3.connect") as mock_connect: with patch.object(database_service.transactions, "persist") as mock_persist:
mock_conn = mock_connect.return_value mock_persist.return_value = sample_transactions_db_format
mock_cursor = mock_conn.cursor.return_value
# Mock fetchone to return (0,) indicating transaction doesn't exist yet
mock_cursor.fetchone.return_value = (0,)
result = await database_service._persist_transactions_sqlite( result = await database_service.persist_transactions(
"test-account-123", sample_transactions_db_format "test-account-123", sample_transactions_db_format
) )
# Should return the transactions (assuming no duplicates) # Should return the new transactions
assert len(result) >= 0 # Could be empty if all are duplicates assert len(result) == 2
mock_persist.assert_called_once_with(
# Verify database operations "test-account-123", sample_transactions_db_format
mock_connect.assert_called() )
mock_cursor.execute.assert_called()
mock_conn.commit.assert_called_once()
mock_conn.close.assert_called_once()
async def test_persist_transactions_sqlite_duplicate_detection( async def test_persist_transactions_sqlite_duplicate_detection(
self, database_service, sample_transactions_db_format self, database_service, sample_transactions_db_format
): ):
"""Test that existing transactions are not returned as new.""" """Test that existing transactions are not returned as new."""
with patch("sqlite3.connect") as mock_connect: with patch.object(database_service.transactions, "persist") as mock_persist:
mock_conn = mock_connect.return_value # Return empty list indicating all were duplicates
mock_cursor = mock_conn.cursor.return_value mock_persist.return_value = []
# Mock fetchone to return (1,) indicating transaction already exists
mock_cursor.fetchone.return_value = (1,)
result = await database_service._persist_transactions_sqlite( result = await database_service.persist_transactions(
"test-account-123", sample_transactions_db_format "test-account-123", sample_transactions_db_format
) )
# Should return empty list since all transactions already exist # Should return empty list since all transactions already exist
assert len(result) == 0 assert len(result) == 0
mock_persist.assert_called_once()
# Verify database operations still happened (INSERT OR REPLACE executed)
mock_connect.assert_called()
mock_cursor.execute.assert_called()
mock_conn.commit.assert_called_once()
mock_conn.close.assert_called_once()
async def test_persist_transactions_sqlite_error(self, database_service): async def test_persist_transactions_sqlite_error(self, database_service):
"""Test handling error during transaction persistence.""" """Test handling error during transaction persistence."""
with patch("sqlite3.connect") as mock_connect: with patch.object(database_service.transactions, "persist") as mock_persist:
mock_connect.side_effect = Exception("Database error") mock_persist.side_effect = Exception("Database error")
with pytest.raises(Exception, match="Database error"): with pytest.raises(Exception, match="Database error"):
await database_service._persist_transactions_sqlite( await database_service.persist_transactions("test-account-123", [])
"test-account-123", []
)
async def test_process_transactions_booked_and_pending(self, database_service): async def test_process_transactions_booked_and_pending(self, database_service):
"""Test processing transactions with both booked and pending.""" """Test processing transactions with both booked and pending."""

View File

@@ -0,0 +1,244 @@
"""Tests for sync service notification functionality."""
from unittest.mock import patch
import pytest
from leggen.services.sync_service import SyncService
@pytest.mark.unit
class TestSyncNotifications:
"""Test sync service notification functionality."""
@pytest.mark.asyncio
async def test_sync_failure_sends_notification(self):
"""Test that sync failures trigger notifications."""
sync_service = SyncService()
# Mock the dependencies
with (
patch.object(
sync_service.gocardless, "get_requisitions"
) as mock_get_requisitions,
patch.object(
sync_service.gocardless, "get_account_details"
) as mock_get_details,
patch.object(
sync_service.notifications, "send_sync_failure_notification"
) as mock_send_notification,
patch.object(sync_service.sync, "persist", return_value=1),
):
# Setup: One requisition with one account that will fail
mock_get_requisitions.return_value = {
"results": [
{
"id": "req-123",
"institution_id": "TEST_BANK",
"status": "LN",
"accounts": ["account-1"],
}
]
}
# Make account details fail
mock_get_details.side_effect = Exception("API Error")
# Execute: Run sync which should fail for the account
await sync_service.sync_all_accounts()
# Assert: Notification should be sent for the failed account
mock_send_notification.assert_called_once()
call_args = mock_send_notification.call_args[0][0]
assert call_args["account_id"] == "account-1"
assert "API Error" in call_args["error"]
assert call_args["type"] == "account_sync_failure"
@pytest.mark.asyncio
async def test_expired_requisition_sends_notification(self):
"""Test that expired requisitions trigger expiry notifications."""
sync_service = SyncService()
# Mock the dependencies
with (
patch.object(
sync_service.gocardless, "get_requisitions"
) as mock_get_requisitions,
patch.object(
sync_service.notifications, "send_expiry_notification"
) as mock_send_expiry,
patch.object(sync_service.sync, "persist", return_value=1),
):
# Setup: One expired requisition
mock_get_requisitions.return_value = {
"results": [
{
"id": "req-expired",
"institution_id": "EXPIRED_BANK",
"status": "EX",
"accounts": [],
}
]
}
# Execute: Run sync
await sync_service.sync_all_accounts()
# Assert: Expiry notification should be sent
mock_send_expiry.assert_called_once()
call_args = mock_send_expiry.call_args[0][0]
assert call_args["requisition_id"] == "req-expired"
assert call_args["bank"] == "EXPIRED_BANK"
assert call_args["status"] == "expired"
assert call_args["days_left"] == 0
@pytest.mark.asyncio
async def test_multiple_failures_send_multiple_notifications(self):
"""Test that multiple account failures send multiple notifications."""
sync_service = SyncService()
# Mock the dependencies
with (
patch.object(
sync_service.gocardless, "get_requisitions"
) as mock_get_requisitions,
patch.object(
sync_service.gocardless, "get_account_details"
) as mock_get_details,
patch.object(
sync_service.notifications, "send_sync_failure_notification"
) as mock_send_notification,
patch.object(sync_service.sync, "persist", return_value=1),
):
# Setup: One requisition with two accounts that will fail
mock_get_requisitions.return_value = {
"results": [
{
"id": "req-123",
"institution_id": "TEST_BANK",
"status": "LN",
"accounts": ["account-1", "account-2"],
}
]
}
# Make all account details fail
mock_get_details.side_effect = Exception("API Error")
# Execute: Run sync
await sync_service.sync_all_accounts()
# Assert: Two notifications should be sent
assert mock_send_notification.call_count == 2
@pytest.mark.asyncio
async def test_successful_sync_no_failure_notification(self):
"""Test that successful syncs don't send failure notifications."""
sync_service = SyncService()
# Mock the dependencies
with (
patch.object(
sync_service.gocardless, "get_requisitions"
) as mock_get_requisitions,
patch.object(
sync_service.gocardless, "get_account_details"
) as mock_get_details,
patch.object(
sync_service.gocardless, "get_account_balances"
) as mock_get_balances,
patch.object(
sync_service.gocardless, "get_account_transactions"
) as mock_get_transactions,
patch.object(
sync_service.notifications, "send_sync_failure_notification"
) as mock_send_notification,
patch.object(sync_service.notifications, "send_transaction_notifications"),
patch.object(sync_service.accounts, "persist"),
patch.object(sync_service.balances, "persist"),
patch.object(
sync_service.transaction_processor,
"process_transactions",
return_value=[],
),
patch.object(sync_service.transactions, "persist", return_value=[]),
patch.object(sync_service.sync, "persist", return_value=1),
):
# Setup: One requisition with one account that succeeds
mock_get_requisitions.return_value = {
"results": [
{
"id": "req-123",
"institution_id": "TEST_BANK",
"status": "LN",
"accounts": ["account-1"],
}
]
}
mock_get_details.return_value = {
"id": "account-1",
"institution_id": "TEST_BANK",
"status": "READY",
"iban": "TEST123",
}
mock_get_balances.return_value = {
"balances": [{"balanceAmount": {"amount": "100", "currency": "EUR"}}]
}
mock_get_transactions.return_value = {"transactions": {"booked": []}}
# Execute: Run sync
await sync_service.sync_all_accounts()
# Assert: No failure notification should be sent
mock_send_notification.assert_not_called()
@pytest.mark.asyncio
async def test_notification_failure_does_not_stop_sync(self):
"""Test that notification failures don't stop the sync process."""
sync_service = SyncService()
# Mock the dependencies
with (
patch.object(
sync_service.gocardless, "get_requisitions"
) as mock_get_requisitions,
patch.object(
sync_service.gocardless, "get_account_details"
) as mock_get_details,
patch.object(
sync_service.notifications, "_send_discord_sync_failure"
) as mock_discord_notification,
patch.object(
sync_service.notifications, "_send_telegram_sync_failure"
) as mock_telegram_notification,
patch.object(sync_service.sync, "persist", return_value=1),
):
# Setup: One requisition with one account that will fail
mock_get_requisitions.return_value = {
"results": [
{
"id": "req-123",
"institution_id": "TEST_BANK",
"status": "LN",
"accounts": ["account-1"],
}
]
}
# Make account details fail
mock_get_details.side_effect = Exception("API Error")
# Make both notification methods fail
mock_discord_notification.side_effect = Exception("Discord Error")
mock_telegram_notification.side_effect = Exception("Telegram Error")
# Execute: Run sync - should not raise exception from notification
result = await sync_service.sync_all_accounts()
# The sync should complete with errors but not crash from notifications
assert result.success is False
assert len(result.errors) > 0
assert "API Error" in result.errors[0]