Compare commits

...

5 Commits

Author SHA1 Message Date
Elisiário Couto
da98b7b2b7 chore: Check import order using ruff. 2025-09-14 21:12:47 +01:00
Elisiário Couto
2467cb2f5a chore: Sort imports, fix deprecated pydantic option. 2025-09-14 21:11:01 +01:00
Elisiário Couto
5ae3a51d81 refactor: Consolidate database layer and eliminate wrapper complexity.
- Merge leggen/database/sqlite.py functionality directly into DatabaseService
- Extract transaction processing logic to separate TransactionProcessor class
- Remove leggen/utils/database.py and leggen/database/ directory entirely
- Update all tests to use new consolidated structure
- Reduce codebase by ~300 lines while maintaining full functionality
- Improve separation of concerns: data processing vs persistence vs CLI

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-09-14 21:01:16 +01:00
Elisiário Couto
d09cf6d04c fix(config): Fix example config file. 2025-09-14 20:31:49 +01:00
Elisiário Couto
2c6e099596 fix(config): Add Pydantic validation and fix telegram config field mappings.
* Add Pydantic models for configuration validation in leggen/models/config.py
* Fix telegram config field aliases (api-key -> token, chat-id -> chat_id)
* Update config.py to use Pydantic validation with proper error handling
* Fix TOML serialization by excluding None values with exclude_none=True
* Update notification service to use correct telegram field names
* Enhance notification service with actual Discord/Telegram implementations
* Fix all failing configuration tests to work with Pydantic validation
* Add pydantic dependency to pyproject.toml

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-09-14 20:31:49 +01:00
41 changed files with 909 additions and 1355 deletions

View File

@@ -20,8 +20,8 @@ enabled = true
# Optional: Telegram notifications # Optional: Telegram notifications
[notifications.telegram] [notifications.telegram]
token = "your-bot-token" api-key = "your-bot-token"
chat_id = 12345 chat-id = 12345
enabled = true enabled = true
# Optional: Transaction filters for notifications # Optional: Transaction filters for notifications

View File

@@ -1,5 +1,5 @@
from datetime import datetime from datetime import datetime
from typing import List, Optional, Dict, Any from typing import Any, Dict, List, Optional
from pydantic import BaseModel from pydantic import BaseModel

View File

@@ -1,4 +1,4 @@
from typing import Optional, List from typing import List, Optional
from pydantic import BaseModel from pydantic import BaseModel

View File

@@ -1,15 +1,16 @@
from typing import Optional, List, Union from typing import List, Optional, Union
from fastapi import APIRouter, HTTPException, Query from fastapi import APIRouter, HTTPException, Query
from loguru import logger from loguru import logger
from leggen.api.models.common import APIResponse
from leggen.api.models.accounts import ( from leggen.api.models.accounts import (
AccountDetails,
AccountBalance, AccountBalance,
AccountDetails,
AccountUpdate,
Transaction, Transaction,
TransactionSummary, TransactionSummary,
AccountUpdate,
) )
from leggen.api.models.common import APIResponse
from leggen.services.database_service import DatabaseService from leggen.services.database_service import DatabaseService
router = APIRouter() router = APIRouter()

View File

@@ -1,13 +1,13 @@
from fastapi import APIRouter, HTTPException, Query from fastapi import APIRouter, HTTPException, Query
from loguru import logger from loguru import logger
from leggen.api.models.common import APIResponse
from leggen.api.models.banks import ( from leggen.api.models.banks import (
BankInstitution,
BankConnectionRequest, BankConnectionRequest,
BankRequisition,
BankConnectionStatus, BankConnectionStatus,
BankInstitution,
BankRequisition,
) )
from leggen.api.models.common import APIResponse
from leggen.services.gocardless_service import GoCardlessService from leggen.services.gocardless_service import GoCardlessService
from leggen.utils.gocardless import REQUISITION_STATUS from leggen.utils.gocardless import REQUISITION_STATUS

View File

@@ -1,14 +1,15 @@
from typing import Dict, Any from typing import Any, Dict
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException
from loguru import logger from loguru import logger
from leggen.api.models.common import APIResponse from leggen.api.models.common import APIResponse
from leggen.api.models.notifications import ( from leggen.api.models.notifications import (
DiscordConfig,
NotificationFilters,
NotificationSettings, NotificationSettings,
NotificationTest, NotificationTest,
DiscordConfig,
TelegramConfig, TelegramConfig,
NotificationFilters,
) )
from leggen.services.notification_service import NotificationService from leggen.services.notification_service import NotificationService
from leggen.utils.config import config from leggen.utils.config import config

View File

@@ -1,11 +1,12 @@
from typing import Optional from typing import Optional
from fastapi import APIRouter, HTTPException, BackgroundTasks
from fastapi import APIRouter, BackgroundTasks, HTTPException
from loguru import logger from loguru import logger
from leggen.api.models.common import APIResponse from leggen.api.models.common import APIResponse
from leggen.api.models.sync import SyncRequest, SchedulerConfig from leggen.api.models.sync import SchedulerConfig, SyncRequest
from leggen.services.sync_service import SyncService
from leggen.background.scheduler import scheduler from leggen.background.scheduler import scheduler
from leggen.services.sync_service import SyncService
from leggen.utils.config import config from leggen.utils.config import config
router = APIRouter() router = APIRouter()

View File

@@ -1,10 +1,11 @@
from typing import Optional, List, Union
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import List, Optional, Union
from fastapi import APIRouter, HTTPException, Query from fastapi import APIRouter, HTTPException, Query
from loguru import logger from loguru import logger
from leggen.api.models.common import APIResponse, PaginatedResponse
from leggen.api.models.accounts import Transaction, TransactionSummary from leggen.api.models.accounts import Transaction, TransactionSummary
from leggen.api.models.common import APIResponse, PaginatedResponse
from leggen.services.database_service import DatabaseService from leggen.services.database_service import DatabaseService
router = APIRouter() router = APIRouter()

View File

@@ -1,8 +1,9 @@
import os import os
import requests from typing import Any, Dict, List, Optional, Union
from typing import Dict, Any, Optional, List, Union
from urllib.parse import urljoin from urllib.parse import urljoin
import requests
from leggen.utils.text import error from leggen.utils.text import error

View File

@@ -2,9 +2,9 @@ from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.cron import CronTrigger from apscheduler.triggers.cron import CronTrigger
from loguru import logger from loguru import logger
from leggen.utils.config import config
from leggen.services.sync_service import SyncService
from leggen.services.notification_service import NotificationService from leggen.services.notification_service import NotificationService
from leggen.services.sync_service import SyncService
from leggen.utils.config import config
class BackgroundScheduler: class BackgroundScheduler:

View File

@@ -1,7 +1,7 @@
import click import click
from leggen.main import cli
from leggen.api_client import LeggenAPIClient from leggen.api_client import LeggenAPIClient
from leggen.main import cli
from leggen.utils.text import datefmt, print_table from leggen.utils.text import datefmt, print_table

View File

@@ -1,9 +1,9 @@
import click import click
from leggen.main import cli
from leggen.api_client import LeggenAPIClient from leggen.api_client import LeggenAPIClient
from leggen.main import cli
from leggen.utils.disk import save_file from leggen.utils.disk import save_file
from leggen.utils.text import info, print_table, warning, success from leggen.utils.text import info, print_table, success, warning
@cli.command() @cli.command()

View File

@@ -1,8 +1,9 @@
"""Generate sample database command.""" """Generate sample database command."""
import click
from pathlib import Path from pathlib import Path
import click
@click.command() @click.command()
@click.option( @click.option(
@@ -34,8 +35,8 @@ def generate_sample_db(
"""Generate a sample database with realistic financial data for testing.""" """Generate a sample database with realistic financial data for testing."""
# Import here to avoid circular imports # Import here to avoid circular imports
import sys
import subprocess import subprocess
import sys
from pathlib import Path as PathlibPath from pathlib import Path as PathlibPath
# Get the script path # Get the script path

View File

@@ -7,7 +7,7 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from loguru import logger from loguru import logger
from leggen.api.routes import banks, accounts, sync, notifications, transactions from leggen.api.routes import accounts, banks, notifications, sync, transactions
from leggen.background.scheduler import scheduler from leggen.background.scheduler import scheduler
from leggen.utils.config import config from leggen.utils.config import config
from leggen.utils.paths import path_manager from leggen.utils.paths import path_manager

View File

@@ -1,7 +1,7 @@
import click import click
from leggen.main import cli
from leggen.api_client import LeggenAPIClient from leggen.api_client import LeggenAPIClient
from leggen.main import cli
from leggen.utils.text import datefmt, echo, info, print_table from leggen.utils.text import datefmt, echo, info, print_table

View File

@@ -1,7 +1,7 @@
import click import click
from leggen.main import cli
from leggen.api_client import LeggenAPIClient from leggen.api_client import LeggenAPIClient
from leggen.main import cli
from leggen.utils.text import error, info, success from leggen.utils.text import error, info, success

View File

@@ -1,7 +1,7 @@
import click import click
from leggen.main import cli
from leggen.api_client import LeggenAPIClient from leggen.api_client import LeggenAPIClient
from leggen.main import cli
from leggen.utils.text import datefmt, info, print_table from leggen.utils.text import datefmt, info, print_table

View File

@@ -1,658 +0,0 @@
import json
import sqlite3
from sqlite3 import IntegrityError
import click
from leggen.utils.text import success, warning
from leggen.utils.paths import path_manager
def persist_balances(ctx: click.Context, balance: dict):
# Connect to SQLite database
db_path = path_manager.get_database_path()
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
# Create the accounts table if it doesn't exist
cursor.execute(
"""CREATE TABLE IF NOT EXISTS accounts (
id TEXT PRIMARY KEY,
institution_id TEXT,
status TEXT,
iban TEXT,
name TEXT,
currency TEXT,
created DATETIME,
last_accessed DATETIME,
last_updated DATETIME
)"""
)
# Create indexes for accounts table
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_accounts_institution_id
ON accounts(institution_id)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_accounts_status
ON accounts(status)"""
)
# Create the balances table if it doesn't exist
cursor.execute(
"""CREATE TABLE IF NOT EXISTS balances (
id INTEGER PRIMARY KEY AUTOINCREMENT,
account_id TEXT,
bank TEXT,
status TEXT,
iban TEXT,
amount REAL,
currency TEXT,
type TEXT,
timestamp DATETIME
)"""
)
# Create indexes for better performance
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_balances_account_id
ON balances(account_id)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_balances_timestamp
ON balances(timestamp)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_balances_account_type_timestamp
ON balances(account_id, type, timestamp)"""
)
# Insert balance into SQLite database
try:
cursor.execute(
"""INSERT INTO balances (
account_id,
bank,
status,
iban,
amount,
currency,
type,
timestamp
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
(
balance["account_id"],
balance["bank"],
balance["status"],
balance["iban"],
balance["amount"],
balance["currency"],
balance["type"],
balance["timestamp"],
),
)
except IntegrityError:
warning(f"[{balance['account_id']}] Skipped duplicate balance")
# Commit changes and close the connection
conn.commit()
conn.close()
success(f"[{balance['account_id']}] Inserted balance of type {balance['type']}")
return balance
def persist_transactions(ctx: click.Context, account: str, transactions: list) -> list:
# Connect to SQLite database
db_path = path_manager.get_database_path()
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
# Create the transactions table if it doesn't exist
cursor.execute(
"""CREATE TABLE IF NOT EXISTS transactions (
accountId TEXT NOT NULL,
transactionId TEXT NOT NULL,
internalTransactionId TEXT,
institutionId TEXT,
iban TEXT,
transactionDate DATETIME,
description TEXT,
transactionValue REAL,
transactionCurrency TEXT,
transactionStatus TEXT,
rawTransaction JSON,
PRIMARY KEY (accountId, transactionId)
)"""
)
# Create indexes for better performance
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_transactions_internal_id
ON transactions(internalTransactionId)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_transactions_date
ON transactions(transactionDate)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_transactions_account_date
ON transactions(accountId, transactionDate)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_transactions_amount
ON transactions(transactionValue)"""
)
# Insert transactions into SQLite database
duplicates_count = 0
# Prepare an SQL statement for inserting data
insert_sql = """INSERT OR REPLACE INTO transactions (
accountId,
transactionId,
internalTransactionId,
institutionId,
iban,
transactionDate,
description,
transactionValue,
transactionCurrency,
transactionStatus,
rawTransaction
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"""
new_transactions = []
for transaction in transactions:
try:
cursor.execute(
insert_sql,
(
transaction["accountId"],
transaction["transactionId"],
transaction.get("internalTransactionId"),
transaction["institutionId"],
transaction["iban"],
transaction["transactionDate"],
transaction["description"],
transaction["transactionValue"],
transaction["transactionCurrency"],
transaction["transactionStatus"],
json.dumps(transaction["rawTransaction"]),
),
)
new_transactions.append(transaction)
except IntegrityError:
# A transaction with the same ID already exists, indicating a duplicate
duplicates_count += 1
# Commit changes and close the connection
conn.commit()
conn.close()
success(f"[{account}] Inserted {len(new_transactions)} new transactions")
if duplicates_count:
warning(f"[{account}] Skipped {duplicates_count} duplicate transactions")
return new_transactions
def get_transactions(
account_id=None,
limit=100,
offset=0,
date_from=None,
date_to=None,
min_amount=None,
max_amount=None,
search=None,
):
"""Get transactions from SQLite database with optional filtering"""
db_path = path_manager.get_database_path()
if not db_path.exists():
return []
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row # Enable dict-like access
cursor = conn.cursor()
# Build query with filters
query = "SELECT * FROM transactions WHERE 1=1"
params = []
if account_id:
query += " AND accountId = ?"
params.append(account_id)
if date_from:
query += " AND transactionDate >= ?"
params.append(date_from)
if date_to:
query += " AND transactionDate <= ?"
params.append(date_to)
if min_amount is not None:
query += " AND transactionValue >= ?"
params.append(min_amount)
if max_amount is not None:
query += " AND transactionValue <= ?"
params.append(max_amount)
if search:
query += " AND description LIKE ?"
params.append(f"%{search}%")
# Add ordering and pagination
query += " ORDER BY transactionDate DESC"
if limit:
query += " LIMIT ?"
params.append(limit)
if offset:
query += " OFFSET ?"
params.append(offset)
try:
cursor.execute(query, params)
rows = cursor.fetchall()
# Convert to list of dicts and parse JSON fields
transactions = []
for row in rows:
transaction = dict(row)
if transaction["rawTransaction"]:
transaction["rawTransaction"] = json.loads(
transaction["rawTransaction"]
)
transactions.append(transaction)
conn.close()
return transactions
except Exception as e:
conn.close()
raise e
def get_balances(account_id=None):
"""Get latest balances from SQLite database"""
db_path = path_manager.get_database_path()
if not db_path.exists():
return []
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
# Get latest balance for each account_id and type combination
query = """
SELECT * FROM balances b1
WHERE b1.timestamp = (
SELECT MAX(b2.timestamp)
FROM balances b2
WHERE b2.account_id = b1.account_id AND b2.type = b1.type
)
"""
params = []
if account_id:
query += " AND b1.account_id = ?"
params.append(account_id)
query += " ORDER BY b1.account_id, b1.type"
try:
cursor.execute(query, params)
rows = cursor.fetchall()
balances = [dict(row) for row in rows]
conn.close()
return balances
except Exception as e:
conn.close()
raise e
def get_account_summary(account_id):
"""Get basic account info from transactions table (avoids GoCardless API call)"""
db_path = path_manager.get_database_path()
if not db_path.exists():
return None
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
try:
# Get account info from most recent transaction
cursor.execute(
"""
SELECT DISTINCT accountId, institutionId, iban
FROM transactions
WHERE accountId = ?
ORDER BY transactionDate DESC
LIMIT 1
""",
(account_id,),
)
row = cursor.fetchone()
conn.close()
if row:
return dict(row)
return None
except Exception as e:
conn.close()
raise e
def get_transaction_count(account_id=None, **filters):
"""Get total count of transactions matching filters"""
db_path = path_manager.get_database_path()
if not db_path.exists():
return 0
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
query = "SELECT COUNT(*) FROM transactions WHERE 1=1"
params = []
if account_id:
query += " AND accountId = ?"
params.append(account_id)
# Add same filters as get_transactions
if filters.get("date_from"):
query += " AND transactionDate >= ?"
params.append(filters["date_from"])
if filters.get("date_to"):
query += " AND transactionDate <= ?"
params.append(filters["date_to"])
if filters.get("min_amount") is not None:
query += " AND transactionValue >= ?"
params.append(filters["min_amount"])
if filters.get("max_amount") is not None:
query += " AND transactionValue <= ?"
params.append(filters["max_amount"])
if filters.get("search"):
query += " AND description LIKE ?"
params.append(f"%{filters['search']}%")
try:
cursor.execute(query, params)
count = cursor.fetchone()[0]
conn.close()
return count
except Exception as e:
conn.close()
raise e
def persist_account(account_data: dict):
"""Persist account details to SQLite database"""
db_path = path_manager.get_database_path()
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
# Create the accounts table if it doesn't exist
cursor.execute(
"""CREATE TABLE IF NOT EXISTS accounts (
id TEXT PRIMARY KEY,
institution_id TEXT,
status TEXT,
iban TEXT,
name TEXT,
currency TEXT,
created DATETIME,
last_accessed DATETIME,
last_updated DATETIME
)"""
)
# Create indexes for accounts table
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_accounts_institution_id
ON accounts(institution_id)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_accounts_status
ON accounts(status)"""
)
try:
# Insert or replace account data
cursor.execute(
"""INSERT OR REPLACE INTO accounts (
id,
institution_id,
status,
iban,
name,
currency,
created,
last_accessed,
last_updated
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
(
account_data["id"],
account_data["institution_id"],
account_data["status"],
account_data.get("iban"),
account_data.get("name"),
account_data.get("currency"),
account_data["created"],
account_data.get("last_accessed"),
account_data.get("last_updated", account_data["created"]),
),
)
conn.commit()
conn.close()
success(f"[{account_data['id']}] Account details persisted to database")
return account_data
except Exception as e:
conn.close()
raise e
def get_accounts(account_ids=None):
"""Get account details from SQLite database"""
db_path = path_manager.get_database_path()
if not db_path.exists():
return []
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
query = "SELECT * FROM accounts"
params = []
if account_ids:
placeholders = ",".join("?" * len(account_ids))
query += f" WHERE id IN ({placeholders})"
params.extend(account_ids)
query += " ORDER BY created DESC"
try:
cursor.execute(query, params)
rows = cursor.fetchall()
accounts = [dict(row) for row in rows]
conn.close()
return accounts
except Exception as e:
conn.close()
raise e
def get_account(account_id: str):
"""Get specific account details from SQLite database"""
db_path = path_manager.get_database_path()
if not db_path.exists():
return None
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
try:
cursor.execute("SELECT * FROM accounts WHERE id = ?", (account_id,))
row = cursor.fetchone()
conn.close()
if row:
return dict(row)
return None
except Exception as e:
conn.close()
raise e
def get_historical_balances(account_id=None, days=365):
"""Get historical balance progression based on transaction history"""
db_path = path_manager.get_database_path()
if not db_path.exists():
return []
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
try:
# Get current balance for each account/type to use as the final balance
current_balances_query = """
SELECT account_id, type, amount, currency
FROM balances b1
WHERE b1.timestamp = (
SELECT MAX(b2.timestamp)
FROM balances b2
WHERE b2.account_id = b1.account_id AND b2.type = b1.type
)
"""
params = []
if account_id:
current_balances_query += " AND b1.account_id = ?"
params.append(account_id)
cursor.execute(current_balances_query, params)
current_balances = {
(row["account_id"], row["type"]): {
"amount": row["amount"],
"currency": row["currency"],
}
for row in cursor.fetchall()
}
# Get transactions for the specified period, ordered by date descending
from datetime import datetime, timedelta
cutoff_date = (datetime.now() - timedelta(days=days)).isoformat()
transactions_query = """
SELECT accountId, transactionDate, transactionValue
FROM transactions
WHERE transactionDate >= ?
"""
if account_id:
transactions_query += " AND accountId = ?"
params = [cutoff_date, account_id]
else:
params = [cutoff_date]
transactions_query += " ORDER BY transactionDate DESC"
cursor.execute(transactions_query, params)
transactions = cursor.fetchall()
# Calculate historical balances by working backwards from current balance
historical_balances = []
account_running_balances: dict[str, dict[str, float]] = {}
# Initialize running balances with current balances
for (acc_id, balance_type), balance_info in current_balances.items():
if acc_id not in account_running_balances:
account_running_balances[acc_id] = {}
account_running_balances[acc_id][balance_type] = balance_info["amount"]
# Group transactions by date
from collections import defaultdict
transactions_by_date = defaultdict(list)
for txn in transactions:
date_str = txn["transactionDate"][:10] # Extract just the date part
transactions_by_date[date_str].append(txn)
# Generate historical balance points
# Start from today and work backwards
current_date = datetime.now().date()
for day_offset in range(0, days, 7): # Sample every 7 days for performance
target_date = current_date - timedelta(days=day_offset)
target_date_str = target_date.isoformat()
# For each account, create balance entries
for acc_id in account_running_balances:
for balance_type in [
"closingBooked"
]: # Focus on closingBooked for the chart
if balance_type in account_running_balances[acc_id]:
balance_amount = account_running_balances[acc_id][balance_type]
currency = current_balances.get((acc_id, balance_type), {}).get(
"currency", "EUR"
)
historical_balances.append(
{
"id": f"{acc_id}_{balance_type}_{target_date_str}",
"account_id": acc_id,
"balance_amount": balance_amount,
"balance_type": balance_type,
"currency": currency,
"reference_date": target_date_str,
"created_at": None,
"updated_at": None,
}
)
# Subtract transactions that occurred on this date and later dates
# to simulate going back in time
for date_str in list(transactions_by_date.keys()):
if date_str >= target_date_str:
for txn in transactions_by_date[date_str]:
acc_id = txn["accountId"]
amount = txn["transactionValue"]
if acc_id in account_running_balances:
for balance_type in account_running_balances[acc_id]:
account_running_balances[acc_id][balance_type] -= amount
# Remove processed transactions to avoid double-processing
del transactions_by_date[date_str]
conn.close()
# Sort by date for proper chronological order
historical_balances.sort(key=lambda x: x["reference_date"])
return historical_balances
except Exception as e:
conn.close()
raise e

View File

@@ -6,8 +6,8 @@ from pathlib import Path
import click import click
from leggen.utils.config import load_config from leggen.utils.config import load_config
from leggen.utils.text import error
from leggen.utils.paths import path_manager from leggen.utils.paths import path_manager
from leggen.utils.text import error
cmd_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), "commands")) cmd_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), "commands"))

65
leggen/models/config.py Normal file
View File

@@ -0,0 +1,65 @@
from typing import List, Optional
from pydantic import BaseModel, Field
class GoCardlessConfig(BaseModel):
key: str = Field(..., description="GoCardless API key")
secret: str = Field(..., description="GoCardless API secret")
url: str = Field(
default="https://bankaccountdata.gocardless.com/api/v2",
description="GoCardless API URL",
)
class DatabaseConfig(BaseModel):
sqlite: bool = Field(default=True, description="Enable SQLite database")
class DiscordNotificationConfig(BaseModel):
webhook: str = Field(..., description="Discord webhook URL")
enabled: bool = Field(default=True, description="Enable Discord notifications")
class TelegramNotificationConfig(BaseModel):
token: str = Field(..., alias="api-key", description="Telegram bot token")
chat_id: int = Field(..., alias="chat-id", description="Telegram chat ID")
enabled: bool = Field(default=True, description="Enable Telegram notifications")
class NotificationConfig(BaseModel):
discord: Optional[DiscordNotificationConfig] = None
telegram: Optional[TelegramNotificationConfig] = None
class FilterConfig(BaseModel):
case_insensitive: Optional[List[str]] = Field(
default_factory=list, alias="case-insensitive"
)
case_sensitive: Optional[List[str]] = Field(
default_factory=list, alias="case-sensitive"
)
class SyncScheduleConfig(BaseModel):
enabled: bool = Field(default=True, description="Enable sync scheduling")
hour: int = Field(default=3, ge=0, le=23, description="Hour to run sync (0-23)")
minute: int = Field(default=0, ge=0, le=59, description="Minute to run sync (0-59)")
cron: Optional[str] = Field(
default=None, description="Custom cron expression (overrides hour/minute)"
)
class SchedulerConfig(BaseModel):
sync: SyncScheduleConfig = Field(default_factory=SyncScheduleConfig)
class Config(BaseModel):
gocardless: GoCardlessConfig
database: DatabaseConfig = Field(default_factory=DatabaseConfig)
notifications: Optional[NotificationConfig] = None
filters: Optional[FilterConfig] = None
scheduler: SchedulerConfig = Field(default_factory=SchedulerConfig)
class Config:
validate_by_name = True

View File

@@ -1,11 +1,13 @@
from datetime import datetime import json
from typing import List, Dict, Any, Optional
import sqlite3 import sqlite3
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from loguru import logger from loguru import logger
from leggen.services.transaction_processor import TransactionProcessor
from leggen.utils.config import config from leggen.utils.config import config
import leggen.database.sqlite as sqlite_db
from leggen.utils.paths import path_manager from leggen.utils.paths import path_manager
@@ -13,6 +15,7 @@ class DatabaseService:
def __init__(self): def __init__(self):
self.db_config = config.database_config self.db_config = config.database_config
self.sqlite_enabled = self.db_config.get("sqlite", True) self.sqlite_enabled = self.db_config.get("sqlite", True)
self.transaction_processor = TransactionProcessor()
async def persist_balance( async def persist_balance(
self, account_id: str, balance_data: Dict[str, Any] self, account_id: str, balance_data: Dict[str, Any]
@@ -41,79 +44,9 @@ class DatabaseService:
transaction_data: Dict[str, Any], transaction_data: Dict[str, Any],
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
"""Process raw transaction data into standardized format""" """Process raw transaction data into standardized format"""
transactions = [] return self.transaction_processor.process_transactions(
account_id, account_info, transaction_data
# Process booked transactions
for transaction in transaction_data.get("transactions", {}).get("booked", []):
processed = self._process_single_transaction(
account_id, account_info, transaction, "booked"
)
transactions.append(processed)
# Process pending transactions
for transaction in transaction_data.get("transactions", {}).get("pending", []):
processed = self._process_single_transaction(
account_id, account_info, transaction, "pending"
)
transactions.append(processed)
return transactions
def _process_single_transaction(
self,
account_id: str,
account_info: Dict[str, Any],
transaction: Dict[str, Any],
status: str,
) -> Dict[str, Any]:
"""Process a single transaction into standardized format"""
# Extract dates
booked_date = transaction.get("bookingDateTime") or transaction.get(
"bookingDate"
) )
value_date = transaction.get("valueDateTime") or transaction.get("valueDate")
if booked_date and value_date:
min_date = min(
datetime.fromisoformat(booked_date), datetime.fromisoformat(value_date)
)
else:
date_str = booked_date or value_date
if not date_str:
raise ValueError("No valid date found in transaction")
min_date = datetime.fromisoformat(date_str)
# Extract amount and currency
transaction_amount = transaction.get("transactionAmount", {})
amount = float(transaction_amount.get("amount", 0))
currency = transaction_amount.get("currency", "")
# Extract description
description = transaction.get(
"remittanceInformationUnstructured",
",".join(transaction.get("remittanceInformationUnstructuredArray", [])),
)
# Extract transaction IDs - transactionId is now primary, internalTransactionId is reference
transaction_id = transaction.get("transactionId")
internal_transaction_id = transaction.get("internalTransactionId")
if not transaction_id:
raise ValueError("Transaction missing required transactionId field")
return {
"accountId": account_id,
"transactionId": transaction_id,
"internalTransactionId": internal_transaction_id,
"institutionId": account_info["institution_id"],
"iban": account_info.get("iban", "N/A"),
"transactionDate": min_date,
"description": description,
"transactionValue": amount,
"transactionCurrency": currency,
"transactionStatus": status,
"rawTransaction": transaction,
}
async def get_transactions_from_db( async def get_transactions_from_db(
self, self,
@@ -132,7 +65,7 @@ class DatabaseService:
return [] return []
try: try:
transactions = sqlite_db.get_transactions( transactions = self._get_transactions(
account_id=account_id, account_id=account_id,
limit=limit, # Pass limit as-is, None means no limit limit=limit, # Pass limit as-is, None means no limit
offset=offset or 0, offset=offset or 0,
@@ -172,7 +105,7 @@ class DatabaseService:
# Remove None values # Remove None values
filters = {k: v for k, v in filters.items() if v is not None} filters = {k: v for k, v in filters.items() if v is not None}
count = sqlite_db.get_transaction_count(account_id=account_id, **filters) count = self._get_transaction_count(account_id=account_id, **filters)
logger.debug(f"Total transaction count: {count}") logger.debug(f"Total transaction count: {count}")
return count return count
except Exception as e: except Exception as e:
@@ -188,7 +121,7 @@ class DatabaseService:
return [] return []
try: try:
balances = sqlite_db.get_balances(account_id=account_id) balances = self._get_balances(account_id=account_id)
logger.debug(f"Retrieved {len(balances)} balances from database") logger.debug(f"Retrieved {len(balances)} balances from database")
return balances return balances
except Exception as e: except Exception as e:
@@ -204,9 +137,7 @@ class DatabaseService:
return [] return []
try: try:
balances = sqlite_db.get_historical_balances( balances = self._get_historical_balances(account_id=account_id, days=days)
account_id=account_id, days=days
)
logger.debug( logger.debug(
f"Retrieved {len(balances)} historical balance points from database" f"Retrieved {len(balances)} historical balance points from database"
) )
@@ -223,7 +154,7 @@ class DatabaseService:
return None return None
try: try:
summary = sqlite_db.get_account_summary(account_id) summary = self._get_account_summary(account_id)
if summary: if summary:
logger.debug( logger.debug(
f"Retrieved account summary from database for {account_id}" f"Retrieved account summary from database for {account_id}"
@@ -250,7 +181,7 @@ class DatabaseService:
return [] return []
try: try:
accounts = sqlite_db.get_accounts(account_ids=account_ids) accounts = self._get_accounts(account_ids=account_ids)
logger.debug(f"Retrieved {len(accounts)} accounts from database") logger.debug(f"Retrieved {len(accounts)} accounts from database")
return accounts return accounts
except Exception as e: except Exception as e:
@@ -266,7 +197,7 @@ class DatabaseService:
return None return None
try: try:
account = sqlite_db.get_account(account_id) account = self._get_account(account_id)
if account: if account:
logger.debug( logger.debug(
f"Retrieved account details from database for {account_id}" f"Retrieved account details from database for {account_id}"
@@ -790,8 +721,8 @@ class DatabaseService:
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
"""Persist transactions to SQLite""" """Persist transactions to SQLite"""
try: try:
import sqlite3
import json import json
import sqlite3
db_path = path_manager.get_database_path() db_path = path_manager.get_database_path()
conn = sqlite3.connect(str(db_path)) conn = sqlite3.connect(str(db_path))
@@ -893,7 +824,7 @@ class DatabaseService:
"""Persist account details to SQLite""" """Persist account details to SQLite"""
try: try:
# Use the sqlite_db module function # Use the sqlite_db module function
sqlite_db.persist_account(account_data) self._persist_account(account_data)
logger.info( logger.info(
f"Persisted account details to SQLite for account {account_data['id']}" f"Persisted account details to SQLite for account {account_data['id']}"
@@ -901,3 +832,453 @@ class DatabaseService:
except Exception as e: except Exception as e:
logger.error(f"Failed to persist account details to SQLite: {e}") logger.error(f"Failed to persist account details to SQLite: {e}")
raise raise
def _get_transactions(
self,
account_id=None,
limit=100,
offset=0,
date_from=None,
date_to=None,
min_amount=None,
max_amount=None,
search=None,
):
"""Get transactions from SQLite database with optional filtering"""
db_path = path_manager.get_database_path()
if not db_path.exists():
return []
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row # Enable dict-like access
cursor = conn.cursor()
# Build query with filters
query = "SELECT * FROM transactions WHERE 1=1"
params = []
if account_id:
query += " AND accountId = ?"
params.append(account_id)
if date_from:
query += " AND transactionDate >= ?"
params.append(date_from)
if date_to:
query += " AND transactionDate <= ?"
params.append(date_to)
if min_amount is not None:
query += " AND transactionValue >= ?"
params.append(min_amount)
if max_amount is not None:
query += " AND transactionValue <= ?"
params.append(max_amount)
if search:
query += " AND description LIKE ?"
params.append(f"%{search}%")
# Add ordering and pagination
query += " ORDER BY transactionDate DESC"
if limit:
query += " LIMIT ?"
params.append(limit)
if offset:
query += " OFFSET ?"
params.append(offset)
try:
cursor.execute(query, params)
rows = cursor.fetchall()
# Convert to list of dicts and parse JSON fields
transactions = []
for row in rows:
transaction = dict(row)
if transaction["rawTransaction"]:
transaction["rawTransaction"] = json.loads(
transaction["rawTransaction"]
)
transactions.append(transaction)
conn.close()
return transactions
except Exception as e:
conn.close()
raise e
def _get_balances(self, account_id=None):
"""Get latest balances from SQLite database"""
db_path = path_manager.get_database_path()
if not db_path.exists():
return []
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
# Get latest balance for each account_id and type combination
query = """
SELECT * FROM balances b1
WHERE b1.timestamp = (
SELECT MAX(b2.timestamp)
FROM balances b2
WHERE b2.account_id = b1.account_id AND b2.type = b1.type
)
"""
params = []
if account_id:
query += " AND b1.account_id = ?"
params.append(account_id)
query += " ORDER BY b1.account_id, b1.type"
try:
cursor.execute(query, params)
rows = cursor.fetchall()
balances = [dict(row) for row in rows]
conn.close()
return balances
except Exception as e:
conn.close()
raise e
def _get_account_summary(self, account_id):
"""Get basic account info from transactions table (avoids GoCardless API call)"""
db_path = path_manager.get_database_path()
if not db_path.exists():
return None
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
try:
# Get account info from most recent transaction
cursor.execute(
"""
SELECT DISTINCT accountId, institutionId, iban
FROM transactions
WHERE accountId = ?
ORDER BY transactionDate DESC
LIMIT 1
""",
(account_id,),
)
row = cursor.fetchone()
conn.close()
if row:
return dict(row)
return None
except Exception as e:
conn.close()
raise e
def _get_transaction_count(self, account_id=None, **filters):
"""Get total count of transactions matching filters"""
db_path = path_manager.get_database_path()
if not db_path.exists():
return 0
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
query = "SELECT COUNT(*) FROM transactions WHERE 1=1"
params = []
if account_id:
query += " AND accountId = ?"
params.append(account_id)
# Add same filters as get_transactions
if filters.get("date_from"):
query += " AND transactionDate >= ?"
params.append(filters["date_from"])
if filters.get("date_to"):
query += " AND transactionDate <= ?"
params.append(filters["date_to"])
if filters.get("min_amount") is not None:
query += " AND transactionValue >= ?"
params.append(filters["min_amount"])
if filters.get("max_amount") is not None:
query += " AND transactionValue <= ?"
params.append(filters["max_amount"])
if filters.get("search"):
query += " AND description LIKE ?"
params.append(f"%{filters['search']}%")
try:
cursor.execute(query, params)
count = cursor.fetchone()[0]
conn.close()
return count
except Exception as e:
conn.close()
raise e
def _persist_account(self, account_data: dict):
"""Persist account details to SQLite database"""
db_path = path_manager.get_database_path()
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
# Create the accounts table if it doesn't exist
cursor.execute(
"""CREATE TABLE IF NOT EXISTS accounts (
id TEXT PRIMARY KEY,
institution_id TEXT,
status TEXT,
iban TEXT,
name TEXT,
currency TEXT,
created DATETIME,
last_accessed DATETIME,
last_updated DATETIME
)"""
)
# Create indexes for accounts table
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_accounts_institution_id
ON accounts(institution_id)"""
)
cursor.execute(
"""CREATE INDEX IF NOT EXISTS idx_accounts_status
ON accounts(status)"""
)
try:
# Insert or replace account data
cursor.execute(
"""INSERT OR REPLACE INTO accounts (
id,
institution_id,
status,
iban,
name,
currency,
created,
last_accessed,
last_updated
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
(
account_data["id"],
account_data["institution_id"],
account_data["status"],
account_data.get("iban"),
account_data.get("name"),
account_data.get("currency"),
account_data["created"],
account_data.get("last_accessed"),
account_data.get("last_updated", account_data["created"]),
),
)
conn.commit()
conn.close()
return account_data
except Exception as e:
conn.close()
raise e
def _get_accounts(self, account_ids=None):
"""Get account details from SQLite database"""
db_path = path_manager.get_database_path()
if not db_path.exists():
return []
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
query = "SELECT * FROM accounts"
params = []
if account_ids:
placeholders = ",".join("?" * len(account_ids))
query += f" WHERE id IN ({placeholders})"
params.extend(account_ids)
query += " ORDER BY created DESC"
try:
cursor.execute(query, params)
rows = cursor.fetchall()
accounts = [dict(row) for row in rows]
conn.close()
return accounts
except Exception as e:
conn.close()
raise e
def _get_account(self, account_id: str):
"""Get specific account details from SQLite database"""
db_path = path_manager.get_database_path()
if not db_path.exists():
return None
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
try:
cursor.execute("SELECT * FROM accounts WHERE id = ?", (account_id,))
row = cursor.fetchone()
conn.close()
if row:
return dict(row)
return None
except Exception as e:
conn.close()
raise e
def _get_historical_balances(self, account_id=None, days=365):
"""Get historical balance progression based on transaction history"""
db_path = path_manager.get_database_path()
if not db_path.exists():
return []
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
try:
# Get current balance for each account/type to use as the final balance
current_balances_query = """
SELECT account_id, type, amount, currency
FROM balances b1
WHERE b1.timestamp = (
SELECT MAX(b2.timestamp)
FROM balances b2
WHERE b2.account_id = b1.account_id AND b2.type = b1.type
)
"""
params = []
if account_id:
current_balances_query += " AND b1.account_id = ?"
params.append(account_id)
cursor.execute(current_balances_query, params)
current_balances = {
(row["account_id"], row["type"]): {
"amount": row["amount"],
"currency": row["currency"],
}
for row in cursor.fetchall()
}
# Get transactions for the specified period, ordered by date descending
cutoff_date = (datetime.now() - timedelta(days=days)).isoformat()
transactions_query = """
SELECT accountId, transactionDate, transactionValue
FROM transactions
WHERE transactionDate >= ?
"""
if account_id:
transactions_query += " AND accountId = ?"
params = [cutoff_date, account_id]
else:
params = [cutoff_date]
transactions_query += " ORDER BY transactionDate DESC"
cursor.execute(transactions_query, params)
transactions = cursor.fetchall()
# Calculate historical balances by working backwards from current balance
historical_balances = []
account_running_balances: dict[str, dict[str, float]] = {}
# Initialize running balances with current balances
for (acc_id, balance_type), balance_info in current_balances.items():
if acc_id not in account_running_balances:
account_running_balances[acc_id] = {}
account_running_balances[acc_id][balance_type] = balance_info["amount"]
# Group transactions by date
transactions_by_date = defaultdict(list)
for txn in transactions:
date_str = txn["transactionDate"][:10] # Extract just the date part
transactions_by_date[date_str].append(txn)
# Generate historical balance points
# Start from today and work backwards
current_date = datetime.now().date()
for day_offset in range(0, days, 7): # Sample every 7 days for performance
target_date = current_date - timedelta(days=day_offset)
target_date_str = target_date.isoformat()
# For each account, create balance entries
for acc_id in account_running_balances:
for balance_type in [
"closingBooked"
]: # Focus on closingBooked for the chart
if balance_type in account_running_balances[acc_id]:
balance_amount = account_running_balances[acc_id][
balance_type
]
currency = current_balances.get(
(acc_id, balance_type), {}
).get("currency", "EUR")
historical_balances.append(
{
"id": f"{acc_id}_{balance_type}_{target_date_str}",
"account_id": acc_id,
"balance_amount": balance_amount,
"balance_type": balance_type,
"currency": currency,
"reference_date": target_date_str,
"created_at": None,
"updated_at": None,
}
)
# Subtract transactions that occurred on this date and later dates
# to simulate going back in time
for date_str in list(transactions_by_date.keys()):
if date_str >= target_date_str:
for txn in transactions_by_date[date_str]:
acc_id = txn["accountId"]
amount = txn["transactionValue"]
if acc_id in account_running_balances:
for balance_type in account_running_balances[acc_id]:
account_running_balances[acc_id][balance_type] -= (
amount
)
# Remove processed transactions to avoid double-processing
del transactions_by_date[date_str]
conn.close()
# Sort by date for proper chronological order
historical_balances.sort(key=lambda x: x["reference_date"])
return historical_balances
except Exception as e:
conn.close()
raise e

View File

@@ -1,8 +1,8 @@
import json import json
import httpx
from pathlib import Path from pathlib import Path
from typing import Dict, Any, List from typing import Any, Dict, List
import httpx
from loguru import logger from loguru import logger
from leggen.utils.config import config from leggen.utils.config import config

View File

@@ -1,4 +1,4 @@
from typing import List, Dict, Any from typing import Any, Dict, List
from loguru import logger from loguru import logger
@@ -109,33 +109,78 @@ class NotificationService:
"""Check if Telegram notifications are enabled""" """Check if Telegram notifications are enabled"""
telegram_config = self.notifications_config.get("telegram", {}) telegram_config = self.notifications_config.get("telegram", {})
return bool( return bool(
telegram_config.get("api-key") telegram_config.get("token")
and telegram_config.get("chat-id") and telegram_config.get("chat_id")
and telegram_config.get("enabled", True) and telegram_config.get("enabled", True)
) )
async def _send_discord_notifications( async def _send_discord_notifications(
self, transactions: List[Dict[str, Any]] self, transactions: List[Dict[str, Any]]
) -> None: ) -> None:
"""Send Discord notifications - placeholder implementation""" """Send Discord notifications for transactions"""
# Would import and use leggen.notifications.discord try:
logger.info(f"Sending {len(transactions)} transaction notifications to Discord") import click
from leggen.notifications.discord import send_transactions_message
# Create a mock context with the webhook
ctx = click.Context(click.Command("notifications"))
ctx.obj = {
"notifications": {
"discord": {
"webhook": self.notifications_config.get("discord", {}).get(
"webhook"
)
}
}
}
# Send transaction notifications using the actual implementation
send_transactions_message(ctx, transactions)
logger.info(
f"Sent {len(transactions)} transaction notifications to Discord"
)
except Exception as e:
logger.error(f"Failed to send Discord transaction notifications: {e}")
raise
async def _send_telegram_notifications( async def _send_telegram_notifications(
self, transactions: List[Dict[str, Any]] self, transactions: List[Dict[str, Any]]
) -> None: ) -> None:
"""Send Telegram notifications - placeholder implementation""" """Send Telegram notifications for transactions"""
# Would import and use leggen.notifications.telegram try:
logger.info( import click
f"Sending {len(transactions)} transaction notifications to Telegram"
) from leggen.notifications.telegram import send_transaction_message
# Create a mock context with the telegram config
ctx = click.Context(click.Command("notifications"))
telegram_config = self.notifications_config.get("telegram", {})
ctx.obj = {
"notifications": {
"telegram": {
"api-key": telegram_config.get("token"),
"chat-id": telegram_config.get("chat_id"),
}
}
}
# Send transaction notifications using the actual implementation
send_transaction_message(ctx, transactions)
logger.info(
f"Sent {len(transactions)} transaction notifications to Telegram"
)
except Exception as e:
logger.error(f"Failed to send Telegram transaction notifications: {e}")
raise
async def _send_discord_test(self, message: str) -> None: async def _send_discord_test(self, message: str) -> None:
"""Send Discord test notification""" """Send Discord test notification"""
try: try:
from leggen.notifications.discord import send_expire_notification
import click import click
from leggen.notifications.discord import send_expire_notification
# Create a mock context with the webhook # Create a mock context with the webhook
ctx = click.Context(click.Command("test")) ctx = click.Context(click.Command("test"))
ctx.obj = { ctx.obj = {
@@ -164,17 +209,18 @@ class NotificationService:
async def _send_telegram_test(self, message: str) -> None: async def _send_telegram_test(self, message: str) -> None:
"""Send Telegram test notification""" """Send Telegram test notification"""
try: try:
from leggen.notifications.telegram import send_expire_notification
import click import click
from leggen.notifications.telegram import send_expire_notification
# Create a mock context with the telegram config # Create a mock context with the telegram config
ctx = click.Context(click.Command("test")) ctx = click.Context(click.Command("test"))
telegram_config = self.notifications_config.get("telegram", {}) telegram_config = self.notifications_config.get("telegram", {})
ctx.obj = { ctx.obj = {
"notifications": { "notifications": {
"telegram": { "telegram": {
"api-key": telegram_config.get("api-key"), "api-key": telegram_config.get("token"),
"chat-id": telegram_config.get("chat-id"), "chat-id": telegram_config.get("chat_id"),
} }
} }
} }
@@ -194,8 +240,52 @@ class NotificationService:
async def _send_discord_expiry(self, notification_data: Dict[str, Any]) -> None: async def _send_discord_expiry(self, notification_data: Dict[str, Any]) -> None:
"""Send Discord expiry notification""" """Send Discord expiry notification"""
logger.info(f"Sending Discord expiry notification: {notification_data}") try:
import click
from leggen.notifications.discord import send_expire_notification
# Create a mock context with the webhook
ctx = click.Context(click.Command("expiry"))
ctx.obj = {
"notifications": {
"discord": {
"webhook": self.notifications_config.get("discord", {}).get(
"webhook"
)
}
}
}
# Send expiry notification using the actual implementation
send_expire_notification(ctx, notification_data)
logger.info(f"Sent Discord expiry notification: {notification_data}")
except Exception as e:
logger.error(f"Failed to send Discord expiry notification: {e}")
raise
async def _send_telegram_expiry(self, notification_data: Dict[str, Any]) -> None: async def _send_telegram_expiry(self, notification_data: Dict[str, Any]) -> None:
"""Send Telegram expiry notification""" """Send Telegram expiry notification"""
logger.info(f"Sending Telegram expiry notification: {notification_data}") try:
import click
from leggen.notifications.telegram import send_expire_notification
# Create a mock context with the telegram config
ctx = click.Context(click.Command("expiry"))
telegram_config = self.notifications_config.get("telegram", {})
ctx.obj = {
"notifications": {
"telegram": {
"api-key": telegram_config.get("token"),
"chat-id": telegram_config.get("chat_id"),
}
}
}
# Send expiry notification using the actual implementation
send_expire_notification(ctx, notification_data)
logger.info(f"Sent Telegram expiry notification: {notification_data}")
except Exception as e:
logger.error(f"Failed to send Telegram expiry notification: {e}")
raise

View File

@@ -4,8 +4,8 @@ from typing import List
from loguru import logger from loguru import logger
from leggen.api.models.sync import SyncResult, SyncStatus from leggen.api.models.sync import SyncResult, SyncStatus
from leggen.services.gocardless_service import GoCardlessService
from leggen.services.database_service import DatabaseService from leggen.services.database_service import DatabaseService
from leggen.services.gocardless_service import GoCardlessService
from leggen.services.notification_service import NotificationService from leggen.services.notification_service import NotificationService

View File

@@ -0,0 +1,87 @@
from datetime import datetime
from typing import Any, Dict, List
class TransactionProcessor:
"""Handles processing and transformation of raw transaction data"""
def process_transactions(
self,
account_id: str,
account_info: Dict[str, Any],
transaction_data: Dict[str, Any],
) -> List[Dict[str, Any]]:
"""Process raw transaction data into standardized format"""
transactions = []
# Process booked transactions
for transaction in transaction_data.get("transactions", {}).get("booked", []):
processed = self._process_single_transaction(
account_id, account_info, transaction, "booked"
)
transactions.append(processed)
# Process pending transactions
for transaction in transaction_data.get("transactions", {}).get("pending", []):
processed = self._process_single_transaction(
account_id, account_info, transaction, "pending"
)
transactions.append(processed)
return transactions
def _process_single_transaction(
self,
account_id: str,
account_info: Dict[str, Any],
transaction: Dict[str, Any],
status: str,
) -> Dict[str, Any]:
"""Process a single transaction into standardized format"""
# Extract dates
booked_date = transaction.get("bookingDateTime") or transaction.get(
"bookingDate"
)
value_date = transaction.get("valueDateTime") or transaction.get("valueDate")
if booked_date and value_date:
min_date = min(
datetime.fromisoformat(booked_date), datetime.fromisoformat(value_date)
)
else:
date_str = booked_date or value_date
if not date_str:
raise ValueError("No valid date found in transaction")
min_date = datetime.fromisoformat(date_str)
# Extract amount and currency
transaction_amount = transaction.get("transactionAmount", {})
amount = float(transaction_amount.get("amount", 0))
currency = transaction_amount.get("currency", "")
# Extract description
description = transaction.get(
"remittanceInformationUnstructured",
",".join(transaction.get("remittanceInformationUnstructuredArray", [])),
)
# Extract transaction IDs - transactionId is now primary, internalTransactionId is reference
transaction_id = transaction.get("transactionId")
internal_transaction_id = transaction.get("internalTransactionId")
if not transaction_id:
raise ValueError("Transaction missing required transactionId field")
return {
"accountId": account_id,
"transactionId": transaction_id,
"internalTransactionId": internal_transaction_id,
"institutionId": account_info["institution_id"],
"iban": account_info.get("iban", "N/A"),
"transactionDate": min_date,
"description": description,
"transactionValue": amount,
"transactionCurrency": currency,
"transactionStatus": status,
"rawTransaction": transaction,
}

View File

@@ -1,20 +1,23 @@
import os import os
import sys import sys
import tomllib import tomllib
import tomli_w
from pathlib import Path from pathlib import Path
from typing import Dict, Any, Optional from typing import Any, Dict, Optional
import click import click
import tomli_w
from loguru import logger from loguru import logger
from pydantic import ValidationError
from leggen.utils.text import error from leggen.models.config import Config as ConfigModel
from leggen.utils.paths import path_manager from leggen.utils.paths import path_manager
from leggen.utils.text import error
class Config: class Config:
_instance = None _instance = None
_config = None _config = None
_config_model = None
_config_path = None _config_path = None
def __new__(cls): def __new__(cls):
@@ -35,8 +38,18 @@ class Config:
try: try:
with open(config_path, "rb") as f: with open(config_path, "rb") as f:
self._config = tomllib.load(f) raw_config = tomllib.load(f)
logger.info(f"Configuration loaded from {config_path}") logger.info(f"Configuration loaded from {config_path}")
# Validate configuration using Pydantic
try:
self._config_model = ConfigModel(**raw_config)
self._config = self._config_model.dict(by_alias=True, exclude_none=True)
logger.info("Configuration validation successful")
except ValidationError as e:
logger.error(f"Configuration validation failed: {e}")
raise ValueError(f"Invalid configuration: {e}") from e
except FileNotFoundError: except FileNotFoundError:
logger.error(f"Configuration file not found: {config_path}") logger.error(f"Configuration file not found: {config_path}")
raise raise
@@ -65,15 +78,24 @@ class Config:
if config_data is None: if config_data is None:
raise ValueError("No config data to save") raise ValueError("No config data to save")
# Validate the configuration before saving
try:
validated_model = ConfigModel(**config_data)
validated_config = validated_model.dict(by_alias=True, exclude_none=True)
except ValidationError as e:
logger.error(f"Configuration validation failed before save: {e}")
raise ValueError(f"Invalid configuration: {e}") from e
# Ensure directory exists # Ensure directory exists
Path(config_path).parent.mkdir(parents=True, exist_ok=True) Path(config_path).parent.mkdir(parents=True, exist_ok=True)
try: try:
with open(config_path, "wb") as f: with open(config_path, "wb") as f:
tomli_w.dump(config_data, f) tomli_w.dump(validated_config, f)
# Update in-memory config # Update in-memory config
self._config = config_data self._config = validated_config
self._config_model = validated_model
self._config_path = config_path self._config_path = config_path
logger.info(f"Configuration saved to {config_path}") logger.info(f"Configuration saved to {config_path}")
except Exception as e: except Exception as e:
@@ -146,8 +168,16 @@ class Config:
def load_config(ctx: click.Context, _, filename): def load_config(ctx: click.Context, _, filename):
try: try:
with click.open_file(str(filename), "rb") as f: with click.open_file(str(filename), "rb") as f:
# TODO: Implement configuration file validation (use pydantic?) raw_config = tomllib.load(f)
ctx.obj = tomllib.load(f)
# Validate configuration using Pydantic
try:
validated_model = ConfigModel(**raw_config)
ctx.obj = validated_model.dict(by_alias=True, exclude_none=True)
except ValidationError as e:
error(f"Configuration validation failed: {e}")
sys.exit(1)
except FileNotFoundError: except FileNotFoundError:
error( error(
"Configuration file not found. Provide a valid configuration file path with leggen --config <path> or LEGGEN_CONFIG=<path> environment variable." "Configuration file not found. Provide a valid configuration file path with leggen --config <path> or LEGGEN_CONFIG=<path> environment variable."

View File

@@ -1,132 +0,0 @@
from datetime import datetime
import click
import leggen.database.sqlite as sqlite_engine
from leggen.utils.text import info, warning
def persist_balance(ctx: click.Context, account: str, balance: dict) -> None:
sqlite = ctx.obj.get("database", {}).get("sqlite", True)
if not sqlite:
warning("SQLite database is disabled, skipping balance saving")
return
info(f"[{account}] Fetched balances, saving to SQLite")
sqlite_engine.persist_balances(ctx, balance)
def persist_transactions(ctx: click.Context, account: str, transactions: list) -> list:
sqlite = ctx.obj.get("database", {}).get("sqlite", True)
if not sqlite:
warning("SQLite database is disabled, skipping transaction saving")
# WARNING: This will return the transactions list as is, without saving it to any database
# Possible duplicate notifications will be sent if the filters are enabled
return transactions
info(f"[{account}] Fetched {len(transactions)} transactions, saving to SQLite")
return sqlite_engine.persist_transactions(ctx, account, transactions)
def save_transactions(ctx: click.Context, account: str) -> list:
import requests
api_url = ctx.obj.get("api_url", "http://localhost:8000")
info(f"[{account}] Getting account details")
res = requests.get(f"{api_url}/accounts/{account}")
res.raise_for_status()
account_info = res.json()
info(f"[{account}] Getting transactions")
transactions = []
res = requests.get(f"{api_url}/accounts/{account}/transactions/")
res.raise_for_status()
account_transactions = res.json().get("transactions", [])
for transaction in account_transactions.get("booked", []):
booked_date = transaction.get("bookingDateTime") or transaction.get(
"bookingDate"
)
value_date = transaction.get("valueDateTime") or transaction.get("valueDate")
if booked_date and value_date:
min_date = min(
datetime.fromisoformat(booked_date), datetime.fromisoformat(value_date)
)
else:
min_date = datetime.fromisoformat(booked_date or value_date)
transactionValue = float(
transaction.get("transactionAmount", {}).get("amount", 0)
)
currency = transaction.get("transactionAmount", {}).get("currency", "")
description = transaction.get(
"remittanceInformationUnstructured",
",".join(transaction.get("remittanceInformationUnstructuredArray", [])),
)
# Extract transaction ID, using transactionId as fallback when internalTransactionId is missing
transaction_id = transaction.get("internalTransactionId") or transaction.get(
"transactionId"
)
t = {
"internalTransactionId": transaction_id,
"institutionId": account_info["institution_id"],
"iban": account_info.get("iban", "N/A"),
"transactionDate": min_date,
"description": description,
"transactionValue": transactionValue,
"transactionCurrency": currency,
"transactionStatus": "booked",
"accountId": account,
"rawTransaction": transaction,
}
transactions.append(t)
for transaction in account_transactions.get("pending", []):
booked_date = transaction.get("bookingDateTime") or transaction.get(
"bookingDate"
)
value_date = transaction.get("valueDateTime") or transaction.get("valueDate")
if booked_date and value_date:
min_date = min(
datetime.fromisoformat(booked_date), datetime.fromisoformat(value_date)
)
else:
min_date = datetime.fromisoformat(booked_date or value_date)
transactionValue = float(
transaction.get("transactionAmount", {}).get("amount", 0)
)
currency = transaction.get("transactionAmount", {}).get("currency", "")
description = transaction.get(
"remittanceInformationUnstructured",
",".join(transaction.get("remittanceInformationUnstructuredArray", [])),
)
# Extract transaction ID, using transactionId as fallback when internalTransactionId is missing
transaction_id = transaction.get("internalTransactionId") or transaction.get(
"transactionId"
)
t = {
"internalTransactionId": transaction_id,
"institutionId": account_info["institution_id"],
"iban": account_info.get("iban", "N/A"),
"transactionDate": min_date,
"description": description,
"transactionValue": transactionValue,
"transactionCurrency": currency,
"transactionStatus": "pending",
"accountId": account,
"rawTransaction": transaction,
}
transactions.append(t)
return persist_transactions(ctx, account, transactions)

View File

@@ -34,6 +34,7 @@ dependencies = [
"apscheduler>=3.10.0,<4", "apscheduler>=3.10.0,<4",
"tomli-w>=1.0.0,<2", "tomli-w>=1.0.0,<2",
"httpx>=0.28.1", "httpx>=0.28.1",
"pydantic>=2.0.0,<3",
] ]
[project.urls] [project.urls]
@@ -68,7 +69,7 @@ build-backend = "hatchling.build"
[tool.ruff] [tool.ruff]
lint.ignore = ["E501", "B008", "B006"] lint.ignore = ["E501", "B008", "B006"]
lint.extend-select = ["B", "C4", "PIE", "T20", "SIM", "TCH"] lint.extend-select = ["B", "C4", "I", "PIE", "T20", "SIM", "TCH"]
[tool.pytest.ini_options] [tool.pytest.ini_options]
testpaths = ["tests"] testpaths = ["tests"]

View File

@@ -7,7 +7,7 @@ import sqlite3
import sys import sys
from datetime import datetime, timedelta from datetime import datetime, timedelta
from pathlib import Path from pathlib import Path
from typing import List, Dict, Any from typing import Any, Dict, List
import click import click

View File

@@ -1,10 +1,11 @@
"""Pytest configuration and shared fixtures.""" """Pytest configuration and shared fixtures."""
import pytest
import tempfile
import json import json
import tempfile
from pathlib import Path from pathlib import Path
from unittest.mock import patch from unittest.mock import patch
import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from leggen.commands.server import create_app from leggen.commands.server import create_app

View File

@@ -1,8 +1,9 @@
"""Tests for analytics fixes to ensure all transactions are used in statistics.""" """Tests for analytics fixes to ensure all transactions are used in statistics."""
import pytest
from datetime import datetime, timedelta from datetime import datetime, timedelta
from unittest.mock import Mock, AsyncMock, patch from unittest.mock import AsyncMock, Mock, patch
import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from leggen.commands.server import create_app from leggen.commands.server import create_app

View File

@@ -1,8 +1,9 @@
"""Tests for accounts API endpoints.""" """Tests for accounts API endpoints."""
import pytest
from unittest.mock import patch from unittest.mock import patch
import pytest
@pytest.mark.api @pytest.mark.api
class TestAccountsAPI: class TestAccountsAPI:

View File

@@ -1,9 +1,10 @@
"""Tests for banks API endpoints.""" """Tests for banks API endpoints."""
from unittest.mock import patch
import httpx
import pytest import pytest
import respx import respx
import httpx
from unittest.mock import patch
@pytest.mark.api @pytest.mark.api

View File

@@ -1,9 +1,10 @@
"""Tests for CLI API client.""" """Tests for CLI API client."""
from unittest.mock import patch
import pytest import pytest
import requests import requests
import requests_mock import requests_mock
from unittest.mock import patch
from leggen.api_client import LeggenAPIClient from leggen.api_client import LeggenAPIClient

View File

@@ -1,8 +1,9 @@
"""Tests for transactions API endpoints.""" """Tests for transactions API endpoints."""
import pytest
from unittest.mock import patch
from datetime import datetime from datetime import datetime
from unittest.mock import patch
import pytest
@pytest.mark.api @pytest.mark.api

View File

@@ -1,8 +1,9 @@
"""Tests for configuration management.""" """Tests for configuration management."""
import pytest
from unittest.mock import patch from unittest.mock import patch
import pytest
from leggen.utils.config import Config from leggen.utils.config import Config
@@ -37,10 +38,14 @@ class TestConfig:
# Reset singleton state for testing # Reset singleton state for testing
config._config = None config._config = None
config._config_path = None config._config_path = None
config._config_model = None
result = config.load_config(str(config_file)) result = config.load_config(str(config_file))
assert result == config_data # Result should contain validated config data
assert result["gocardless"]["key"] == "test-key"
assert result["gocardless"]["secret"] == "test-secret"
assert result["database"]["sqlite"] is True
assert config.gocardless_config["key"] == "test-key" assert config.gocardless_config["key"] == "test-key"
assert config.database_config["sqlite"] is True assert config.database_config["sqlite"] is True
@@ -54,11 +59,19 @@ class TestConfig:
def test_save_config_success(self, temp_config_dir): def test_save_config_success(self, temp_config_dir):
"""Test successful configuration saving.""" """Test successful configuration saving."""
config_data = {"gocardless": {"key": "new-key", "secret": "new-secret"}} config_data = {
"gocardless": {
"key": "new-key",
"secret": "new-secret",
"url": "https://bankaccountdata.gocardless.com/api/v2",
},
"database": {"sqlite": True},
}
config_file = temp_config_dir / "new_config.toml" config_file = temp_config_dir / "new_config.toml"
config = Config() config = Config()
config._config = None config._config = None
config._config_model = None
config.save_config(config_data, str(config_file)) config.save_config(config_data, str(config_file))
@@ -70,12 +83,18 @@ class TestConfig:
with open(config_file, "rb") as f: with open(config_file, "rb") as f:
saved_data = tomllib.load(f) saved_data = tomllib.load(f)
assert saved_data == config_data assert saved_data["gocardless"]["key"] == "new-key"
assert saved_data["gocardless"]["secret"] == "new-secret"
assert saved_data["database"]["sqlite"] is True
def test_update_config_success(self, temp_config_dir): def test_update_config_success(self, temp_config_dir):
"""Test updating configuration values.""" """Test updating configuration values."""
initial_config = { initial_config = {
"gocardless": {"key": "old-key"}, "gocardless": {
"key": "old-key",
"secret": "old-secret",
"url": "https://bankaccountdata.gocardless.com/api/v2",
},
"database": {"sqlite": True}, "database": {"sqlite": True},
} }
@@ -87,6 +106,7 @@ class TestConfig:
config = Config() config = Config()
config._config = None config._config = None
config._config_model = None
config.load_config(str(config_file)) config.load_config(str(config_file))
config.update_config("gocardless", "key", "new-key") config.update_config("gocardless", "key", "new-key")
@@ -102,7 +122,14 @@ class TestConfig:
def test_update_section_success(self, temp_config_dir): def test_update_section_success(self, temp_config_dir):
"""Test updating entire configuration section.""" """Test updating entire configuration section."""
initial_config = {"database": {"sqlite": True}} initial_config = {
"gocardless": {
"key": "test-key",
"secret": "test-secret",
"url": "https://bankaccountdata.gocardless.com/api/v2",
},
"database": {"sqlite": True},
}
config_file = temp_config_dir / "config.toml" config_file = temp_config_dir / "config.toml"
with open(config_file, "wb") as f: with open(config_file, "wb") as f:
@@ -112,12 +139,13 @@ class TestConfig:
config = Config() config = Config()
config._config = None config._config = None
config._config_model = None
config.load_config(str(config_file)) config.load_config(str(config_file))
new_db_config = {"sqlite": False, "path": "./custom.db"} new_db_config = {"sqlite": False}
config.update_section("database", new_db_config) config.update_section("database", new_db_config)
assert config.database_config == new_db_config assert config.database_config["sqlite"] is False
def test_scheduler_config_defaults(self): def test_scheduler_config_defaults(self):
"""Test scheduler configuration with defaults.""" """Test scheduler configuration with defaults."""

View File

@@ -1,17 +1,14 @@
"""Integration tests for configurable paths.""" """Integration tests for configurable paths."""
import pytest
import tempfile
import os import os
import tempfile
from pathlib import Path from pathlib import Path
from unittest.mock import patch from unittest.mock import patch
import pytest
from leggen.services.database_service import DatabaseService
from leggen.utils.paths import path_manager from leggen.utils.paths import path_manager
from leggen.database.sqlite import persist_balances, get_balances
class MockContext:
"""Mock context for testing."""
@pytest.mark.unit @pytest.mark.unit
@@ -109,24 +106,31 @@ class TestConfigurablePaths:
# Set custom database path # Set custom database path
path_manager.set_database_path(test_db_path) path_manager.set_database_path(test_db_path)
# Test database operations # Test database operations using DatabaseService
ctx = MockContext() database_service = DatabaseService()
balance = { balance_data = {
"account_id": "test-account", "balances": [
"bank": "TEST_BANK", {
"status": "active", "balanceAmount": {"amount": "1000.0", "currency": "EUR"},
"balanceType": "available",
}
],
"institution_id": "TEST_BANK",
"account_status": "active",
"iban": "TEST_IBAN", "iban": "TEST_IBAN",
"amount": 1000.0,
"currency": "EUR",
"type": "available",
"timestamp": "2023-01-01T00:00:00",
} }
# Persist balance # Use the internal balance persistence method since the test needs direct database access
persist_balances(ctx, balance) import asyncio
asyncio.run(
database_service._persist_balance_sqlite("test-account", balance_data)
)
# Retrieve balances # Retrieve balances
balances = get_balances() balances = asyncio.run(
database_service.get_balances_from_db("test-account")
)
assert len(balances) == 1 assert len(balances) == 1
assert balances[0]["account_id"] == "test-account" assert balances[0]["account_id"] == "test-account"

View File

@@ -1,8 +1,9 @@
"""Tests for database service.""" """Tests for database service."""
import pytest
from unittest.mock import patch
from datetime import datetime from datetime import datetime
from unittest.mock import patch
import pytest
from leggen.services.database_service import DatabaseService from leggen.services.database_service import DatabaseService
@@ -83,7 +84,9 @@ class TestDatabaseService:
self, database_service, sample_transactions_db_format self, database_service, sample_transactions_db_format
): ):
"""Test successful retrieval of transactions from database.""" """Test successful retrieval of transactions from database."""
with patch("leggen.database.sqlite.get_transactions") as mock_get_transactions: with patch.object(
database_service, "_get_transactions"
) as mock_get_transactions:
mock_get_transactions.return_value = sample_transactions_db_format mock_get_transactions.return_value = sample_transactions_db_format
result = await database_service.get_transactions_from_db( result = await database_service.get_transactions_from_db(
@@ -107,7 +110,9 @@ class TestDatabaseService:
self, database_service, sample_transactions_db_format self, database_service, sample_transactions_db_format
): ):
"""Test retrieving transactions with filters.""" """Test retrieving transactions with filters."""
with patch("leggen.database.sqlite.get_transactions") as mock_get_transactions: with patch.object(
database_service, "_get_transactions"
) as mock_get_transactions:
mock_get_transactions.return_value = sample_transactions_db_format mock_get_transactions.return_value = sample_transactions_db_format
result = await database_service.get_transactions_from_db( result = await database_service.get_transactions_from_db(
@@ -143,7 +148,9 @@ class TestDatabaseService:
async def test_get_transactions_from_db_error(self, database_service): async def test_get_transactions_from_db_error(self, database_service):
"""Test handling error when getting transactions.""" """Test handling error when getting transactions."""
with patch("leggen.database.sqlite.get_transactions") as mock_get_transactions: with patch.object(
database_service, "_get_transactions"
) as mock_get_transactions:
mock_get_transactions.side_effect = Exception("Database error") mock_get_transactions.side_effect = Exception("Database error")
result = await database_service.get_transactions_from_db() result = await database_service.get_transactions_from_db()
@@ -152,7 +159,7 @@ class TestDatabaseService:
async def test_get_transaction_count_from_db_success(self, database_service): async def test_get_transaction_count_from_db_success(self, database_service):
"""Test successful retrieval of transaction count.""" """Test successful retrieval of transaction count."""
with patch("leggen.database.sqlite.get_transaction_count") as mock_get_count: with patch.object(database_service, "_get_transaction_count") as mock_get_count:
mock_get_count.return_value = 42 mock_get_count.return_value = 42
result = await database_service.get_transaction_count_from_db( result = await database_service.get_transaction_count_from_db(
@@ -164,7 +171,7 @@ class TestDatabaseService:
async def test_get_transaction_count_from_db_with_filters(self, database_service): async def test_get_transaction_count_from_db_with_filters(self, database_service):
"""Test getting transaction count with filters.""" """Test getting transaction count with filters."""
with patch("leggen.database.sqlite.get_transaction_count") as mock_get_count: with patch.object(database_service, "_get_transaction_count") as mock_get_count:
mock_get_count.return_value = 15 mock_get_count.return_value = 15
result = await database_service.get_transaction_count_from_db( result = await database_service.get_transaction_count_from_db(
@@ -194,7 +201,7 @@ class TestDatabaseService:
async def test_get_transaction_count_from_db_error(self, database_service): async def test_get_transaction_count_from_db_error(self, database_service):
"""Test handling error when getting count.""" """Test handling error when getting count."""
with patch("leggen.database.sqlite.get_transaction_count") as mock_get_count: with patch.object(database_service, "_get_transaction_count") as mock_get_count:
mock_get_count.side_effect = Exception("Database error") mock_get_count.side_effect = Exception("Database error")
result = await database_service.get_transaction_count_from_db() result = await database_service.get_transaction_count_from_db()
@@ -205,7 +212,7 @@ class TestDatabaseService:
self, database_service, sample_balances_db_format self, database_service, sample_balances_db_format
): ):
"""Test successful retrieval of balances from database.""" """Test successful retrieval of balances from database."""
with patch("leggen.database.sqlite.get_balances") as mock_get_balances: with patch.object(database_service, "_get_balances") as mock_get_balances:
mock_get_balances.return_value = sample_balances_db_format mock_get_balances.return_value = sample_balances_db_format
result = await database_service.get_balances_from_db( result = await database_service.get_balances_from_db(
@@ -227,7 +234,7 @@ class TestDatabaseService:
async def test_get_balances_from_db_error(self, database_service): async def test_get_balances_from_db_error(self, database_service):
"""Test handling error when getting balances.""" """Test handling error when getting balances."""
with patch("leggen.database.sqlite.get_balances") as mock_get_balances: with patch.object(database_service, "_get_balances") as mock_get_balances:
mock_get_balances.side_effect = Exception("Database error") mock_get_balances.side_effect = Exception("Database error")
result = await database_service.get_balances_from_db() result = await database_service.get_balances_from_db()
@@ -242,7 +249,7 @@ class TestDatabaseService:
"iban": "LT313250081177977789", "iban": "LT313250081177977789",
} }
with patch("leggen.database.sqlite.get_account_summary") as mock_get_summary: with patch.object(database_service, "_get_account_summary") as mock_get_summary:
mock_get_summary.return_value = mock_summary mock_get_summary.return_value = mock_summary
result = await database_service.get_account_summary_from_db( result = await database_service.get_account_summary_from_db(
@@ -262,7 +269,7 @@ class TestDatabaseService:
async def test_get_account_summary_from_db_error(self, database_service): async def test_get_account_summary_from_db_error(self, database_service):
"""Test handling error when getting summary.""" """Test handling error when getting summary."""
with patch("leggen.database.sqlite.get_account_summary") as mock_get_summary: with patch.object(database_service, "_get_account_summary") as mock_get_summary:
mock_get_summary.side_effect = Exception("Database error") mock_get_summary.side_effect = Exception("Database error")
result = await database_service.get_account_summary_from_db( result = await database_service.get_account_summary_from_db(

View File

@@ -1,8 +1,9 @@
"""Tests for background scheduler.""" """Tests for background scheduler."""
import pytest
from unittest.mock import patch, AsyncMock, MagicMock
from datetime import datetime from datetime import datetime
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from leggen.background.scheduler import BackgroundScheduler from leggen.background.scheduler import BackgroundScheduler

View File

@@ -1,364 +0,0 @@
"""Tests for SQLite database functions."""
import pytest
import tempfile
from pathlib import Path
from unittest.mock import patch
from datetime import datetime
import leggen.database.sqlite as sqlite_db
@pytest.fixture
def temp_db_path():
"""Create a temporary database file for testing."""
import uuid
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / f"test_{uuid.uuid4().hex}.db"
yield db_path
@pytest.fixture
def mock_home_db_path(temp_db_path):
"""Mock the database path to use temp file."""
from leggen.utils.paths import path_manager
# Set the path manager to use the temporary database
original_database_path = path_manager._database_path
path_manager.set_database_path(temp_db_path)
try:
yield temp_db_path
finally:
# Restore original path
path_manager._database_path = original_database_path
@pytest.fixture
def sample_transactions():
"""Sample transaction data for testing."""
return [
{
"transactionId": "bank-txn-001", # NEW: stable bank-provided ID
"internalTransactionId": "txn-001",
"institutionId": "REVOLUT_REVOLT21",
"iban": "LT313250081177977789",
"transactionDate": datetime(2025, 9, 1, 9, 30),
"description": "Coffee Shop Payment",
"transactionValue": -10.50,
"transactionCurrency": "EUR",
"transactionStatus": "booked",
"accountId": "test-account-123",
"rawTransaction": {"transactionId": "bank-txn-001", "some": "data"},
},
{
"transactionId": "bank-txn-002", # NEW: stable bank-provided ID
"internalTransactionId": "txn-002",
"institutionId": "REVOLUT_REVOLT21",
"iban": "LT313250081177977789",
"transactionDate": datetime(2025, 9, 2, 14, 15),
"description": "Grocery Store",
"transactionValue": -45.30,
"transactionCurrency": "EUR",
"transactionStatus": "booked",
"accountId": "test-account-123",
"rawTransaction": {"transactionId": "bank-txn-002", "other": "data"},
},
]
@pytest.fixture
def sample_balance():
"""Sample balance data for testing."""
return {
"account_id": "test-account-123",
"bank": "REVOLUT_REVOLT21",
"status": "active",
"iban": "LT313250081177977789",
"amount": 1000.00,
"currency": "EUR",
"type": "interimAvailable",
"timestamp": datetime.now(),
}
class MockContext:
"""Mock context for testing."""
class TestSQLiteDatabase:
"""Test SQLite database operations."""
def test_persist_transactions(self, mock_home_db_path, sample_transactions):
"""Test persisting transactions to database."""
ctx = MockContext()
# Persist transactions
new_transactions = sqlite_db.persist_transactions(
ctx, "test-account-123", sample_transactions
)
# Should return all transactions as new
assert len(new_transactions) == 2
assert new_transactions[0]["internalTransactionId"] == "txn-001"
def test_persist_transactions_duplicates(
self, mock_home_db_path, sample_transactions
):
"""Test handling duplicate transactions."""
ctx = MockContext()
# Insert transactions twice
new_transactions_1 = sqlite_db.persist_transactions(
ctx, "test-account-123", sample_transactions
)
new_transactions_2 = sqlite_db.persist_transactions(
ctx, "test-account-123", sample_transactions
)
# First time should return all as new
assert len(new_transactions_1) == 2
# Second time should also return all (INSERT OR REPLACE behavior with composite key)
assert len(new_transactions_2) == 2
def test_get_transactions_all(self, mock_home_db_path, sample_transactions):
"""Test retrieving all transactions."""
ctx = MockContext()
# Insert test data
sqlite_db.persist_transactions(ctx, "test-account-123", sample_transactions)
# Get all transactions
transactions = sqlite_db.get_transactions()
assert len(transactions) == 2
assert (
transactions[0]["internalTransactionId"] == "txn-002"
) # Ordered by date DESC
assert transactions[1]["internalTransactionId"] == "txn-001"
def test_get_transactions_filtered_by_account(
self, mock_home_db_path, sample_transactions
):
"""Test filtering transactions by account ID."""
ctx = MockContext()
# Add transaction for different account
other_account_transaction = sample_transactions[0].copy()
other_account_transaction["internalTransactionId"] = "txn-003"
other_account_transaction["accountId"] = "other-account"
all_transactions = sample_transactions + [other_account_transaction]
with patch("pathlib.Path.home") as mock_home:
mock_home.return_value = mock_home_db_path.parent / ".."
sqlite_db.persist_transactions(ctx, "test-account-123", all_transactions)
# Filter by account
transactions = sqlite_db.get_transactions(account_id="test-account-123")
assert len(transactions) == 2
for txn in transactions:
assert txn["accountId"] == "test-account-123"
def test_get_transactions_with_pagination(
self, mock_home_db_path, sample_transactions
):
"""Test transaction pagination."""
ctx = MockContext()
with patch("pathlib.Path.home") as mock_home:
mock_home.return_value = mock_home_db_path.parent / ".."
sqlite_db.persist_transactions(ctx, "test-account-123", sample_transactions)
# Get first page
transactions_page1 = sqlite_db.get_transactions(limit=1, offset=0)
assert len(transactions_page1) == 1
# Get second page
transactions_page2 = sqlite_db.get_transactions(limit=1, offset=1)
assert len(transactions_page2) == 1
# Should be different transactions
assert (
transactions_page1[0]["internalTransactionId"]
!= transactions_page2[0]["internalTransactionId"]
)
def test_get_transactions_with_amount_filter(
self, mock_home_db_path, sample_transactions
):
"""Test filtering transactions by amount."""
ctx = MockContext()
with patch("pathlib.Path.home") as mock_home:
mock_home.return_value = mock_home_db_path.parent / ".."
sqlite_db.persist_transactions(ctx, "test-account-123", sample_transactions)
# Filter by minimum amount (should exclude coffee shop payment)
transactions = sqlite_db.get_transactions(min_amount=-20.0)
assert len(transactions) == 1
assert transactions[0]["transactionValue"] == -10.50
def test_get_transactions_with_search(self, mock_home_db_path, sample_transactions):
"""Test searching transactions by description."""
ctx = MockContext()
with patch("pathlib.Path.home") as mock_home:
mock_home.return_value = mock_home_db_path.parent / ".."
sqlite_db.persist_transactions(ctx, "test-account-123", sample_transactions)
# Search for "Coffee"
transactions = sqlite_db.get_transactions(search="Coffee")
assert len(transactions) == 1
assert "Coffee" in transactions[0]["description"]
def test_get_transactions_empty_database(self, mock_home_db_path):
"""Test getting transactions from empty database."""
with patch("pathlib.Path.home") as mock_home:
mock_home.return_value = mock_home_db_path.parent / ".."
transactions = sqlite_db.get_transactions()
assert transactions == []
def test_get_transactions_nonexistent_database(self):
"""Test getting transactions when database doesn't exist."""
with patch("pathlib.Path.home") as mock_home:
mock_home.return_value = Path("/nonexistent")
transactions = sqlite_db.get_transactions()
assert transactions == []
def test_persist_balances(self, mock_home_db_path, sample_balance):
"""Test persisting balance data."""
ctx = MockContext()
with patch("pathlib.Path.home") as mock_home:
mock_home.return_value = mock_home_db_path.parent / ".."
result = sqlite_db.persist_balances(ctx, sample_balance)
# Should return the balance data
assert result["account_id"] == "test-account-123"
def test_get_balances(self, mock_home_db_path, sample_balance):
"""Test retrieving balances."""
ctx = MockContext()
with patch("pathlib.Path.home") as mock_home:
mock_home.return_value = mock_home_db_path.parent / ".."
# Insert test balance
sqlite_db.persist_balances(ctx, sample_balance)
# Get balances
balances = sqlite_db.get_balances()
assert len(balances) == 1
assert balances[0]["account_id"] == "test-account-123"
assert balances[0]["amount"] == 1000.00
def test_get_balances_filtered_by_account(self, mock_home_db_path, sample_balance):
"""Test filtering balances by account ID."""
ctx = MockContext()
# Create balance for different account
other_balance = sample_balance.copy()
other_balance["account_id"] = "other-account"
with patch("pathlib.Path.home") as mock_home:
mock_home.return_value = mock_home_db_path.parent / ".."
sqlite_db.persist_balances(ctx, sample_balance)
sqlite_db.persist_balances(ctx, other_balance)
# Filter by account
balances = sqlite_db.get_balances(account_id="test-account-123")
assert len(balances) == 1
assert balances[0]["account_id"] == "test-account-123"
def test_get_account_summary(self, mock_home_db_path, sample_transactions):
"""Test getting account summary from transactions."""
ctx = MockContext()
with patch("pathlib.Path.home") as mock_home:
mock_home.return_value = mock_home_db_path.parent / ".."
sqlite_db.persist_transactions(ctx, "test-account-123", sample_transactions)
summary = sqlite_db.get_account_summary("test-account-123")
assert summary is not None
assert summary["accountId"] == "test-account-123"
assert summary["institutionId"] == "REVOLUT_REVOLT21"
assert summary["iban"] == "LT313250081177977789"
def test_get_account_summary_nonexistent(self, mock_home_db_path):
"""Test getting summary for nonexistent account."""
with patch("pathlib.Path.home") as mock_home:
mock_home.return_value = mock_home_db_path.parent / ".."
summary = sqlite_db.get_account_summary("nonexistent")
assert summary is None
def test_get_transaction_count(self, mock_home_db_path, sample_transactions):
"""Test getting transaction count."""
ctx = MockContext()
with patch("pathlib.Path.home") as mock_home:
mock_home.return_value = mock_home_db_path.parent / ".."
sqlite_db.persist_transactions(ctx, "test-account-123", sample_transactions)
# Get total count
count = sqlite_db.get_transaction_count()
assert count == 2
# Get count for specific account
count_filtered = sqlite_db.get_transaction_count(
account_id="test-account-123"
)
assert count_filtered == 2
# Get count for nonexistent account
count_none = sqlite_db.get_transaction_count(account_id="nonexistent")
assert count_none == 0
def test_get_transaction_count_with_filters(
self, mock_home_db_path, sample_transactions
):
"""Test getting transaction count with filters."""
ctx = MockContext()
with patch("pathlib.Path.home") as mock_home:
mock_home.return_value = mock_home_db_path.parent / ".."
sqlite_db.persist_transactions(ctx, "test-account-123", sample_transactions)
# Filter by search
count = sqlite_db.get_transaction_count(search="Coffee")
assert count == 1
# Filter by amount
count = sqlite_db.get_transaction_count(min_amount=-20.0)
assert count == 1
def test_database_indexes_created(self, mock_home_db_path, sample_transactions):
"""Test that database indexes are created properly."""
ctx = MockContext()
with patch("pathlib.Path.home") as mock_home:
mock_home.return_value = mock_home_db_path.parent / ".."
# Persist transactions to create tables and indexes
sqlite_db.persist_transactions(ctx, "test-account-123", sample_transactions)
# Get transactions to ensure we can query the table (indexes working)
transactions = sqlite_db.get_transactions(account_id="test-account-123")
assert len(transactions) == 2

2
uv.lock generated
View File

@@ -229,6 +229,7 @@ dependencies = [
{ name = "fastapi" }, { name = "fastapi" },
{ name = "httpx" }, { name = "httpx" },
{ name = "loguru" }, { name = "loguru" },
{ name = "pydantic" },
{ name = "requests" }, { name = "requests" },
{ name = "tabulate" }, { name = "tabulate" },
{ name = "tomli-w" }, { name = "tomli-w" },
@@ -257,6 +258,7 @@ requires-dist = [
{ name = "fastapi", specifier = ">=0.104.0,<1" }, { name = "fastapi", specifier = ">=0.104.0,<1" },
{ name = "httpx", specifier = ">=0.28.1" }, { name = "httpx", specifier = ">=0.28.1" },
{ name = "loguru", specifier = ">=0.7.2,<0.8" }, { name = "loguru", specifier = ">=0.7.2,<0.8" },
{ name = "pydantic", specifier = ">=2.0.0,<3" },
{ name = "requests", specifier = ">=2.31.0,<3" }, { name = "requests", specifier = ">=2.31.0,<3" },
{ name = "tabulate", specifier = ">=0.9.0,<0.10" }, { name = "tabulate", specifier = ">=0.9.0,<0.10" },
{ name = "tomli-w", specifier = ">=1.0.0,<2" }, { name = "tomli-w", specifier = ">=1.0.0,<2" },